Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Andy Zhao 2022-06-06 12:39:56 -07:00
commit 0127caffc1
24 changed files with 1304 additions and 174 deletions

View File

@ -17,6 +17,12 @@ var Amazon = oauth2.Endpoint{
TokenURL: "https://api.amazon.com/auth/o2/token", TokenURL: "https://api.amazon.com/auth/o2/token",
} }
// Battlenet is the endpoint for Battlenet.
var Battlenet = oauth2.Endpoint{
AuthURL: "https://battle.net/oauth/authorize",
TokenURL: "https://battle.net/oauth/token",
}
// Bitbucket is the endpoint for Bitbucket. // Bitbucket is the endpoint for Bitbucket.
var Bitbucket = oauth2.Endpoint{ var Bitbucket = oauth2.Endpoint{
AuthURL: "https://bitbucket.org/site/oauth2/authorize", AuthURL: "https://bitbucket.org/site/oauth2/authorize",
@ -167,6 +173,12 @@ var StackOverflow = oauth2.Endpoint{
TokenURL: "https://stackoverflow.com/oauth/access_token", TokenURL: "https://stackoverflow.com/oauth/access_token",
} }
// Strava is the endpoint for Strava.
var Strava = oauth2.Endpoint{
AuthURL: "https://www.strava.com/oauth/authorize",
TokenURL: "https://www.strava.com/oauth/token",
}
// Twitch is the endpoint for Twitch. // Twitch is the endpoint for Twitch.
var Twitch = oauth2.Endpoint{ var Twitch = oauth2.Endpoint{
AuthURL: "https://id.twitch.tv/oauth2/authorize", AuthURL: "https://id.twitch.tv/oauth2/authorize",

2
go.mod
View File

@ -4,6 +4,6 @@ go 1.11
require ( require (
cloud.google.com/go v0.65.0 cloud.google.com/go v0.65.0
golang.org/x/net v0.0.0-20200822124328-c89045814202 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
google.golang.org/appengine v1.6.6 google.golang.org/appengine v1.6.6
) )

7
go.sum
View File

@ -177,8 +177,9 @@ golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/
golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -217,11 +218,15 @@ golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@ -94,20 +94,20 @@ func DefaultTokenSource(ctx context.Context, scope ...string) (oauth2.TokenSourc
// It looks for credentials in the following places, // It looks for credentials in the following places,
// preferring the first location found: // preferring the first location found:
// //
// 1. A JSON file whose path is specified by the // 1. A JSON file whose path is specified by the
// GOOGLE_APPLICATION_CREDENTIALS environment variable. // GOOGLE_APPLICATION_CREDENTIALS environment variable.
// For workload identity federation, refer to // For workload identity federation, refer to
// https://cloud.google.com/iam/docs/how-to#using-workload-identity-federation on // https://cloud.google.com/iam/docs/how-to#using-workload-identity-federation on
// how to generate the JSON configuration file for on-prem/non-Google cloud // how to generate the JSON configuration file for on-prem/non-Google cloud
// platforms. // platforms.
// 2. A JSON file in a location known to the gcloud command-line tool. // 2. A JSON file in a location known to the gcloud command-line tool.
// On Windows, this is %APPDATA%/gcloud/application_default_credentials.json. // On Windows, this is %APPDATA%/gcloud/application_default_credentials.json.
// On other systems, $HOME/.config/gcloud/application_default_credentials.json. // On other systems, $HOME/.config/gcloud/application_default_credentials.json.
// 3. On Google App Engine standard first generation runtimes (<= Go 1.9) it uses // 3. On Google App Engine standard first generation runtimes (<= Go 1.9) it uses
// the appengine.AccessToken function. // the appengine.AccessToken function.
// 4. On Google Compute Engine, Google App Engine standard second generation runtimes // 4. On Google Compute Engine, Google App Engine standard second generation runtimes
// (>= Go 1.11), and Google App Engine flexible environment, it fetches // (>= Go 1.11), and Google App Engine flexible environment, it fetches
// credentials from the metadata server. // credentials from the metadata server.
func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsParams) (*Credentials, error) { func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsParams) (*Credentials, error) {
// Make defensive copy of the slices in params. // Make defensive copy of the slices in params.
params = params.deepCopy() params = params.deepCopy()

View File

@ -4,9 +4,9 @@
// Package google provides support for making OAuth2 authorized and authenticated // Package google provides support for making OAuth2 authorized and authenticated
// HTTP requests to Google APIs. It supports the Web server flow, client-side // HTTP requests to Google APIs. It supports the Web server flow, client-side
// credentials, service accounts, Google Compute Engine service accounts, Google // credentials, service accounts, Google Compute Engine service accounts,
// App Engine service accounts and workload identity federation from non-Google // Google App Engine service accounts and workload identity federation
// cloud platforms. // from non-Google cloud platforms.
// //
// A brief overview of the package follows. For more information, please read // A brief overview of the package follows. For more information, please read
// https://developers.google.com/accounts/docs/OAuth2 // https://developers.google.com/accounts/docs/OAuth2
@ -15,14 +15,14 @@
// For more information on using workload identity federation, refer to // For more information on using workload identity federation, refer to
// https://cloud.google.com/iam/docs/how-to#using-workload-identity-federation. // https://cloud.google.com/iam/docs/how-to#using-workload-identity-federation.
// //
// OAuth2 Configs // # OAuth2 Configs
// //
// Two functions in this package return golang.org/x/oauth2.Config values from Google credential // Two functions in this package return golang.org/x/oauth2.Config values from Google credential
// data. Google supports two JSON formats for OAuth2 credentials: one is handled by ConfigFromJSON, // data. Google supports two JSON formats for OAuth2 credentials: one is handled by ConfigFromJSON,
// the other by JWTConfigFromJSON. The returned Config can be used to obtain a TokenSource or // the other by JWTConfigFromJSON. The returned Config can be used to obtain a TokenSource or
// create an http.Client. // create an http.Client.
// //
// Workload Identity Federation // # Workload Identity Federation
// //
// Using workload identity federation, your application can access Google Cloud // Using workload identity federation, your application can access Google Cloud
// resources from Amazon Web Services (AWS), Microsoft Azure or any identity // resources from Amazon Web Services (AWS), Microsoft Azure or any identity
@ -36,9 +36,9 @@
// Follow the detailed instructions on how to configure Workload Identity Federation // Follow the detailed instructions on how to configure Workload Identity Federation
// in various platforms: // in various platforms:
// //
// Amazon Web Services (AWS): https://cloud.google.com/iam/docs/access-resources-aws // Amazon Web Services (AWS): https://cloud.google.com/iam/docs/access-resources-aws
// Microsoft Azure: https://cloud.google.com/iam/docs/access-resources-azure // Microsoft Azure: https://cloud.google.com/iam/docs/access-resources-azure
// OIDC identity provider: https://cloud.google.com/iam/docs/access-resources-oidc // OIDC identity provider: https://cloud.google.com/iam/docs/access-resources-oidc
// //
// For OIDC providers, the library can retrieve OIDC tokens either from a // For OIDC providers, the library can retrieve OIDC tokens either from a
// local file location (file-sourced credentials) or from a local server // local file location (file-sourced credentials) or from a local server
@ -51,8 +51,7 @@
// return the OIDC token. The response can be in plain text or JSON. // return the OIDC token. The response can be in plain text or JSON.
// Additional required request headers can also be specified. // Additional required request headers can also be specified.
// //
// // # Credentials
// Credentials
// //
// The Credentials type represents Google credentials, including Application Default // The Credentials type represents Google credentials, including Application Default
// Credentials. // Credentials.

View File

@ -0,0 +1,211 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package downscope implements the ability to downscope, or restrict, the
Identity and Access Management permissions that a short-lived Token
can use. Please note that only Google Cloud Storage supports this feature.
For complete documentation, see https://cloud.google.com/iam/docs/downscoping-short-lived-credentials
To downscope permissions of a source credential, you need to define
a Credential Access Boundary. Said Boundary specifies which resources
the newly created credential can access, an upper bound on the permissions
it has over those resources, and optionally attribute-based conditional
access to the aforementioned resources. For more information on IAM
Conditions, see https://cloud.google.com/iam/docs/conditions-overview.
This functionality can be used to provide a third party with
limited access to and permissions on resources held by the owner of the root
credential or internally in conjunction with the principle of least privilege
to ensure that internal services only hold the minimum necessary privileges
for their function.
For example, a token broker can be set up on a server in a private network.
Various workloads (token consumers) in the same network will send authenticated
requests to that broker for downscoped tokens to access or modify specific google
cloud storage buckets. See the NewTokenSource example for an example of how a
token broker would use this package.
The broker will use the functionality in this package to generate a downscoped
token with the requested configuration, and then pass it back to the token
consumer. These downscoped access tokens can then be used to access Google
Storage resources. For instance, you can create a NewClient from the
"cloud.google.com/go/storage" package and pass in option.WithTokenSource(yourTokenSource))
*/
package downscope
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"golang.org/x/oauth2"
)
var (
identityBindingEndpoint = "https://sts.googleapis.com/v1/token"
)
type accessBoundary struct {
AccessBoundaryRules []AccessBoundaryRule `json:"accessBoundaryRules"`
}
// An AvailabilityCondition restricts access to a given Resource.
type AvailabilityCondition struct {
// An Expression specifies the Cloud Storage objects where
// permissions are available. For further documentation, see
// https://cloud.google.com/iam/docs/conditions-overview
Expression string `json:"expression"`
// Title is short string that identifies the purpose of the condition. Optional.
Title string `json:"title,omitempty"`
// Description details about the purpose of the condition. Optional.
Description string `json:"description,omitempty"`
}
// An AccessBoundaryRule Sets the permissions (and optionally conditions)
// that the new token has on given resource.
type AccessBoundaryRule struct {
// AvailableResource is the full resource name of the Cloud Storage bucket that the rule applies to.
// Use the format //storage.googleapis.com/projects/_/buckets/bucket-name.
AvailableResource string `json:"availableResource"`
// AvailablePermissions is a list that defines the upper bound on the available permissions
// for the resource. Each value is the identifier for an IAM predefined role or custom role,
// with the prefix inRole:. For example: inRole:roles/storage.objectViewer.
// Only the permissions in these roles will be available.
AvailablePermissions []string `json:"availablePermissions"`
// An Condition restricts the availability of permissions
// to specific Cloud Storage objects. Optional.
//
// A Condition can be used to make permissions available for specific objects,
// rather than all objects in a Cloud Storage bucket.
Condition *AvailabilityCondition `json:"availabilityCondition,omitempty"`
}
type downscopedTokenResponse struct {
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
// DownscopingConfig specifies the information necessary to request a downscoped token.
type DownscopingConfig struct {
// RootSource is the TokenSource used to create the downscoped token.
// The downscoped token therefore has some subset of the accesses of
// the original RootSource.
RootSource oauth2.TokenSource
// Rules defines the accesses held by the new
// downscoped Token. One or more AccessBoundaryRules are required to
// define permissions for the new downscoped token. Each one defines an
// access (or set of accesses) that the new token has to a given resource.
// There can be a maximum of 10 AccessBoundaryRules.
Rules []AccessBoundaryRule
}
// A downscopingTokenSource is used to retrieve a downscoped token with restricted
// permissions compared to the root Token that is used to generate it.
type downscopingTokenSource struct {
// ctx is the context used to query the API to retrieve a downscoped Token.
ctx context.Context
// config holds the information necessary to generate a downscoped Token.
config DownscopingConfig
}
// NewTokenSource returns a configured downscopingTokenSource.
func NewTokenSource(ctx context.Context, conf DownscopingConfig) (oauth2.TokenSource, error) {
if conf.RootSource == nil {
return nil, fmt.Errorf("downscope: rootSource cannot be nil")
}
if len(conf.Rules) == 0 {
return nil, fmt.Errorf("downscope: length of AccessBoundaryRules must be at least 1")
}
if len(conf.Rules) > 10 {
return nil, fmt.Errorf("downscope: length of AccessBoundaryRules may not be greater than 10")
}
for _, val := range conf.Rules {
if val.AvailableResource == "" {
return nil, fmt.Errorf("downscope: all rules must have a nonempty AvailableResource: %+v", val)
}
if len(val.AvailablePermissions) == 0 {
return nil, fmt.Errorf("downscope: all rules must provide at least one permission: %+v", val)
}
}
return downscopingTokenSource{ctx: ctx, config: conf}, nil
}
// Token() uses a downscopingTokenSource to generate an oauth2 Token.
// Do note that the returned TokenSource is an oauth2.StaticTokenSource. If you wish
// to refresh this token automatically, then initialize a locally defined
// TokenSource struct with the Token held by the StaticTokenSource and wrap
// that TokenSource in an oauth2.ReuseTokenSource.
func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
downscopedOptions := struct {
Boundary accessBoundary `json:"accessBoundary"`
}{
Boundary: accessBoundary{
AccessBoundaryRules: dts.config.Rules,
},
}
tok, err := dts.config.RootSource.Token()
if err != nil {
return nil, fmt.Errorf("downscope: unable to obtain root token: %v", err)
}
b, err := json.Marshal(downscopedOptions)
if err != nil {
return nil, fmt.Errorf("downscope: unable to marshal AccessBoundary payload %v", err)
}
form := url.Values{}
form.Add("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
form.Add("subject_token_type", "urn:ietf:params:oauth:token-type:access_token")
form.Add("requested_token_type", "urn:ietf:params:oauth:token-type:access_token")
form.Add("subject_token", tok.AccessToken)
form.Add("options", string(b))
myClient := oauth2.NewClient(dts.ctx, nil)
resp, err := myClient.PostForm(identityBindingEndpoint, form)
if err != nil {
return nil, fmt.Errorf("unable to generate POST Request %v", err)
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("downscope: unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("downscope: unable to exchange token; %v. Server responded: %s", resp.StatusCode, respBody)
}
var tresp downscopedTokenResponse
err = json.Unmarshal(respBody, &tresp)
if err != nil {
return nil, fmt.Errorf("downscope: unable to unmarshal response body: %v", err)
}
// an exchanged token that is derived from a service account (2LO) has an expired_in value
// a token derived from a users token (3LO) does not.
// The following code uses the time remaining on rootToken for a user as the value for the
// derived token's lifetime
var expiryTime time.Time
if tresp.ExpiresIn > 0 {
expiryTime = time.Now().Add(time.Duration(tresp.ExpiresIn) * time.Second)
} else {
expiryTime = tok.Expiry
}
newToken := &oauth2.Token{
AccessToken: tresp.AccessToken,
TokenType: tresp.TokenType,
Expiry: expiryTime,
}
return newToken, nil
}

View File

@ -0,0 +1,55 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package downscope
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"golang.org/x/oauth2"
)
var (
standardReqBody = "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&options=%7B%22accessBoundary%22%3A%7B%22accessBoundaryRules%22%3A%5B%7B%22availableResource%22%3A%22test1%22%2C%22availablePermissions%22%3A%5B%22Perm1%22%2C%22Perm2%22%5D%7D%5D%7D%7D&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&subject_token=Mellon&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token"
standardRespBody = `{"access_token":"Open Sesame","expires_in":432,"issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer"}`
)
func Test_DownscopedTokenSource(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Unexpected request method, %v is found", r.Method)
}
if r.URL.String() != "/" {
t.Errorf("Unexpected request URL, %v is found", r.URL)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
}
if got, want := string(body), standardReqBody; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v,", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(standardRespBody))
}))
new := []AccessBoundaryRule{
{
AvailableResource: "test1",
AvailablePermissions: []string{"Perm1", "Perm2"},
},
}
myTok := oauth2.Token{AccessToken: "Mellon"}
tmpSrc := oauth2.StaticTokenSource(&myTok)
dts := downscopingTokenSource{context.Background(), DownscopingConfig{tmpSrc, new}}
identityBindingEndpoint = ts.URL
_, err := dts.Token()
if err != nil {
t.Fatalf("NewDownscopedTokenSource failed with error: %v", err)
}
}

View File

@ -0,0 +1,57 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package downscope_test
import (
"context"
"fmt"
"golang.org/x/oauth2/google"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/downscope"
)
func ExampleNewTokenSource() {
// This shows how to generate a downscoped token. This code would be run on the
// token broker, which holds the root token used to generate the downscoped token.
ctx := context.Background()
// Initializes an accessBoundary with one Rule which restricts the downscoped
// token to only be able to access the bucket "foo" and only grants it the
// permission "storage.objectViewer".
accessBoundary := []downscope.AccessBoundaryRule{
{
AvailableResource: "//storage.googleapis.com/projects/_/buckets/foo",
AvailablePermissions: []string{"inRole:roles/storage.objectViewer"},
},
}
var rootSource oauth2.TokenSource
// This Source can be initialized in multiple ways; the following example uses
// Application Default Credentials.
rootSource, err := google.DefaultTokenSource(ctx, "https://www.googleapis.com/auth/cloud-platform")
dts, err := downscope.NewTokenSource(ctx, downscope.DownscopingConfig{RootSource: rootSource, Rules: accessBoundary})
if err != nil {
fmt.Printf("failed to generate downscoped token source: %v", err)
return
}
tok, err := dts.Token()
if err != nil {
fmt.Printf("failed to generate token: %v", err)
return
}
_ = tok
// You can now pass tok to a token consumer however you wish, such as exposing
// a REST API and sending it over HTTP.
// You can instead use the token held in dts to make
// Google Cloud Storage calls, as follows:
// storageClient, err := storage.NewClient(ctx, option.WithTokenSource(dts))
}

View File

