google: support AWS 3rd party credentials

This commit is contained in:
Ryan Kohler 2021-01-26 08:45:25 -08:00
parent af13f521f1
commit d6857d1e58
3 changed files with 651 additions and 38 deletions

View File

@ -8,38 +8,50 @@ import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"sort"
"strconv"
"strings"
"time"
)
// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
type awsRequestSigner struct {
RegionName string
AwsSecurityCredentials map[string]string
type awsSecurityCredentials struct {
AccessKeyId string
SecretAccessKey string
SecurityToken string
}
type awsRequestSigner struct {
RegionName string
AwsSecurityCredentials awsSecurityCredentials
}
// getenv aliases os.Getenv for testing
var getenv = os.Getenv
const (
// AWS Signature Version 4 signing algorithm identifier.
// AWS Signature Version 4 signing algorithm identifier.
awsAlgorithm = "AWS4-HMAC-SHA256"
// The termination string for the AWS credential scope value as defined in
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
// The termination string for the AWS credential scope value as defined in
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
awsRequestType = "aws4_request"
// The AWS authorization header name for the security session token if available.
// The AWS authorization header name for the security session token if available.
awsSecurityTokenHeader = "x-amz-security-token"
// The AWS authorization header name for the auto-generated date.
// The AWS authorization header name for the auto-generated date.
awsDateHeader = "x-amz-date"
awsTimeFormatLong = "20060102T150405Z"
awsTimeFormatLong = "20060102T150405Z"
awsTimeFormatShort = "20060102"
)
@ -167,8 +179,8 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
signedRequest.Header.Add("host", requestHost(req))
if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok {
signedRequest.Header.Add(awsSecurityTokenHeader, securityToken)
if rs.AwsSecurityCredentials.SecurityToken != "" {
signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken)
}
if signedRequest.Header.Get("date") == "" {
@ -186,15 +198,6 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
}
func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"]
if !ok {
return "", errors.New("oauth2/google: missing secret_access_key header")
}
accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"]
if !ok {
return "", errors.New("oauth2/google: missing access_key_id header")
}
canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
dateStamp := timestamp.Format(awsTimeFormatShort)
@ -203,28 +206,248 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
serviceName = splitHost[0]
}
credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType)
credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
if err != nil {
return "", err
}
requestHash, err := getSha256([]byte(requestString))
if err != nil{
if err != nil {
return "", err
}
stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
signingKey := []byte("AWS4" + secretAccessKey)
signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
for _, signingInput := range []string{
dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
} {
signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
if err != nil{
if err != nil {
return "", err
}
}
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
}
type awsCredentialSource struct {
EnvironmentID string
RegionURL string
RegionalCredVerificationURL string
CredVerificationURL string
TargetResource string
requestSigner *awsRequestSigner
region string
}
type AwsRequestHeader struct {
Key string `json:"key"`
Value string `json:"value"`
}
type AwsRequest struct {
URL string `json:"url"`
Method string `json:"method"`
Headers []AwsRequestHeader `json:"headers"`
}
func (cs awsCredentialSource) subjectToken() (string, error) {
if version, _ := strconv.Atoi(cs.EnvironmentID[3:]); version != 1 {
return "", errors.New(fmt.Sprintf("oauth2/google: aws version '%d' is not supported in the current build.", version))
}
if cs.requestSigner == nil {
awsSecurityCredentials, err := cs.getSecurityCredentials()
if err != nil {
return "", err
}
if cs.region, err = cs.getRegion(); err != nil {
return "", err
}
cs.requestSigner = &awsRequestSigner{
RegionName: cs.region,
AwsSecurityCredentials: awsSecurityCredentials,
}
}
// Generate the signed request to AWS STS GetCallerIdentity API.
// Use the required regional endpoint. Otherwise, the request will fail.
req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
if err != nil {
return "", err
}
// The full, canonical resource name of the workload identity pool
// provider, with or without the HTTPS prefix.
// Including this header as part of the signature is recommended to
// ensure data integrity.
if cs.TargetResource != "" {
req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource)
}
cs.requestSigner.SignRequest(req)
/*
The GCP STS endpoint expects the headers to be formatted as:
# [
# {key: 'x-amz-date', value: '...'},
# {key: 'Authorization', value: '...'},
# ...
# ]
# And then serialized as:
# quote(json.dumps({
# url: '...',
# method: 'POST',
# headers: [{key: 'x-amz-date', value: '...'}, ...]
# }))
*/
awsSignedReq := AwsRequest{
URL: req.URL.String(),
Method: req.Method,
}
for headerKey, headerList := range req.Header {
for _, headerValue := range headerList {
awsSignedReq.Headers = append(awsSignedReq.Headers, AwsRequestHeader{
Key: headerKey,
Value: headerValue,
})
}
}
sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
if headerCompare == 0 {
return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
}
return headerCompare < 0
})
result, err := json.Marshal(awsSignedReq)
if err != nil {
return "", err
}
return string(result), nil
}
func (cs *awsCredentialSource) getRegion() (string, error) {
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
return envAwsRegion, nil
}
if cs.RegionURL == "" {
return "", errors.New("oauth2/google: Unable to determine AWS region.")
}
resp, err := http.Get(cs.RegionURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS region - %s.", string(respBody)))
}
return string(respBody)[:len(respBody)-1], nil
}
func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials awsSecurityCredentials, err error) {
if accessKeyId := getenv("AWS_ACCESS_KEY_ID"); accessKeyId != "" {
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
return awsSecurityCredentials{
AccessKeyId: accessKeyId,
SecretAccessKey: secretAccessKey,
SecurityToken: getenv("AWS_SESSION_TOKEN"),
}, nil
}
}
roleName, err := cs.getMetadataRoleName()
if err != nil {
return awsSecurityCredentials{}, err
}
credentials, err := cs.getMetadataSecurityCredentials(roleName)
if err != nil {
return awsSecurityCredentials{}, err
}
accessKeyId, ok := credentials["AccessKeyId"]
if !ok {
return awsSecurityCredentials{}, errors.New("oauth2/google: missing AccessKeyId credential.")
}
secretAccessKey, ok := credentials["SecretAccessKey"]
if !ok {
return awsSecurityCredentials{}, errors.New("oauth2/google: missing SecretAccessKey credential.")
}
securityToken, _ := credentials["Token"]
return awsSecurityCredentials{
AccessKeyId: accessKeyId,
SecretAccessKey: secretAccessKey,
SecurityToken: securityToken,
}, nil
}
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (map[string]string, error) {
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
return nil, errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS security credentials - %s.", string(respBody)))
}
var result map[string]string
err = json.Unmarshal(respBody, &result)
if err != nil {
return nil, err
}
return result, nil
}
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
if cs.CredVerificationURL == "" {
return "", errors.New("oauth2/google: Unable to determine the AWS metadata server security credentials endpoint.")
}
resp, err := http.Get(cs.CredVerificationURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS role name - %s.", string(respBody)))
}
return string(respBody), nil
}

View File

@ -5,7 +5,10 @@
package externalaccount
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
@ -21,24 +24,33 @@ func setTime(testTime time.Time) func() time.Time {
}
}
func setEnvironment(env map[string]string) func(string) string {
return func(key string) string {
value, _ := env[key]
return value
}
}
var defaultRequestSigner = &awsRequestSigner{
RegionName: "us-east-1",
AwsSecurityCredentials: map[string]string{
"access_key_id": "AKIDEXAMPLE",
"secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
AwsSecurityCredentials: awsSecurityCredentials{
AccessKeyId: "AKIDEXAMPLE",
SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
},
}
const accessKeyId = "ASIARD4OQDT6A77FR3CL"
const secretAccessKey = "Y8AfSaucF37G4PpvfguKZ3/l7Id4uocLXxX0+VTx"
const securityToken = "IQoJb3JpZ2luX2VjEIz//////////wEaCXVzLWVhc3QtMiJGMEQCIH7MHX/Oy/OB8OlLQa9GrqU1B914+iMikqWQW7vPCKlgAiA/Lsv8Jcafn14owfxXn95FURZNKaaphj0ykpmS+Ki+CSq0AwhlEAAaDDA3NzA3MTM5MTk5NiIMx9sAeP1ovlMTMKLjKpEDwuJQg41/QUKx0laTZYjPlQvjwSqS3OB9P1KAXPWSLkliVMMqaHqelvMF/WO/glv3KwuTfQsavRNs3v5pcSEm4SPO3l7mCs7KrQUHwGP0neZhIKxEXy+Ls//1C/Bqt53NL+LSbaGv6RPHaX82laz2qElphg95aVLdYgIFY6JWV5fzyjgnhz0DQmy62/Vi8pNcM2/VnxeCQ8CC8dRDSt52ry2v+nc77vstuI9xV5k8mPtnaPoJDRANh0bjwY5Sdwkbp+mGRUJBAQRlNgHUJusefXQgVKBCiyJY4w3Csd8Bgj9IyDV+Azuy1jQqfFZWgP68LSz5bURyIjlWDQunO82stZ0BgplKKAa/KJHBPCp8Qi6i99uy7qh76FQAqgVTsnDuU6fGpHDcsDSGoCls2HgZjZFPeOj8mmRhFk1Xqvkbjuz8V1cJk54d3gIJvQt8gD2D6yJQZecnuGWd5K2e2HohvCc8Fc9kBl1300nUJPV+k4tr/A5R/0QfEKOZL1/k5lf1g9CREnrM8LVkGxCgdYMxLQow1uTL+QU67AHRRSp5PhhGX4Rek+01vdYSnJCMaPhSEgcLqDlQkhk6MPsyT91QMXcWmyO+cAZwUPwnRamFepuP4K8k2KVXs/LIJHLELwAZ0ekyaS7CptgOqS7uaSTFG3U+vzFZLEnGvWQ7y9IPNQZ+Dffgh4p3vF4J68y9049sI6Sr5d5wbKkcbm8hdCDHZcv4lnqohquPirLiFQ3q7B17V9krMPu3mz1cg4Ekgcrn/E09NTsxAqD8NcZ7C7ECom9r+X3zkDOxaajW6hu3Az8hGlyylDaMiFfRbBJpTIlxp7jfa7CxikNgNtEKLH9iCzvuSg2vhA=="
const (
accessKeyId = "ASIARD4OQDT6A77FR3CL"
secretAccessKey = "Y8AfSaucF37G4PpvfguKZ3/l7Id4uocLXxX0+VTx"
securityToken = "IQoJb3JpZ2luX2VjEIz//////////wEaCXVzLWVhc3QtMiJGMEQCIH7MHX/Oy/OB8OlLQa9GrqU1B914+iMikqWQW7vPCKlgAiA/Lsv8Jcafn14owfxXn95FURZNKaaphj0ykpmS+Ki+CSq0AwhlEAAaDDA3NzA3MTM5MTk5NiIMx9sAeP1ovlMTMKLjKpEDwuJQg41/QUKx0laTZYjPlQvjwSqS3OB9P1KAXPWSLkliVMMqaHqelvMF/WO/glv3KwuTfQsavRNs3v5pcSEm4SPO3l7mCs7KrQUHwGP0neZhIKxEXy+Ls//1C/Bqt53NL+LSbaGv6RPHaX82laz2qElphg95aVLdYgIFY6JWV5fzyjgnhz0DQmy62/Vi8pNcM2/VnxeCQ8CC8dRDSt52ry2v+nc77vstuI9xV5k8mPtnaPoJDRANh0bjwY5Sdwkbp+mGRUJBAQRlNgHUJusefXQgVKBCiyJY4w3Csd8Bgj9IyDV+Azuy1jQqfFZWgP68LSz5bURyIjlWDQunO82stZ0BgplKKAa/KJHBPCp8Qi6i99uy7qh76FQAqgVTsnDuU6fGpHDcsDSGoCls2HgZjZFPeOj8mmRhFk1Xqvkbjuz8V1cJk54d3gIJvQt8gD2D6yJQZecnuGWd5K2e2HohvCc8Fc9kBl1300nUJPV+k4tr/A5R/0QfEKOZL1/k5lf1g9CREnrM8LVkGxCgdYMxLQow1uTL+QU67AHRRSp5PhhGX4Rek+01vdYSnJCMaPhSEgcLqDlQkhk6MPsyT91QMXcWmyO+cAZwUPwnRamFepuP4K8k2KVXs/LIJHLELwAZ0ekyaS7CptgOqS7uaSTFG3U+vzFZLEnGvWQ7y9IPNQZ+Dffgh4p3vF4J68y9049sI6Sr5d5wbKkcbm8hdCDHZcv4lnqohquPirLiFQ3q7B17V9krMPu3mz1cg4Ekgcrn/E09NTsxAqD8NcZ7C7ECom9r+X3zkDOxaajW6hu3Az8hGlyylDaMiFfRbBJpTIlxp7jfa7CxikNgNtEKLH9iCzvuSg2vhA=="
)
var requestSignerWithToken = &awsRequestSigner{
RegionName: "us-east-2",
AwsSecurityCredentials: map[string]string{
"access_key_id": accessKeyId,
"secret_access_key": secretAccessKey,
"security_token": securityToken,
AwsSecurityCredentials: awsSecurityCredentials{
AccessKeyId: accessKeyId,
SecretAccessKey: secretAccessKey,
SecurityToken: securityToken,
},
}
@ -373,9 +385,9 @@ func TestAwsV4Signature_PostRequestWithSecurityTokenAndAdditionalHeaders(t *test
func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
var requestSigner = &awsRequestSigner{
RegionName: "us-east-2",
AwsSecurityCredentials: map[string]string{
"access_key_id": accessKeyId,
"secret_access_key": secretAccessKey,
AwsSecurityCredentials: awsSecurityCredentials{
AccessKeyId: accessKeyId,
SecretAccessKey: secretAccessKey,
},
}
@ -394,3 +406,369 @@ func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
testRequestSigner(t, requestSigner, input, output)
}
type testAwsServer struct {
url string
securityCredentialUrl string
regionUrl string
regionalCredVerificationUrl string
Credentials map[string]string
WriteRolename func(http.ResponseWriter)
WriteSecurityCredentials func(http.ResponseWriter)
WriteRegion func(http.ResponseWriter)
}
func createAwsTestServer(url, regionUrl, regionalCredVerificationUrl, rolename, region string, credentials map[string]string) *testAwsServer {
server := &testAwsServer{
url: url,
securityCredentialUrl: fmt.Sprintf("%s/%s", url, rolename),
regionUrl: regionUrl,
regionalCredVerificationUrl: regionalCredVerificationUrl,
Credentials: credentials,
WriteRolename: func(w http.ResponseWriter) {
w.Write([]byte(rolename))
},
WriteRegion: func(w http.ResponseWriter) {
w.Write([]byte(region))
},
}
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
jsonCredentials, _ := json.Marshal(server.Credentials)
w.Write(jsonCredentials)
}
return server
}
func createDefaultAwsTestServer() *testAwsServer {
return createAwsTestServer(
"/latest/meta-data/iam/security-credentials",
"/latest/meta-data/placement/availability-zone",
"https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"gcp-aws-role",
"us-east-2b",
map[string]string{
"SecretAccessKey": secretAccessKey,
"AccessKeyId": accessKeyId,
"Token": securityToken,
},
)
}
func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch p := r.URL.Path; p {
case server.url:
server.WriteRolename(w)
case server.securityCredentialUrl:
server.WriteSecurityCredentials(w)
case server.regionUrl:
server.WriteRegion(w)
}
}
func notFound(w http.ResponseWriter) {
w.WriteHeader(404)
w.Write([]byte("Not Found"))
}
func (server *testAwsServer) getCredentialSource(url string) CredentialSource {
return CredentialSource{
EnvironmentID: "aws1",
URL: url + server.url,
RegionURL: url + server.regionUrl,
RegionalCredVerificationURL: server.regionalCredVerificationUrl,
}
}
func getExpectedSubjectToken(url, region, accessKeyId, secretAccessKey, securityToken string) string {
req, _ := http.NewRequest("POST", url, nil)
req.Header.Add("x-goog-cloud-target-resource", testFileConfig.Audience)
signer := &awsRequestSigner{
RegionName: region,
AwsSecurityCredentials: awsSecurityCredentials{
AccessKeyId: accessKeyId,
SecretAccessKey: secretAccessKey,
SecurityToken: securityToken,
},
}
signer.SignRequest(req)
result := AwsRequest{
URL: url,
Method: "POST",
Headers: []AwsRequestHeader{
AwsRequestHeader{
Key: "Authorization",
Value: req.Header.Get("Authorization"),
},
AwsRequestHeader{
Key: "Host",
Value: req.Header.Get("Host"),
},
AwsRequestHeader{
Key: "X-Amz-Date",
Value: req.Header.Get("X-Amz-Date"),
},
},
}
if securityToken != "" {
result.Headers = append(result.Headers, AwsRequestHeader{
Key: "X-Amz-Security-Token",
Value: securityToken,
})
}
result.Headers = append(result.Headers, AwsRequestHeader{
Key: "X-Goog-Cloud-Target-Resource",
Value: testFileConfig.Audience,
})
str, _ := json.Marshal(result)
return string(str)
}
func TestAwsCredential_BasicRequest(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
out, err := tfc.parse(nil).subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-east-2",
accessKeyId,
secretAccessKey,
securityToken,
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1",
})
out, err := tfc.parse(nil).subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-west-1",
"AKIDEXAMPLE",
"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"",
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.EnvironmentID = "aws3"
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken()
if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed")
}
if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithNoRegionUrl(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.RegionURL = ""
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken()
if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed")
}
if got, want := err.Error(), "oauth2/google: Unable to determine AWS region."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithBadRegionUrl(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteRegion = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken()
if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed")
}
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS region - Not Found."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
w.Write([]byte("{}"))
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken()
if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed")
}
if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteSecurityCredentials = func(w http.ResponseWriter) {
w.Write([]byte("{\"AccessKeyId\":\"FOOBARBAS\"}"))
}
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken()
if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed")
}
if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithNoCredentialUrl(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.URL = ""
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken()
if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed")
}
if got, want := err.Error(), "oauth2/google: Unable to determine the AWS metadata server security credentials endpoint."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithBadCredentialUrl(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteRolename = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken()
if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed")
}
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS role name - Not Found."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithBadFinalCredentialUrl(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
server.WriteSecurityCredentials = notFound
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(nil).subjectToken()
if err == nil {
t.Fatalf("retrieveSubjectToken() should have failed")
}
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS security credentials - Not Found."; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}

View File

@ -9,6 +9,7 @@ import (
"fmt"
"golang.org/x/oauth2"
"net/http"
"strconv"
"time"
)
@ -67,6 +68,17 @@ type CredentialSource struct {
// parse determines the type of CredentialSource needed
func (c *Config) parse(ctx context.Context) baseCredentialSource {
if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" {
if _, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil {
return awsCredentialSource{
EnvironmentID: c.CredentialSource.EnvironmentID,
RegionURL: c.CredentialSource.RegionURL,
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
CredVerificationURL: c.CredentialSource.URL,
TargetResource: c.Audience,
}
}
}
if c.CredentialSource.File != "" {
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}
} else if c.CredentialSource.URL != "" {