diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index ae96ac3..0f8c31c 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -5,30 +5,31 @@ package externalaccount import ( + "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" + "golang.org/x/oauth2" "io" "io/ioutil" "net/http" "os" "path" "sort" - "strconv" "strings" "time" ) -// RequestSigner is a utility class to sign http requests using a AWS V4 signature. type awsSecurityCredentials struct { - AccessKeyId string - SecretAccessKey string - SecurityToken string + AccessKeyID string `json:"AccessKeyID"` + SecretAccessKey string `json:"SecretAccessKey"` + SecurityToken string `json:"Token"` } +// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature. type awsRequestSigner struct { RegionName string AwsSecurityCredentials awsSecurityCredentials @@ -229,7 +230,7 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp } } - return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil + return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil } type awsCredentialSource struct { @@ -240,24 +241,29 @@ type awsCredentialSource struct { TargetResource string requestSigner *awsRequestSigner region string + ctx context.Context + client *http.Client } -type AwsRequestHeader struct { +type awsRequestHeader struct { Key string `json:"key"` Value string `json:"value"` } -type AwsRequest struct { +type awsRequest struct { URL string `json:"url"` Method string `json:"method"` - Headers []AwsRequestHeader `json:"headers"` + Headers []awsRequestHeader `json:"headers"` +} + +func (cs awsCredentialSource) request(req *http.Request) (*http.Response, error) { + if cs.client == nil { + cs.client = oauth2.NewClient(cs.ctx, nil) + } + return cs.client.Do(req.WithContext(cs.ctx)) } func (cs awsCredentialSource) subjectToken() (string, error) { - if version, _ := strconv.Atoi(cs.EnvironmentID[3:]); version != 1 { - return "", errors.New(fmt.Sprintf("oauth2/google: aws version '%d' is not supported in the current build.", version)) - } - if cs.requestSigner == nil { awsSecurityCredentials, err := cs.getSecurityCredentials() if err != nil { @@ -304,13 +310,13 @@ func (cs awsCredentialSource) subjectToken() (string, error) { # })) */ - awsSignedReq := AwsRequest{ + awsSignedReq := awsRequest{ URL: req.URL.String(), - Method: req.Method, + Method: "POST", } for headerKey, headerList := range req.Header { for _, headerValue := range headerList { - awsSignedReq.Headers = append(awsSignedReq.Headers, AwsRequestHeader{ + awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{ Key: headerKey, Value: headerValue, }) @@ -337,10 +343,15 @@ func (cs *awsCredentialSource) getRegion() (string, error) { } if cs.RegionURL == "" { - return "", errors.New("oauth2/google: Unable to determine AWS region.") + return "", errors.New("oauth2/google: unable to determine AWS region") } - resp, err := http.Get(cs.RegionURL) + req, err := http.NewRequest("GET", cs.RegionURL, nil) + if err != nil { + return "", err + } + + resp, err := cs.request(req) if err != nil { return "", err } @@ -352,17 +363,23 @@ func (cs *awsCredentialSource) getRegion() (string, error) { } if resp.StatusCode != 200 { - return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS region - %s.", string(respBody))) + return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody)) } - return string(respBody)[:len(respBody)-1], nil + // This endpoint will return the region in format: us-east-2b. + // Only the us-east-2 part should be used. + respBodyEnd := 0 + if len(respBody) > 1 { + respBodyEnd = len(respBody) - 1 + } + return string(respBody[:respBodyEnd]), nil } -func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials awsSecurityCredentials, err error) { - if accessKeyId := getenv("AWS_ACCESS_KEY_ID"); accessKeyId != "" { +func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) { + if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { return awsSecurityCredentials{ - AccessKeyId: accessKeyId, + AccessKeyID: accessKeyID, SecretAccessKey: secretAccessKey, SecurityToken: getenv("AWS_SESSION_TOKEN"), }, nil @@ -371,70 +388,64 @@ func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials aws roleName, err := cs.getMetadataRoleName() if err != nil { - return awsSecurityCredentials{}, err + return } credentials, err := cs.getMetadataSecurityCredentials(roleName) if err != nil { - return awsSecurityCredentials{}, err + return } - accessKeyId, ok := credentials["AccessKeyId"] - if !ok { - return awsSecurityCredentials{}, errors.New("oauth2/google: missing AccessKeyId credential.") + if credentials.AccessKeyID == "" { + return result, errors.New("oauth2/google: missing AccessKeyId credential") } - secretAccessKey, ok := credentials["SecretAccessKey"] - if !ok { - return awsSecurityCredentials{}, errors.New("oauth2/google: missing SecretAccessKey credential.") + if credentials.SecretAccessKey == "" { + return result, errors.New("oauth2/google: missing SecretAccessKey credential") } - securityToken, _ := credentials["Token"] - - return awsSecurityCredentials{ - AccessKeyId: accessKeyId, - SecretAccessKey: secretAccessKey, - SecurityToken: securityToken, - }, nil + return credentials, nil } -func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (map[string]string, error) { +func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) { + var result awsSecurityCredentials + req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) if err != nil { - return nil, err + return result, err } req.Header.Add("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := cs.request(req) if err != nil { - return nil, err + return result, err } defer resp.Body.Close() respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { - return nil, err + return result, err } if resp.StatusCode != 200 { - return nil, errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS security credentials - %s.", string(respBody))) + return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody)) } - var result map[string]string err = json.Unmarshal(respBody, &result) - if err != nil { - return nil, err - } - - return result, nil + return result, err } func (cs *awsCredentialSource) getMetadataRoleName() (string, error) { if cs.CredVerificationURL == "" { - return "", errors.New("oauth2/google: Unable to determine the AWS metadata server security credentials endpoint.") + return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint") } - resp, err := http.Get(cs.CredVerificationURL) + req, err := http.NewRequest("GET", cs.CredVerificationURL, nil) + if err != nil { + return "", err + } + + resp, err := cs.request(req) if err != nil { return "", err } @@ -446,7 +457,7 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) { } if resp.StatusCode != 200 { - return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS role name - %s.", string(respBody))) + return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody)) } return string(respBody), nil diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 18081f2..08f366d 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -5,6 +5,7 @@ package externalaccount import ( + "context" "encoding/json" "fmt" "net/http" @@ -34,7 +35,7 @@ func setEnvironment(env map[string]string) func(string) string { var defaultRequestSigner = &awsRequestSigner{ RegionName: "us-east-1", AwsSecurityCredentials: awsSecurityCredentials{ - AccessKeyId: "AKIDEXAMPLE", + AccessKeyID: "AKIDEXAMPLE", SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", }, } @@ -48,7 +49,7 @@ const ( var requestSignerWithToken = &awsRequestSigner{ RegionName: "us-east-2", AwsSecurityCredentials: awsSecurityCredentials{ - AccessKeyId: accessKeyId, + AccessKeyID: accessKeyId, SecretAccessKey: secretAccessKey, SecurityToken: securityToken, }, @@ -386,7 +387,7 @@ func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) { var requestSigner = &awsRequestSigner{ RegionName: "us-east-2", AwsSecurityCredentials: awsSecurityCredentials{ - AccessKeyId: accessKeyId, + AccessKeyID: accessKeyId, SecretAccessKey: secretAccessKey, }, } @@ -489,26 +490,24 @@ func getExpectedSubjectToken(url, region, accessKeyId, secretAccessKey, security signer := &awsRequestSigner{ RegionName: region, AwsSecurityCredentials: awsSecurityCredentials{ - AccessKeyId: accessKeyId, + AccessKeyID: accessKeyId, SecretAccessKey: secretAccessKey, SecurityToken: securityToken, }, } signer.SignRequest(req) - result := AwsRequest{ + result := awsRequest{ URL: url, Method: "POST", - Headers: []AwsRequestHeader{ - AwsRequestHeader{ + Headers: []awsRequestHeader{ + { Key: "Authorization", Value: req.Header.Get("Authorization"), - }, - AwsRequestHeader{ + }, { Key: "Host", Value: req.Header.Get("Host"), - }, - AwsRequestHeader{ + }, { Key: "X-Amz-Date", Value: req.Header.Get("X-Amz-Date"), }, @@ -516,13 +515,13 @@ func getExpectedSubjectToken(url, region, accessKeyId, secretAccessKey, security } if securityToken != "" { - result.Headers = append(result.Headers, AwsRequestHeader{ + result.Headers = append(result.Headers, awsRequestHeader{ Key: "X-Amz-Security-Token", Value: securityToken, }) } - result.Headers = append(result.Headers, AwsRequestHeader{ + result.Headers = append(result.Headers, awsRequestHeader{ Key: "X-Goog-Cloud-Target-Resource", Value: testFileConfig.Audience, }) @@ -542,7 +541,12 @@ func TestAwsCredential_BasicRequest(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - out, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() if err != nil { t.Fatalf("retrieveSubjectToken() failed: %v", err) } @@ -560,6 +564,41 @@ func TestAwsCredential_BasicRequest(t *testing.T) { } } +func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) { + server := createDefaultAwsTestServer() + ts := httptest.NewServer(server) + delete(server.Credentials, "Token") + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + defer func() { getenv = oldGetenv }() + getenv = setEnvironment(map[string]string{}) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyId, + secretAccessKey, + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = %q, want %q", got, want) + } +} + func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -575,7 +614,12 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { "AWS_REGION": "us-west-1", }) - out, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() if err != nil { t.Fatalf("retrieveSubjectToken() failed: %v", err) } @@ -605,12 +649,11 @@ func TestAwsCredential_RequestWithBadVersion(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - _, err := tfc.parse(nil).subjectToken() + _, err := tfc.parse(context.Background()) if err == nil { - t.Fatalf("retrieveSubjectToken() should have failed") + t.Fatalf("parse() should have failed") } - - if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build."; !reflect.DeepEqual(got, want) { + if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build"; !reflect.DeepEqual(got, want) { t.Errorf("subjectToken = %q, want %q", got, want) } } @@ -627,12 +670,17 @@ func TestAwsCredential_RequestWithNoRegionUrl(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - _, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() if err == nil { t.Fatalf("retrieveSubjectToken() should have failed") } - if got, want := err.Error(), "oauth2/google: Unable to determine AWS region."; !reflect.DeepEqual(got, want) { + if got, want := err.Error(), "oauth2/google: unable to determine AWS region"; !reflect.DeepEqual(got, want) { t.Errorf("subjectToken = %q, want %q", got, want) } } @@ -649,12 +697,17 @@ func TestAwsCredential_RequestWithBadRegionUrl(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - _, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() if err == nil { t.Fatalf("retrieveSubjectToken() should have failed") } - if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS region - Not Found."; !reflect.DeepEqual(got, want) { + if got, want := err.Error(), "oauth2/google: unable to retrieve AWS region - Not Found"; !reflect.DeepEqual(got, want) { t.Errorf("subjectToken = %q, want %q", got, want) } } @@ -673,12 +726,17 @@ func TestAwsCredential_RequestWithMissingCredential(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - _, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() if err == nil { t.Fatalf("retrieveSubjectToken() should have failed") } - if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential."; !reflect.DeepEqual(got, want) { + if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential"; !reflect.DeepEqual(got, want) { t.Errorf("subjectToken = %q, want %q", got, want) } } @@ -697,12 +755,17 @@ func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - _, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() if err == nil { t.Fatalf("retrieveSubjectToken() should have failed") } - if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential."; !reflect.DeepEqual(got, want) { + if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential"; !reflect.DeepEqual(got, want) { t.Errorf("subjectToken = %q, want %q", got, want) } } @@ -719,12 +782,17 @@ func TestAwsCredential_RequestWithNoCredentialUrl(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - _, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() if err == nil { t.Fatalf("retrieveSubjectToken() should have failed") } - if got, want := err.Error(), "oauth2/google: Unable to determine the AWS metadata server security credentials endpoint."; !reflect.DeepEqual(got, want) { + if got, want := err.Error(), "oauth2/google: unable to determine the AWS metadata server security credentials endpoint"; !reflect.DeepEqual(got, want) { t.Errorf("subjectToken = %q, want %q", got, want) } } @@ -741,12 +809,17 @@ func TestAwsCredential_RequestWithBadCredentialUrl(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - _, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() if err == nil { t.Fatalf("retrieveSubjectToken() should have failed") } - if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS role name - Not Found."; !reflect.DeepEqual(got, want) { + if got, want := err.Error(), "oauth2/google: unable to retrieve AWS role name - Not Found"; !reflect.DeepEqual(got, want) { t.Errorf("subjectToken = %q, want %q", got, want) } } @@ -763,12 +836,17 @@ func TestAwsCredential_RequestWithBadFinalCredentialUrl(t *testing.T) { defer func() { getenv = oldGetenv }() getenv = setEnvironment(map[string]string{}) - _, err := tfc.parse(nil).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + _, err = base.subjectToken() if err == nil { t.Fatalf("retrieveSubjectToken() should have failed") } - if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS security credentials - Not Found."; !reflect.DeepEqual(got, want) { + if got, want := err.Error(), "oauth2/google: unable to retrieve AWS security credentials - Not Found"; !reflect.DeepEqual(got, want) { t.Errorf("subjectToken = %q, want %q", got, want) } } diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 63d52f7..c343a22 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -67,24 +67,28 @@ type CredentialSource struct { } // parse determines the type of CredentialSource needed -func (c *Config) parse(ctx context.Context) baseCredentialSource { +func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" { - if _, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil { + if awsVersion, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil { + if awsVersion != 1 { + return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion) + } return awsCredentialSource{ EnvironmentID: c.CredentialSource.EnvironmentID, RegionURL: c.CredentialSource.RegionURL, RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, CredVerificationURL: c.CredentialSource.URL, TargetResource: c.Audience, - } + ctx: ctx, + }, nil } } if c.CredentialSource.File != "" { - return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format} + return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil } else if c.CredentialSource.URL != "" { - return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx} + return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx}, nil } - return nil + return nil, fmt.Errorf("oauth2/google: unable to parse credential source") } type baseCredentialSource interface { @@ -101,11 +105,12 @@ type tokenSource struct { func (ts tokenSource) Token() (*oauth2.Token, error) { conf := ts.conf - credSource := conf.parse(ts.ctx) - if credSource == nil { - return nil, fmt.Errorf("oauth2/google: unable to parse credential source") + credSource, err := conf.parse(ts.ctx) + if err != nil { + return nil, err } subjectToken, err := credSource.subjectToken() + if err != nil { return nil, err } diff --git a/google/internal/externalaccount/filecredsource_test.go b/google/internal/externalaccount/filecredsource_test.go index 56dd71e..ebd2bb7 100644 --- a/google/internal/externalaccount/filecredsource_test.go +++ b/google/internal/externalaccount/filecredsource_test.go @@ -56,7 +56,12 @@ func TestRetrieveFileSubjectToken(t *testing.T) { tfc.CredentialSource = test.cs t.Run(test.name, func(t *testing.T) { - out, err := tfc.parse(context.Background()).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() if err != nil { t.Errorf("Method subjectToken() errored.") } else if test.want != out { diff --git a/google/internal/externalaccount/urlcredsource_test.go b/google/internal/externalaccount/urlcredsource_test.go index 592610f..1b78e68 100644 --- a/google/internal/externalaccount/urlcredsource_test.go +++ b/google/internal/externalaccount/urlcredsource_test.go @@ -28,7 +28,12 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) { tfc := testFileConfig tfc.CredentialSource = cs - out, err := tfc.parse(context.Background()).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() if err != nil { t.Fatalf("retrieveSubjectToken() failed: %v", err) } @@ -51,7 +56,12 @@ func TestRetrieveURLSubjectToken_Untyped(t *testing.T) { tfc := testFileConfig tfc.CredentialSource = cs - out, err := tfc.parse(context.Background()).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() if err != nil { t.Fatalf("Failed to retrieve URL subject token: %v", err) } @@ -82,7 +92,12 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) { tfc := testFileConfig tfc.CredentialSource = cs - out, err := tfc.parse(context.Background()).subjectToken() + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() if err != nil { t.Fatalf("%v", err) }