diff --git a/endpoints/endpoints.go b/endpoints/endpoints.go index 3b7f32a..7cc37c8 100644 --- a/endpoints/endpoints.go +++ b/endpoints/endpoints.go @@ -17,6 +17,12 @@ var Amazon = oauth2.Endpoint{ TokenURL: "https://api.amazon.com/auth/o2/token", } +// Battlenet is the endpoint for Battlenet. +var Battlenet = oauth2.Endpoint{ + AuthURL: "https://battle.net/oauth/authorize", + TokenURL: "https://battle.net/oauth/token", +} + // Bitbucket is the endpoint for Bitbucket. var Bitbucket = oauth2.Endpoint{ AuthURL: "https://bitbucket.org/site/oauth2/authorize", @@ -167,6 +173,12 @@ var StackOverflow = oauth2.Endpoint{ TokenURL: "https://stackoverflow.com/oauth/access_token", } +// Strava is the endpoint for Strava. +var Strava = oauth2.Endpoint{ + AuthURL: "https://www.strava.com/oauth/authorize", + TokenURL: "https://www.strava.com/oauth/token", +} + // Twitch is the endpoint for Twitch. var Twitch = oauth2.Endpoint{ AuthURL: "https://id.twitch.tv/oauth2/authorize", diff --git a/go.mod b/go.mod index 2b13f0b..468b626 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,6 @@ go 1.11 require ( cloud.google.com/go v0.65.0 - golang.org/x/net v0.0.0-20200822124328-c89045814202 + golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd google.golang.org/appengine v1.6.6 ) diff --git a/go.sum b/go.sum index eab5833..bdceef9 100644 --- a/go.sum +++ b/go.sum @@ -177,8 +177,9 @@ golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -217,11 +218,15 @@ golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/google/default.go b/google/default.go index 880dd7b..dd00420 100644 --- a/google/default.go +++ b/google/default.go @@ -94,20 +94,20 @@ func DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSourc // It looks for credentials in the following places, // preferring the first location found: // -// 1. A JSON file whose path is specified by the -// GOOGLE_APPLICATION_CREDENTIALS environment variable. -// For workload identity federation, refer to -// https://cloud.google.com/iam/docs/how-to#using-workload-identity-federation on -// how to generate the JSON configuration file for on-prem/non-Google cloud -// platforms. -// 2. A JSON file in a location known to the gcloud command-line tool. -// On Windows, this is %APPDATA%/gcloud/application_default_credentials.json. -// On other systems, $HOME/.config/gcloud/application_default_credentials.json. -// 3. On Google App Engine standard first generation runtimes (<= Go 1.9) it uses -// the appengine.AccessToken function. -// 4. On Google Compute Engine, Google App Engine standard second generation runtimes -// (>= Go 1.11), and Google App Engine flexible environment, it fetches -// credentials from the metadata server. +// 1. A JSON file whose path is specified by the +// GOOGLE_APPLICATION_CREDENTIALS environment variable. +// For workload identity federation, refer to +// https://cloud.google.com/iam/docs/how-to#using-workload-identity-federation on +// how to generate the JSON configuration file for on-prem/non-Google cloud +// platforms. +// 2. A JSON file in a location known to the gcloud command-line tool. +// On Windows, this is %APPDATA%/gcloud/application_default_credentials.json. +// On other systems, $HOME/.config/gcloud/application_default_credentials.json. +// 3. On Google App Engine standard first generation runtimes (<= Go 1.9) it uses +// the appengine.AccessToken function. +// 4. On Google Compute Engine, Google App Engine standard second generation runtimes +// (>= Go 1.11), and Google App Engine flexible environment, it fetches +// credentials from the metadata server. func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsParams) (*Credentials, error) { // Make defensive copy of the slices in params. params = params.deepCopy() diff --git a/google/doc.go b/google/doc.go index b241c72..dddf651 100644 --- a/google/doc.go +++ b/google/doc.go @@ -4,9 +4,9 @@ // Package google provides support for making OAuth2 authorized and authenticated // HTTP requests to Google APIs. It supports the Web server flow, client-side -// credentials, service accounts, Google Compute Engine service accounts, Google -// App Engine service accounts and workload identity federation from non-Google -// cloud platforms. +// credentials, service accounts, Google Compute Engine service accounts, +// Google App Engine service accounts and workload identity federation +// from non-Google cloud platforms. // // A brief overview of the package follows. For more information, please read // https://developers.google.com/accounts/docs/OAuth2 @@ -15,14 +15,14 @@ // For more information on using workload identity federation, refer to // https://cloud.google.com/iam/docs/how-to#using-workload-identity-federation. // -// OAuth2 Configs +// # OAuth2 Configs // // Two functions in this package return golang.org/x/oauth2.Config values from Google credential // data. Google supports two JSON formats for OAuth2 credentials: one is handled by ConfigFromJSON, // the other by JWTConfigFromJSON. The returned Config can be used to obtain a TokenSource or // create an http.Client. // -// Workload Identity Federation +// # Workload Identity Federation // // Using workload identity federation, your application can access Google Cloud // resources from Amazon Web Services (AWS), Microsoft Azure or any identity @@ -36,9 +36,9 @@ // Follow the detailed instructions on how to configure Workload Identity Federation // in various platforms: // -// Amazon Web Services (AWS): https://cloud.google.com/iam/docs/access-resources-aws -// Microsoft Azure: https://cloud.google.com/iam/docs/access-resources-azure -// OIDC identity provider: https://cloud.google.com/iam/docs/access-resources-oidc +// Amazon Web Services (AWS): https://cloud.google.com/iam/docs/access-resources-aws +// Microsoft Azure: https://cloud.google.com/iam/docs/access-resources-azure +// OIDC identity provider: https://cloud.google.com/iam/docs/access-resources-oidc // // For OIDC providers, the library can retrieve OIDC tokens either from a // local file location (file-sourced credentials) or from a local server @@ -51,8 +51,7 @@ // return the OIDC token. The response can be in plain text or JSON. // Additional required request headers can also be specified. // -// -// Credentials +// # Credentials // // The Credentials type represents Google credentials, including Application Default // Credentials. diff --git a/google/downscope/downscoping.go b/google/downscope/downscoping.go new file mode 100644 index 0000000..3d4b553 --- /dev/null +++ b/google/downscope/downscoping.go @@ -0,0 +1,211 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package downscope implements the ability to downscope, or restrict, the +Identity and Access Management permissions that a short-lived Token +can use. Please note that only Google Cloud Storage supports this feature. +For complete documentation, see https://cloud.google.com/iam/docs/downscoping-short-lived-credentials + +To downscope permissions of a source credential, you need to define +a Credential Access Boundary. Said Boundary specifies which resources +the newly created credential can access, an upper bound on the permissions +it has over those resources, and optionally attribute-based conditional +access to the aforementioned resources. For more information on IAM +Conditions, see https://cloud.google.com/iam/docs/conditions-overview. + +This functionality can be used to provide a third party with +limited access to and permissions on resources held by the owner of the root +credential or internally in conjunction with the principle of least privilege +to ensure that internal services only hold the minimum necessary privileges +for their function. + +For example, a token broker can be set up on a server in a private network. +Various workloads (token consumers) in the same network will send authenticated +requests to that broker for downscoped tokens to access or modify specific google +cloud storage buckets. See the NewTokenSource example for an example of how a +token broker would use this package. + +The broker will use the functionality in this package to generate a downscoped +token with the requested configuration, and then pass it back to the token +consumer. These downscoped access tokens can then be used to access Google +Storage resources. For instance, you can create a NewClient from the +"cloud.google.com/go/storage" package and pass in option.WithTokenSource(yourTokenSource)) +*/ +package downscope + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" + + "golang.org/x/oauth2" +) + +var ( + identityBindingEndpoint = "https://sts.googleapis.com/v1/token" +) + +type accessBoundary struct { + AccessBoundaryRules []AccessBoundaryRule `json:"accessBoundaryRules"` +} + +// An AvailabilityCondition restricts access to a given Resource. +type AvailabilityCondition struct { + // An Expression specifies the Cloud Storage objects where + // permissions are available. For further documentation, see + // https://cloud.google.com/iam/docs/conditions-overview + Expression string `json:"expression"` + // Title is short string that identifies the purpose of the condition. Optional. + Title string `json:"title,omitempty"` + // Description details about the purpose of the condition. Optional. + Description string `json:"description,omitempty"` +} + +// An AccessBoundaryRule Sets the permissions (and optionally conditions) +// that the new token has on given resource. +type AccessBoundaryRule struct { + // AvailableResource is the full resource name of the Cloud Storage bucket that the rule applies to. + // Use the format //storage.googleapis.com/projects/_/buckets/bucket-name. + AvailableResource string `json:"availableResource"` + // AvailablePermissions is a list that defines the upper bound on the available permissions + // for the resource. Each value is the identifier for an IAM predefined role or custom role, + // with the prefix inRole:. For example: inRole:roles/storage.objectViewer. + // Only the permissions in these roles will be available. + AvailablePermissions []string `json:"availablePermissions"` + // An Condition restricts the availability of permissions + // to specific Cloud Storage objects. Optional. + // + // A Condition can be used to make permissions available for specific objects, + // rather than all objects in a Cloud Storage bucket. + Condition *AvailabilityCondition `json:"availabilityCondition,omitempty"` +} + +type downscopedTokenResponse struct { + AccessToken string `json:"access_token"` + IssuedTokenType string `json:"issued_token_type"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +// DownscopingConfig specifies the information necessary to request a downscoped token. +type DownscopingConfig struct { + // RootSource is the TokenSource used to create the downscoped token. + // The downscoped token therefore has some subset of the accesses of + // the original RootSource. + RootSource oauth2.TokenSource + // Rules defines the accesses held by the new + // downscoped Token. One or more AccessBoundaryRules are required to + // define permissions for the new downscoped token. Each one defines an + // access (or set of accesses) that the new token has to a given resource. + // There can be a maximum of 10 AccessBoundaryRules. + Rules []AccessBoundaryRule +} + +// A downscopingTokenSource is used to retrieve a downscoped token with restricted +// permissions compared to the root Token that is used to generate it. +type downscopingTokenSource struct { + // ctx is the context used to query the API to retrieve a downscoped Token. + ctx context.Context + // config holds the information necessary to generate a downscoped Token. + config DownscopingConfig +} + +// NewTokenSource returns a configured downscopingTokenSource. +func NewTokenSource(ctx context.Context, conf DownscopingConfig) (oauth2.TokenSource, error) { + if conf.RootSource == nil { + return nil, fmt.Errorf("downscope: rootSource cannot be nil") + } + if len(conf.Rules) == 0 { + return nil, fmt.Errorf("downscope: length of AccessBoundaryRules must be at least 1") + } + if len(conf.Rules) > 10 { + return nil, fmt.Errorf("downscope: length of AccessBoundaryRules may not be greater than 10") + } + for _, val := range conf.Rules { + if val.AvailableResource == "" { + return nil, fmt.Errorf("downscope: all rules must have a nonempty AvailableResource: %+v", val) + } + if len(val.AvailablePermissions) == 0 { + return nil, fmt.Errorf("downscope: all rules must provide at least one permission: %+v", val) + } + } + return downscopingTokenSource{ctx: ctx, config: conf}, nil +} + +// Token() uses a downscopingTokenSource to generate an oauth2 Token. +// Do note that the returned TokenSource is an oauth2.StaticTokenSource. If you wish +// to refresh this token automatically, then initialize a locally defined +// TokenSource struct with the Token held by the StaticTokenSource and wrap +// that TokenSource in an oauth2.ReuseTokenSource. +func (dts downscopingTokenSource) Token() (*oauth2.Token, error) { + + downscopedOptions := struct { + Boundary accessBoundary `json:"accessBoundary"` + }{ + Boundary: accessBoundary{ + AccessBoundaryRules: dts.config.Rules, + }, + } + + tok, err := dts.config.RootSource.Token() + if err != nil { + return nil, fmt.Errorf("downscope: unable to obtain root token: %v", err) + } + + b, err := json.Marshal(downscopedOptions) + if err != nil { + return nil, fmt.Errorf("downscope: unable to marshal AccessBoundary payload %v", err) + } + + form := url.Values{} + form.Add("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + form.Add("subject_token_type", "urn:ietf:params:oauth:token-type:access_token") + form.Add("requested_token_type", "urn:ietf:params:oauth:token-type:access_token") + form.Add("subject_token", tok.AccessToken) + form.Add("options", string(b)) + + myClient := oauth2.NewClient(dts.ctx, nil) + resp, err := myClient.PostForm(identityBindingEndpoint, form) + if err != nil { + return nil, fmt.Errorf("unable to generate POST Request %v", err) + } + defer resp.Body.Close() + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("downscope: unable to read response body: %v", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("downscope: unable to exchange token; %v. Server responded: %s", resp.StatusCode, respBody) + } + + var tresp downscopedTokenResponse + + err = json.Unmarshal(respBody, &tresp) + if err != nil { + return nil, fmt.Errorf("downscope: unable to unmarshal response body: %v", err) + } + + // an exchanged token that is derived from a service account (2LO) has an expired_in value + // a token derived from a users token (3LO) does not. + // The following code uses the time remaining on rootToken for a user as the value for the + // derived token's lifetime + var expiryTime time.Time + if tresp.ExpiresIn > 0 { + expiryTime = time.Now().Add(time.Duration(tresp.ExpiresIn) * time.Second) + } else { + expiryTime = tok.Expiry + } + + newToken := &oauth2.Token{ + AccessToken: tresp.AccessToken, + TokenType: tresp.TokenType, + Expiry: expiryTime, + } + return newToken, nil +} diff --git a/google/downscope/downscoping_test.go b/google/downscope/downscoping_test.go new file mode 100644 index 0000000..d5adda1 --- /dev/null +++ b/google/downscope/downscoping_test.go @@ -0,0 +1,55 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package downscope + +import ( + "context" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "golang.org/x/oauth2" +) + +var ( + standardReqBody = "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&options=%7B%22accessBoundary%22%3A%7B%22accessBoundaryRules%22%3A%5B%7B%22availableResource%22%3A%22test1%22%2C%22availablePermissions%22%3A%5B%22Perm1%22%2C%22Perm2%22%5D%7D%5D%7D%7D&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&subject_token=Mellon&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token" + standardRespBody = `{"access_token":"Open Sesame","expires_in":432,"issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer"}` +) + +func Test_DownscopedTokenSource(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Unexpected request method, %v is found", r.Method) + } + if r.URL.String() != "/" { + t.Errorf("Unexpected request URL, %v is found", r.URL) + } + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("Failed to read request body: %v", err) + } + if got, want := string(body), standardReqBody; got != want { + t.Errorf("Unexpected exchange payload: got %v but want %v,", got, want) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(standardRespBody)) + + })) + new := []AccessBoundaryRule{ + { + AvailableResource: "test1", + AvailablePermissions: []string{"Perm1", "Perm2"}, + }, + } + myTok := oauth2.Token{AccessToken: "Mellon"} + tmpSrc := oauth2.StaticTokenSource(&myTok) + dts := downscopingTokenSource{context.Background(), DownscopingConfig{tmpSrc, new}} + identityBindingEndpoint = ts.URL + _, err := dts.Token() + if err != nil { + t.Fatalf("NewDownscopedTokenSource failed with error: %v", err) + } +} diff --git a/google/downscope/tokenbroker_test.go b/google/downscope/tokenbroker_test.go new file mode 100644 index 0000000..cb16878 --- /dev/null +++ b/google/downscope/tokenbroker_test.go @@ -0,0 +1,57 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package downscope_test + +import ( + "context" + "fmt" + + "golang.org/x/oauth2/google" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google/downscope" +) + +func ExampleNewTokenSource() { + // This shows how to generate a downscoped token. This code would be run on the + // token broker, which holds the root token used to generate the downscoped token. + ctx := context.Background() + // Initializes an accessBoundary with one Rule which restricts the downscoped + // token to only be able to access the bucket "foo" and only grants it the + // permission "storage.objectViewer". + accessBoundary := []downscope.AccessBoundaryRule{ + { + AvailableResource: "//storage.googleapis.com/projects/_/buckets/foo", + AvailablePermissions: []string{"inRole:roles/storage.objectViewer"}, + }, + } + + var rootSource oauth2.TokenSource + // This Source can be initialized in multiple ways; the following example uses + // Application Default Credentials. + + rootSource, err := google.DefaultTokenSource(ctx, "https://www.googleapis.com/auth/cloud-platform") + + dts, err := downscope.NewTokenSource(ctx, downscope.DownscopingConfig{RootSource: rootSource, Rules: accessBoundary}) + if err != nil { + fmt.Printf("failed to generate downscoped token source: %v", err) + return + } + + tok, err := dts.Token() + if err != nil { + fmt.Printf("failed to generate token: %v", err) + return + } + _ = tok + // You can now pass tok to a token consumer however you wish, such as exposing + // a REST API and sending it over HTTP. + + // You can instead use the token held in dts to make + // Google Cloud Storage calls, as follows: + + // storageClient, err := storage.NewClient(ctx, option.WithTokenSource(dts)) + +} diff --git a/google/google.go b/google/google.go index 2b631f5..ceddd5d 100644 --- a/google/google.go +++ b/google/google.go @@ -92,9 +92,10 @@ func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) { // JSON key file types. const ( - serviceAccountKey = "service_account" - userCredentialsKey = "authorized_user" - externalAccountKey = "external_account" + serviceAccountKey = "service_account" + userCredentialsKey = "authorized_user" + externalAccountKey = "external_account" + impersonatedServiceAccount = "impersonated_service_account" ) // credentialsFile is the unmarshalled representation of a credentials file. @@ -121,8 +122,13 @@ type credentialsFile struct { TokenURLExternal string `json:"token_url"` TokenInfoURL string `json:"token_info_url"` ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"` + Delegates []string `json:"delegates"` CredentialSource externalaccount.CredentialSource `json:"credential_source"` QuotaProjectID string `json:"quota_project_id"` + WorkforcePoolUserProject string `json:"workforce_pool_user_project"` + + // Service account impersonation + SourceCredentials *credentialsFile `json:"source_credentials"` } func (f *credentialsFile) jwtConfig(scopes []string, subject string) *jwt.Config { @@ -133,6 +139,7 @@ func (f *credentialsFile) jwtConfig(scopes []string, subject string) *jwt.Config Scopes: scopes, TokenURL: f.TokenURL, Subject: subject, // This is the user email to impersonate + Audience: f.Audience, } if cfg.TokenURL == "" { cfg.TokenURL = JWTTokenURL @@ -176,8 +183,26 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar CredentialSource: f.CredentialSource, QuotaProjectID: f.QuotaProjectID, Scopes: params.Scopes, + WorkforcePoolUserProject: f.WorkforcePoolUserProject, } - return cfg.TokenSource(ctx), nil + return cfg.TokenSource(ctx) + case impersonatedServiceAccount: + if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil { + return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials") + } + + ts, err := f.SourceCredentials.tokenSource(ctx, params) + if err != nil { + return nil, err + } + imp := externalaccount.ImpersonateTokenSource{ + Ctx: ctx, + URL: f.ServiceAccountImpersonationURL, + Scopes: params.Scopes, + Ts: ts, + Delegates: f.Delegates, + } + return oauth2.ReuseTokenSource(nil, imp), nil case "": return nil, errors.New("missing 'type' field in credentials") default: diff --git a/google/google_test.go b/google/google_test.go index be30b08..ea01049 100644 --- a/google/google_test.go +++ b/google/google_test.go @@ -37,7 +37,8 @@ var jwtJSONKey = []byte(`{ "client_email": "gopher@developer.gserviceaccount.com", "client_id": "gopher.apps.googleusercontent.com", "token_uri": "https://accounts.google.com/o/gophers/token", - "type": "service_account" + "type": "service_account", + "audience": "https://testservice.googleapis.com/" }`) var jwtJSONKeyNoTokenURL = []byte(`{ @@ -48,6 +49,15 @@ var jwtJSONKeyNoTokenURL = []byte(`{ "type": "service_account" }`) +var jwtJSONKeyNoAudience = []byte(`{ + "private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b", + "private_key": "super secret key", + "client_email": "gopher@developer.gserviceaccount.com", + "client_id": "gopher.apps.googleusercontent.com", + "token_uri": "https://accounts.google.com/o/gophers/token", + "type": "service_account" +}`) + func TestConfigFromJSON(t *testing.T) { conf, err := ConfigFromJSON(webJSONKey, "scope1", "scope2") if err != nil { @@ -103,6 +113,9 @@ func TestJWTConfigFromJSON(t *testing.T) { if got, want := conf.TokenURL, "https://accounts.google.com/o/gophers/token"; got != want { t.Errorf("TokenURL = %q; want %q", got, want) } + if got, want := conf.Audience, "https://testservice.googleapis.com/"; got != want { + t.Errorf("Audience = %q; want %q", got, want) + } } func TestJWTConfigFromJSONNoTokenURL(t *testing.T) { @@ -114,3 +127,13 @@ func TestJWTConfigFromJSONNoTokenURL(t *testing.T) { t.Errorf("TokenURL = %q; want %q", got, want) } } + +func TestJWTConfigFromJSONNoAudience(t *testing.T) { + conf, err := JWTConfigFromJSON(jwtJSONKeyNoAudience, "scope1", "scope2") + if err != nil { + t.Fatal(err) + } + if got, want := conf.Audience, ""; got != want { + t.Errorf("Audience = %q; want %q", got, want) + } +} diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index fbcefb4..e917195 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -13,7 +13,6 @@ import ( "encoding/json" "errors" "fmt" - "golang.org/x/oauth2" "io" "io/ioutil" "net/http" @@ -23,6 +22,8 @@ import ( "sort" "strings" "time" + + "golang.org/x/oauth2" ) type awsSecurityCredentials struct { @@ -51,6 +52,13 @@ const ( // The AWS authorization header name for the security session token if available. awsSecurityTokenHeader = "x-amz-security-token" + // The name of the header containing the session token for metadata endpoint calls + awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token" + + awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" + + awsIMDSv2SessionTtl = "300" + // The AWS authorization header name for the auto-generated date. awsDateHeader = "x-amz-date" @@ -240,6 +248,7 @@ type awsCredentialSource struct { RegionURL string RegionalCredVerificationURL string CredVerificationURL string + IMDSv2SessionTokenURL string TargetResource string requestSigner *awsRequestSigner region string @@ -267,12 +276,22 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro func (cs awsCredentialSource) subjectToken() (string, error) { if cs.requestSigner == nil { - awsSecurityCredentials, err := cs.getSecurityCredentials() + awsSessionToken, err := cs.getAWSSessionToken() if err != nil { return "", err } - if cs.region, err = cs.getRegion(); err != nil { + headers := make(map[string]string) + if awsSessionToken != "" { + headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + } + + awsSecurityCredentials, err := cs.getSecurityCredentials(headers) + if err != nil { + return "", err + } + + if cs.region, err = cs.getRegion(headers); err != nil { return "", err } @@ -339,10 +358,43 @@ func (cs awsCredentialSource) subjectToken() (string, error) { return url.QueryEscape(string(result)), nil } -func (cs *awsCredentialSource) getRegion() (string, error) { +func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { + if cs.IMDSv2SessionTokenURL == "" { + return "", nil + } + + req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil) + if err != nil { + return "", err + } + + req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl) + + resp, err := cs.doRequest(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", err + } + + if resp.StatusCode != 200 { + return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody)) + } + + return string(respBody), nil +} + +func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { return envAwsRegion, nil } + if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" { + return envAwsRegion, nil + } if cs.RegionURL == "" { return "", errors.New("oauth2/google: unable to determine AWS region") @@ -353,6 +405,10 @@ func (cs *awsCredentialSource) getRegion() (string, error) { return "", err } + for name, value := range headers { + req.Header.Add(name, value) + } + resp, err := cs.doRequest(req) if err != nil { return "", err @@ -377,7 +433,7 @@ func (cs *awsCredentialSource) getRegion() (string, error) { return string(respBody[:respBodyEnd]), nil } -func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) { +func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { return awsSecurityCredentials{ @@ -388,12 +444,12 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede } } - roleName, err := cs.getMetadataRoleName() + roleName, err := cs.getMetadataRoleName(headers) if err != nil { return } - credentials, err := cs.getMetadataSecurityCredentials(roleName) + credentials, err := cs.getMetadataSecurityCredentials(roleName, headers) if err != nil { return } @@ -409,7 +465,7 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede return credentials, nil } -func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) { +func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) { var result awsSecurityCredentials req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) @@ -418,6 +474,10 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) ( } req.Header.Add("Content-Type", "application/json") + for name, value := range headers { + req.Header.Add(name, value) + } + resp, err := cs.doRequest(req) if err != nil { return result, err @@ -437,7 +497,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) ( return result, err } -func (cs *awsCredentialSource) getMetadataRoleName() (string, error) { +func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) { if cs.CredVerificationURL == "" { return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint") } @@ -447,6 +507,10 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) { return "", err } + for name, value := range headers { + req.Header.Add(name, value) + } + resp, err := cs.doRequest(req) if err != nil { return "", err diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 95ff9ce..0934389 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -20,6 +20,8 @@ import ( var defaultTime = time.Date(2011, 9, 9, 23, 36, 0, 0, time.UTC) var secondDefaultTime = time.Date(2020, 8, 11, 6, 55, 22, 0, time.UTC) +type validateHeaders func(r *http.Request) + func setTime(testTime time.Time) func() time.Time { return func() time.Time { return testTime @@ -28,8 +30,7 @@ func setTime(testTime time.Time) func() time.Time { func setEnvironment(env map[string]string) func(string) string { return func(key string) string { - value, _ := env[key] - return value + return env[key] } } @@ -83,7 +84,7 @@ func testRequestSigner(t *testing.T, rs *awsRequestSigner, input, expectedOutput } } -func TestAwsV4Signature_GetRequest(t *testing.T) { +func TestAWSv4Signature_GetRequest(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com", nil) setDefaultTime(input) @@ -101,7 +102,7 @@ func TestAwsV4Signature_GetRequest(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) { +func TestAWSv4Signature_GetRequestWithRelativePath(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil) setDefaultTime(input) @@ -119,7 +120,7 @@ func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) { +func TestAWSv4Signature_GetRequestWithDotPath(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/./", nil) setDefaultTime(input) @@ -137,7 +138,7 @@ func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) { +func TestAWSv4Signature_GetRequestWithPointlessDotPath(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil) setDefaultTime(input) @@ -155,7 +156,7 @@ func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) { +func TestAWSv4Signature_GetRequestWithUtf8Path(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil) setDefaultTime(input) @@ -173,7 +174,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) { +func TestAWSv4Signature_GetRequestWithDuplicateQuery(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil) setDefaultTime(input) @@ -191,7 +192,7 @@ func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) { +func TestAWSv4Signature_GetRequestWithMisorderedQuery(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil) setDefaultTime(input) @@ -209,7 +210,7 @@ func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) { +func TestAWSv4Signature_GetRequestWithUtf8Query(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil) setDefaultTime(input) @@ -227,7 +228,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequest(t *testing.T) { +func TestAWSv4Signature_PostRequest(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) setDefaultTime(input) input.Header.Add("ZOO", "zoobar") @@ -247,7 +248,7 @@ func TestAwsV4Signature_PostRequest(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) { +func TestAWSv4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) setDefaultTime(input) input.Header.Add("zoo", "ZOOBAR") @@ -267,7 +268,7 @@ func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequestPhfft(t *testing.T) { +func TestAWSv4Signature_PostRequestPhfft(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) setDefaultTime(input) input.Header.Add("p", "phfft") @@ -287,7 +288,7 @@ func TestAwsV4Signature_PostRequestPhfft(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequestWithBody(t *testing.T) { +func TestAWSv4Signature_PostRequestWithBody(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/", strings.NewReader("foo=bar")) setDefaultTime(input) input.Header.Add("Content-Type", "application/x-www-form-urlencoded") @@ -307,7 +308,7 @@ func TestAwsV4Signature_PostRequestWithBody(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) { +func TestAWSv4Signature_PostRequestWithQueryString(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil) setDefaultTime(input) @@ -325,7 +326,7 @@ func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) { +func TestAWSv4Signature_GetRequestWithSecurityToken(t *testing.T) { input, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil) output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil) @@ -343,7 +344,7 @@ func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) { testRequestSigner(t, requestSignerWithToken, input, output) } -func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) { +func TestAWSv4Signature_PostRequestWithSecurityToken(t *testing.T) { input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil) output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil) @@ -361,7 +362,7 @@ func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) { testRequestSigner(t, requestSignerWithToken, input, output) } -func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) { +func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) { requestParams := "{\"KeySchema\":[{\"KeyType\":\"HASH\",\"AttributeName\":\"Id\"}],\"TableName\":\"TestTable\",\"AttributeDefinitions\":[{\"AttributeName\":\"Id\",\"AttributeType\":\"S\"}],\"ProvisionedThroughput\":{\"WriteCapacityUnits\":5,\"ReadCapacityUnits\":5}}" input, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams)) input.Header.Add("Content-Type", "application/x-amz-json-1.0") @@ -384,7 +385,7 @@ func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test testRequestSigner(t, requestSignerWithToken, input, output) } -func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) { +func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) { var requestSigner = &awsRequestSigner{ RegionName: "us-east-2", AwsSecurityCredentials: awsSecurityCredentials{ @@ -414,30 +415,40 @@ type testAwsServer struct { securityCredentialURL string regionURL string regionalCredVerificationURL string + imdsv2SessionTokenUrl string Credentials map[string]string - WriteRolename func(http.ResponseWriter) - WriteSecurityCredentials func(http.ResponseWriter) - WriteRegion func(http.ResponseWriter) + WriteRolename func(http.ResponseWriter, *http.Request) + WriteSecurityCredentials func(http.ResponseWriter, *http.Request) + WriteRegion func(http.ResponseWriter, *http.Request) + WriteIMDSv2SessionToken func(http.ResponseWriter, *http.Request) } -func createAwsTestServer(url, regionURL, regionalCredVerificationURL, rolename, region string, credentials map[string]string) *testAwsServer { +func createAwsTestServer(url, regionURL, regionalCredVerificationURL, imdsv2SessionTokenUrl string, rolename, region string, credentials map[string]string, imdsv2SessionToken string, validateHeaders validateHeaders) *testAwsServer { server := &testAwsServer{ url: url, securityCredentialURL: fmt.Sprintf("%s/%s", url, rolename), regionURL: regionURL, regionalCredVerificationURL: regionalCredVerificationURL, + imdsv2SessionTokenUrl: imdsv2SessionTokenUrl, Credentials: credentials, - WriteRolename: func(w http.ResponseWriter) { + WriteRolename: func(w http.ResponseWriter, r *http.Request) { + validateHeaders(r) w.Write([]byte(rolename)) }, - WriteRegion: func(w http.ResponseWriter) { + WriteRegion: func(w http.ResponseWriter, r *http.Request) { + validateHeaders(r) w.Write([]byte(region)) }, + WriteIMDSv2SessionToken: func(w http.ResponseWriter, r *http.Request) { + validateHeaders(r) + w.Write([]byte(imdsv2SessionToken)) + }, } - server.WriteSecurityCredentials = func(w http.ResponseWriter) { + server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { + validateHeaders(r) jsonCredentials, _ := json.Marshal(server.Credentials) w.Write(jsonCredentials) } @@ -450,6 +461,7 @@ func createDefaultAwsTestServer() *testAwsServer { "/latest/meta-data/iam/security-credentials", "/latest/meta-data/placement/availability-zone", "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "", "gcp-aws-role", "us-east-2b", map[string]string{ @@ -457,31 +469,38 @@ func createDefaultAwsTestServer() *testAwsServer { "AccessKeyId": accessKeyID, "Token": securityToken, }, + "", + noHeaderValidation, ) } func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch p := r.URL.Path; p { case server.url: - server.WriteRolename(w) + server.WriteRolename(w, r) case server.securityCredentialURL: - server.WriteSecurityCredentials(w) + server.WriteSecurityCredentials(w, r) case server.regionURL: - server.WriteRegion(w) + server.WriteRegion(w, r) + case server.imdsv2SessionTokenUrl: + server.WriteIMDSv2SessionToken(w, r) } } -func notFound(w http.ResponseWriter) { +func notFound(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte("Not Found")) } +func noHeaderValidation(r *http.Request) {} + func (server *testAwsServer) getCredentialSource(url string) CredentialSource { return CredentialSource{ EnvironmentID: "aws1", URL: url + server.url, RegionURL: url + server.regionURL, RegionalCredVerificationURL: server.regionalCredVerificationURL, + IMDSv2SessionTokenURL: url + server.imdsv2SessionTokenUrl, } } @@ -531,7 +550,7 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security return neturl.QueryEscape(string(str)) } -func TestAwsCredential_BasicRequest(t *testing.T) { +func TestAWSCredential_BasicRequest(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -541,6 +560,9 @@ func TestAwsCredential_BasicRequest(t *testing.T) { oldGetenv := getenv defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) + oldNow := now + defer func() { now = oldNow }() + now = setTime(defaultTime) base, err := tfc.parse(context.Background()) if err != nil { @@ -561,11 +583,76 @@ func TestAwsCredential_BasicRequest(t *testing.T) { ) if got, want := out, expected; !reflect.DeepEqual(got, want) { - t.Errorf("subjectToken = %q, want %q", got, want) + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) } } -func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) { +func TestAWSCredential_IMDSv2(t *testing.T) { + validateSessionTokenHeaders := func(r *http.Request) { + if r.URL.Path == "/latest/api/token" { + headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) + if headerValue != awsIMDSv2SessionTtl { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) + } + } else { + headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) + if headerValue != "sessiontoken" { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") + } + } + } + + server := createAwsTestServer( + "/latest/meta-data/iam/security-credentials", + "/latest/meta-data/placement/availability-zone", + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "/latest/api/token", + "gcp-aws-role", + "us-east-2b", + map[string]string{ + "SecretAccessKey": secretAccessKey, + "AccessKeyId": accessKeyID, + "Token": securityToken, + }, + "sessiontoken", + validateSessionTokenHeaders, + ) + ts := httptest.NewServer(server) + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + defer func() { getenv = oldGetenv }() + getenv = setEnvironment(map[string]string{}) + oldNow := now + defer func() { now = oldNow }() + now = setTime(defaultTime) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) delete(server.Credentials, "Token") @@ -576,6 +663,9 @@ func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) { oldGetenv := getenv defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) + oldNow := now + defer func() { now = oldNow }() + now = setTime(defaultTime) base, err := tfc.parse(context.Background()) if err != nil { @@ -596,11 +686,11 @@ func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) { ) if got, want := out, expected; !reflect.DeepEqual(got, want) { - t.Errorf("subjectToken = %q, want %q", got, want) + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) } } -func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { +func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -614,6 +704,9 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_REGION": "us-west-1", }) + oldNow := now + defer func() { now = oldNow }() + now = setTime(defaultTime) base, err := tfc.parse(context.Background()) if err != nil { @@ -634,11 +727,92 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { ) if got, want := out, expected; !reflect.DeepEqual(got, want) { - t.Errorf("subjectToken = %q, want %q", got, want) + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) } } -func TestAwsCredential_RequestWithBadVersion(t *testing.T) { +func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { + server := createDefaultAwsTestServer() + ts := httptest.NewServer(server) + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + defer func() { getenv = oldGetenv }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_DEFAULT_REGION": "us-west-1", + }) + oldNow := now + defer func() { now = oldNow }() + now = setTime(defaultTime) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + "AKIDEXAMPLE", + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { + server := createDefaultAwsTestServer() + ts := httptest.NewServer(server) + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + defer func() { getenv = oldGetenv }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + "AWS_DEFAULT_REGION": "us-east-1", + }) + oldNow := now + defer func() { now = oldNow }() + now = setTime(defaultTime) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + "AKIDEXAMPLE", + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_RequestWithBadVersion(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -659,7 +833,7 @@ func TestAwsCredential_RequestWithBadVersion(t *testing.T) { } } -func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) { +func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -686,7 +860,7 @@ func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) { } } -func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) { +func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) server.WriteRegion = notFound @@ -713,10 +887,10 @@ func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) { } } -func TestAwsCredential_RequestWithMissingCredential(t *testing.T) { +func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) - server.WriteSecurityCredentials = func(w http.ResponseWriter) { + server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("{}")) } @@ -742,10 +916,10 @@ func TestAwsCredential_RequestWithMissingCredential(t *testing.T) { } } -func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) { +func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) - server.WriteSecurityCredentials = func(w http.ResponseWriter) { + server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) } @@ -771,7 +945,7 @@ func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) { } } -func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) { +func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -798,7 +972,7 @@ func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) { } } -func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) { +func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) server.WriteRolename = notFound @@ -825,7 +999,7 @@ func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) { } } -func TestAwsCredential_RequestWithBadFinalCredentialURL(t *testing.T) { +func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) server.WriteSecurityCredentials = notFound diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 1a6e93c..83ce9c2 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -7,10 +7,14 @@ package externalaccount import ( "context" "fmt" - "golang.org/x/oauth2" "net/http" + "net/url" + "regexp" "strconv" + "strings" "time" + + "golang.org/x/oauth2" ) // now aliases time.Now for testing @@ -20,36 +24,129 @@ var now = func() time.Time { // Config stores the configuration for fetching tokens with external credentials. type Config struct { - Audience string - SubjectTokenType string - TokenURL string - TokenInfoURL string + // Audience is the Secure Token Service (STS) audience which contains the resource name for the workload + // identity pool or the workforce pool and the provider identifier in that pool. + Audience string + // SubjectTokenType is the STS token type based on the Oauth2.0 token exchange spec + // e.g. `urn:ietf:params:oauth:token-type:jwt`. + SubjectTokenType string + // TokenURL is the STS token exchange endpoint. + TokenURL string + // TokenInfoURL is the token_info endpoint used to retrieve the account related information ( + // user attributes like account identifier, eg. email, username, uid, etc). This is + // needed for gCloud session account identification. + TokenInfoURL string + // ServiceAccountImpersonationURL is the URL for the service account impersonation request. This is only + // required for workload identity pools when APIs to be accessed have not integrated with UberMint. ServiceAccountImpersonationURL string - ClientSecret string - ClientID string - CredentialSource CredentialSource - QuotaProjectID string - Scopes []string + // ClientSecret is currently only required if token_info endpoint also + // needs to be called with the generated GCP access token. When provided, STS will be + // called with additional basic authentication using client_id as username and client_secret as password. + ClientSecret string + // ClientID is only required in conjunction with ClientSecret, as described above. + ClientID string + // CredentialSource contains the necessary information to retrieve the token itself, as well + // as some environmental information. + CredentialSource CredentialSource + // QuotaProjectID is injected by gCloud. If the value is non-empty, the Auth libraries + // will set the x-goog-user-project which overrides the project associated with the credentials. + QuotaProjectID string + // Scopes contains the desired scopes for the returned access token. + Scopes []string + // The optional workforce pool user project number when the credential + // corresponds to a workforce pool and not a workload identity pool. + // The underlying principal must still have serviceusage.services.use IAM + // permission to use the project for billing/quota. + WorkforcePoolUserProject string +} + +// Each element consists of a list of patterns. validateURLs checks for matches +// that include all elements in a given list, in that order. + +var ( + validTokenURLPatterns = []*regexp.Regexp{ + // The complicated part in the middle matches any number of characters that + // aren't period, spaces, or slashes. + regexp.MustCompile(`(?i)^[^\.\s\/\\]+\.sts\.googleapis\.com$`), + regexp.MustCompile(`(?i)^sts\.googleapis\.com$`), + regexp.MustCompile(`(?i)^sts\.[^\.\s\/\\]+\.googleapis\.com$`), + regexp.MustCompile(`(?i)^[^\.\s\/\\]+-sts\.googleapis\.com$`), + } + validImpersonateURLPatterns = []*regexp.Regexp{ + regexp.MustCompile(`^[^\.\s\/\\]+\.iamcredentials\.googleapis\.com$`), + regexp.MustCompile(`^iamcredentials\.googleapis\.com$`), + regexp.MustCompile(`^iamcredentials\.[^\.\s\/\\]+\.googleapis\.com$`), + regexp.MustCompile(`^[^\.\s\/\\]+-iamcredentials\.googleapis\.com$`), + } + validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`) +) + +func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool { + parsed, err := url.Parse(input) + if err != nil { + return false + } + if !strings.EqualFold(parsed.Scheme, scheme) { + return false + } + toTest := parsed.Host + + for _, pattern := range patterns { + if pattern.MatchString(toTest) { + return true + } + } + return false +} + +func validateWorkforceAudience(input string) bool { + return validWorkforceAudiencePattern.MatchString(input) } // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials. -func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { +func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns, "https") +} + +// tokenSource is a private function that's directly called by some of the tests, +// because the unit test URLs are mocked, and would otherwise fail the +// validity check. +func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Regexp, impersonateURLValidPats []*regexp.Regexp, scheme string) (oauth2.TokenSource, error) { + valid := validateURL(c.TokenURL, tokenURLValidPats, scheme) + if !valid { + return nil, fmt.Errorf("oauth2/google: invalid TokenURL provided while constructing tokenSource") + } + + if c.ServiceAccountImpersonationURL != "" { + valid := validateURL(c.ServiceAccountImpersonationURL, impersonateURLValidPats, scheme) + if !valid { + return nil, fmt.Errorf("oauth2/google: invalid ServiceAccountImpersonationURL provided while constructing tokenSource") + } + } + + if c.WorkforcePoolUserProject != "" { + valid := validateWorkforceAudience(c.Audience) + if !valid { + return nil, fmt.Errorf("oauth2/google: workforce_pool_user_project should not be set for non-workforce pool credentials") + } + } + ts := tokenSource{ ctx: ctx, conf: c, } if c.ServiceAccountImpersonationURL == "" { - return oauth2.ReuseTokenSource(nil, ts) + return oauth2.ReuseTokenSource(nil, ts), nil } scopes := c.Scopes ts.conf.Scopes = []string{"https://www.googleapis.com/auth/cloud-platform"} - imp := impersonateTokenSource{ - ctx: ctx, - url: c.ServiceAccountImpersonationURL, - scopes: scopes, - ts: oauth2.ReuseTokenSource(nil, ts), + imp := ImpersonateTokenSource{ + Ctx: ctx, + URL: c.ServiceAccountImpersonationURL, + Scopes: scopes, + Ts: oauth2.ReuseTokenSource(nil, ts), } - return oauth2.ReuseTokenSource(nil, imp) + return oauth2.ReuseTokenSource(nil, imp), nil } // Subject token file types. @@ -59,13 +156,15 @@ const ( ) type format struct { - // Type is either "text" or "json". When not provided "text" type is assumed. + // Type is either "text" or "json". When not provided "text" type is assumed. Type string `json:"type"` - // SubjectTokenFieldName is only required for JSON format. This would be "access_token" for azure. + // SubjectTokenFieldName is only required for JSON format. This would be "access_token" for azure. SubjectTokenFieldName string `json:"subject_token_field_name"` } // CredentialSource stores the information necessary to retrieve the credentials for the STS exchange. +// Either the File or the URL field should be filled, depending on the kind of credential in question. +// The EnvironmentID should start with AWS if being used for an AWS credential. type CredentialSource struct { File string `json:"file"` @@ -76,6 +175,7 @@ type CredentialSource struct { RegionURL string `json:"region_url"` RegionalCredVerificationURL string `json:"regional_cred_verification_url"` CredVerificationURL string `json:"cred_verification_url"` + IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"` Format format `json:"format"` } @@ -86,14 +186,20 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { if awsVersion != 1 { return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion) } - return awsCredentialSource{ + + awsCredSource := awsCredentialSource{ EnvironmentID: c.CredentialSource.EnvironmentID, RegionURL: c.CredentialSource.RegionURL, RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, CredVerificationURL: c.CredentialSource.URL, TargetResource: c.Audience, ctx: ctx, - }, nil + } + if c.CredentialSource.IMDSv2SessionTokenURL != "" { + awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL + } + + return awsCredSource, nil } } else if c.CredentialSource.File != "" { return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil @@ -107,7 +213,7 @@ type baseCredentialSource interface { subjectToken() (string, error) } -// tokenSource is the source that handles external credentials. +// tokenSource is the source that handles external credentials. It is used to retrieve Tokens. type tokenSource struct { ctx context.Context conf *Config @@ -141,7 +247,15 @@ func (ts tokenSource) Token() (*oauth2.Token, error) { ClientID: conf.ClientID, ClientSecret: conf.ClientSecret, } - stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, nil) + var options map[string]interface{} + // Do not pass workforce_pool_user_project when client authentication is used. + // The client ID is sufficient for determining the user project. + if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" { + options = map[string]interface{}{ + "userProject": conf.WorkforcePoolUserProject, + } + } + stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options) if err != nil { return nil, err } diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index 1ebb227..5aa0d46 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -9,8 +9,11 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" "time" + + "golang.org/x/oauth2" ) const ( @@ -34,55 +37,64 @@ var testConfig = Config{ } var ( - baseCredsRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt" - baseCredsResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}` - correctAT = "Sample.Access.Token" - expiry int64 = 234852 + baseCredsRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token" + baseCredsResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}` + workforcePoolRequestBodyWithClientId = "audience=%2F%2Fiam.googleapis.com%2Flocations%2Feu%2FworkforcePools%2Fpool-id%2Fproviders%2Fprovider-id&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token" + workforcePoolRequestBodyWithoutClientId = "audience=%2F%2Fiam.googleapis.com%2Flocations%2Feu%2FworkforcePools%2Fpool-id%2Fproviders%2Fprovider-id&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&options=%7B%22userProject%22%3A%22myProject%22%7D&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token" + correctAT = "Sample.Access.Token" + expiry int64 = 234852 ) var ( testNow = func() time.Time { return time.Unix(expiry, 0) } ) -func TestToken(t *testing.T) { +type testExchangeTokenServer struct { + url string + authorization string + contentType string + body string + response string +} - targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if got, want := r.URL.String(), "/"; got != want { +func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.Token, error) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got, want := r.URL.String(), tets.url; got != want { t.Errorf("URL.String(): got %v but want %v", got, want) } headerAuth := r.Header.Get("Authorization") - if got, want := headerAuth, "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want { + if got, want := headerAuth, tets.authorization; got != want { t.Errorf("got %v but want %v", got, want) } headerContentType := r.Header.Get("Content-Type") - if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want { + if got, want := headerContentType, tets.contentType; got != want { t.Errorf("got %v but want %v", got, want) } body, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatalf("Failed reading request body: %s.", err) } - if got, want := string(body), baseCredsRequestBody; got != want { + if got, want := string(body), tets.body; got != want { t.Errorf("Unexpected exchange payload: got %v but want %v", got, want) } w.Header().Set("Content-Type", "application/json") - w.Write([]byte(baseCredsResponseBody)) + w.Write([]byte(tets.response)) })) - defer targetServer.Close() - - testConfig.TokenURL = targetServer.URL - ourTS := tokenSource{ - ctx: context.Background(), - conf: &testConfig, - } + defer server.Close() + config.TokenURL = server.URL oldNow := now defer func() { now = oldNow }() now = testNow - tok, err := ourTS.Token() - if err != nil { - t.Fatalf("Unexpected error: %e", err) + ts := tokenSource{ + ctx: context.Background(), + conf: config, } + + return ts.Token() +} + +func validateToken(t *testing.T, tok *oauth2.Token) { if got, want := tok.AccessToken, correctAT; got != want { t.Errorf("Unexpected access token: got %v, but wanted %v", got, want) } @@ -90,8 +102,260 @@ func TestToken(t *testing.T) { t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want) } - if got, want := tok.Expiry, now().Add(time.Duration(3600)*time.Second); got != want { + if got, want := tok.Expiry, testNow().Add(time.Duration(3600)*time.Second); got != want { t.Errorf("Unexpected Expiry: got %v, but wanted %v", got, want) } - +} + +func TestToken(t *testing.T) { + config := Config{ + Audience: "32555940559.apps.googleusercontent.com", + SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + ClientSecret: "notsosecret", + ClientID: "rbrgnognrhongo3bi4gb9ghg9g", + CredentialSource: testBaseCredSource, + Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, + } + + server := testExchangeTokenServer{ + url: "/", + authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=", + contentType: "application/x-www-form-urlencoded", + body: baseCredsRequestBody, + response: baseCredsResponseBody, + } + + tok, err := run(t, &config, &server) + + if err != nil { + t.Fatalf("Unexpected error: %e", err) + } + validateToken(t, tok) +} + +func TestWorkforcePoolTokenWithClientID(t *testing.T) { + config := Config{ + Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", + SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + ClientSecret: "notsosecret", + ClientID: "rbrgnognrhongo3bi4gb9ghg9g", + CredentialSource: testBaseCredSource, + Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, + WorkforcePoolUserProject: "myProject", + } + + server := testExchangeTokenServer{ + url: "/", + authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=", + contentType: "application/x-www-form-urlencoded", + body: workforcePoolRequestBodyWithClientId, + response: baseCredsResponseBody, + } + + tok, err := run(t, &config, &server) + + if err != nil { + t.Fatalf("Unexpected error: %e", err) + } + validateToken(t, tok) +} + +func TestWorkforcePoolTokenWithoutClientID(t *testing.T) { + config := Config{ + Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", + SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + ClientSecret: "notsosecret", + CredentialSource: testBaseCredSource, + Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, + WorkforcePoolUserProject: "myProject", + } + + server := testExchangeTokenServer{ + url: "/", + authorization: "", + contentType: "application/x-www-form-urlencoded", + body: workforcePoolRequestBodyWithoutClientId, + response: baseCredsResponseBody, + } + + tok, err := run(t, &config, &server) + + if err != nil { + t.Fatalf("Unexpected error: %e", err) + } + validateToken(t, tok) +} + +func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) { + config := Config{ + Audience: "32555940559.apps.googleusercontent.com", + SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + TokenURL: "https://sts.googleapis.com", + ClientSecret: "notsosecret", + ClientID: "rbrgnognrhongo3bi4gb9ghg9g", + CredentialSource: testBaseCredSource, + Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, + WorkforcePoolUserProject: "myProject", + } + + _, err := config.TokenSource(context.Background()) + + if err == nil { + t.Fatalf("Expected error but found none") + } + if got, want := err.Error(), "oauth2/google: workforce_pool_user_project should not be set for non-workforce pool credentials"; got != want { + t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got) + } +} + +func TestValidateURLTokenURL(t *testing.T) { + var urlValidityTests = []struct { + tokURL string + expectSuccess bool + }{ + {"https://east.sts.googleapis.com", true}, + {"https://sts.googleapis.com", true}, + {"https://sts.asfeasfesef.googleapis.com", true}, + {"https://us-east-1-sts.googleapis.com", true}, + {"https://sts.googleapis.com/your/path/here", true}, + {"https://.sts.googleapis.com", false}, + {"https://badsts.googleapis.com", false}, + {"https://sts.asfe.asfesef.googleapis.com", false}, + {"https://sts..googleapis.com", false}, + {"https://-sts.googleapis.com", false}, + {"https://us-ea.st-1-sts.googleapis.com", false}, + {"https://sts.googleapis.com.evil.com/whatever/path", false}, + {"https://us-eas\\t-1.sts.googleapis.com", false}, + {"https:/us-ea/st-1.sts.googleapis.com", false}, + {"https:/us-east 1.sts.googleapis.com", false}, + {"https://", false}, + {"http://us-east-1.sts.googleapis.com", false}, + {"https://us-east-1.sts.googleapis.comevil.com", false}, + } + ctx := context.Background() + for _, tt := range urlValidityTests { + t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. + config := testConfig + config.TokenURL = tt.tokURL + _, err := config.TokenSource(ctx) + + if tt.expectSuccess && err != nil { + t.Errorf("got %v but want nil", err) + } else if !tt.expectSuccess && err == nil { + t.Errorf("got nil but expected an error") + } + }) + } + for _, el := range urlValidityTests { + el.tokURL = strings.ToUpper(el.tokURL) + } + for _, tt := range urlValidityTests { + t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. + config := testConfig + config.TokenURL = tt.tokURL + _, err := config.TokenSource(ctx) + + if tt.expectSuccess && err != nil { + t.Errorf("got %v but want nil", err) + } else if !tt.expectSuccess && err == nil { + t.Errorf("got nil but expected an error") + } + }) + } +} + +func TestValidateURLImpersonateURL(t *testing.T) { + var urlValidityTests = []struct { + impURL string + expectSuccess bool + }{ + {"https://east.iamcredentials.googleapis.com", true}, + {"https://iamcredentials.googleapis.com", true}, + {"https://iamcredentials.asfeasfesef.googleapis.com", true}, + {"https://us-east-1-iamcredentials.googleapis.com", true}, + {"https://iamcredentials.googleapis.com/your/path/here", true}, + {"https://.iamcredentials.googleapis.com", false}, + {"https://badiamcredentials.googleapis.com", false}, + {"https://iamcredentials.asfe.asfesef.googleapis.com", false}, + {"https://iamcredentials..googleapis.com", false}, + {"https://-iamcredentials.googleapis.com", false}, + {"https://us-ea.st-1-iamcredentials.googleapis.com", false}, + {"https://iamcredentials.googleapis.com.evil.com/whatever/path", false}, + {"https://us-eas\\t-1.iamcredentials.googleapis.com", false}, + {"https:/us-ea/st-1.iamcredentials.googleapis.com", false}, + {"https:/us-east 1.iamcredentials.googleapis.com", false}, + {"https://", false}, + {"http://us-east-1.iamcredentials.googleapis.com", false}, + {"https://us-east-1.iamcredentials.googleapis.comevil.com", false}, + } + ctx := context.Background() + for _, tt := range urlValidityTests { + t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. + config := testConfig + config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL + config.ServiceAccountImpersonationURL = tt.impURL + _, err := config.TokenSource(ctx) + + if tt.expectSuccess && err != nil { + t.Errorf("got %v but want nil", err) + } else if !tt.expectSuccess && err == nil { + t.Errorf("got nil but expected an error") + } + }) + } + for _, el := range urlValidityTests { + el.impURL = strings.ToUpper(el.impURL) + } + for _, tt := range urlValidityTests { + t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. + config := testConfig + config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL + config.ServiceAccountImpersonationURL = tt.impURL + _, err := config.TokenSource(ctx) + + if tt.expectSuccess && err != nil { + t.Errorf("got %v but want nil", err) + } else if !tt.expectSuccess && err == nil { + t.Errorf("got nil but expected an error") + } + }) + } +} + +func TestWorkforcePoolCreation(t *testing.T) { + var audienceValidatyTests = []struct { + audience string + expectSuccess bool + }{ + {"//iam.googleapis.com/locations/global/workforcePools/pool-id/providers/provider-id", true}, + {"//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", true}, + {"//iam.googleapis.com/locations/eu/workforcePools/workloadIdentityPools/providers/provider-id", true}, + {"identitynamespace:1f12345:my_provider", false}, + {"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/projects/123456/locations/eu/workloadIdentityPools/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/workforcePools/providers/provider-id", false}, + {"//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", false}, + {"//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", false}, + {"//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", false}, + } + + ctx := context.Background() + for _, tt := range audienceValidatyTests { + t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. + config := testConfig + config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL + config.ServiceAccountImpersonationURL = "https://iamcredentials.googleapis.com" + config.Audience = tt.audience + config.WorkforcePoolUserProject = "myProject" + _, err := config.TokenSource(ctx) + + if tt.expectSuccess && err != nil { + t.Errorf("got %v but want nil", err) + } else if !tt.expectSuccess && err == nil { + t.Errorf("got nil but expected an error") + } + }) + } } diff --git a/google/internal/externalaccount/clientauth.go b/google/internal/externalaccount/clientauth.go index feccf8b..99987ce 100644 --- a/google/internal/externalaccount/clientauth.go +++ b/google/internal/externalaccount/clientauth.go @@ -6,9 +6,10 @@ package externalaccount import ( "encoding/base64" - "golang.org/x/oauth2" "net/http" "net/url" + + "golang.org/x/oauth2" ) // clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1. @@ -19,6 +20,9 @@ type clientAuthentication struct { ClientSecret string } +// InjectAuthentication is used to add authentication to a Secure Token Service exchange +// request. It modifies either the passed url.Values or http.Header depending on the desired +// authentication format. func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) { if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil { return diff --git a/google/internal/externalaccount/clientauth_test.go b/google/internal/externalaccount/clientauth_test.go index 38633e3..bfb339d 100644 --- a/google/internal/externalaccount/clientauth_test.go +++ b/google/internal/externalaccount/clientauth_test.go @@ -5,11 +5,12 @@ package externalaccount import ( - "golang.org/x/oauth2" "net/http" "net/url" "reflect" "testing" + + "golang.org/x/oauth2" ) var clientID = "rbrgnognrhongo3bi4gb9ghg9g" diff --git a/google/internal/externalaccount/impersonate.go b/google/internal/externalaccount/impersonate.go index 1d29c46..8251fc8 100644 --- a/google/internal/externalaccount/impersonate.go +++ b/google/internal/externalaccount/impersonate.go @@ -9,11 +9,12 @@ import ( "context" "encoding/json" "fmt" - "golang.org/x/oauth2" "io" "io/ioutil" "net/http" "time" + + "golang.org/x/oauth2" ) // generateAccesstokenReq is used for service account impersonation @@ -28,30 +29,44 @@ type impersonateTokenResponse struct { ExpireTime string `json:"expireTime"` } -type impersonateTokenSource struct { - ctx context.Context - ts oauth2.TokenSource +// ImpersonateTokenSource uses a source credential, stored in Ts, to request an access token to the provided URL. +// Scopes can be defined when the access token is requested. +type ImpersonateTokenSource struct { + // Ctx is the execution context of the impersonation process + // used to perform http call to the URL. Required + Ctx context.Context + // Ts is the source credential used to generate a token on the + // impersonated service account. Required. + Ts oauth2.TokenSource - url string - scopes []string + // URL is the endpoint to call to generate a token + // on behalf the service account. Required. + URL string + // Scopes that the impersonated credential should have. Required. + Scopes []string + // Delegates are the service account email addresses in a delegation chain. + // Each service account must be granted roles/iam.serviceAccountTokenCreator + // on the next service account in the chain. Optional. + Delegates []string } -// Token performs the exchange to get a temporary service account -func (its impersonateTokenSource) Token() (*oauth2.Token, error) { +// Token performs the exchange to get a temporary service account token to allow access to GCP. +func (its ImpersonateTokenSource) Token() (*oauth2.Token, error) { reqBody := generateAccessTokenReq{ - Lifetime: "3600s", - Scope: its.scopes, + Lifetime: "3600s", + Scope: its.Scopes, + Delegates: its.Delegates, } b, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("oauth2/google: unable to marshal request: %v", err) } - client := oauth2.NewClient(its.ctx, its.ts) - req, err := http.NewRequest("POST", its.url, bytes.NewReader(b)) + client := oauth2.NewClient(its.Ctx, its.Ts) + req, err := http.NewRequest("POST", its.URL, bytes.NewReader(b)) if err != nil { return nil, fmt.Errorf("oauth2/google: unable to create impersonation request: %v", err) } - req = req.WithContext(its.ctx) + req = req.WithContext(its.Ctx) req.Header.Set("Content-Type", "application/json") resp, err := client.Do(req) diff --git a/google/internal/externalaccount/impersonate_test.go b/google/internal/externalaccount/impersonate_test.go index 197fe3c..6fed7b9 100644 --- a/google/internal/externalaccount/impersonate_test.go +++ b/google/internal/externalaccount/impersonate_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "regexp" "testing" ) @@ -76,7 +77,11 @@ func TestImpersonation(t *testing.T) { defer targetServer.Close() testImpersonateConfig.TokenURL = targetServer.URL - ourTS := testImpersonateConfig.TokenSource(context.Background()) + allURLs := regexp.MustCompile(".+") + ourTS, err := testImpersonateConfig.tokenSource(context.Background(), []*regexp.Regexp{allURLs}, []*regexp.Regexp{allURLs}, "http") + if err != nil { + t.Fatalf("Failed to create TokenSource: %v", err) + } oldNow := now defer func() { now = oldNow }() diff --git a/google/internal/externalaccount/sts_exchange.go b/google/internal/externalaccount/sts_exchange.go index a8a704b..e6fcae5 100644 --- a/google/internal/externalaccount/sts_exchange.go +++ b/google/internal/externalaccount/sts_exchange.go @@ -65,6 +65,9 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan defer resp.Body.Close() body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, err + } if c := resp.StatusCode; c < 200 || c > 299 { return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body) } diff --git a/google/internal/externalaccount/sts_exchange_test.go b/google/internal/externalaccount/sts_exchange_test.go index 3d498c6..df4d5ff 100644 --- a/google/internal/externalaccount/sts_exchange_test.go +++ b/google/internal/externalaccount/sts_exchange_test.go @@ -7,12 +7,13 @@ package externalaccount import ( "context" "encoding/json" - "golang.org/x/oauth2" "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" + + "golang.org/x/oauth2" ) var auth = clientAuthentication{ @@ -127,6 +128,9 @@ func TestExchangeToken_Opts(t *testing.T) { } var opts map[string]interface{} err = json.Unmarshal([]byte(strOpts[0]), &opts) + if err != nil { + t.Fatalf("Couldn't parse received \"options\" field.") + } if len(opts) < 2 { t.Errorf("Too few options received.") } diff --git a/google/internal/externalaccount/urlcredsource.go b/google/internal/externalaccount/urlcredsource.go index 91b8f20..16dca65 100644 --- a/google/internal/externalaccount/urlcredsource.go +++ b/google/internal/externalaccount/urlcredsource.go @@ -9,10 +9,11 @@ import ( "encoding/json" "errors" "fmt" - "golang.org/x/oauth2" "io" "io/ioutil" "net/http" + + "golang.org/x/oauth2" ) type urlCredentialSource struct { diff --git a/google/internal/externalaccount/urlcredsource_test.go b/google/internal/externalaccount/urlcredsource_test.go index 8ade2a2..6a36d0d 100644 --- a/google/internal/externalaccount/urlcredsource_test.go +++ b/google/internal/externalaccount/urlcredsource_test.go @@ -7,7 +7,6 @@ package externalaccount import ( "context" "encoding/json" - "fmt" "net/http" "net/http/httptest" "testing" @@ -20,7 +19,6 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) { if r.Method != "GET" { t.Errorf("Unexpected request method, %v is found", r.Method) } - fmt.Println(r.Header) if r.Header.Get("Metadata") != "True" { t.Errorf("Metadata header not properly included.") } diff --git a/google/jwt.go b/google/jwt.go index b0fdb3a..67d97b9 100644 --- a/google/jwt.go +++ b/google/jwt.go @@ -7,6 +7,7 @@ package google import ( "crypto/rsa" "fmt" + "strings" "time" "golang.org/x/oauth2" @@ -24,6 +25,28 @@ import ( // optimization supported by a few Google services. // Unless you know otherwise, you should use JWTConfigFromJSON instead. func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) { + return newJWTSource(jsonKey, audience, nil) +} + +// JWTAccessTokenSourceWithScope uses a Google Developers service account JSON +// key file to read the credentials that authorize and authenticate the +// requests, and returns a TokenSource that does not use any OAuth2 flow but +// instead creates a JWT and sends that as the access token. +// The scope is typically a list of URLs that specifies the scope of the +// credentials. +// +// Note that this is not a standard OAuth flow, but rather an +// optimization supported by a few Google services. +// Unless you know otherwise, you should use JWTConfigFromJSON instead. +func JWTAccessTokenSourceWithScope(jsonKey []byte, scope ...string) (oauth2.TokenSource, error) { + return newJWTSource(jsonKey, "", scope) +} + +func newJWTSource(jsonKey []byte, audience string, scopes []string) (oauth2.TokenSource, error) { + if len(scopes) == 0 && audience == "" { + return nil, fmt.Errorf("google: missing scope/audience for JWT access token") + } + cfg, err := JWTConfigFromJSON(jsonKey) if err != nil { return nil, fmt.Errorf("google: could not parse JSON key: %v", err) @@ -35,6 +58,7 @@ func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.Token ts := &jwtAccessTokenSource{ email: cfg.Email, audience: audience, + scopes: scopes, pk: pk, pkID: cfg.PrivateKeyID, } @@ -47,6 +71,7 @@ func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.Token type jwtAccessTokenSource struct { email, audience string + scopes []string pk *rsa.PrivateKey pkID string } @@ -54,12 +79,14 @@ type jwtAccessTokenSource struct { func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) { iat := time.Now() exp := iat.Add(time.Hour) + scope := strings.Join(ts.scopes, " ") cs := &jws.ClaimSet{ - Iss: ts.email, - Sub: ts.email, - Aud: ts.audience, - Iat: iat.Unix(), - Exp: exp.Unix(), + Iss: ts.email, + Sub: ts.email, + Aud: ts.audience, + Scope: scope, + Iat: iat.Unix(), + Exp: exp.Unix(), } hdr := &jws.Header{ Algorithm: "RS256", diff --git a/google/jwt_test.go b/google/jwt_test.go index f844436..043f445 100644 --- a/google/jwt_test.go +++ b/google/jwt_test.go @@ -13,29 +13,21 @@ import ( "encoding/json" "encoding/pem" "strings" + "sync" "testing" "time" "golang.org/x/oauth2/jws" ) -func TestJWTAccessTokenSourceFromJSON(t *testing.T) { - // Generate a key we can use in the test data. - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatal(err) - } +var ( + privateKey *rsa.PrivateKey + jsonKey []byte + once sync.Once +) - // Encode the key and substitute into our example JSON. - enc := pem.EncodeToMemory(&pem.Block{ - Type: "PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(privateKey), - }) - enc, err = json.Marshal(string(enc)) - if err != nil { - t.Fatalf("json.Marshal: %v", err) - } - jsonKey := bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1) +func TestJWTAccessTokenSourceFromJSON(t *testing.T) { + setupDummyKey(t) ts, err := JWTAccessTokenSourceFromJSON(jsonKey, "audience") if err != nil { @@ -89,3 +81,80 @@ func TestJWTAccessTokenSourceFromJSON(t *testing.T) { t.Errorf("Header KeyID = %q, want %q", got, want) } } + +func TestJWTAccessTokenSourceWithScope(t *testing.T) { + setupDummyKey(t) + + ts, err := JWTAccessTokenSourceWithScope(jsonKey, "scope1", "scope2") + if err != nil { + t.Fatalf("JWTAccessTokenSourceWithScope: %v\nJSON: %s", err, string(jsonKey)) + } + + tok, err := ts.Token() + if err != nil { + t.Fatalf("Token: %v", err) + } + + if got, want := tok.TokenType, "Bearer"; got != want { + t.Errorf("TokenType = %q, want %q", got, want) + } + if got := tok.Expiry; tok.Expiry.Before(time.Now()) { + t.Errorf("Expiry = %v, should not be expired", got) + } + + err = jws.Verify(tok.AccessToken, &privateKey.PublicKey) + if err != nil { + t.Errorf("jws.Verify on AccessToken: %v", err) + } + + claim, err := jws.Decode(tok.AccessToken) + if err != nil { + t.Fatalf("jws.Decode on AccessToken: %v", err) + } + + if got, want := claim.Iss, "gopher@developer.gserviceaccount.com"; got != want { + t.Errorf("Iss = %q, want %q", got, want) + } + if got, want := claim.Sub, "gopher@developer.gserviceaccount.com"; got != want { + t.Errorf("Sub = %q, want %q", got, want) + } + if got, want := claim.Scope, "scope1 scope2"; got != want { + t.Errorf("Aud = %q, want %q", got, want) + } + + // Finally, check the header private key. + parts := strings.Split(tok.AccessToken, ".") + hdrJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + t.Fatalf("base64 DecodeString: %v\nString: %q", err, parts[0]) + } + var hdr jws.Header + if err := json.Unmarshal([]byte(hdrJSON), &hdr); err != nil { + t.Fatalf("json.Unmarshal: %v (%q)", err, hdrJSON) + } + + if got, want := hdr.KeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want { + t.Errorf("Header KeyID = %q, want %q", got, want) + } +} + +func setupDummyKey(t *testing.T) { + once.Do(func() { + // Generate a key we can use in the test data. + pk, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + privateKey = pk + // Encode the key and substitute into our example JSON. + enc := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + enc, err = json.Marshal(string(enc)) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + jsonKey = bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1) + }) +}