@ -92,9 +92,10 @@ func JWTConfigFromJSON(jsonKey []byte, scope ...string) (*jwt.Config, error) {
// JSON key file types. // JSON key file types.
const ( const (
serviceAccountKey = "service_account" serviceAccountKey = "service_account"
userCredentialsKey = "authorized_user" userCredentialsKey = "authorized_user"
externalAccountKey = "external_account" externalAccountKey = "external_account"
impersonatedServiceAccount = "impersonated_service_account"
) )
// credentialsFile is the unmarshalled representation of a credentials file. // credentialsFile is the unmarshalled representation of a credentials file.
@ -121,8 +122,13 @@ type credentialsFile struct {
TokenURLExternal string `json:"token_url"` TokenURLExternal string `json:"token_url"`
TokenInfoURL string `json:"token_info_url"` TokenInfoURL string `json:"token_info_url"`
ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"` ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"`
Delegates []string `json:"delegates"`
CredentialSource externalaccount.CredentialSource `json:"credential_source"` CredentialSource externalaccount.CredentialSource `json:"credential_source"`
QuotaProjectID string `json:"quota_project_id"` QuotaProjectID string `json:"quota_project_id"`
WorkforcePoolUserProject string `json:"workforce_pool_user_project"`
// Service account impersonation
SourceCredentials *credentialsFile `json:"source_credentials"`
} }
func (f *credentialsFile) jwtConfig(scopes []string, subject string) *jwt.Config { func (f *credentialsFile) jwtConfig(scopes []string, subject string) *jwt.Config {
@ -133,6 +139,7 @@ func (f *credentialsFile) jwtConfig(scopes []string, subject string) *jwt.Config
Scopes: scopes, Scopes: scopes,
TokenURL: f.TokenURL, TokenURL: f.TokenURL,
Subject: subject, // This is the user email to impersonate Subject: subject, // This is the user email to impersonate
Audience: f.Audience,
} }
if cfg.TokenURL == "" { if cfg.TokenURL == "" {
cfg.TokenURL = JWTTokenURL cfg.TokenURL = JWTTokenURL
@ -176,8 +183,26 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
CredentialSource: f.CredentialSource, CredentialSource: f.CredentialSource,
QuotaProjectID: f.QuotaProjectID, QuotaProjectID: f.QuotaProjectID,
Scopes: params.Scopes, Scopes: params.Scopes,
WorkforcePoolUserProject: f.WorkforcePoolUserProject,
} }
return cfg.TokenSource(ctx), nil return cfg.TokenSource(ctx)
case impersonatedServiceAccount:
if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil {
return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials")
}
ts, err := f.SourceCredentials.tokenSource(ctx, params)
if err != nil {
return nil, err
}
imp := externalaccount.ImpersonateTokenSource{
Ctx: ctx,
URL: f.ServiceAccountImpersonationURL,
Scopes: params.Scopes,
Ts: ts,
Delegates: f.Delegates,
}
return oauth2.ReuseTokenSource(nil, imp), nil
case "": case "":
return nil, errors.New("missing 'type' field in credentials") return nil, errors.New("missing 'type' field in credentials")
default: default:

View File

@ -37,7 +37,8 @@ var jwtJSONKey = []byte(`{
"client_email": "gopher@developer.gserviceaccount.com", "client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com", "client_id": "gopher.apps.googleusercontent.com",
"token_uri": "https://accounts.google.com/o/gophers/token", "token_uri": "https://accounts.google.com/o/gophers/token",
"type": "service_account" "type": "service_account",
"audience": "https://testservice.googleapis.com/"
}`) }`)
var jwtJSONKeyNoTokenURL = []byte(`{ var jwtJSONKeyNoTokenURL = []byte(`{
@ -48,6 +49,15 @@ var jwtJSONKeyNoTokenURL = []byte(`{
"type": "service_account" "type": "service_account"
}`) }`)
var jwtJSONKeyNoAudience = []byte(`{
"private_key_id": "268f54e43a1af97cfc71731688434f45aca15c8b",
"private_key": "super secret key",
"client_email": "gopher@developer.gserviceaccount.com",
"client_id": "gopher.apps.googleusercontent.com",
"token_uri": "https://accounts.google.com/o/gophers/token",
"type": "service_account"
}`)
func TestConfigFromJSON(t *testing.T) { func TestConfigFromJSON(t *testing.T) {
conf, err := ConfigFromJSON(webJSONKey, "scope1", "scope2") conf, err := ConfigFromJSON(webJSONKey, "scope1", "scope2")
if err != nil { if err != nil {
@ -103,6 +113,9 @@ func TestJWTConfigFromJSON(t *testing.T) {
if got, want := conf.TokenURL, "https://accounts.google.com/o/gophers/token"; got != want { if got, want := conf.TokenURL, "https://accounts.google.com/o/gophers/token"; got != want {
t.Errorf("TokenURL = %q; want %q", got, want) t.Errorf("TokenURL = %q; want %q", got, want)
} }
if got, want := conf.Audience, "https://testservice.googleapis.com/"; got != want {
t.Errorf("Audience = %q; want %q", got, want)
}
} }
func TestJWTConfigFromJSONNoTokenURL(t *testing.T) { func TestJWTConfigFromJSONNoTokenURL(t *testing.T) {
@ -114,3 +127,13 @@ func TestJWTConfigFromJSONNoTokenURL(t *testing.T) {
t.Errorf("TokenURL = %q; want %q", got, want) t.Errorf("TokenURL = %q; want %q", got, want)
} }
} }
func TestJWTConfigFromJSONNoAudience(t *testing.T) {
conf, err := JWTConfigFromJSON(jwtJSONKeyNoAudience, "scope1", "scope2")
if err != nil {
t.Fatal(err)
}
if got, want := conf.Audience, ""; got != want {
t.Errorf("Audience = %q; want %q", got, want)
}
}

View File

@ -13,7 +13,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/oauth2"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -23,6 +22,8 @@ import (
"sort" "sort"
"strings" "strings"
"time" "time"
"golang.org/x/oauth2"
) )
type awsSecurityCredentials struct { type awsSecurityCredentials struct {
@ -51,6 +52,13 @@ const (
// The AWS authorization header name for the security session token if available. // The AWS authorization header name for the security session token if available.
awsSecurityTokenHeader = "x-amz-security-token" awsSecurityTokenHeader = "x-amz-security-token"
// The name of the header containing the session token for metadata endpoint calls
awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token"
awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
awsIMDSv2SessionTtl = "300"
// The AWS authorization header name for the auto-generated date. // The AWS authorization header name for the auto-generated date.
awsDateHeader = "x-amz-date" awsDateHeader = "x-amz-date"
@ -240,6 +248,7 @@ type awsCredentialSource struct {
RegionURL string RegionURL string
RegionalCredVerificationURL string RegionalCredVerificationURL string
CredVerificationURL string CredVerificationURL string
IMDSv2SessionTokenURL string
TargetResource string TargetResource string
requestSigner *awsRequestSigner requestSigner *awsRequestSigner
region string region string
@ -267,12 +276,22 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro
func (cs awsCredentialSource) subjectToken() (string, error) { func (cs awsCredentialSource) subjectToken() (string, error) {
if cs.requestSigner == nil { if cs.requestSigner == nil {
awsSecurityCredentials, err := cs.getSecurityCredentials() awsSessionToken, err := cs.getAWSSessionToken()
if err != nil { if err != nil {
return "", err return "", err
} }
if cs.region, err = cs.getRegion(); err != nil { headers := make(map[string]string)
if awsSessionToken != "" {
headers[awsIMDSv2SessionTokenHeader] = awsSessionToken
}
awsSecurityCredentials, err := cs.getSecurityCredentials(headers)
if err != nil {
return "", err
}
if cs.region, err = cs.getRegion(headers); err != nil {
return "", err return "", err
} }
@ -339,10 +358,43 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
return url.QueryEscape(string(result)), nil return url.QueryEscape(string(result)), nil
} }
func (cs *awsCredentialSource) getRegion() (string, error) { func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
if cs.IMDSv2SessionTokenURL == "" {
return "", nil
}
req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil)
if err != nil {
return "", err
}
req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl)
resp, err := cs.doRequest(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody))
}
return string(respBody), nil
}
func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) {
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
return envAwsRegion, nil return envAwsRegion, nil
} }
if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" {
return envAwsRegion, nil
}
if cs.RegionURL == "" { if cs.RegionURL == "" {
return "", errors.New("oauth2/google: unable to determine AWS region") return "", errors.New("oauth2/google: unable to determine AWS region")
@ -353,6 +405,10 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
return "", err return "", err
} }
for name, value := range headers {
req.Header.Add(name, value)
}
resp, err := cs.doRequest(req) resp, err := cs.doRequest(req)
if err != nil { if err != nil {
return "", err return "", err
@ -377,7 +433,7 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
return string(respBody[:respBodyEnd]), nil return string(respBody[:respBodyEnd]), nil
} }
func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) { func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) {
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
return awsSecurityCredentials{ return awsSecurityCredentials{
@ -388,12 +444,12 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede
} }
} }
roleName, err := cs.getMetadataRoleName() roleName, err := cs.getMetadataRoleName(headers)
if err != nil { if err != nil {
return return
} }
credentials, err := cs.getMetadataSecurityCredentials(roleName) credentials, err := cs.getMetadataSecurityCredentials(roleName, headers)
if err != nil { if err != nil {
return return
} }
@ -409,7 +465,7 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede
return credentials, nil return credentials, nil
} }
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) { func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) {
var result awsSecurityCredentials var result awsSecurityCredentials
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
@ -418,6 +474,10 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
for name, value := range headers {
req.Header.Add(name, value)
}
resp, err := cs.doRequest(req) resp, err := cs.doRequest(req)
if err != nil { if err != nil {
return result, err return result, err
@ -437,7 +497,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (
return result, err return result, err
} }
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) { func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) {
if cs.CredVerificationURL == "" { if cs.CredVerificationURL == "" {
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint") return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
} }
@ -447,6 +507,10 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
return "", err return "", err
} }
for name, value := range headers {
req.Header.Add(name, value)
}
resp, err := cs.doRequest(req) resp, err := cs.doRequest(req)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -20,6 +20,8 @@ import (
var defaultTime = time.Date(2011, 9, 9, 23, 36, 0, 0, time.UTC) var defaultTime = time.Date(2011, 9, 9, 23, 36, 0, 0, time.UTC)
var secondDefaultTime = time.Date(2020, 8, 11, 6, 55, 22, 0, time.UTC) var secondDefaultTime = time.Date(2020, 8, 11, 6, 55, 22, 0, time.UTC)
type validateHeaders func(r *http.Request)
func setTime(testTime time.Time) func() time.Time { func setTime(testTime time.Time) func() time.Time {
return func() time.Time { return func() time.Time {
return testTime return testTime
@ -28,8 +30,7 @@ func setTime(testTime time.Time) func() time.Time {
func setEnvironment(env map[string]string) func(string) string { func setEnvironment(env map[string]string) func(string) string {
return func(key string) string { return func(key string) string {
value, _ := env[key] return env[key]
return value
} }
} }
@ -83,7 +84,7 @@ func testRequestSigner(t *testing.T, rs *awsRequestSigner, input, expectedOutput
} }
} }
func TestAwsV4Signature_GetRequest(t *testing.T) { func TestAWSv4Signature_GetRequest(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com", nil) input, _ := http.NewRequest("GET", "https://host.foo.com", nil)
setDefaultTime(input) setDefaultTime(input)
@ -101,7 +102,7 @@ func TestAwsV4Signature_GetRequest(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) { func TestAWSv4Signature_GetRequestWithRelativePath(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil) input, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil)
setDefaultTime(input) setDefaultTime(input)
@ -119,7 +120,7 @@ func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) { func TestAWSv4Signature_GetRequestWithDotPath(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/./", nil) input, _ := http.NewRequest("GET", "https://host.foo.com/./", nil)
setDefaultTime(input) setDefaultTime(input)
@ -137,7 +138,7 @@ func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) { func TestAWSv4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil) input, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil)
setDefaultTime(input) setDefaultTime(input)
@ -155,7 +156,7 @@ func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) { func TestAWSv4Signature_GetRequestWithUtf8Path(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil) input, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil)
setDefaultTime(input) setDefaultTime(input)
@ -173,7 +174,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) { func TestAWSv4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil) input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil)
setDefaultTime(input) setDefaultTime(input)
@ -191,7 +192,7 @@ func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) { func TestAWSv4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil) input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil)
setDefaultTime(input) setDefaultTime(input)
@ -209,7 +210,7 @@ func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) { func TestAWSv4Signature_GetRequestWithUtf8Query(t *testing.T) {
input, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil) input, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil)
setDefaultTime(input) setDefaultTime(input)
@ -227,7 +228,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_PostRequest(t *testing.T) { func TestAWSv4Signature_PostRequest(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
setDefaultTime(input) setDefaultTime(input)
input.Header.Add("ZOO", "zoobar") input.Header.Add("ZOO", "zoobar")
@ -247,7 +248,7 @@ func TestAwsV4Signature_PostRequest(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) { func TestAWSv4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
setDefaultTime(input) setDefaultTime(input)
input.Header.Add("zoo", "ZOOBAR") input.Header.Add("zoo", "ZOOBAR")
@ -267,7 +268,7 @@ func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_PostRequestPhfft(t *testing.T) { func TestAWSv4Signature_PostRequestPhfft(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) input, _ := http.NewRequest("POST", "https://host.foo.com/", nil)
setDefaultTime(input) setDefaultTime(input)
input.Header.Add("p", "phfft") input.Header.Add("p", "phfft")
@ -287,7 +288,7 @@ func TestAwsV4Signature_PostRequestPhfft(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_PostRequestWithBody(t *testing.T) { func TestAWSv4Signature_PostRequestWithBody(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/", strings.NewReader("foo=bar")) input, _ := http.NewRequest("POST", "https://host.foo.com/", strings.NewReader("foo=bar"))
setDefaultTime(input) setDefaultTime(input)
input.Header.Add("Content-Type", "application/x-www-form-urlencoded") input.Header.Add("Content-Type", "application/x-www-form-urlencoded")
@ -307,7 +308,7 @@ func TestAwsV4Signature_PostRequestWithBody(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) { func TestAWSv4Signature_PostRequestWithQueryString(t *testing.T) {
input, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil) input, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil)
setDefaultTime(input) setDefaultTime(input)
@ -325,7 +326,7 @@ func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) {
testRequestSigner(t, defaultRequestSigner, input, output) testRequestSigner(t, defaultRequestSigner, input, output)
} }
func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) { func TestAWSv4Signature_GetRequestWithSecurityToken(t *testing.T) {
input, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil) input, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil) output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil)
@ -343,7 +344,7 @@ func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) {
testRequestSigner(t, requestSignerWithToken, input, output) testRequestSigner(t, requestSignerWithToken, input, output)
} }
func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) { func TestAWSv4Signature_PostRequestWithSecurityToken(t *testing.T) {
input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil) input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil) output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil)
@ -361,7 +362,7 @@ func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) {
testRequestSigner(t, requestSignerWithToken, input, output) testRequestSigner(t, requestSignerWithToken, input, output)
} }
func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) { func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) {
requestParams := "{\"KeySchema\":[{\"KeyType\":\"HASH\",\"AttributeName\":\"Id\"}],\"TableName\":\"TestTable\",\"AttributeDefinitions\":[{\"AttributeName\":\"Id\",\"AttributeType\":\"S\"}],\"ProvisionedThroughput\":{\"WriteCapacityUnits\":5,\"ReadCapacityUnits\":5}}" requestParams := "{\"KeySchema\":[{\"KeyType\":\"HASH\",\"AttributeName\":\"Id\"}],\"TableName\":\"TestTable\",\"AttributeDefinitions\":[{\"AttributeName\":\"Id\",\"AttributeType\":\"S\"}],\"ProvisionedThroughput\":{\"WriteCapacityUnits\":5,\"ReadCapacityUnits\":5}}"
input, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams)) input, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams))
input.Header.Add("Content-Type", "application/x-amz-json-1.0") input.Header.Add("Content-Type", "application/x-amz-json-1.0")
@ -384,7 +385,7 @@ func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test
testRequestSigner(t, requestSignerWithToken, input, output) testRequestSigner(t, requestSignerWithToken, input, output)
} }
func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) { func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
var requestSigner = &awsRequestSigner{ var requestSigner = &awsRequestSigner{
RegionName: "us-east-2", RegionName: "us-east-2",
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: awsSecurityCredentials{
@ -414,30 +415,40 @@ type testAwsServer struct {
securityCredentialURL string securityCredentialURL string
regionURL string regionURL string
regionalCredVerificationURL string regionalCredVerificationURL string
imdsv2SessionTokenUrl string
Credentials map[string]string Credentials map[string]string
WriteRolename func(http.ResponseWriter) WriteRolename func(http.ResponseWriter, *http.Request)
WriteSecurityCredentials func(http.ResponseWriter) WriteSecurityCredentials func(http.ResponseWriter, *http.Request)
WriteRegion func(http.ResponseWriter) WriteRegion func(http.ResponseWriter, *http.Request)
WriteIMDSv2SessionToken func(http.ResponseWriter, *http.Request)
} }
func createAwsTestServer(url, regionURL, regionalCredVerificationURL, rolename, region string, credentials map[string]string) *testAwsServer { func createAwsTestServer(url, regionURL, regionalCredVerificationURL, imdsv2SessionTokenUrl string, rolename, region string, credentials map[string]string, imdsv2SessionToken string, validateHeaders validateHeaders) *testAwsServer {
server := &testAwsServer{ server := &testAwsServer{
url: url, url: url,
securityCredentialURL: fmt.Sprintf("%s/%s", url, rolename), securityCredentialURL: fmt.Sprintf("%s/%s", url, rolename),
regionURL: regionURL, regionURL: regionURL,
regionalCredVerificationURL: regionalCredVerificationURL, regionalCredVerificationURL: regionalCredVerificationURL,
imdsv2SessionTokenUrl: imdsv2SessionTokenUrl,
Credentials: credentials, Credentials: credentials,
WriteRolename: func(w http.ResponseWriter) { WriteRolename: func(w http.ResponseWriter, r *http.Request) {
validateHeaders(r)
w.Write([]byte(rolename)) w.Write([]byte(rolename))
}, },
WriteRegion: func(w http.ResponseWriter) { WriteRegion: func(w http.ResponseWriter, r *http.Request) {
validateHeaders(r)
w.Write([]byte(region)) w.Write([]byte(region))
}, },
WriteIMDSv2SessionToken: func(w http.ResponseWriter, r *http.Request) {
validateHeaders(r)
w.Write([]byte(imdsv2SessionToken))
},
} }
server.WriteSecurityCredentials = func(w http.ResponseWriter) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
validateHeaders(r)
jsonCredentials, _ := json.Marshal(server.Credentials) jsonCredentials, _ := json.Marshal(server.Credentials)
w.Write(jsonCredentials) w.Write(jsonCredentials)
} }
@ -450,6 +461,7 @@ func createDefaultAwsTestServer() *testAwsServer {
"/latest/meta-data/iam/security-credentials", "/latest/meta-data/iam/security-credentials",
"/latest/meta-data/placement/availability-zone", "/latest/meta-data/placement/availability-zone",
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"",
"gcp-aws-role", "gcp-aws-role",
"us-east-2b", "us-east-2b",
map[string]string{ map[string]string{
@ -457,31 +469,38 @@ func createDefaultAwsTestServer() *testAwsServer {
"AccessKeyId": accessKeyID, "AccessKeyId": accessKeyID,
"Token": securityToken, "Token": securityToken,
}, },
"",
noHeaderValidation,
) )
} }
func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch p := r.URL.Path; p { switch p := r.URL.Path; p {
case server.url: case server.url:
server.WriteRolename(w) server.WriteRolename(w, r)
case server.securityCredentialURL: case server.securityCredentialURL:
server.WriteSecurityCredentials(w) server.WriteSecurityCredentials(w, r)
case server.regionURL: case server.regionURL:
server.WriteRegion(w) server.WriteRegion(w, r)
case server.imdsv2SessionTokenUrl:
server.WriteIMDSv2SessionToken(w, r)
} }
} }
func notFound(w http.ResponseWriter) { func notFound(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404) w.WriteHeader(404)
w.Write([]byte("Not Found")) w.Write([]byte("Not Found"))
} }
func noHeaderValidation(r *http.Request) {}
func (server *testAwsServer) getCredentialSource(url string) CredentialSource { func (server *testAwsServer) getCredentialSource(url string) CredentialSource {
return CredentialSource{ return CredentialSource{
EnvironmentID: "aws1", EnvironmentID: "aws1",
URL: url + server.url, URL: url + server.url,
RegionURL: url + server.regionURL, RegionURL: url + server.regionURL,
RegionalCredVerificationURL: server.regionalCredVerificationURL, RegionalCredVerificationURL: server.regionalCredVerificationURL,
IMDSv2SessionTokenURL: url + server.imdsv2SessionTokenUrl,
} }
} }
@ -531,7 +550,7 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
return neturl.QueryEscape(string(str)) return neturl.QueryEscape(string(str))
} }
func TestAwsCredential_BasicRequest(t *testing.T) { func TestAWSCredential_BasicRequest(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
@ -541,6 +560,9 @@ func TestAwsCredential_BasicRequest(t *testing.T) {
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -561,11 +583,76 @@ func TestAwsCredential_BasicRequest(t *testing.T) {
) )
if got, want := out, expected; !reflect.DeepEqual(got, want) { if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
} }
} }
func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) { func TestAWSCredential_IMDSv2(t *testing.T) {
validateSessionTokenHeaders := func(r *http.Request) {
if r.URL.Path == "/latest/api/token" {
headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader)
if headerValue != awsIMDSv2SessionTtl {
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl)
}
} else {
headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader)
if headerValue != "sessiontoken" {
t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken")
}
}
}
server := createAwsTestServer(
"/latest/meta-data/iam/security-credentials",
"/latest/meta-data/placement/availability-zone",
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"/latest/api/token",
"gcp-aws-role",
"us-east-2b",
map[string]string{
"SecretAccessKey": secretAccessKey,
"AccessKeyId": accessKeyID,
"Token": securityToken,
},
"sessiontoken",
validateSessionTokenHeaders,
)
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyID,
secretAccessKey,
securityToken,
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
delete(server.Credentials, "Token") delete(server.Credentials, "Token")
@ -576,6 +663,9 @@ func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -596,11 +686,11 @@ func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
) )
if got, want := out, expected; !reflect.DeepEqual(got, want) { if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
} }
} }
func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
@ -614,6 +704,9 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -634,11 +727,92 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
) )
if got, want := out, expected; !reflect.DeepEqual(got, want) { if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
} }
} }
func TestAwsCredential_RequestWithBadVersion(t *testing.T) { func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_DEFAULT_REGION": "us-west-1",
})
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-west-1",
"AKIDEXAMPLE",
"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"",
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1",
"AWS_DEFAULT_REGION": "us-east-1",
})
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime)
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-west-1",
"AKIDEXAMPLE",
"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"",
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = \n%q\n want \n%q", got, want)
}
}
func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
@ -659,7 +833,7 @@ func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
} }
} }
func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) { func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
@ -686,7 +860,7 @@ func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) {
} }
} }
func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) { func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
server.WriteRegion = notFound server.WriteRegion = notFound
@ -713,10 +887,10 @@ func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) {
} }
} }
func TestAwsCredential_RequestWithMissingCredential(t *testing.T) { func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
server.WriteSecurityCredentials = func(w http.ResponseWriter) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}")) w.Write([]byte("{}"))
} }
@ -742,10 +916,10 @@ func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
} }
} }
func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) { func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
server.WriteSecurityCredentials = func(w http.ResponseWriter) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
} }
@ -771,7 +945,7 @@ func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
} }
} }
func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
@ -798,7 +972,7 @@ func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) {
} }
} }
func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
server.WriteRolename = notFound server.WriteRolename = notFound
@ -825,7 +999,7 @@ func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) {
} }
} }
func TestAwsCredential_RequestWithBadFinalCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
server.WriteSecurityCredentials = notFound server.WriteSecurityCredentials = notFound

View File

@ -7,10 +7,14 @@ package externalaccount
import ( import (
"context" "context"
"fmt" "fmt"
"golang.org/x/oauth2"
"net/http" "net/http"
"net/url"
"regexp"
"strconv" "strconv"
"strings"
"time" "time"
"golang.org/x/oauth2"
) )
// now aliases time.Now for testing // now aliases time.Now for testing
@ -20,36 +24,129 @@ var now = func() time.Time {
// Config stores the configuration for fetching tokens with external credentials. // Config stores the configuration for fetching tokens with external credentials.
type Config struct { type Config struct {
Audience string // Audience is the Secure Token Service (STS) audience which contains the resource name for the workload
SubjectTokenType string // identity pool or the workforce pool and the provider identifier in that pool.
TokenURL string Audience string
TokenInfoURL string // SubjectTokenType is the STS token type based on the Oauth2.0 token exchange spec
// e.g. `urn:ietf:params:oauth:token-type:jwt`.
SubjectTokenType string
// TokenURL is the STS token exchange endpoint.
TokenURL string
// TokenInfoURL is the token_info endpoint used to retrieve the account related information (
// user attributes like account identifier, eg. email, username, uid, etc). This is
// needed for gCloud session account identification.
TokenInfoURL string
// ServiceAccountImpersonationURL is the URL for the service account impersonation request. This is only
// required for workload identity pools when APIs to be accessed have not integrated with UberMint.
ServiceAccountImpersonationURL string ServiceAccountImpersonationURL string
ClientSecret string // ClientSecret is currently only required if token_info endpoint also
ClientID string // needs to be called with the generated GCP access token. When provided, STS will be
CredentialSource CredentialSource // called with additional basic authentication using client_id as username and client_secret as password.
QuotaProjectID string ClientSecret string
Scopes []string // ClientID is only required in conjunction with ClientSecret, as described above.
ClientID string
// CredentialSource contains the necessary information to retrieve the token itself, as well
// as some environmental information.
CredentialSource CredentialSource
// QuotaProjectID is injected by gCloud. If the value is non-empty, the Auth libraries
// will set the x-goog-user-project which overrides the project associated with the credentials.
QuotaProjectID string
// Scopes contains the desired scopes for the returned access token.
Scopes []string
// The optional workforce pool user project number when the credential
// corresponds to a workforce pool and not a workload identity pool.
// The underlying principal must still have serviceusage.services.use IAM
// permission to use the project for billing/quota.
WorkforcePoolUserProject string
}
// Each element consists of a list of patterns. validateURLs checks for matches
// that include all elements in a given list, in that order.
var (
validTokenURLPatterns = []*regexp.Regexp{
// The complicated part in the middle matches any number of characters that
// aren't period, spaces, or slashes.
regexp.MustCompile(`(?i)^[^\.\s\/\\]+\.sts\.googleapis\.com$`),
regexp.MustCompile(`(?i)^sts\.googleapis\.com$`),
regexp.MustCompile(`(?i)^sts\.[^\.\s\/\\]+\.googleapis\.com$`),
regexp.MustCompile(`(?i)^[^\.\s\/\\]+-sts\.googleapis\.com$`),
}
validImpersonateURLPatterns = []*regexp.Regexp{
regexp.MustCompile(`^[^\.\s\/\\]+\.iamcredentials\.googleapis\.com$`),
regexp.MustCompile(`^iamcredentials\.googleapis\.com$`),
regexp.MustCompile(`^iamcredentials\.[^\.\s\/\\]+\.googleapis\.com$`),
regexp.MustCompile(`^[^\.\s\/\\]+-iamcredentials\.googleapis\.com$`),
}
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
)
func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
parsed, err := url.Parse(input)
if err != nil {
return false
}
if !strings.EqualFold(parsed.Scheme, scheme) {
return false
}
toTest := parsed.Host
for _, pattern := range patterns {
if pattern.MatchString(toTest) {
return true
}
}
return false
}
func validateWorkforceAudience(input string) bool {
return validWorkforceAudiencePattern.MatchString(input)
} }
// TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials. // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials.
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns, "https")
}
// tokenSource is a private function that's directly called by some of the tests,
// because the unit test URLs are mocked, and would otherwise fail the
// validity check.
func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Regexp, impersonateURLValidPats []*regexp.Regexp, scheme string) (oauth2.TokenSource, error) {
valid := validateURL(c.TokenURL, tokenURLValidPats, scheme)
if !valid {
return nil, fmt.Errorf("oauth2/google: invalid TokenURL provided while constructing tokenSource")
}
if c.ServiceAccountImpersonationURL != "" {
valid := validateURL(c.ServiceAccountImpersonationURL, impersonateURLValidPats, scheme)
if !valid {
return nil, fmt.Errorf("oauth2/google: invalid ServiceAccountImpersonationURL provided while constructing tokenSource")
}
}
if c.WorkforcePoolUserProject != "" {
valid := validateWorkforceAudience(c.Audience)
if !valid {
return nil, fmt.Errorf("oauth2/google: workforce_pool_user_project should not be set for non-workforce pool credentials")
}
}
ts := tokenSource{ ts := tokenSource{
ctx: ctx, ctx: ctx,
conf: c, conf: c,
} }
if c.ServiceAccountImpersonationURL == "" { if c.ServiceAccountImpersonationURL == "" {
return oauth2.ReuseTokenSource(nil, ts) return oauth2.ReuseTokenSource(nil, ts), nil
} }
scopes := c.Scopes scopes := c.Scopes
ts.conf.Scopes = []string{"https://www.googleapis.com/auth/cloud-platform"} ts.conf.Scopes = []string{"https://www.googleapis.com/auth/cloud-platform"}
imp := impersonateTokenSource{ imp := ImpersonateTokenSource{
ctx: ctx, Ctx: ctx,
url: c.ServiceAccountImpersonationURL, URL: c.ServiceAccountImpersonationURL,
scopes: scopes, Scopes: scopes,
ts: oauth2.ReuseTokenSource(nil, ts), Ts: oauth2.ReuseTokenSource(nil, ts),
} }
return oauth2.ReuseTokenSource(nil, imp) return oauth2.ReuseTokenSource(nil, imp), nil
} }
// Subject token file types. // Subject token file types.
@ -59,13 +156,15 @@ const (
) )
type format struct { type format struct {
// Type is either "text" or "json". When not provided "text" type is assumed. // Type is either "text" or "json". When not provided "text" type is assumed.
Type string `json:"type"` Type string `json:"type"`
// SubjectTokenFieldName is only required for JSON format. This would be "access_token" for azure. // SubjectTokenFieldName is only required for JSON format. This would be "access_token" for azure.
SubjectTokenFieldName string `json:"subject_token_field_name"` SubjectTokenFieldName string `json:"subject_token_field_name"`
} }
// CredentialSource stores the information necessary to retrieve the credentials for the STS exchange. // CredentialSource stores the information necessary to retrieve the credentials for the STS exchange.
// Either the File or the URL field should be filled, depending on the kind of credential in question.
// The EnvironmentID should start with AWS if being used for an AWS credential.
type CredentialSource struct { type CredentialSource struct {
File string `json:"file"` File string `json:"file"`
@ -76,6 +175,7 @@ type CredentialSource struct {
RegionURL string `json:"region_url"` RegionURL string `json:"region_url"`
RegionalCredVerificationURL string `json:"regional_cred_verification_url"` RegionalCredVerificationURL string `json:"regional_cred_verification_url"`
CredVerificationURL string `json:"cred_verification_url"` CredVerificationURL string `json:"cred_verification_url"`
IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"`
Format format `json:"format"` Format format `json:"format"`
} }
@ -86,14 +186,20 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
if awsVersion != 1 { if awsVersion != 1 {
return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion) return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion)
} }
return awsCredentialSource{
awsCredSource := awsCredentialSource{
EnvironmentID: c.CredentialSource.EnvironmentID, EnvironmentID: c.CredentialSource.EnvironmentID,
RegionURL: c.CredentialSource.RegionURL, RegionURL: c.CredentialSource.RegionURL,
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
CredVerificationURL: c.CredentialSource.URL, CredVerificationURL: c.CredentialSource.URL,
TargetResource: c.Audience, TargetResource: c.Audience,
ctx: ctx, ctx: ctx,
}, nil }
if c.CredentialSource.IMDSv2SessionTokenURL != "" {
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
}
return awsCredSource, nil
} }
} else if c.CredentialSource.File != "" { } else if c.CredentialSource.File != "" {
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil
@ -107,7 +213,7 @@ type baseCredentialSource interface {
subjectToken() (string, error) subjectToken() (string, error)
} }
// tokenSource is the source that handles external credentials. // tokenSource is the source that handles external credentials. It is used to retrieve Tokens.
type tokenSource struct { type tokenSource struct {
ctx context.Context ctx context.Context
conf *Config conf *Config
@ -141,7 +247,15 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
ClientID: conf.ClientID, ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret, ClientSecret: conf.ClientSecret,
} }
stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, nil) var options map[string]interface{}
// Do not pass workforce_pool_user_project when client authentication is used.
// The client ID is sufficient for determining the user project.
if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" {
options = map[string]interface{}{
"userProject": conf.WorkforcePoolUserProject,
}
}
stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,8 +9,11 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
"golang.org/x/oauth2"
) )
const ( const (
@ -34,55 +37,64 @@ var testConfig = Config{
} }
var ( var (
baseCredsRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt" baseCredsRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
baseCredsResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}` baseCredsResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
correctAT = "Sample.Access.Token" workforcePoolRequestBodyWithClientId = "audience=%2F%2Fiam.googleapis.com%2Flocations%2Feu%2FworkforcePools%2Fpool-id%2Fproviders%2Fprovider-id&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
expiry int64 = 234852 workforcePoolRequestBodyWithoutClientId = "audience=%2F%2Fiam.googleapis.com%2Flocations%2Feu%2FworkforcePools%2Fpool-id%2Fproviders%2Fprovider-id&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&options=%7B%22userProject%22%3A%22myProject%22%7D&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=street123&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aid_token"
correctAT = "Sample.Access.Token"
expiry int64 = 234852
) )
var ( var (
testNow = func() time.Time { return time.Unix(expiry, 0) } testNow = func() time.Time { return time.Unix(expiry, 0) }
) )
func TestToken(t *testing.T) { type testExchangeTokenServer struct {
url string
authorization string
contentType string
body string
response string
}
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.Token, error) {
if got, want := r.URL.String(), "/"; got != want { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), tets.url; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want) t.Errorf("URL.String(): got %v but want %v", got, want)
} }
headerAuth := r.Header.Get("Authorization") headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want { if got, want := headerAuth, tets.authorization; got != want {
t.Errorf("got %v but want %v", got, want) t.Errorf("got %v but want %v", got, want)
} }
headerContentType := r.Header.Get("Content-Type") headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want { if got, want := headerContentType, tets.contentType; got != want {
t.Errorf("got %v but want %v", got, want) t.Errorf("got %v but want %v", got, want)
} }
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed reading request body: %s.", err) t.Fatalf("Failed reading request body: %s.", err)
} }
if got, want := string(body), baseCredsRequestBody; got != want { if got, want := string(body), tets.body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want) t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(baseCredsResponseBody)) w.Write([]byte(tets.response))
})) }))
defer targetServer.Close() defer server.Close()
config.TokenURL = server.URL
testConfig.TokenURL = targetServer.URL
ourTS := tokenSource{
ctx: context.Background(),
conf: &testConfig,
}
oldNow := now oldNow := now
defer func() { now = oldNow }() defer func() { now = oldNow }()
now = testNow now = testNow
tok, err := ourTS.Token() ts := tokenSource{
if err != nil { ctx: context.Background(),
t.Fatalf("Unexpected error: %e", err) conf: config,
} }
return ts.Token()
}
func validateToken(t *testing.T, tok *oauth2.Token) {
if got, want := tok.AccessToken, correctAT; got != want { if got, want := tok.AccessToken, correctAT; got != want {
t.Errorf("Unexpected access token: got %v, but wanted %v", got, want) t.Errorf("Unexpected access token: got %v, but wanted %v", got, want)
} }
@ -90,8 +102,260 @@ func TestToken(t *testing.T) {
t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want) t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want)
} }
if got, want := tok.Expiry, now().Add(time.Duration(3600)*time.Second); got != want { if got, want := tok.Expiry, testNow().Add(time.Duration(3600)*time.Second); got != want {
t.Errorf("Unexpected Expiry: got %v, but wanted %v", got, want) t.Errorf("Unexpected Expiry: got %v, but wanted %v", got, want)
} }
}
func TestToken(t *testing.T) {
config := Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
}
server := testExchangeTokenServer{
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
body: baseCredsRequestBody,
response: baseCredsResponseBody,
}
tok, err := run(t, &config, &server)
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
validateToken(t, tok)
}
func TestWorkforcePoolTokenWithClientID(t *testing.T) {
config := Config{
Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
server := testExchangeTokenServer{
url: "/",
authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=",
contentType: "application/x-www-form-urlencoded",
body: workforcePoolRequestBodyWithClientId,
response: baseCredsResponseBody,
}
tok, err := run(t, &config, &server)
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
validateToken(t, tok)
}
func TestWorkforcePoolTokenWithoutClientID(t *testing.T) {
config := Config{
Audience: "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
ClientSecret: "notsosecret",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
server := testExchangeTokenServer{
url: "/",
authorization: "",
contentType: "application/x-www-form-urlencoded",
body: workforcePoolRequestBodyWithoutClientId,
response: baseCredsResponseBody,
}
tok, err := run(t, &config, &server)
if err != nil {
t.Fatalf("Unexpected error: %e", err)
}
validateToken(t, tok)
}
func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) {
config := Config{
Audience: "32555940559.apps.googleusercontent.com",
SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token",
TokenURL: "https://sts.googleapis.com",
ClientSecret: "notsosecret",
ClientID: "rbrgnognrhongo3bi4gb9ghg9g",
CredentialSource: testBaseCredSource,
Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"},
WorkforcePoolUserProject: "myProject",
}
_, err := config.TokenSource(context.Background())
if err == nil {
t.Fatalf("Expected error but found none")
}
if got, want := err.Error(), "oauth2/google: workforce_pool_user_project should not be set for non-workforce pool credentials"; got != want {
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
}
}
func TestValidateURLTokenURL(t *testing.T) {
var urlValidityTests = []struct {
tokURL string
expectSuccess bool
}{
{"https://east.sts.googleapis.com", true},
{"https://sts.googleapis.com", true},
{"https://sts.asfeasfesef.googleapis.com", true},
{"https://us-east-1-sts.googleapis.com", true},
{"https://sts.googleapis.com/your/path/here", true},
{"https://.sts.googleapis.com", false},
{"https://badsts.googleapis.com", false},
{"https://sts.asfe.asfesef.googleapis.com", false},
{"https://sts..googleapis.com", false},
{"https://-sts.googleapis.com", false},
{"https://us-ea.st-1-sts.googleapis.com", false},
{"https://sts.googleapis.com.evil.com/whatever/path", false},
{"https://us-eas\\t-1.sts.googleapis.com", false},
{"https:/us-ea/st-1.sts.googleapis.com", false},
{"https:/us-east 1.sts.googleapis.com", false},
{"https://", false},
{"http://us-east-1.sts.googleapis.com", false},
{"https://us-east-1.sts.googleapis.comevil.com", false},
}
ctx := context.Background()
for _, tt := range urlValidityTests {
t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
config := testConfig
config.TokenURL = tt.tokURL
_, err := config.TokenSource(ctx)
if tt.expectSuccess && err != nil {
t.Errorf("got %v but want nil", err)
} else if !tt.expectSuccess && err == nil {
t.Errorf("got nil but expected an error")
}
})
}
for _, el := range urlValidityTests {
el.tokURL = strings.ToUpper(el.tokURL)
}
for _, tt := range urlValidityTests {
t.Run(" "+tt.tokURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
config := testConfig
config.TokenURL = tt.tokURL
_, err := config.TokenSource(ctx)
if tt.expectSuccess && err != nil {
t.Errorf("got %v but want nil", err)
} else if !tt.expectSuccess && err == nil {
t.Errorf("got nil but expected an error")
}
})
}
}
func TestValidateURLImpersonateURL(t *testing.T) {
var urlValidityTests = []struct {
impURL string
expectSuccess bool
}{
{"https://east.iamcredentials.googleapis.com", true},
{"https://iamcredentials.googleapis.com", true},
{"https://iamcredentials.asfeasfesef.googleapis.com", true},
{"https://us-east-1-iamcredentials.googleapis.com", true},
{"https://iamcredentials.googleapis.com/your/path/here", true},
{"https://.iamcredentials.googleapis.com", false},
{"https://badiamcredentials.googleapis.com", false},
{"https://iamcredentials.asfe.asfesef.googleapis.com", false},
{"https://iamcredentials..googleapis.com", false},
{"https://-iamcredentials.googleapis.com", false},
{"https://us-ea.st-1-iamcredentials.googleapis.com", false},
{"https://iamcredentials.googleapis.com.evil.com/whatever/path", false},
{"https://us-eas\\t-1.iamcredentials.googleapis.com", false},
{"https:/us-ea/st-1.iamcredentials.googleapis.com", false},
{"https:/us-east 1.iamcredentials.googleapis.com", false},
{"https://", false},
{"http://us-east-1.iamcredentials.googleapis.com", false},
{"https://us-east-1.iamcredentials.googleapis.comevil.com", false},
}
ctx := context.Background()
for _, tt := range urlValidityTests {
t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
config := testConfig
config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
config.ServiceAccountImpersonationURL = tt.impURL
_, err := config.TokenSource(ctx)
if tt.expectSuccess && err != nil {
t.Errorf("got %v but want nil", err)
} else if !tt.expectSuccess && err == nil {
t.Errorf("got nil but expected an error")
}
})
}
for _, el := range urlValidityTests {
el.impURL = strings.ToUpper(el.impURL)
}
for _, tt := range urlValidityTests {
t.Run(" "+tt.impURL, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
config := testConfig
config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
config.ServiceAccountImpersonationURL = tt.impURL
_, err := config.TokenSource(ctx)
if tt.expectSuccess && err != nil {
t.Errorf("got %v but want nil", err)
} else if !tt.expectSuccess && err == nil {
t.Errorf("got nil but expected an error")
}
})
}
}
func TestWorkforcePoolCreation(t *testing.T) {
var audienceValidatyTests = []struct {
audience string
expectSuccess bool
}{
{"//iam.googleapis.com/locations/global/workforcePools/pool-id/providers/provider-id", true},
{"//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", true},
{"//iam.googleapis.com/locations/eu/workforcePools/workloadIdentityPools/providers/provider-id", true},
{"identitynamespace:1f12345:my_provider", false},
{"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/projects/123456/locations/eu/workloadIdentityPools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/workforcePools/providers/provider-id", false},
{"//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", false},
{"//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", false},
}
ctx := context.Background()
for _, tt := range audienceValidatyTests {
t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
config := testConfig
config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
config.ServiceAccountImpersonationURL = "https://iamcredentials.googleapis.com"
config.Audience = tt.audience
config.WorkforcePoolUserProject = "myProject"
_, err := config.TokenSource(ctx)
if tt.expectSuccess && err != nil {
t.Errorf("got %v but want nil", err)
} else if !tt.expectSuccess && err == nil {
t.Errorf("got nil but expected an error")
}
})
}
} }

View File

@ -6,9 +6,10 @@ package externalaccount
import ( import (
"encoding/base64" "encoding/base64"
"golang.org/x/oauth2"
"net/http" "net/http"
"net/url" "net/url"
"golang.org/x/oauth2"
) )
// clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1. // clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
@ -19,6 +20,9 @@ type clientAuthentication struct {
ClientSecret string ClientSecret string
} }
// InjectAuthentication is used to add authentication to a Secure Token Service exchange
// request. It modifies either the passed url.Values or http.Header depending on the desired
// authentication format.
func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) { func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil { if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
return return

View File

@ -5,11 +5,12 @@
package externalaccount package externalaccount
import ( import (
"golang.org/x/oauth2"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
"testing" "testing"
"golang.org/x/oauth2"
) )
var clientID = "rbrgnognrhongo3bi4gb9ghg9g" var clientID = "rbrgnognrhongo3bi4gb9ghg9g"

View File

@ -9,11 +9,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"golang.org/x/oauth2"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"time" "time"
"golang.org/x/oauth2"
) )
// generateAccesstokenReq is used for service account impersonation // generateAccesstokenReq is used for service account impersonation
@ -28,30 +29,44 @@ type impersonateTokenResponse struct {
ExpireTime string `json:"expireTime"` ExpireTime string `json:"expireTime"`
} }
type impersonateTokenSource struct { // ImpersonateTokenSource uses a source credential, stored in Ts, to request an access token to the provided URL.
ctx context.Context // Scopes can be defined when the access token is requested.
ts oauth2.TokenSource type ImpersonateTokenSource struct {
// Ctx is the execution context of the impersonation process
// used to perform http call to the URL. Required
Ctx context.Context
// Ts is the source credential used to generate a token on the
// impersonated service account. Required.
Ts oauth2.TokenSource
url string // URL is the endpoint to call to generate a token
scopes []string // on behalf the service account. Required.
URL string
// Scopes that the impersonated credential should have. Required.
Scopes []string
// Delegates are the service account email addresses in a delegation chain.
// Each service account must be granted roles/iam.serviceAccountTokenCreator
// on the next service account in the chain. Optional.
Delegates []string
} }
// Token performs the exchange to get a temporary service account // Token performs the exchange to get a temporary service account token to allow access to GCP.
func (its impersonateTokenSource) Token() (*oauth2.Token, error) { func (its ImpersonateTokenSource) Token() (*oauth2.Token, error) {
reqBody := generateAccessTokenReq{ reqBody := generateAccessTokenReq{
Lifetime: "3600s", Lifetime: "3600s",
Scope: its.scopes, Scope: its.Scopes,
Delegates: its.Delegates,
} }
b, err := json.Marshal(reqBody) b, err := json.Marshal(reqBody)
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: unable to marshal request: %v", err) return nil, fmt.Errorf("oauth2/google: unable to marshal request: %v", err)
} }
client := oauth2.NewClient(its.ctx, its.ts) client := oauth2.NewClient(its.Ctx, its.Ts)
req, err := http.NewRequest("POST", its.url, bytes.NewReader(b)) req, err := http.NewRequest("POST", its.URL, bytes.NewReader(b))
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: unable to create impersonation request: %v", err) return nil, fmt.Errorf("oauth2/google: unable to create impersonation request: %v", err)
} }
req = req.WithContext(its.ctx) req = req.WithContext(its.Ctx)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req) resp, err := client.Do(req)

View File

@ -9,6 +9,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"regexp"
"testing" "testing"
) )
@ -76,7 +77,11 @@ func TestImpersonation(t *testing.T) {
defer targetServer.Close() defer targetServer.Close()
testImpersonateConfig.TokenURL = targetServer.URL testImpersonateConfig.TokenURL = targetServer.URL
ourTS := testImpersonateConfig.TokenSource(context.Background()) allURLs := regexp.MustCompile(".+")
ourTS, err := testImpersonateConfig.tokenSource(context.Background(), []*regexp.Regexp{allURLs}, []*regexp.Regexp{allURLs}, "http")
if err != nil {
t.Fatalf("Failed to create TokenSource: %v", err)
}
oldNow := now oldNow := now
defer func() { now = oldNow }() defer func() { now = oldNow }()

View File

@ -65,6 +65,9 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, err
}
if c := resp.StatusCode; c < 200 || c > 299 { if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body) return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
} }

View File

@ -7,12 +7,13 @@ package externalaccount
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"golang.org/x/oauth2"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"golang.org/x/oauth2"
) )
var auth = clientAuthentication{ var auth = clientAuthentication{
@ -127,6 +128,9 @@ func TestExchangeToken_Opts(t *testing.T) {
} }
var opts map[string]interface{} var opts map[string]interface{}
err = json.Unmarshal([]byte(strOpts[0]), &opts) err = json.Unmarshal([]byte(strOpts[0]), &opts)
if err != nil {
t.Fatalf("Couldn't parse received \"options\" field.")
}
if len(opts) < 2 { if len(opts) < 2 {
t.Errorf("Too few options received.") t.Errorf("Too few options received.")
} }

View File

@ -9,10 +9,11 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/oauth2"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"golang.org/x/oauth2"
) )
type urlCredentialSource struct { type urlCredentialSource struct {

View File

@ -7,7 +7,6 @@ package externalaccount
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -20,7 +19,6 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) {
if r.Method != "GET" { if r.Method != "GET" {
t.Errorf("Unexpected request method, %v is found", r.Method) t.Errorf("Unexpected request method, %v is found", r.Method)
} }
fmt.Println(r.Header)
if r.Header.Get("Metadata") != "True" { if r.Header.Get("Metadata") != "True" {
t.Errorf("Metadata header not properly included.") t.Errorf("Metadata header not properly included.")
} }

View File

@ -7,6 +7,7 @@ package google
import ( import (
"crypto/rsa" "crypto/rsa"
"fmt" "fmt"
"strings"
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -24,6 +25,28 @@ import (
// optimization supported by a few Google services. // optimization supported by a few Google services.
// Unless you know otherwise, you should use JWTConfigFromJSON instead. // Unless you know otherwise, you should use JWTConfigFromJSON instead.
func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) { func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.TokenSource, error) {
return newJWTSource(jsonKey, audience, nil)
}
// JWTAccessTokenSourceWithScope uses a Google Developers service account JSON
// key file to read the credentials that authorize and authenticate the
// requests, and returns a TokenSource that does not use any OAuth2 flow but
// instead creates a JWT and sends that as the access token.
// The scope is typically a list of URLs that specifies the scope of the
// credentials.
//
// Note that this is not a standard OAuth flow, but rather an
// optimization supported by a few Google services.
// Unless you know otherwise, you should use JWTConfigFromJSON instead.
func JWTAccessTokenSourceWithScope(jsonKey []byte, scope ...string) (oauth2.TokenSource, error) {
return newJWTSource(jsonKey, "", scope)
}
func newJWTSource(jsonKey []byte, audience string, scopes []string) (oauth2.TokenSource, error) {
if len(scopes) == 0 && audience == "" {
return nil, fmt.Errorf("google: missing scope/audience for JWT access token")
}
cfg, err := JWTConfigFromJSON(jsonKey) cfg, err := JWTConfigFromJSON(jsonKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("google: could not parse JSON key: %v", err) return nil, fmt.Errorf("google: could not parse JSON key: %v", err)
@ -35,6 +58,7 @@ func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.Token
ts := &jwtAccessTokenSource{ ts := &jwtAccessTokenSource{
email: cfg.Email, email: cfg.Email,
audience: audience, audience: audience,
scopes: scopes,
pk: pk, pk: pk,
pkID: cfg.PrivateKeyID, pkID: cfg.PrivateKeyID,
} }
@ -47,6 +71,7 @@ func JWTAccessTokenSourceFromJSON(jsonKey []byte, audience string) (oauth2.Token
type jwtAccessTokenSource struct { type jwtAccessTokenSource struct {
email, audience string email, audience string
scopes []string
pk *rsa.PrivateKey pk *rsa.PrivateKey
pkID string pkID string
} }
@ -54,12 +79,14 @@ type jwtAccessTokenSource struct {
func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) { func (ts *jwtAccessTokenSource) Token() (*oauth2.Token, error) {
iat := time.Now() iat := time.Now()
exp := iat.Add(time.Hour) exp := iat.Add(time.Hour)
scope := strings.Join(ts.scopes, " ")
cs := &jws.ClaimSet{ cs := &jws.ClaimSet{
Iss: ts.email, Iss: ts.email,
Sub: ts.email, Sub: ts.email,
Aud: ts.audience, Aud: ts.audience,
Iat: iat.Unix(), Scope: scope,
Exp: exp.Unix(), Iat: iat.Unix(),
Exp: exp.Unix(),
} }
hdr := &jws.Header{ hdr := &jws.Header{
Algorithm: "RS256", Algorithm: "RS256",

View File

@ -13,29 +13,21 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"golang.org/x/oauth2/jws" "golang.org/x/oauth2/jws"
) )
func TestJWTAccessTokenSourceFromJSON(t *testing.T) { var (
// Generate a key we can use in the test data. privateKey *rsa.PrivateKey
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) jsonKey []byte
if err != nil { once sync.Once
t.Fatal(err) )
}
// Encode the key and substitute into our example JSON. func TestJWTAccessTokenSourceFromJSON(t *testing.T) {
enc := pem.EncodeToMemory(&pem.Block{ setupDummyKey(t)
Type: "PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
enc, err = json.Marshal(string(enc))
if err != nil {
t.Fatalf("json.Marshal: %v", err)
}
jsonKey := bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1)
ts, err := JWTAccessTokenSourceFromJSON(jsonKey, "audience") ts, err := JWTAccessTokenSourceFromJSON(jsonKey, "audience")
if err != nil { if err != nil {
@ -89,3 +81,80 @@ func TestJWTAccessTokenSourceFromJSON(t *testing.T) {
t.Errorf("Header KeyID = %q, want %q", got, want) t.Errorf("Header KeyID = %q, want %q", got, want)
} }
} }
func TestJWTAccessTokenSourceWithScope(t *testing.T) {
setupDummyKey(t)
ts, err := JWTAccessTokenSourceWithScope(jsonKey, "scope1", "scope2")
if err != nil {
t.Fatalf("JWTAccessTokenSourceWithScope: %v\nJSON: %s", err, string(jsonKey))
}
tok, err := ts.Token()
if err != nil {
t.Fatalf("Token: %v", err)
}
if got, want := tok.TokenType, "Bearer"; got != want {
t.Errorf("TokenType = %q, want %q", got, want)
}
if got := tok.Expiry; tok.Expiry.Before(time.Now()) {
t.Errorf("Expiry = %v, should not be expired", got)
}
err = jws.Verify(tok.AccessToken, &privateKey.PublicKey)
if err != nil {
t.Errorf("jws.Verify on AccessToken: %v", err)
}
claim, err := jws.Decode(tok.AccessToken)
if err != nil {
t.Fatalf("jws.Decode on AccessToken: %v", err)
}
if got, want := claim.Iss, "gopher@developer.gserviceaccount.com"; got != want {
t.Errorf("Iss = %q, want %q", got, want)
}
if got, want := claim.Sub, "gopher@developer.gserviceaccount.com"; got != want {
t.Errorf("Sub = %q, want %q", got, want)
}
if got, want := claim.Scope, "scope1 scope2"; got != want {
t.Errorf("Aud = %q, want %q", got, want)
}
// Finally, check the header private key.
parts := strings.Split(tok.AccessToken, ".")
hdrJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
t.Fatalf("base64 DecodeString: %v\nString: %q", err, parts[0])
}
var hdr jws.Header
if err := json.Unmarshal([]byte(hdrJSON), &hdr); err != nil {
t.Fatalf("json.Unmarshal: %v (%q)", err, hdrJSON)
}
if got, want := hdr.KeyID, "268f54e43a1af97cfc71731688434f45aca15c8b"; got != want {
t.Errorf("Header KeyID = %q, want %q", got, want)
}
}
func setupDummyKey(t *testing.T) {
once.Do(func() {
// Generate a key we can use in the test data.
pk, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
privateKey = pk
// Encode the key and substitute into our example JSON.
enc := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
enc, err = json.Marshal(string(enc))
if err != nil {
t.Fatalf("json.Marshal: %v", err)
}
jsonKey = bytes.Replace(jwtJSONKey, []byte(`"super secret key"`), enc, 1)
})
}