google/internal/externalaccount: Adding metadata verification

Change-Id: I4d664862b7b287131c1481b238ebd0875f7c233b
GitHub-Last-Rev: 74bcc33f5e
GitHub-Pull-Request: golang/oauth2#608
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/449975
Run-TryBot: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
Ryan Kohler 2022-11-17 21:52:54 +00:00 committed by Gopher Robot
parent 68a41d64f9
commit ec4a9b2ff2
3 changed files with 286 additions and 28 deletions

View File

@ -267,6 +267,49 @@ type awsRequest struct {
Headers []awsRequestHeader `json:"headers"` Headers []awsRequestHeader `json:"headers"`
} }
func (cs awsCredentialSource) validateMetadataServers() error {
if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
return err
}
if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
return err
}
return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
}
var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
if metadataUrl == "" {
// Zero value means use default, which is valid.
return true
}
u, err := url.Parse(metadataUrl)
if err != nil {
// Unparseable URL means invalid
return false
}
for _, validHostname := range validHostnames {
if u.Hostname() == validHostname {
// If it's one of the valid hostnames, everything is good
return true
}
}
// hostname not found in our allowlist, so not valid
return false
}
func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
if !cs.isValidMetadataServer(metadataUrl) {
return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
}
return nil
}
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) { func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
if cs.client == nil { if cs.client == nil {
cs.client = oauth2.NewClient(cs.ctx, nil) cs.client = oauth2.NewClient(cs.ctx, nil)

View File

@ -553,16 +553,25 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
func TestAWSCredential_BasicRequest(t *testing.T) { func TestAWSCredential_BasicRequest(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
oldNow := now oldNow := now
defer func() { now = oldNow }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -618,16 +627,25 @@ func TestAWSCredential_IMDSv2(t *testing.T) {
validateSessionTokenHeaders, validateSessionTokenHeaders,
) )
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
oldNow := now oldNow := now
defer func() { now = oldNow }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -655,17 +673,26 @@ func TestAWSCredential_IMDSv2(t *testing.T) {
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
delete(server.Credentials, "Token") delete(server.Credentials, "Token")
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
oldNow := now oldNow := now
defer func() { now = oldNow }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{})
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -693,20 +720,29 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -734,20 +770,29 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_DEFAULT_REGION": "us-west-1", "AWS_REGION": "us-west-1",
}) })
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -774,21 +819,30 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldNow := now
oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
now = oldNow
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{ getenv = setEnvironment(map[string]string{
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
"AWS_REGION": "us-west-1", "AWS_REGION": "us-west-1",
"AWS_DEFAULT_REGION": "us-east-1", "AWS_DEFAULT_REGION": "us-east-1",
}) })
oldNow := now
defer func() { now = oldNow }()
now = setTime(defaultTime) now = setTime(defaultTime)
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -815,16 +869,25 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
func TestAWSCredential_RequestWithBadVersion(t *testing.T) { func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.EnvironmentID = "aws3" tfc.CredentialSource.EnvironmentID = "aws3"
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
_, err := tfc.parse(context.Background()) _, err = tfc.parse(context.Background())
if err == nil { if err == nil {
t.Fatalf("parse() should have failed") t.Fatalf("parse() should have failed")
} }
@ -836,14 +899,23 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.RegionURL = "" tfc.CredentialSource.RegionURL = ""
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -863,14 +935,23 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRegion = notFound server.WriteRegion = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -890,6 +971,10 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}")) w.Write([]byte("{}"))
} }
@ -898,8 +983,13 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -919,6 +1009,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
} }
@ -927,8 +1021,13 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -948,14 +1047,23 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
tfc.CredentialSource.URL = "" tfc.CredentialSource.URL = ""
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -975,14 +1083,23 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteRolename = notFound server.WriteRolename = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -1002,14 +1119,23 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
server := createDefaultAwsTestServer() server := createDefaultAwsTestServer()
ts := httptest.NewServer(server) ts := httptest.NewServer(server)
tsURL, err := neturl.Parse(ts.URL)
if err != nil {
t.Fatalf("couldn't parse httptest servername")
}
server.WriteSecurityCredentials = notFound server.WriteSecurityCredentials = notFound
tfc := testFileConfig tfc := testFileConfig
tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource = server.getCredentialSource(ts.URL)
oldGetenv := getenv oldGetenv := getenv
defer func() { getenv = oldGetenv }() oldValidHostnames := validHostnames
defer func() {
getenv = oldGetenv
validHostnames = oldValidHostnames
}()
getenv = setEnvironment(map[string]string{}) getenv = setEnvironment(map[string]string{})
validHostnames = []string{tsURL.Hostname()}
base, err := tfc.parse(context.Background()) base, err := tfc.parse(context.Background())
if err != nil { if err != nil {
@ -1025,3 +1151,88 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
t.Errorf("subjectToken = %q, want %q", got, want) t.Errorf("subjectToken = %q, want %q", got, want)
} }
} }
func TestAWSCredential_Validations(t *testing.T) {
var metadataServerValidityTests = []struct {
name string
credSource CredentialSource
errText string
}{
{
name: "No Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "",
URL: "",
IMDSv2SessionTokenURL: "",
},
}, {
name: "IPv4 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
}, {
name: "IPv6 Metadata Server URLs",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone",
URL: "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token",
},
}, {
name: "Faulty RegionURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://abc.com/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url",
}, {
name: "Faulty CredVerificationURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://abc.com/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url",
}, {
name: "Faulty IMDSv2SessionTokenURL",
credSource: CredentialSource{
EnvironmentID: "aws1",
RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone",
URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials",
IMDSv2SessionTokenURL: "http://abc.com/latest/api/token",
},
errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url",
},
}
for _, tt := range metadataServerValidityTests {
t.Run(tt.name, func(t *testing.T) {
tfc := testFileConfig
tfc.CredentialSource = tt.credSource
oldGetenv := getenv
defer func() { getenv = oldGetenv }()
getenv = setEnvironment(map[string]string{})
_, err := tfc.parse(context.Background())
if err != nil {
if tt.errText == "" {
t.Errorf("Didn't expect an error, but got %v", err)
} else if tt.errText != err.Error() {
t.Errorf("Expected %v, but got %v", tt.errText, err)
}
} else {
if tt.errText != "" {
t.Errorf("Expected error %v, but got none", tt.errText)
}
}
})
}
}

View File

@ -213,6 +213,10 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
} }
if err := awsCredSource.validateMetadataServers(); err != nil {
return nil, err
}
return awsCredSource, nil return awsCredSource, nil
} }
} else if c.CredentialSource.File != "" { } else if c.CredentialSource.File != "" {