google: check additional AWS variable

AWS_DEFAULT_REGION should have been checked as a backup to AWS_REGION but wasn't.  Also removed a redundant print statement in a test case.

Change-Id: Ia6e13eb20f509110a81e3071228283c43a1e9283
GitHub-Last-Rev: 1a10bcc079
GitHub-Pull-Request: golang/oauth2#486
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/302789
Reviewed-by: Cody Oss <codyoss@google.com>
Trust: Cody Oss <codyoss@google.com>
Trust: Tyler Bui-Palsulich <tbp@google.com>
This commit is contained in:
gIthuriel 2021-06-22 16:39:14 +00:00 committed by Cody Oss
parent d04028783c
commit 14747e66f6
3 changed files with 77 additions and 2 deletions

View File

@ -342,6 +342,8 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
func (cs *awsCredentialSource) getRegion() (string, error) { func (cs *awsCredentialSource) getRegion() (string, error) {
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
return envAwsRegion, nil return envAwsRegion, nil
} if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" {
return envAwsRegion, nil
} }
if cs.RegionURL == "" { if cs.RegionURL == "" {

View File

@ -638,6 +638,81 @@ func TestAwsCredential_BasicRequestWithEnv(t *testing.T) {
} }
} }
func TestAwsCredential_BasicRequestWithDefaultEnv(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_DEFAULT_REGION": "us-west-1",
})
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-west-1",
"AKIDEXAMPLE",
"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"",
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_BasicRequestWithTwoRegions(t *testing.T) {
server := createDefaultAwsTestServer()
ts := httptest.NewServer(server)
tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1",
"AWS_DEFAULT_REGION": "us-east-1",
})
base, err := tfc.parse(context.Background())
if err != nil {
t.Fatalf("parse() failed %v", err)
}
out, err := base.subjectToken()
if err != nil {
t.Fatalf("retrieveSubjectToken() failed: %v", err)
}
expected := getExpectedSubjectToken(
"https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"us-west-1",
"AKIDEXAMPLE",
"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"",
)
if got, want := out, expected; !reflect.DeepEqual(got, want) {
t.Errorf("subjectToken = %q, want %q", got, want)
}
}
func TestAwsCredential_RequestWithBadVersion(t *testing.T) { func TestAwsCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)

View File

@ -7,7 +7,6 @@ package externalaccount
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -20,7 +19,6 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) {
if r.Method != "GET" { if r.Method != "GET" {
t.Errorf("Unexpected request method, %v is found", r.Method) t.Errorf("Unexpected request method, %v is found", r.Method)
} }
fmt.Println(r.Header)
if r.Header.Get("Metadata") != "True" { if r.Header.Get("Metadata") != "True" {
t.Errorf("Metadata header not properly included.") t.Errorf("Metadata header not properly included.")
} }