diff --git a/google/downscoped/downscoping.go b/google/downscoped/downscoping.go index 5d2ce6e..80677d4 100644 --- a/google/downscoped/downscoping.go +++ b/google/downscoped/downscoping.go @@ -1,28 +1,25 @@ package downscoped import ( + "context" "encoding/json" "fmt" "golang.org/x/oauth2" - "io/ioutil" "net/http" "net/url" "time" ) const ( - IDENTITY_BINDING_ENDPOINT = "https://sts.googleapis.com/v1beta/token" + identityBindingEndpoint = "https://sts.googleapis.com/v1beta/token" ) // Defines an upper bound of permissions available for a GCP credential for one or more resources type AccessBoundary struct { + // One or more AccessBoundaryRules are required to define permissions for the new downscoped token AccessBoundaryRules []AccessBoundaryRule `json:"accessBoundaryRules"` } -func NewAccessBoundary() AccessBoundary { - return AccessBoundary{make([]AccessBoundaryRule, 0)} -} - type AvailabilityCondition struct { Title string `json:"title,omitempty"` Expression string `json:"expression"` @@ -35,7 +32,7 @@ type AccessBoundaryRule struct { Condition *AvailabilityCondition `json:"availabilityCondition,omitempty"` } -type DownScopedTokenResponse struct { +type downscopedTokenResponse struct { AccessToken string `json:"access_token"` IssuedTokenType string `json:"issued_token_type"` TokenType string `json:"token_type"` @@ -47,12 +44,12 @@ type DownscopingConfig struct { CredentialAccessBoundary AccessBoundary } -func DownscopedTokenWithEndpoint(config DownscopingConfig, endpoint string) (oauth2.TokenSource, error) { +func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig, endpoint string) (oauth2.TokenSource, error) { if config.RootSource == nil { - return nil, fmt.Errorf("oauth2/google: rootTokenSource cannot be nil") + return nil, fmt.Errorf("oauth2/google/downscoped: rootTokenSource cannot be nil") } if len(config.CredentialAccessBoundary.AccessBoundaryRules) == 0 { - return nil, fmt.Errorf("oauth2/google: length of AccessBoundaryRules must be at least 1") + return nil, fmt.Errorf("oauth2/google/downscoped: length of AccessBoundaryRules must be at least 1") } downscopedOptions := struct { @@ -63,12 +60,12 @@ func DownscopedTokenWithEndpoint(config DownscopingConfig, endpoint string) (oau tok, err := config.RootSource.Token() if err != nil { - return nil, fmt.Errorf("oauth2/google: unable to refresh root token %v", err) + return nil, fmt.Errorf("oauth2/google/downscoped: unable to refresh root token %v", err) } - b, err := json.Marshal(downscopedOptions) // TODO: make sure that this marshals properly! + b, err := json.Marshal(downscopedOptions) if err != nil { - return nil, fmt.Errorf("oauth2/google: Unable to marshall AccessBoundary payload %v", err) + return nil, fmt.Errorf("oauth2/google/downscoped: Unable to marshall AccessBoundary payload %v", err) } form := url.Values{} @@ -78,19 +75,21 @@ func DownscopedTokenWithEndpoint(config DownscopingConfig, endpoint string) (oau form.Add("subject_token", tok.AccessToken) form.Add("options", url.QueryEscape(string(b))) - resp, err := http.PostForm(endpoint, form) - defer resp.Body.Close() + myClient := oauth2.NewClient(ctx, nil) + resp, err := myClient.PostForm(endpoint, form) if err != nil { return nil, fmt.Errorf("unable to generate POST Request %v", err) } + defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := ioutil.ReadAll(resp.Body) - return nil, fmt.Errorf("unable to exchange token %v", string(bodyBytes)) + var tresp downscopedTokenResponse + err = json.NewDecoder(resp.Body).Decode(&tresp) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal response body: %v", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unable to exchange token %v", tresp) } - - tresp := DownScopedTokenResponse{} - json.NewDecoder(resp.Body).Decode(&tresp) // an exchanged token that is derived from a service account (2LO) has an expired_in value // a token derived from a users token (3LO) does not. @@ -111,6 +110,6 @@ func DownscopedTokenWithEndpoint(config DownscopingConfig, endpoint string) (oau return oauth2.StaticTokenSource(newToken), nil } -func NewDownscopedTokenSource(config DownscopingConfig) (oauth2.TokenSource, error) { - return DownscopedTokenWithEndpoint(config, IDENTITY_BINDING_ENDPOINT) +func NewDownscopedTokenSource(ctx context.Context, config DownscopingConfig) (oauth2.TokenSource, error) { + return downscopedTokenWithEndpoint(ctx, config, identityBindingEndpoint) } diff --git a/google/downscoped/downscoping_test.go b/google/downscoped/downscoping_test.go index fca1601..d25d2e0 100644 --- a/google/downscoped/downscoping_test.go +++ b/google/downscoped/downscoping_test.go @@ -1,6 +1,7 @@ package downscoped import ( + "context" "golang.org/x/oauth2" "io/ioutil" "net/http" @@ -14,7 +15,7 @@ var ( ) func Test_NewAccessBoundary(t *testing.T) { - got := NewAccessBoundary() + got := AccessBoundary{make([]AccessBoundaryRule, 0)} want := AccessBoundary{nil} if got.AccessBoundaryRules == nil || len(got.AccessBoundaryRules) != 0 { t.Errorf("NewAccessBoundary() = %v; want %v", got, want) @@ -26,7 +27,7 @@ func Test_DownscopedTokenSource(t *testing.T) { if r.Method != "POST" { t.Errorf("Unexpected request method, %v is found", r.Method) } - if r.URL.String() != "/" { //TODO: Will this work, or do I need to redirect this to this test server instead? + if r.URL.String() != "/" { t.Errorf("Unexpected request URL, %v is found", r.URL) } body, err := ioutil.ReadAll(r.Body) @@ -40,11 +41,11 @@ func Test_DownscopedTokenSource(t *testing.T) { w.Write([]byte(standardRespBody)) })) - new := NewAccessBoundary() + new := AccessBoundary{make([]AccessBoundaryRule, 0)} new.AccessBoundaryRules = append(new.AccessBoundaryRules, AccessBoundaryRule{"test1", []string{"Perm1, perm2"}, nil}) myTok := oauth2.Token{AccessToken: "Mellon"} tmpSrc := oauth2.StaticTokenSource(&myTok) - out, err := DownscopedTokenWithEndpoint(DownscopingConfig{tmpSrc, new}, ts.URL) + out, err := downscopedTokenWithEndpoint(context.Background(), DownscopingConfig{tmpSrc, new}, ts.URL) if err != nil { t.Fatalf("NewDownscopedTokenSource failed with error: %v", err) }