From 109292283c9baf749e643df6a2109af4ed9fe624 Mon Sep 17 00:00:00 2001 From: Patrick Jones Date: Mon, 9 Aug 2021 14:38:40 -0700 Subject: [PATCH] made some changes --- google/internal/externalaccount/aws_test.go | 3 +- .../externalaccount/basecredentials.go | 64 ++++++++----------- .../externalaccount/basecredentials_test.go | 29 ++++++--- .../externalaccount/impersonate_test.go | 5 +- .../externalaccount/sts_exchange_test.go | 3 + 5 files changed, 56 insertions(+), 48 deletions(-) diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index d5de415..b1f592c 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -28,8 +28,7 @@ func setTime(testTime time.Time) func() time.Time { func setEnvironment(env map[string]string) func(string) string { return func(key string) string { - value, _ := env[key] - return value + return env[key] } } diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 6fedabc..426f4d6 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -55,62 +55,52 @@ type Config struct { // Each element consists of a list of patterns. validateURLs checks for matches // that include all elements in a given list, in that order. + var ( - validTokenURLPatterns = []string{ - "https://[^\\.]+\\.sts\\.googleapis\\.com", - "https://sts\\.googleapis\\.com", - "https://sts\\.[^\\.]+\\.googleapis\\.com", - "https://[^\\.]+-sts\\.googleapis\\.com", + validTokenURLPatterns = []*regexp.Regexp{ + regexp.MustCompile("https://[^\\.]+\\.sts\\.googleapis\\.com"), + regexp.MustCompile("https://sts\\.googleapis\\.com"), + regexp.MustCompile("https://sts\\.[^\\.]+\\.googleapis\\.com"), + regexp.MustCompile("https://[^\\.]+-sts\\.googleapis\\.com"), } - validImpersonateURLPatterns = []string{ - "https://[^\\.]+\\.iamcredentials\\.googleapis\\.com", - "https://iamcredentials\\.googleapis\\.com", - "https://iamcredentials\\.[^\\.]+\\.googleapis\\.com", - "https://[^\\.]+-iamcredentials\\.googleapis\\.com", + validImpersonateURLPatterns = []*regexp.Regexp{ + regexp.MustCompile("https://[^\\.]+\\.iamcredentials\\.googleapis\\.com"), + regexp.MustCompile("https://iamcredentials\\.googleapis\\.com"), + regexp.MustCompile("https://iamcredentials\\.[^\\.]+\\.googleapis\\.com"), + regexp.MustCompile("https://[^\\.]+-iamcredentials\\.googleapis\\.com"), } ) -func validateURL(input string, patterns []string) (bool, error) { +func validateURL(input string, patterns []*regexp.Regexp) bool { for _, pattern := range patterns { - valid, err := regexp.MatchString(pattern, input) - if err != nil { - return false, err - } + valid := pattern.MatchString(input) if valid { - return true, nil + return true } } - return false, nil + return false } // TokenSource Returns an external account TokenSource struct. This is to be called by package google to construct a google.Credentials. func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { - return c.tokenSource(ctx, false) + return c.tokenSource(ctx, validTokenURLPatterns, validImpersonateURLPatterns) } // tokenSource is a private function that's directly called by some of the tests, // because the unit test URLs are mocked, and would otherwise fail the // validity check. -func (c *Config) tokenSource(ctx context.Context, testing bool) (oauth2.TokenSource, error) { - if !testing { - // Check the validity of TokenURL. - valid, err := validateURL(c.TokenURL, validTokenURLPatterns) - if err != nil { - return nil, err - } - if !valid { - return nil, fmt.Errorf("oauth2/google: invalid TokenURL provided while constructing tokenSource") - } +func (c *Config) tokenSource(ctx context.Context, tokenURLValidPats []*regexp.Regexp, impersonateURLValidPats []*regexp.Regexp) (oauth2.TokenSource, error) { + // Check the validity of TokenURL. + valid := validateURL(c.TokenURL, tokenURLValidPats) + if !valid { + return nil, fmt.Errorf("oauth2/google: invalid TokenURL provided while constructing tokenSource") + } - // If ServiceAccountImpersonationURL is present, check its validity. - if c.ServiceAccountImpersonationURL != "" { - valid, err := validateURL(c.ServiceAccountImpersonationURL, validImpersonateURLPatterns) - if err != nil { - return nil, err - } - if !valid { - return nil, fmt.Errorf("oauth2/google: invalid ServiceAccountImpersonationURL provided while constructing tokenSource") - } + // If ServiceAccountImpersonationURL is present, check its validity. + if c.ServiceAccountImpersonationURL != "" { + valid := validateURL(c.ServiceAccountImpersonationURL, impersonateURLValidPats) + if !valid { + return nil, fmt.Errorf("oauth2/google: invalid ServiceAccountImpersonationURL provided while constructing tokenSource") } } diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index a1342cf..34c5542 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "regexp" "testing" "time" ) @@ -96,10 +97,10 @@ func TestToken(t *testing.T) { } -func TestValidateURL(t *testing.T) { +func TestValidateURLTokenURL(t *testing.T) { var urlValidityTests = []struct { input string - pattern []string + pattern []*regexp.Regexp result bool }{ {"https://east.sts.googleapis.com", validTokenURLPatterns, true}, @@ -112,7 +113,23 @@ func TestValidateURL(t *testing.T) { {"https://sts..googleapis.com", validTokenURLPatterns, false}, {"https://-sts.googleapis.com", validTokenURLPatterns, false}, {"https://us-ea.st-1-sts.googleapis.com", validTokenURLPatterns, false}, - // Repeat for iamcredentials as well + } + for _, tt := range urlValidityTests { + t.Run(" "+tt.input, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. + valid := validateURL(tt.input, tt.pattern) + if valid != tt.result { + t.Errorf("got %v, want %v", valid, tt.result) + } + }) + } +} + +func TestValidateURLImpersonateURL (t *testing.T) { + var urlValidityTests = []struct { + input string + pattern []*regexp.Regexp + result bool + }{ {"https://east.iamcredentials.googleapis.com", validImpersonateURLPatterns, true}, {"https://iamcredentials.googleapis.com", validImpersonateURLPatterns, true}, {"https://iamcredentials.asfeasfesef.googleapis.com", validImpersonateURLPatterns, true}, @@ -126,11 +143,7 @@ func TestValidateURL(t *testing.T) { } for _, tt := range urlValidityTests { t.Run(" "+tt.input, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability. - valid, err := validateURL(tt.input, tt.pattern) - if err != nil { - t.Errorf("validateURL returned an error: %v", err) - return - } + valid := validateURL(tt.input, tt.pattern) if valid != tt.result { t.Errorf("got %v, want %v", valid, tt.result) } diff --git a/google/internal/externalaccount/impersonate_test.go b/google/internal/externalaccount/impersonate_test.go index acf08d2..3d1fd23 100644 --- a/google/internal/externalaccount/impersonate_test.go +++ b/google/internal/externalaccount/impersonate_test.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "regexp" "testing" ) @@ -77,7 +78,9 @@ func TestImpersonation(t *testing.T) { defer targetServer.Close() testImpersonateConfig.TokenURL = targetServer.URL - ourTS, err := testImpersonateConfig.tokenSource(context.Background(), true) + allURLs := regexp.MustCompile(".*") + fmt.Println(allURLs) + ourTS, err := testImpersonateConfig.tokenSource(context.Background(), []*regexp.Regexp{allURLs}, []*regexp.Regexp{allURLs}) if err != nil { fmt.Println(testImpersonateConfig.TokenURL) t.Fatalf("Failed to create TokenSource: %v", err) diff --git a/google/internal/externalaccount/sts_exchange_test.go b/google/internal/externalaccount/sts_exchange_test.go index d7507a4..df4d5ff 100644 --- a/google/internal/externalaccount/sts_exchange_test.go +++ b/google/internal/externalaccount/sts_exchange_test.go @@ -128,6 +128,9 @@ func TestExchangeToken_Opts(t *testing.T) { } var opts map[string]interface{} err = json.Unmarshal([]byte(strOpts[0]), &opts) + if err != nil { + t.Fatalf("Couldn't parse received \"options\" field.") + } if len(opts) < 2 { t.Errorf("Too few options received.") }