diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index a5a5423..e917195 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -52,6 +52,13 @@ const ( // The AWS authorization header name for the security session token if available. awsSecurityTokenHeader = "x-amz-security-token" + // The name of the header containing the session token for metadata endpoint calls + awsIMDSv2SessionTokenHeader = "X-aws-ec2-metadata-token" + + awsIMDSv2SessionTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds" + + awsIMDSv2SessionTtl = "300" + // The AWS authorization header name for the auto-generated date. awsDateHeader = "x-amz-date" @@ -241,6 +248,7 @@ type awsCredentialSource struct { RegionURL string RegionalCredVerificationURL string CredVerificationURL string + IMDSv2SessionTokenURL string TargetResource string requestSigner *awsRequestSigner region string @@ -268,12 +276,22 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro func (cs awsCredentialSource) subjectToken() (string, error) { if cs.requestSigner == nil { - awsSecurityCredentials, err := cs.getSecurityCredentials() + awsSessionToken, err := cs.getAWSSessionToken() if err != nil { return "", err } - if cs.region, err = cs.getRegion(); err != nil { + headers := make(map[string]string) + if awsSessionToken != "" { + headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + } + + awsSecurityCredentials, err := cs.getSecurityCredentials(headers) + if err != nil { + return "", err + } + + if cs.region, err = cs.getRegion(headers); err != nil { return "", err } @@ -340,7 +358,37 @@ func (cs awsCredentialSource) subjectToken() (string, error) { return url.QueryEscape(string(result)), nil } -func (cs *awsCredentialSource) getRegion() (string, error) { +func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { + if cs.IMDSv2SessionTokenURL == "" { + return "", nil + } + + req, err := http.NewRequest("PUT", cs.IMDSv2SessionTokenURL, nil) + if err != nil { + return "", err + } + + req.Header.Add(awsIMDSv2SessionTtlHeader, awsIMDSv2SessionTtl) + + resp, err := cs.doRequest(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", err + } + + if resp.StatusCode != 200 { + return "", fmt.Errorf("oauth2/google: unable to retrieve AWS session token - %s", string(respBody)) + } + + return string(respBody), nil +} + +func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { return envAwsRegion, nil } @@ -357,6 +405,10 @@ func (cs *awsCredentialSource) getRegion() (string, error) { return "", err } + for name, value := range headers { + req.Header.Add(name, value) + } + resp, err := cs.doRequest(req) if err != nil { return "", err @@ -381,7 +433,7 @@ func (cs *awsCredentialSource) getRegion() (string, error) { return string(respBody[:respBodyEnd]), nil } -func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) { +func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { return awsSecurityCredentials{ @@ -392,12 +444,12 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede } } - roleName, err := cs.getMetadataRoleName() + roleName, err := cs.getMetadataRoleName(headers) if err != nil { return } - credentials, err := cs.getMetadataSecurityCredentials(roleName) + credentials, err := cs.getMetadataSecurityCredentials(roleName, headers) if err != nil { return } @@ -413,7 +465,7 @@ func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCrede return credentials, nil } -func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) { +func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, headers map[string]string) (awsSecurityCredentials, error) { var result awsSecurityCredentials req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) @@ -422,6 +474,10 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) ( } req.Header.Add("Content-Type", "application/json") + for name, value := range headers { + req.Header.Add(name, value) + } + resp, err := cs.doRequest(req) if err != nil { return result, err @@ -441,7 +497,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) ( return result, err } -func (cs *awsCredentialSource) getMetadataRoleName() (string, error) { +func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (string, error) { if cs.CredVerificationURL == "" { return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint") } @@ -451,6 +507,10 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) { return "", err } + for name, value := range headers { + req.Header.Add(name, value) + } + resp, err := cs.doRequest(req) if err != nil { return "", err diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 4900189..0934389 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -20,6 +20,8 @@ import ( var defaultTime = time.Date(2011, 9, 9, 23, 36, 0, 0, time.UTC) var secondDefaultTime = time.Date(2020, 8, 11, 6, 55, 22, 0, time.UTC) +type validateHeaders func(r *http.Request) + func setTime(testTime time.Time) func() time.Time { return func() time.Time { return testTime @@ -82,7 +84,7 @@ func testRequestSigner(t *testing.T, rs *awsRequestSigner, input, expectedOutput } } -func TestAwsV4Signature_GetRequest(t *testing.T) { +func TestAWSv4Signature_GetRequest(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com", nil) setDefaultTime(input) @@ -100,7 +102,7 @@ func TestAwsV4Signature_GetRequest(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) { +func TestAWSv4Signature_GetRequestWithRelativePath(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/foo/bar/../..", nil) setDefaultTime(input) @@ -118,7 +120,7 @@ func TestAwsV4Signature_GetRequestWithRelativePath(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) { +func TestAWSv4Signature_GetRequestWithDotPath(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/./", nil) setDefaultTime(input) @@ -136,7 +138,7 @@ func TestAwsV4Signature_GetRequestWithDotPath(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) { +func TestAWSv4Signature_GetRequestWithPointlessDotPath(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/./foo", nil) setDefaultTime(input) @@ -154,7 +156,7 @@ func TestAwsV4Signature_GetRequestWithPointlessDotPath(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) { +func TestAWSv4Signature_GetRequestWithUtf8Path(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/%E1%88%B4", nil) setDefaultTime(input) @@ -172,7 +174,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Path(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) { +func TestAWSv4Signature_GetRequestWithDuplicateQuery(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=Zoo&foo=aha", nil) setDefaultTime(input) @@ -190,7 +192,7 @@ func TestAwsV4Signature_GetRequestWithDuplicateQuery(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) { +func TestAWSv4Signature_GetRequestWithMisorderedQuery(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/?foo=b&foo=a", nil) setDefaultTime(input) @@ -208,7 +210,7 @@ func TestAwsV4Signature_GetRequestWithMisorderedQuery(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) { +func TestAWSv4Signature_GetRequestWithUtf8Query(t *testing.T) { input, _ := http.NewRequest("GET", "https://host.foo.com/?ሴ=bar", nil) setDefaultTime(input) @@ -226,7 +228,7 @@ func TestAwsV4Signature_GetRequestWithUtf8Query(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequest(t *testing.T) { +func TestAWSv4Signature_PostRequest(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) setDefaultTime(input) input.Header.Add("ZOO", "zoobar") @@ -246,7 +248,7 @@ func TestAwsV4Signature_PostRequest(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) { +func TestAWSv4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) setDefaultTime(input) input.Header.Add("zoo", "ZOOBAR") @@ -266,7 +268,7 @@ func TestAwsV4Signature_PostRequestWithCapitalizedHeaderValue(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequestPhfft(t *testing.T) { +func TestAWSv4Signature_PostRequestPhfft(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/", nil) setDefaultTime(input) input.Header.Add("p", "phfft") @@ -286,7 +288,7 @@ func TestAwsV4Signature_PostRequestPhfft(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequestWithBody(t *testing.T) { +func TestAWSv4Signature_PostRequestWithBody(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/", strings.NewReader("foo=bar")) setDefaultTime(input) input.Header.Add("Content-Type", "application/x-www-form-urlencoded") @@ -306,7 +308,7 @@ func TestAwsV4Signature_PostRequestWithBody(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) { +func TestAWSv4Signature_PostRequestWithQueryString(t *testing.T) { input, _ := http.NewRequest("POST", "https://host.foo.com/?foo=bar", nil) setDefaultTime(input) @@ -324,7 +326,7 @@ func TestAwsV4Signature_PostRequestWithQueryString(t *testing.T) { testRequestSigner(t, defaultRequestSigner, input, output) } -func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) { +func TestAWSv4Signature_GetRequestWithSecurityToken(t *testing.T) { input, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil) output, _ := http.NewRequest("GET", "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", nil) @@ -342,7 +344,7 @@ func TestAwsV4Signature_GetRequestWithSecurityToken(t *testing.T) { testRequestSigner(t, requestSignerWithToken, input, output) } -func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) { +func TestAWSv4Signature_PostRequestWithSecurityToken(t *testing.T) { input, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil) output, _ := http.NewRequest("POST", "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", nil) @@ -360,7 +362,7 @@ func TestAwsV4Signature_PostRequestWithSecurityToken(t *testing.T) { testRequestSigner(t, requestSignerWithToken, input, output) } -func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) { +func TestAWSv4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *testing.T) { requestParams := "{\"KeySchema\":[{\"KeyType\":\"HASH\",\"AttributeName\":\"Id\"}],\"TableName\":\"TestTable\",\"AttributeDefinitions\":[{\"AttributeName\":\"Id\",\"AttributeType\":\"S\"}],\"ProvisionedThroughput\":{\"WriteCapacityUnits\":5,\"ReadCapacityUnits\":5}}" input, _ := http.NewRequest("POST", "https://dynamodb.us-east-2.amazonaws.com/", strings.NewReader(requestParams)) input.Header.Add("Content-Type", "application/x-amz-json-1.0") @@ -383,7 +385,7 @@ func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test testRequestSigner(t, requestSignerWithToken, input, output) } -func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) { +func TestAWSv4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) { var requestSigner = &awsRequestSigner{ RegionName: "us-east-2", AwsSecurityCredentials: awsSecurityCredentials{ @@ -413,30 +415,40 @@ type testAwsServer struct { securityCredentialURL string regionURL string regionalCredVerificationURL string + imdsv2SessionTokenUrl string Credentials map[string]string - WriteRolename func(http.ResponseWriter) - WriteSecurityCredentials func(http.ResponseWriter) - WriteRegion func(http.ResponseWriter) + WriteRolename func(http.ResponseWriter, *http.Request) + WriteSecurityCredentials func(http.ResponseWriter, *http.Request) + WriteRegion func(http.ResponseWriter, *http.Request) + WriteIMDSv2SessionToken func(http.ResponseWriter, *http.Request) } -func createAwsTestServer(url, regionURL, regionalCredVerificationURL, rolename, region string, credentials map[string]string) *testAwsServer { +func createAwsTestServer(url, regionURL, regionalCredVerificationURL, imdsv2SessionTokenUrl string, rolename, region string, credentials map[string]string, imdsv2SessionToken string, validateHeaders validateHeaders) *testAwsServer { server := &testAwsServer{ url: url, securityCredentialURL: fmt.Sprintf("%s/%s", url, rolename), regionURL: regionURL, regionalCredVerificationURL: regionalCredVerificationURL, + imdsv2SessionTokenUrl: imdsv2SessionTokenUrl, Credentials: credentials, - WriteRolename: func(w http.ResponseWriter) { + WriteRolename: func(w http.ResponseWriter, r *http.Request) { + validateHeaders(r) w.Write([]byte(rolename)) }, - WriteRegion: func(w http.ResponseWriter) { + WriteRegion: func(w http.ResponseWriter, r *http.Request) { + validateHeaders(r) w.Write([]byte(region)) }, + WriteIMDSv2SessionToken: func(w http.ResponseWriter, r *http.Request) { + validateHeaders(r) + w.Write([]byte(imdsv2SessionToken)) + }, } - server.WriteSecurityCredentials = func(w http.ResponseWriter) { + server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { + validateHeaders(r) jsonCredentials, _ := json.Marshal(server.Credentials) w.Write(jsonCredentials) } @@ -449,6 +461,7 @@ func createDefaultAwsTestServer() *testAwsServer { "/latest/meta-data/iam/security-credentials", "/latest/meta-data/placement/availability-zone", "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "", "gcp-aws-role", "us-east-2b", map[string]string{ @@ -456,31 +469,38 @@ func createDefaultAwsTestServer() *testAwsServer { "AccessKeyId": accessKeyID, "Token": securityToken, }, + "", + noHeaderValidation, ) } func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch p := r.URL.Path; p { case server.url: - server.WriteRolename(w) + server.WriteRolename(w, r) case server.securityCredentialURL: - server.WriteSecurityCredentials(w) + server.WriteSecurityCredentials(w, r) case server.regionURL: - server.WriteRegion(w) + server.WriteRegion(w, r) + case server.imdsv2SessionTokenUrl: + server.WriteIMDSv2SessionToken(w, r) } } -func notFound(w http.ResponseWriter) { +func notFound(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte("Not Found")) } +func noHeaderValidation(r *http.Request) {} + func (server *testAwsServer) getCredentialSource(url string) CredentialSource { return CredentialSource{ EnvironmentID: "aws1", URL: url + server.url, RegionURL: url + server.regionURL, RegionalCredVerificationURL: server.regionalCredVerificationURL, + IMDSv2SessionTokenURL: url + server.imdsv2SessionTokenUrl, } } @@ -530,7 +550,7 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security return neturl.QueryEscape(string(str)) } -func TestAwsCredential_BasicRequest(t *testing.T) { +func TestAWSCredential_BasicRequest(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -567,7 +587,72 @@ func TestAwsCredential_BasicRequest(t *testing.T) { } } -func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) { +func TestAWSCredential_IMDSv2(t *testing.T) { + validateSessionTokenHeaders := func(r *http.Request) { + if r.URL.Path == "/latest/api/token" { + headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) + if headerValue != awsIMDSv2SessionTtl { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) + } + } else { + headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) + if headerValue != "sessiontoken" { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") + } + } + } + + server := createAwsTestServer( + "/latest/meta-data/iam/security-credentials", + "/latest/meta-data/placement/availability-zone", + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "/latest/api/token", + "gcp-aws-role", + "us-east-2b", + map[string]string{ + "SecretAccessKey": secretAccessKey, + "AccessKeyId": accessKeyID, + "Token": securityToken, + }, + "sessiontoken", + validateSessionTokenHeaders, + ) + ts := httptest.NewServer(server) + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + defer func() { getenv = oldGetenv }() + getenv = setEnvironment(map[string]string{}) + oldNow := now + defer func() { now = oldNow }() + now = setTime(defaultTime) + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) delete(server.Credentials, "Token") @@ -605,7 +690,7 @@ func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) { } } -func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { +func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -646,7 +731,7 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { } } -func TestAwsCredential_BasicRequestWithDefaultEnv(t *testing.T) { +func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -686,7 +771,7 @@ func TestAwsCredential_BasicRequestWithDefaultEnv(t *testing.T) { } } -func TestAwsCredential_BasicRequestWithTwoRegions(t *testing.T) { +func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -727,7 +812,7 @@ func TestAwsCredential_BasicRequestWithTwoRegions(t *testing.T) { } } -func TestAwsCredential_RequestWithBadVersion(t *testing.T) { +func TestAWSCredential_RequestWithBadVersion(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -748,7 +833,7 @@ func TestAwsCredential_RequestWithBadVersion(t *testing.T) { } } -func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) { +func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -775,7 +860,7 @@ func TestAwsCredential_RequestWithNoRegionURL(t *testing.T) { } } -func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) { +func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) server.WriteRegion = notFound @@ -802,10 +887,10 @@ func TestAwsCredential_RequestWithBadRegionURL(t *testing.T) { } } -func TestAwsCredential_RequestWithMissingCredential(t *testing.T) { +func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) - server.WriteSecurityCredentials = func(w http.ResponseWriter) { + server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("{}")) } @@ -831,10 +916,10 @@ func TestAwsCredential_RequestWithMissingCredential(t *testing.T) { } } -func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) { +func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) - server.WriteSecurityCredentials = func(w http.ResponseWriter) { + server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) } @@ -860,7 +945,7 @@ func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) { } } -func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) { +func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) @@ -887,7 +972,7 @@ func TestAwsCredential_RequestWithNoCredentialURL(t *testing.T) { } } -func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) { +func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) server.WriteRolename = notFound @@ -914,7 +999,7 @@ func TestAwsCredential_RequestWithBadCredentialURL(t *testing.T) { } } -func TestAwsCredential_RequestWithBadFinalCredentialURL(t *testing.T) { +func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) server.WriteSecurityCredentials = notFound diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index bc3ce53..83ce9c2 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -175,6 +175,7 @@ type CredentialSource struct { RegionURL string `json:"region_url"` RegionalCredVerificationURL string `json:"regional_cred_verification_url"` CredVerificationURL string `json:"cred_verification_url"` + IMDSv2SessionTokenURL string `json:"imdsv2_session_token_url"` Format format `json:"format"` } @@ -185,14 +186,20 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { if awsVersion != 1 { return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion) } - return awsCredentialSource{ + + awsCredSource := awsCredentialSource{ EnvironmentID: c.CredentialSource.EnvironmentID, RegionURL: c.CredentialSource.RegionURL, RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, CredVerificationURL: c.CredentialSource.URL, TargetResource: c.Audience, ctx: ctx, - }, nil + } + if c.CredentialSource.IMDSv2SessionTokenURL != "" { + awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL + } + + return awsCredSource, nil } } else if c.CredentialSource.File != "" { return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil