diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index e917195..6318a23 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -267,6 +267,49 @@ type awsRequest struct { Headers []awsRequestHeader `json:"headers"` } +func (cs awsCredentialSource) validateMetadataServers() error { + if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil { + return err + } + if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil { + return err + } + return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url") +} + +var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"} + +func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool { + if metadataUrl == "" { + // Zero value means use default, which is valid. + return true + } + + u, err := url.Parse(metadataUrl) + if err != nil { + // Unparseable URL means invalid + return false + } + + for _, validHostname := range validHostnames { + if u.Hostname() == validHostname { + // If it's one of the valid hostnames, everything is good + return true + } + } + + // hostname not found in our allowlist, so not valid + return false +} + +func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error { + if !cs.isValidMetadataServer(metadataUrl) { + return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName) + } + + return nil +} + func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) { if cs.client == nil { cs.client = oauth2.NewClient(cs.ctx, nil) diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 0934389..30a003a 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -553,16 +553,25 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security func TestAWSCredential_BasicRequest(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -618,16 +627,25 @@ func TestAWSCredential_IMDSv2(t *testing.T) { validateSessionTokenHeaders, ) ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -655,17 +673,26 @@ func TestAWSCredential_IMDSv2(t *testing.T) { func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } delete(server.Credentials, "Token") tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() - getenv = setEnvironment(map[string]string{}) oldNow := now - defer func() { now = oldNow }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{}) now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -693,20 +720,29 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) { func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_REGION": "us-west-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -734,20 +770,29 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", - "AWS_DEFAULT_REGION": "us-west-1", + "AWS_REGION": "us-west-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -774,21 +819,30 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) { func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{ "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", "AWS_REGION": "us-west-1", "AWS_DEFAULT_REGION": "us-east-1", }) - oldNow := now - defer func() { now = oldNow }() now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -815,16 +869,25 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) { func TestAWSCredential_RequestWithBadVersion(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.EnvironmentID = "aws3" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} - _, err := tfc.parse(context.Background()) + _, err = tfc.parse(context.Background()) if err == nil { t.Fatalf("parse() should have failed") } @@ -836,14 +899,23 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) { func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.RegionURL = "" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -863,14 +935,23 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) { func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteRegion = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -890,6 +971,10 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) { func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("{}")) } @@ -898,8 +983,13 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -919,6 +1009,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) { func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`)) } @@ -927,8 +1021,13 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -948,14 +1047,23 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) { func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) tfc.CredentialSource.URL = "" oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -975,14 +1083,23 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteRolename = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -1002,14 +1119,23 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) { func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { server := createDefaultAwsTestServer() ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } server.WriteSecurityCredentials = notFound tfc := testFileConfig tfc.CredentialSource = server.getCredentialSource(ts.URL) oldGetenv := getenv - defer func() { getenv = oldGetenv }() + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + validHostnames = oldValidHostnames + }() getenv = setEnvironment(map[string]string{}) + validHostnames = []string{tsURL.Hostname()} base, err := tfc.parse(context.Background()) if err != nil { @@ -1025,3 +1151,88 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { t.Errorf("subjectToken = %q, want %q", got, want) } } + +func TestAWSCredential_Validations(t *testing.T) { + var metadataServerValidityTests = []struct { + name string + credSource CredentialSource + errText string + }{ + { + name: "No Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "", + URL: "", + IMDSv2SessionTokenURL: "", + }, + }, { + name: "IPv4 Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + }, { + name: "IPv6 Metadata Server URLs", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://[fd00:ec2::254]/latest/meta-data/placement/availability-zone", + URL: "http://[fd00:ec2::254]/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://[fd00:ec2::254]/latest/api/token", + }, + }, { + name: "Faulty RegionURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://abc.com/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/placement/availability-zone for region_url", + }, { + name: "Faulty CredVerificationURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://abc.com/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://169.254.169.254/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/meta-data/iam/security-credentials for url", + }, { + name: "Faulty IMDSv2SessionTokenURL", + credSource: CredentialSource{ + EnvironmentID: "aws1", + RegionURL: "http://169.254.169.254/latest/meta-data/placement/availability-zone", + URL: "http://169.254.169.254/latest/meta-data/iam/security-credentials", + IMDSv2SessionTokenURL: "http://abc.com/latest/api/token", + }, + errText: "oauth2/google: invalid hostname http://abc.com/latest/api/token for imdsv2_session_token_url", + }, + } + + for _, tt := range metadataServerValidityTests { + t.Run(tt.name, func(t *testing.T) { + tfc := testFileConfig + tfc.CredentialSource = tt.credSource + + oldGetenv := getenv + defer func() { getenv = oldGetenv }() + getenv = setEnvironment(map[string]string{}) + + _, err := tfc.parse(context.Background()) + if err != nil { + if tt.errText == "" { + t.Errorf("Didn't expect an error, but got %v", err) + } else if tt.errText != err.Error() { + t.Errorf("Expected %v, but got %v", tt.errText, err) + } + } else { + if tt.errText != "" { + t.Errorf("Expected error %v, but got none", tt.errText) + } + } + }) + } +} diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 9fc3553..3eab8df 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -213,6 +213,10 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL } + if err := awsCredSource.validateMetadataServers(); err != nil { + return nil, err + } + return awsCredSource, nil } } else if c.CredentialSource.File != "" {