diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 57a5870..2eb5c8e 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -96,7 +96,7 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) { } else if c.CredentialSource.File != "" { return fileCredentialSource{File: c.CredentialSource.File, Format: c.CredentialSource.Format}, nil } else if c.CredentialSource.URL != "" { - return urlCredentialSource{URL: c.CredentialSource.URL, Format: c.CredentialSource.Format, ctx: ctx}, nil + return urlCredentialSource{URL: c.CredentialSource.URL, Headers: c.CredentialSource.Headers, Format: c.CredentialSource.Format, ctx: ctx}, nil } return nil, fmt.Errorf("oauth2/google: unable to parse credential source") } diff --git a/google/internal/externalaccount/sts_exchange.go b/google/internal/externalaccount/sts_exchange.go index 1a1c9b4..fbb477d 100644 --- a/google/internal/externalaccount/sts_exchange.go +++ b/google/internal/externalaccount/sts_exchange.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "io" + "io/ioutil" "net/http" "net/url" "strconv" @@ -63,9 +64,12 @@ func ExchangeToken(ctx context.Context, endpoint string, request *STSTokenExchan } defer resp.Body.Close() - bodyJson := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)) + body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if c := resp.StatusCode; c < 200 || c > 299 { + return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body) + } var stsResp STSTokenExchangeResponse - err = bodyJson.Decode(&stsResp) + err = json.Unmarshal(body, &stsResp) if err != nil { return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err) diff --git a/google/internal/externalaccount/urlcredsource.go b/google/internal/externalaccount/urlcredsource.go index b0d5d35..91b8f20 100644 --- a/google/internal/externalaccount/urlcredsource.go +++ b/google/internal/externalaccount/urlcredsource.go @@ -39,15 +39,18 @@ func (cs urlCredentialSource) subjectToken() (string, error) { } defer resp.Body.Close() - tokenBytes, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return "", fmt.Errorf("oauth2/google: invalid body in subject token URL query: %v", err) } + if c := resp.StatusCode; c < 200 || c > 299 { + return "", fmt.Errorf("oauth2/google: status code %d: %s", c, respBody) + } switch cs.Format.Type { case "json": jsonData := make(map[string]interface{}) - err = json.Unmarshal(tokenBytes, &jsonData) + err = json.Unmarshal(respBody, &jsonData) if err != nil { return "", fmt.Errorf("oauth2/google: failed to unmarshal subject token file: %v", err) } @@ -61,9 +64,9 @@ func (cs urlCredentialSource) subjectToken() (string, error) { } return token, nil case "text": - return string(tokenBytes), nil + return string(respBody), nil case "": - return string(tokenBytes), nil + return string(respBody), nil default: return "", errors.New("oauth2/google: invalid credential_source file format type") } diff --git a/google/internal/externalaccount/urlcredsource_test.go b/google/internal/externalaccount/urlcredsource_test.go index 1b78e68..6874f11 100644 --- a/google/internal/externalaccount/urlcredsource_test.go +++ b/google/internal/externalaccount/urlcredsource_test.go @@ -7,6 +7,7 @@ package externalaccount import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -19,11 +20,18 @@ func TestRetrieveURLSubjectToken_Text(t *testing.T) { if r.Method != "GET" { t.Errorf("Unexpected request method, %v is found", r.Method) } + fmt.Println(r.Header) + if r.Header.Get("Metadata") != "True" { + t.Errorf("Metadata header not properly included.") + } w.Write([]byte("testTokenValue")) })) + heads := make(map[string]string) + heads["Metadata"] = "True" cs := CredentialSource{ URL: ts.URL, Format: format{Type: fileTypeText}, + Headers: heads, } tfc := testFileConfig tfc.CredentialSource = cs