forked from Mirrors/oauth2
changes requested by codyoss@
This commit is contained in:
parent
d6857d1e58
commit
ba9ae8ce1b
|
@ -5,30 +5,31 @@
|
||||||
package externalaccount
|
package externalaccount
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
|
|
||||||
type awsSecurityCredentials struct {
|
type awsSecurityCredentials struct {
|
||||||
AccessKeyId string
|
AccessKeyID string `json:"AccessKeyID"`
|
||||||
SecretAccessKey string
|
SecretAccessKey string `json:"SecretAccessKey"`
|
||||||
SecurityToken string
|
SecurityToken string `json:"Token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
|
||||||
type awsRequestSigner struct {
|
type awsRequestSigner struct {
|
||||||
RegionName string
|
RegionName string
|
||||||
AwsSecurityCredentials awsSecurityCredentials
|
AwsSecurityCredentials awsSecurityCredentials
|
||||||
|
@ -229,7 +230,7 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
|
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type awsCredentialSource struct {
|
type awsCredentialSource struct {
|
||||||
|
@ -240,24 +241,29 @@ type awsCredentialSource struct {
|
||||||
TargetResource string
|
TargetResource string
|
||||||
requestSigner *awsRequestSigner
|
requestSigner *awsRequestSigner
|
||||||
region string
|
region string
|
||||||
|
ctx context.Context
|
||||||
|
client *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
type AwsRequestHeader struct {
|
type awsRequestHeader struct {
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AwsRequest struct {
|
type awsRequest struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
Method string `json:"method"`
|
Method string `json:"method"`
|
||||||
Headers []AwsRequestHeader `json:"headers"`
|
Headers []awsRequestHeader `json:"headers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs awsCredentialSource) request(req *http.Request) (*http.Response, error) {
|
||||||
|
if cs.client == nil {
|
||||||
|
cs.client = oauth2.NewClient(cs.ctx, nil)
|
||||||
|
}
|
||||||
|
return cs.client.Do(req.WithContext(cs.ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs awsCredentialSource) subjectToken() (string, error) {
|
func (cs awsCredentialSource) subjectToken() (string, error) {
|
||||||
if version, _ := strconv.Atoi(cs.EnvironmentID[3:]); version != 1 {
|
|
||||||
return "", errors.New(fmt.Sprintf("oauth2/google: aws version '%d' is not supported in the current build.", version))
|
|
||||||
}
|
|
||||||
|
|
||||||
if cs.requestSigner == nil {
|
if cs.requestSigner == nil {
|
||||||
awsSecurityCredentials, err := cs.getSecurityCredentials()
|
awsSecurityCredentials, err := cs.getSecurityCredentials()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -304,13 +310,13 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
|
||||||
# }))
|
# }))
|
||||||
*/
|
*/
|
||||||
|
|
||||||
awsSignedReq := AwsRequest{
|
awsSignedReq := awsRequest{
|
||||||
URL: req.URL.String(),
|
URL: req.URL.String(),
|
||||||
Method: req.Method,
|
Method: "POST",
|
||||||
}
|
}
|
||||||
for headerKey, headerList := range req.Header {
|
for headerKey, headerList := range req.Header {
|
||||||
for _, headerValue := range headerList {
|
for _, headerValue := range headerList {
|
||||||
awsSignedReq.Headers = append(awsSignedReq.Headers, AwsRequestHeader{
|
awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
|
||||||
Key: headerKey,
|
Key: headerKey,
|
||||||
Value: headerValue,
|
Value: headerValue,
|
||||||
})
|
})
|
||||||
|
@ -337,10 +343,15 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cs.RegionURL == "" {
|
if cs.RegionURL == "" {
|
||||||
return "", errors.New("oauth2/google: Unable to determine AWS region.")
|
return "", errors.New("oauth2/google: unable to determine AWS region")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.Get(cs.RegionURL)
|
req, err := http.NewRequest("GET", cs.RegionURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := cs.request(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -352,17 +363,23 @@ func (cs *awsCredentialSource) getRegion() (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS region - %s.", string(respBody)))
|
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(respBody)[:len(respBody)-1], nil
|
// This endpoint will return the region in format: us-east-2b.
|
||||||
|
// Only the us-east-2 part should be used.
|
||||||
|
respBodyEnd := 0
|
||||||
|
if len(respBody) > 1 {
|
||||||
|
respBodyEnd = len(respBody) - 1
|
||||||
|
}
|
||||||
|
return string(respBody[:respBodyEnd]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials awsSecurityCredentials, err error) {
|
func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
|
||||||
if accessKeyId := getenv("AWS_ACCESS_KEY_ID"); accessKeyId != "" {
|
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
|
||||||
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
|
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
|
||||||
return awsSecurityCredentials{
|
return awsSecurityCredentials{
|
||||||
AccessKeyId: accessKeyId,
|
AccessKeyID: accessKeyID,
|
||||||
SecretAccessKey: secretAccessKey,
|
SecretAccessKey: secretAccessKey,
|
||||||
SecurityToken: getenv("AWS_SESSION_TOKEN"),
|
SecurityToken: getenv("AWS_SESSION_TOKEN"),
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -371,70 +388,64 @@ func (cs *awsCredentialSource) getSecurityCredentials() (securityCredentials aws
|
||||||
|
|
||||||
roleName, err := cs.getMetadataRoleName()
|
roleName, err := cs.getMetadataRoleName()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return awsSecurityCredentials{}, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
credentials, err := cs.getMetadataSecurityCredentials(roleName)
|
credentials, err := cs.getMetadataSecurityCredentials(roleName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return awsSecurityCredentials{}, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accessKeyId, ok := credentials["AccessKeyId"]
|
if credentials.AccessKeyID == "" {
|
||||||
if !ok {
|
return result, errors.New("oauth2/google: missing AccessKeyId credential")
|
||||||
return awsSecurityCredentials{}, errors.New("oauth2/google: missing AccessKeyId credential.")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
secretAccessKey, ok := credentials["SecretAccessKey"]
|
if credentials.SecretAccessKey == "" {
|
||||||
if !ok {
|
return result, errors.New("oauth2/google: missing SecretAccessKey credential")
|
||||||
return awsSecurityCredentials{}, errors.New("oauth2/google: missing SecretAccessKey credential.")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
securityToken, _ := credentials["Token"]
|
return credentials, nil
|
||||||
|
|
||||||
return awsSecurityCredentials{
|
|
||||||
AccessKeyId: accessKeyId,
|
|
||||||
SecretAccessKey: secretAccessKey,
|
|
||||||
SecurityToken: securityToken,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (map[string]string, error) {
|
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
|
||||||
|
var result awsSecurityCredentials
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
|
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return result, err
|
||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := cs.request(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return result, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return nil, errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS security credentials - %s.", string(respBody)))
|
return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
var result map[string]string
|
|
||||||
err = json.Unmarshal(respBody, &result)
|
err = json.Unmarshal(respBody, &result)
|
||||||
if err != nil {
|
return result, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
|
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
|
||||||
if cs.CredVerificationURL == "" {
|
if cs.CredVerificationURL == "" {
|
||||||
return "", errors.New("oauth2/google: Unable to determine the AWS metadata server security credentials endpoint.")
|
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.Get(cs.CredVerificationURL)
|
req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := cs.request(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -446,7 +457,7 @@ func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return "", errors.New(fmt.Sprintf("oauth2/google: Unable to retrieve AWS role name - %s.", string(respBody)))
|
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(respBody), nil
|
return string(respBody), nil
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package externalaccount
|
package externalaccount
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -34,7 +35,7 @@ func setEnvironment(env map[string]string) func(string) string {
|
||||||
var defaultRequestSigner = &awsRequestSigner{
|
var defaultRequestSigner = &awsRequestSigner{
|
||||||
RegionName: "us-east-1",
|
RegionName: "us-east-1",
|
||||||
AwsSecurityCredentials: awsSecurityCredentials{
|
AwsSecurityCredentials: awsSecurityCredentials{
|
||||||
AccessKeyId: "AKIDEXAMPLE",
|
AccessKeyID: "AKIDEXAMPLE",
|
||||||
SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
|
SecretAccessKey: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -48,7 +49,7 @@ const (
|
||||||
var requestSignerWithToken = &awsRequestSigner{
|
var requestSignerWithToken = &awsRequestSigner{
|
||||||
RegionName: "us-east-2",
|
RegionName: "us-east-2",
|
||||||
AwsSecurityCredentials: awsSecurityCredentials{
|
AwsSecurityCredentials: awsSecurityCredentials{
|
||||||
AccessKeyId: accessKeyId,
|
AccessKeyID: accessKeyId,
|
||||||
SecretAccessKey: secretAccessKey,
|
SecretAccessKey: secretAccessKey,
|
||||||
SecurityToken: securityToken,
|
SecurityToken: securityToken,
|
||||||
},
|
},
|
||||||
|
@ -386,7 +387,7 @@ func TestAwsV4Signature_PostRequestWithAmzDateButNoSecurityToken(t *testing.T) {
|
||||||
var requestSigner = &awsRequestSigner{
|
var requestSigner = &awsRequestSigner{
|
||||||
RegionName: "us-east-2",
|
RegionName: "us-east-2",
|
||||||
AwsSecurityCredentials: awsSecurityCredentials{
|
AwsSecurityCredentials: awsSecurityCredentials{
|
||||||
AccessKeyId: accessKeyId,
|
AccessKeyID: accessKeyId,
|
||||||
SecretAccessKey: secretAccessKey,
|
SecretAccessKey: secretAccessKey,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -489,26 +490,24 @@ func getExpectedSubjectToken(url, region, accessKeyId, secretAccessKey, security
|
||||||
signer := &awsRequestSigner{
|
signer := &awsRequestSigner{
|
||||||
RegionName: region,
|
RegionName: region,
|
||||||
AwsSecurityCredentials: awsSecurityCredentials{
|
AwsSecurityCredentials: awsSecurityCredentials{
|
||||||
AccessKeyId: accessKeyId,
|
AccessKeyID: accessKeyId,
|
||||||
SecretAccessKey: secretAccessKey,
|
SecretAccessKey: secretAccessKey,
|
||||||
SecurityToken: securityToken,
|
SecurityToken: securityToken,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
signer.SignRequest(req)
|
signer.SignRequest(req)
|
||||||
|
|
||||||
result := AwsRequest{
|
result := awsRequest{
|
||||||
URL: url,
|
URL: url,
|
||||||
Method: "POST",
|
Method: "POST",
|
||||||
Headers: []AwsRequestHeader{
|
Headers: []awsRequestHeader{
|
||||||
AwsRequestHeader{
|
{
|
||||||
Key: "Authorization",
|
Key: "Authorization",
|
||||||
Value: req.Header.Get("Authorization"),
|
Value: req.Header.Get("Authorization"),
|
||||||
},
|
}, {
|
||||||
AwsRequestHeader{
|
|
||||||
Key: "Host",
|
Key: "Host",
|
||||||
Value: req.Header.Get("Host"),
|
Value: req.Header.Get("Host"),
|
||||||
},
|
}, {
|
||||||
AwsRequestHeader{
|
|
||||||
Key: "X-Amz-Date",
|
Key: "X-Amz-Date",
|
||||||
Value: req.Header.Get("X-Amz-Date"),
|
Value: req.Header.Get("X-Amz-Date"),
|
||||||
},
|
},
|
||||||
|
@ -516,13 +515,13 @@ func getExpectedSubjectToken(url, region, accessKeyId, secretAccessKey, security
|
||||||
}
|
}
|
||||||
|
|
||||||
if securityToken != "" {
|
if securityToken != "" {
|
||||||
result.Headers = append(result.Headers, AwsRequestHeader{
|
result.Headers = append(result.Headers, awsRequestHeader{
|
||||||
Key: "X-Amz-Security-Token",
|
Key: "X-Amz-Security-Token",
|
||||||
Value: securityToken,
|
Value: securityToken,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
result.Headers = append(result.Headers, AwsRequestHeader{
|
result.Headers = append(result.Headers, awsRequestHeader{
|
||||||
Key: "X-Goog-Cloud-Target-Resource",
|
Key: "X-Goog-Cloud-Target-Resource",
|
||||||
Value: testFileConfig.Audience,
|
Value: testFileConfig.Audience,
|
||||||
})
|
})
|
||||||
|
@ -542,7 +541,12 @@ func TestAwsCredential_BasicRequest(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
out, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := base.subjectToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("retrieveSubjectToken() failed: %v", err)
|
t.Fatalf("retrieveSubjectToken() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -560,6 +564,41 @@ func TestAwsCredential_BasicRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAwsCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
||||||
|
server := createDefaultAwsTestServer()
|
||||||
|
ts := httptest.NewServer(server)
|
||||||
|
delete(server.Credentials, "Token")
|
||||||
|
|
||||||
|
tfc := testFileConfig
|
||||||
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
|
oldGetenv := getenv
|
||||||
|
defer func() { getenv = oldGetenv }()
|
||||||
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := base.subjectToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("retrieveSubjectToken() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := getExpectedSubjectToken(
|
||||||
|
"https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
|
||||||
|
"us-east-2",
|
||||||
|
accessKeyId,
|
||||||
|
secretAccessKey,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
|
if got, want := out, expected; !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
|
func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
|
@ -575,7 +614,12 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
|
||||||
"AWS_REGION": "us-west-1",
|
"AWS_REGION": "us-west-1",
|
||||||
})
|
})
|
||||||
|
|
||||||
out, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := base.subjectToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("retrieveSubjectToken() failed: %v", err)
|
t.Fatalf("retrieveSubjectToken() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -605,12 +649,11 @@ func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
_, err := tfc.parse(nil).subjectToken()
|
_, err := tfc.parse(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("retrieveSubjectToken() should have failed")
|
t.Fatalf("parse() should have failed")
|
||||||
}
|
}
|
||||||
|
if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build"; !reflect.DeepEqual(got, want) {
|
||||||
if got, want := err.Error(), "oauth2/google: aws version '3' is not supported in the current build."; !reflect.DeepEqual(got, want) {
|
|
||||||
t.Errorf("subjectToken = %q, want %q", got, want)
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -627,12 +670,17 @@ func TestAwsCredential_RequestWithNoRegionUrl(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
_, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = base.subjectToken()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("retrieveSubjectToken() should have failed")
|
t.Fatalf("retrieveSubjectToken() should have failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, want := err.Error(), "oauth2/google: Unable to determine AWS region."; !reflect.DeepEqual(got, want) {
|
if got, want := err.Error(), "oauth2/google: unable to determine AWS region"; !reflect.DeepEqual(got, want) {
|
||||||
t.Errorf("subjectToken = %q, want %q", got, want)
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -649,12 +697,17 @@ func TestAwsCredential_RequestWithBadRegionUrl(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
_, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = base.subjectToken()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("retrieveSubjectToken() should have failed")
|
t.Fatalf("retrieveSubjectToken() should have failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS region - Not Found."; !reflect.DeepEqual(got, want) {
|
if got, want := err.Error(), "oauth2/google: unable to retrieve AWS region - Not Found"; !reflect.DeepEqual(got, want) {
|
||||||
t.Errorf("subjectToken = %q, want %q", got, want)
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -673,12 +726,17 @@ func TestAwsCredential_RequestWithMissingCredential(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
_, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = base.subjectToken()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("retrieveSubjectToken() should have failed")
|
t.Fatalf("retrieveSubjectToken() should have failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential."; !reflect.DeepEqual(got, want) {
|
if got, want := err.Error(), "oauth2/google: missing AccessKeyId credential"; !reflect.DeepEqual(got, want) {
|
||||||
t.Errorf("subjectToken = %q, want %q", got, want)
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -697,12 +755,17 @@ func TestAwsCredential_RequestWithIncompleteCredential(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
_, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = base.subjectToken()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("retrieveSubjectToken() should have failed")
|
t.Fatalf("retrieveSubjectToken() should have failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential."; !reflect.DeepEqual(got, want) {
|
if got, want := err.Error(), "oauth2/google: missing SecretAccessKey credential"; !reflect.DeepEqual(got, want) {
|
||||||
t.Errorf("subjectToken = %q, want %q", got, want)
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -719,12 +782,17 @@ func TestAwsCredential_RequestWithNoCredentialUrl(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
_, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = base.subjectToken()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("retrieveSubjectToken() should have failed")
|
t.Fatalf("retrieveSubjectToken() should have failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, want := err.Error(), "oauth2/google: Unable to determine the AWS metadata server security credentials endpoint."; !reflect.DeepEqual(got, want) {
|
if got, want := err.Error(), "oauth2/google: unable to determine the AWS metadata server security credentials endpoint"; !reflect.DeepEqual(got, want) {
|
||||||
t.Errorf("subjectToken = %q, want %q", got, want)
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -741,12 +809,17 @@ func TestAwsCredential_RequestWithBadCredentialUrl(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
_, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = base.subjectToken()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("retrieveSubjectToken() should have failed")
|
t.Fatalf("retrieveSubjectToken() should have failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS role name - Not Found."; !reflect.DeepEqual(got, want) {
|
if got, want := err.Error(), "oauth2/google: unable to retrieve AWS role name - Not Found"; !reflect.DeepEqual(got, want) {
|
||||||
t.Errorf("subjectToken = %q, want %q", got, want)
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -763,12 +836,17 @@ func TestAwsCredential_RequestWithBadFinalCredentialUrl(t *testing.T) {
|
||||||
defer func() { getenv = oldGetenv }()
|
defer func() { getenv = oldGetenv }()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
|
|
||||||
_, err := tfc.parse(nil).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = base.subjectToken()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("retrieveSubjectToken() should have failed")
|
t.Fatalf("retrieveSubjectToken() should have failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, want := err.Error(), "oauth2/google: Unable to retrieve AWS security credentials - Not Found."; !reflect.DeepEqual(got, want) {
|
if got, want := err.Error(), "oauth2/google: unable to retrieve AWS security credentials - Not Found"; !reflect.DeepEqual(got, want) {
|
||||||
t.Errorf("subjectToken = %q, want %q", got, want)
|
t.Errorf("subjectToken = %q, want %q", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,24 +67,28 @@ type CredentialSource struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse determines the type of CredentialSource needed
|
// parse determines the type of CredentialSource needed
|
||||||
func (c *Config) parse(ctx context.Context) baseCredentialSource {
|
func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
|
||||||
if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" {
|
if len(c.CredentialSource.EnvironmentID) > 3 && c.CredentialSource.EnvironmentID[:3] == "aws" {
|
||||||
if _, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil {
|
if awsVersion, err := strconv.Atoi(c.CredentialSource.EnvironmentID[3:]); err == nil {
|
||||||
|
if awsVersion != 1 {
|
||||||
|
return nil, fmt.Errorf("oauth2/google: aws version '%d' is not supported in the current build", awsVersion)
|
||||||
|
}
|
||||||
return awsCredentialSource{
|
return awsCredentialSource{
|
||||||
EnvironmentID: c.CredentialSource.EnvironmentID,
|
EnvironmentID: c.CredentialSource.EnvironmentID,
|
||||||
RegionURL: c.CredentialSource.RegionURL,
|
RegionURL: c.CredentialSource.RegionURL,
|
||||||
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
|
RegionalCredVerificationURL: c.CredentialSource.RegionalCredVerificationURL,
|
||||||
CredVerificationURL: c.CredentialSource.URL,
|
CredVerificationURL: c.CredentialSource.URL,
|
||||||
TargetResource: c.Audience,
|
TargetResource: c.Audience,
|
||||||
}
|
ctx: ctx,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if c.CredentialSource.File != "" {
|
if c.CredentialSource.File != "" {
|
||||||
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}
|
return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil
|
||||||
} else if c.CredentialSource.URL != "" {
|
} else if c.CredentialSource.URL != "" {
|
||||||
return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx}
|
return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx}, nil
|
||||||
}
|
}
|
||||||
return nil
|
return nil, fmt.Errorf("oauth2/google: unable to parse credential source")
|
||||||
}
|
}
|
||||||
|
|
||||||
type baseCredentialSource interface {
|
type baseCredentialSource interface {
|
||||||
|
@ -101,11 +105,12 @@ type tokenSource struct {
|
||||||
func (ts tokenSource) Token() (*oauth2.Token, error) {
|
func (ts tokenSource) Token() (*oauth2.Token, error) {
|
||||||
conf := ts.conf
|
conf := ts.conf
|
||||||
|
|
||||||
credSource := conf.parse(ts.ctx)
|
credSource, err := conf.parse(ts.ctx)
|
||||||
if credSource == nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2/google: unable to parse credential source")
|
return nil, err
|
||||||
}
|
}
|
||||||
subjectToken, err := credSource.subjectToken()
|
subjectToken, err := credSource.subjectToken()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,12 @@ func TestRetrieveFileSubjectToken(t *testing.T) {
|
||||||
tfc.CredentialSource = test.cs
|
tfc.CredentialSource = test.cs
|
||||||
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
out, err := tfc.parse(context.Background()).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := base.subjectToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Method subjectToken() errored.")
|
t.Errorf("Method subjectToken() errored.")
|
||||||
} else if test.want != out {
|
} else if test.want != out {
|
||||||
|
|
|
@ -28,7 +28,12 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) {
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = cs
|
tfc.CredentialSource = cs
|
||||||
|
|
||||||
out, err := tfc.parse(context.Background()).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := base.subjectToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("retrieveSubjectToken() failed: %v", err)
|
t.Fatalf("retrieveSubjectToken() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -51,7 +56,12 @@ func TestRetrieveURLSubjectToken_Untyped(t *testing.T) {
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = cs
|
tfc.CredentialSource = cs
|
||||||
|
|
||||||
out, err := tfc.parse(context.Background()).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := base.subjectToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to retrieve URL subject token: %v", err)
|
t.Fatalf("Failed to retrieve URL subject token: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -82,7 +92,12 @@ func TestRetrieveURLSubjectToken_JSON(t *testing.T) {
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = cs
|
tfc.CredentialSource = cs
|
||||||
|
|
||||||
out, err := tfc.parse(context.Background()).subjectToken()
|
base, err := tfc.parse(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse() failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := base.subjectToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue