diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index ec0a1a9..13353d5 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -61,26 +61,30 @@ var ( validTokenURLPatterns = []*regexp.Regexp{ // The complicated part in the middle matches any number of characters that // aren't period, spaces, or slashes. - regexp.MustCompile("^https://[^\\.\\s\\/\\\\]+\\.sts\\.googleapis\\.com$"), - regexp.MustCompile("^https://sts\\.googleapis\\.com$"), - regexp.MustCompile("^https://sts\\.[^\\.\\s\\/\\\\]+\\.googleapis\\.com$"), - regexp.MustCompile("^https://[^\\.\\s\\/\\\\]+-sts\\.googleapis\\.com$"), + regexp.MustCompile("^[^\\.\\s\\/\\\\]+\\.sts\\.googleapis\\.com$"), + regexp.MustCompile("^sts\\.googleapis\\.com$"), + regexp.MustCompile("^sts\\.[^\\.\\s\\/\\\\]+\\.googleapis\\.com$"), + regexp.MustCompile("^[^\\.\\s\\/\\\\]+-sts\\.googleapis\\.com$"), } validImpersonateURLPatterns = []*regexp.Regexp{ - regexp.MustCompile("^https://[^\\.\\s\\/\\\\]+\\.iamcredentials\\.googleapis\\.com$"), - regexp.MustCompile("^https://iamcredentials\\.googleapis\\.com$"), - regexp.MustCompile("^https://iamcredentials\\.[^\\.\\s\\/\\\\]+\\.googleapis\\.com$"), - regexp.MustCompile("^https://[^\\.\\s\\/\\\\]+-iamcredentials\\.googleapis\\.com$"), + regexp.MustCompile("^[^\\.\\s\\/\\\\]+\\.iamcredentials\\.googleapis\\.com$"), + regexp.MustCompile("^iamcredentials\\.googleapis\\.com$"), + regexp.MustCompile("^iamcredentials\\.[^\\.\\s\\/\\\\]+\\.googleapis\\.com$"), + regexp.MustCompile("^[^\\.\\s\\/\\\\]+-iamcredentials\\.googleapis\\.com$"), } ) -func validateURL(input string, patterns []*regexp.Regexp) bool { +func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool { + fmt.Println(input) parsed, err := url.Parse(input) if err != nil { return false } - path := parsed.Path - toTest := input[0 : len(input)-len(path)] + if parsed.Scheme != scheme { + return false + } + toTest := parsed.Host + fmt.Println(toTest) for _, pattern := range patterns { valid := pattern.MatchString(toTest) @@ -93,22 +97,22 @@ func validateURL(input string, patterns []*regexp.Regexp) bool { // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials. func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { - return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns) + return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns, "https") } // tokenSource is a private function that's directly called by some of the tests, // because the unit test URLs are mocked, and would otherwise fail the // validity check. -func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Regexp, impersonateURLValidPats []*regexp.Regexp) (oauth2.TokenSource, error) { +func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Regexp, impersonateURLValidPats []*regexp.Regexp, scheme string) (oauth2.TokenSource, error) { // Check the validity of TokenURL. - valid := validateURL(c.TokenURL, tokenURLValidPats) + valid := validateURL(c.TokenURL, tokenURLValidPats, scheme) if !valid { return nil, fmt.Errorf("oauth2/google: invalid TokenURL provided while constructing tokenSource") } // If ServiceAccountImpersonationURL is present, check its validity. if c.ServiceAccountImpersonationURL != "" { - valid := validateURL(c.ServiceAccountImpersonationURL, impersonateURLValidPats) + valid := validateURL(c.ServiceAccountImpersonationURL, impersonateURLValidPats, scheme) if !valid { return nil, fmt.Errorf("oauth2/google: invalid ServiceAccountImpersonationURL provided while constructing tokenSource") } diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index 2a23e5f..3189b16 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -118,7 +118,7 @@ func TestValidateURLTokenURL(t *testing.T) { } for _, tt := range urlValidityTests { t.Run(" "+tt.input, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - valid := validateURL(tt.input, tt.pattern) + valid := validateURL(tt.input, tt.pattern, "https") if valid != tt.result { t.Errorf("got %v, want %v", valid, tt.result) } @@ -147,7 +147,7 @@ func TestValidateURLImpersonateURL(t *testing.T) { } for _, tt := range urlValidityTests { t.Run(" "+tt.input, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - valid := validateURL(tt.input, tt.pattern) + valid := validateURL(tt.input, tt.pattern, "https") if valid != tt.result { t.Errorf("got %v, want %v", valid, tt.result) } diff --git a/google/internal/externalaccount/impersonate_test.go b/google/internal/externalaccount/impersonate_test.go index 3d1fd23..d8eccf2 100644 --- a/google/internal/externalaccount/impersonate_test.go +++ b/google/internal/externalaccount/impersonate_test.go @@ -78,9 +78,8 @@ func TestImpersonation(t *testing.T) { defer targetServer.Close() testImpersonateConfig.TokenURL = targetServer.URL - allURLs := regexp.MustCompile(".*") - fmt.Println(allURLs) - ourTS, err := testImpersonateConfig.tokenSource(context.Background(), []*regexp.Regexp{allURLs}, []*regexp.Regexp{allURLs}) + allURLs := regexp.MustCompile(".+") + ourTS, err := testImpersonateConfig.tokenSource(context.Background(), []*regexp.Regexp{allURLs}, []*regexp.Regexp{allURLs}, "http") if err != nil { fmt.Println(testImpersonateConfig.TokenURL) t.Fatalf("Failed to create TokenSource: %v", err)