changes requested by codyoss@

This commit is contained in:
Ryan Kohler 2021-01-29 15:34:54 -08:00
parent d6857d1e58
commit ba9ae8ce1b
5 changed files with 213 additions and 99 deletions

View File

@ -5,30 +5,31 @@
package externalaccount package externalaccount
import ( import (
"context"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/oauth2"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"path" "path"
"sort" "sort"
"strconv"
"strings" "strings"
"time" "time"
) )
// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsSecurityCredentials struct { type awsSecurityCredentials struct {
AccessKeyId string AccessKeyID string `json:"AccessKeyID"`
SecretAccessKey string SecretAccessKey string `json:"SecretAccessKey"`
SecurityToken string SecurityToken string `json:"Token"`
} }
// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsRequestSigner struct { type awsRequestSigner struct {
RegionName string RegionName string
AwsSecurityCredentials awsSecurityCredentials AwsSecurityCredentials awsSecurityCredentials
@ -229,7 +230,7 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
} }
} }
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
} }
type awsCredentialSource struct { type awsCredentialSource struct {
@ -240,24 +241,29 @@ type awsCredentialSource struct {
TargetResource string TargetResource string
requestSigner *awsRequestSigner requestSigner *awsRequestSigner
region string region string
ctx context.Context
client *http.Client
} }
type AwsRequestHeader struct { type awsRequestHeader struct {
Key string `json:"key"` Key string `json:"key"`
Value string `json:"value"` Value string `json:"value"`
} }
type AwsRequest struct { type awsRequest struct {
URL string `json:"url"` URL string `json:"url"`
Method string `json:"method"` Method string `json:"method"`
Headers []AwsRequestHeader `json:"headers"` Headers []awsRequestHeader `json:"headers"`
}
func (cs awsCredentialSource) request(req *http.Request) (*http.Response, error) {
if cs.client == nil {
cs.client = oauth2.NewClient(cs.ctx, nil)
}
return cs.client.Do(req.WithContext(cs.ctx))
} }
func (cs awsCredentialSource) subjectToken() (string, error) { func (cs awsCredentialSource) subjectToken() (string, error) {
if version, _ := strconv.Atoi(cs.EnvironmentID[3:]); version != 1 {
return "", errors.New(fmt.Sprintf("oauth2/google: aws version '%d' is not supported in the current build.", version))
}
if cs.requestSigner == nil { if cs.requestSigner == nil {
awsSecurityCredentials, err := cs.getSecurityCredentials() awsSecurityCredentials, err := cs.getSecurityCredentials()
if err != nil { if err != nil {
@ -304,13 +310,13 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
# })) # }))
*/ */
awsSignedReq := AwsRequest{ awsSignedReq := awsRequest{
URL: req.URL.String(), URL: req.URL.String(),
Method: req.Method, Method: "POST",
} }
for headerKey, headerList := range req.Header { for headerKey, headerList := range req.Header {
for _, headerValue := range headerList { for _, headerValue := range headerList {
awsSignedReq.Headers = append(awsSignedReq.Headers, AwsRequestHeader{ awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
Key: headerKey, Key: headerKey,
Value: headerValue, Value: headerValue,
}) })
@ -337,10 +343,15 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
} }
if cs.RegionURL == "" { if cs.RegionURL == "" {
return "", errors.New("oauth2/google: Unable to determine AWS region.") return "", errors.New("oauth2/google: unable to determine AWS region")
} }
resp, err := http.Get(cs.RegionURL) req, err := http.NewRequest("GET", cs.RegionURL, nil)
if err != nil {
return "", err
}
resp, err := cs.request(req)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -352,17 +363,23 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS region - %s.", string(respBody))) return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
} }
return string(respBody)[:len(respBody)-1], nil // This endpoint will return the region in format: us-east-2b.
// Only the us-east-2 part should be used.
respBodyEnd := 0
if len(respBody) > 1 {
respBodyEnd = len(respBody) - 1
}
return string(respBody[:respBodyEnd]), nil
} }
func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials awsSecurityCredentials, err error) { func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
if accessKeyId := getenv("AWS_ACCESS_KEY_ID"); accessKeyId != "" { if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
return awsSecurityCredentials{ return awsSecurityCredentials{
AccessKeyId: accessKeyId, AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey, SecretAccessKey: secretAccessKey,
SecurityToken: getenv("AWS_SESSION_TOKEN"), SecurityToken: getenv("AWS_SESSION_TOKEN"),
}, nil }, nil
@ -371,70 +388,64 @@ func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials aws
roleName, err := cs.getMetadataRoleName() roleName, err := cs.getMetadataRoleName()
if err != nil { if err != nil {
return awsSecurityCredentials{}, err return
} }
credentials, err := cs.getMetadataSecurityCredentials(roleName) credentials, err := cs.getMetadataSecurityCredentials(roleName)
if err != nil { if err != nil {
return awsSecurityCredentials{}, err return
} }
accessKeyId, ok := credentials["AccessKeyId"] if credentials.AccessKeyID == "" {
if !ok { return result, errors.New("oauth2/google: missing AccessKeyId credential")
return awsSecurityCredentials{}, errors.New("oauth2/google: missing AccessKeyId credential.")
} }
secretAccessKey, ok := credentials["SecretAccessKey"] if credentials.SecretAccessKey == "" {
if !ok { return result, errors.New("oauth2/google: missing SecretAccessKey credential")
return awsSecurityCredentials{}, errors.New("oauth2/google: missing SecretAccessKey credential.")
} }
securityToken, _ := credentials["Token"] return credentials, nil
return awsSecurityCredentials{
AccessKeyId: accessKeyId,
SecretAccessKey: secretAccessKey,
SecurityToken: securityToken,
}, nil
} }
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (map[string]string, error) { func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
var result awsSecurityCredentials
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil) req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
if err != nil { if err != nil {
return nil, err return result, err
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req) resp, err := cs.request(req)
if err != nil { if err != nil {
return nil, err return result, err
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return nil, err return result, err
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return nil, errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS security credentials - %s.", string(respBody))) return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
} }
var result map[string]string
err = json.Unmarshal(respBody, &result) err = json.Unmarshal(respBody, &result)
if err != nil { return result, err
return nil, err
}
return result, nil
} }
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) { func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
if cs.CredVerificationURL == "" { if cs.CredVerificationURL == "" {
return "", errors.New("oauth2/google: Unable to determine the AWS metadata server security credentials endpoint.") return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
} }
resp, err := http.Get(cs.CredVerificationURL) req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
if err != nil {
return "", err
}
resp, err := cs.request(req)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -446,7 +457,7 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS role name - %s.", string(respBody))) return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
} }
return string(respBody), nil return string(respBody), nil

View File

@ -5,6 +5,7 @@
package externalaccount package externalaccount
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -34,7 +35,7 @@ func setEnvironment(env map[string]string) func(string) string {
var defaultRequestSigner = &awsRequestSigner{ var defaultRequestSigner = &awsRequestSigner{
RegionName: "us-east-1", RegionName: "us-east-1",
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: awsSecurityCredentials{
AccessKeyId: "AKIDEXAMPLE", AccessKeyID: "AKIDEXAMPLE",
SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
}, },
} }
@ -48,7 +49,7 @@ const (
var requestSignerWithToken = &awsRequestSigner{ var requestSignerWithToken = &awsRequestSigner{
RegionName: "us-east-2", RegionName: "us-east-2",
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: awsSecurityCredentials{
AccessKeyId: accessKeyId, AccessKeyID: accessKeyId,
SecretAccessKey: secretAccessKey, SecretAccessKey: secretAccessKey,
SecurityToken: securityToken, SecurityToken: securityToken,
}, },
@ -386,7 +387,7 @@ func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
var requestSigner = &awsRequestSigner{ var requestSigner = &awsRequestSigner{
RegionName: "us-east-2", RegionName: "us-east-2",
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: awsSecurityCredentials{
AccessKeyId: accessKeyId, AccessKeyID: accessKeyId,
SecretAccessKey: secretAccessKey, SecretAccessKey: secretAccessKey,
}, },
} }
@ -489,26 +490,24 @@ func getExpectedSubjectToken(url, region, accessKeyId, secretAccessKey, security
signer := &awsRequestSigner{ signer := &awsRequestSigner{
RegionName: region, RegionName: region,
AwsSecurityCredentials: awsSecurityCredentials{ AwsSecurityCredentials: awsSecurityCredentials{
AccessKeyId: accessKeyId, AccessKeyID: accessKeyId,
SecretAccessKey: secretAccessKey, SecretAccessKey: secretAccessKey,
SecurityToken: securityToken, SecurityToken: securityToken,
}, },
} }
signer.SignRequest(req) signer.SignRequest(req)
result := AwsRequest{ result := awsRequest{
URL: url, URL: url,
Method: "POST", Method: "POST",
Headers: []AwsRequestHeader{ Headers: []awsRequestHeader{
AwsRequestHeader{ {
Key: "Authorization", Key: "Authorization",
Value: req.Header.Get("Authorization"), Value: req.Header.Get("Authorization"),
}, }, {
AwsRequestHeader{
Key: "Host", Key: "Host",
Value: req.Header.Get("Host"), Value: req.Header.Get("Host"),
}, }, {
AwsRequestHeader{
Key: "X-Amz-Date", Key: "X-Amz-Date",
Value: req.Header.Get("X-Amz-Date"), Value: req.Header.Get("X-Amz-Date"),
}, },
@ -516,13 +515,13 @@ func getExpectedSubjectToken(url, region, accessKeyId, secretAccessKey, security
} }
if securityToken != "" { if securityToken != "" {
result.Headers = append(result.Headers, AwsRequestHeader{ result.Headers = append(result.Headers, awsRequestHeader{
Key: "X-Amz-Security-Token", Key: "X-Amz-Security-Token",
Value: securityToken, Value: securityToken,
}) })
} }
result.Headers = append(result.Headers, AwsRequestHeader{ result.Headers = append(result.Headers, awsRequestHeader{
Key: "X-Goog-Cloud-Target-Resource", Key: "X-Goog-Cloud-Target-Resource",
Value: testFileConfig.Audience, Value: testFileConfig.Audience,
}) })
@ -542,7 +541,12 @@ func TestAwsCredential_BasicRequest(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
out, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil { if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err) t.Fatalf("retrieveSubjectToken() failed: %v", err)
} }
@ -560,6 +564,41 @@ func TestAwsCredential_BasicRequest(t *testing.T) {
} }
} }
func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
delete(server.Credentials, "Token")
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyId,
secretAccessKey,
"",
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_BasicRequestWithEnv(t *testing.T) { func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
@ -575,7 +614,12 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
out, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil { if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err) t.Fatalf("retrieveSubjectToken() failed: %v", err)
} }
@ -605,12 +649,11 @@ func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken() _, err := tfc.parse(context.Background())
if err == nil { if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("parse() should have failed")
} }
if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build"; !reflect.DeepEqual(got, want) {
if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@ -627,12 +670,17 @@ func TestAwsCredential_RequestWithNoRegionUrl(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil { if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: Unable to determine AWS region."; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google: unable to determine AWS region"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@ -649,12 +697,17 @@ func TestAwsCredential_RequestWithBadRegionUrl(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil { if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS region - Not Found."; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google: unable to retrieve AWS region - Not Found"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@ -673,12 +726,17 @@ func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil { if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential."; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@ -697,12 +755,17 @@ func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil { if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential."; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@ -719,12 +782,17 @@ func TestAwsCredential_RequestWithNoCredentialUrl(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil { if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: Unable to determine the AWS metadata server security credentials endpoint."; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google: unable to determine the AWS metadata server security credentials endpoint"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@ -741,12 +809,17 @@ func TestAwsCredential_RequestWithBadCredentialUrl(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil { if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS role name - Not Found."; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google: unable to retrieve AWS role name - Not Found"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
@ -763,12 +836,17 @@ func TestAwsCredential_RequestWithBadFinalCredentialUrl(t *testing.T) {
defer func() { getenv = oldGetenv }() defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
_, err = base.subjectToken()
if err == nil { if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed") t.Fatalf("retrieveSubjectToken() should have failed")
} }
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS security credentials - Not Found."; !reflect.DeepEqual(got, want) { if got, want := err.Error(), "oauth2/google: unable to retrieve AWS security credentials - Not Found"; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }

View File

@ -67,24 +67,28 @@ type CredentialSource struct {
} }
// parse determines the type of CredentialSource needed // parse determines the type of CredentialSource needed
func (c *Config) parse(ctx context.Context) baseCredentialSource { func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" { if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" {
if _, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil { if awsVersion, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil {
if awsVersion != 1 {
return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion)
}
return awsCredentialSource{ return awsCredentialSource{
EnvironmentID: c.CredentialSource.EnvironmentID, EnvironmentID: c.CredentialSource.EnvironmentID,
RegionURL: c.CredentialSource.RegionURL, RegionURL: c.CredentialSource.RegionURL,
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL, RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
CredVerificationURL: c.CredentialSource.URL, CredVerificationURL: c.CredentialSource.URL,
TargetResource: c.Audience, TargetResource: c.Audience,
} ctx: ctx,
}, nil
} }
} }
if c.CredentialSource.File != "" { if c.CredentialSource.File != "" {
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format} return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil
} else if c.CredentialSource.URL != "" { } else if c.CredentialSource.URL != "" {
return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx} return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx}, nil
} }
return nil return nil, fmt.Errorf("oauth2/google: unable to parse credential source")
} }
type baseCredentialSource interface { type baseCredentialSource interface {
@ -101,11 +105,12 @@ type tokenSource struct {
func (ts tokenSource) Token() (*oauth2.Token, error) { func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf conf := ts.conf
credSource := conf.parse(ts.ctx) credSource, err := conf.parse(ts.ctx)
if credSource == nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: unable to parse credential source") return nil, err
} }
subjectToken, err := credSource.subjectToken() subjectToken, err := credSource.subjectToken()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -56,7 +56,12 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
tfc.CredentialSource = test.cs tfc.CredentialSource = test.cs
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
out, err := tfc.parse(context.Background()).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil { if err != nil {
t.Errorf("Method subjectToken() errored.") t.Errorf("Method subjectToken() errored.")
} else if test.want != out { } else if test.want != out {

View File

@ -28,7 +28,12 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) {
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = cs
out, err := tfc.parse(context.Background()).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil { if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err) t.Fatalf("retrieveSubjectToken() failed: %v", err)
} }
@ -51,7 +56,12 @@ func TestRetrieveURLSubjectToken_Untyped(t *testing.T) {
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = cs
out, err := tfc.parse(context.Background()).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil { if err != nil {
t.Fatalf("Failed to retrieve URL subject token: %v", err) t.Fatalf("Failed to retrieve URL subject token: %v", err)
} }
@ -82,7 +92,12 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = cs tfc.CredentialSource = cs
out, err := tfc.parse(context.Background()).subjectToken() base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }