diff --git a/google/internal/externalaccount/sts_exchange.go b/google/internal/externalaccount/sts_exchange.go index 099e4b3..11548f0 100644 --- a/google/internal/externalaccount/sts_exchange.go +++ b/google/internal/externalaccount/sts_exchange.go @@ -10,7 +10,6 @@ import ( "fmt" "golang.org/x/oauth2" "io" - "io/ioutil" "net/http" "net/url" "strconv" @@ -33,35 +32,41 @@ func ExchangeToken(ctx context.Context, endpoint string, request *STSTokenExchan fmt.Errorf("oauth2/google: failed to marshal additional options: %v", err) return nil, err } - data.Set("options", string(opts)) + data.Set("options", string(opts)) authentication.InjectAuthentication(data, headers) encodedData := data.Encode() + req, err := http.NewRequestWithContext(ctx, "POST", endpoint, strings.NewReader(encodedData)) if err != nil { fmt.Errorf("oauth2/google: failed to properly build http request: %v", err) + return nil, err } - for key, _ := range headers { - for _, val := range headers.Values(key) { + for key, list := range headers { + for _, val := range list { req.Header.Add(key, val) } } req.Header.Add("Content-Length", strconv.Itoa(len(encodedData))) resp, err := client.Do(req) + if err != nil { fmt.Errorf("oauth2/google: invalid response from Secure Token Server: %v", err) + return nil, err } defer resp.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - fmt.Errorf("oauth2/google: invalid body in Secure Token Server response: %v", err) - } + + bodyJson := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)) var stsResp STSTokenExchangeResponse - err = json.Unmarshal(body, &stsResp) - if err != nil { - fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err) + for bodyJson.More() { + err = bodyJson.Decode(&stsResp) + if err != nil { + fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err) + return nil, err + } } + return &stsResp, nil } diff --git a/google/internal/externalaccount/sts_exchange_test.go b/google/internal/externalaccount/sts_exchange_test.go index 703a96e..496bba6 100644 --- a/google/internal/externalaccount/sts_exchange_test.go +++ b/google/internal/externalaccount/sts_exchange_test.go @@ -34,7 +34,7 @@ var tokenRequest = STSTokenExchangeRequest{ SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", } -var requestbody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt" +var requestbody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&options=null&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt" var responseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}` var expectedToken = STSTokenExchangeResponse{ AccessToken: "Sample.Access.Token", @@ -48,34 +48,37 @@ var expectedToken = STSTokenExchangeResponse{ func TestExchangeToken(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Unexpected request method, %v is found", r.Method) + } if r.URL.String() != "/" { - t.Errorf("Unexpected request URL, %v is found.", r.URL) + t.Errorf("Unexpected request URL, %v is found", r.URL) } headerAuth := r.Header.Get("Authorization") if headerAuth != "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=" { - t.Errorf("Unexpected autohrization header, %v is found.", headerAuth) + t.Errorf("Unexpected autohrization header, %v is found", headerAuth) } headerContentType := r.Header.Get("Content-Type") if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + t.Errorf("Unexpected Content-Type header, %v is found", headerContentType) } body, err := ioutil.ReadAll(r.Body) if err != nil { - t.Errorf("Failed reading request body: %s.", err) + t.Errorf("Failed reading request body: %v.", err) } if string(body) != requestbody { - t.Errorf("Unexpected exchange payload, %v is found.", string(body)) + t.Errorf("Unexpected exchange payload, %v is found", string(body)) } w.Header().Set("Content-Type", "application/json") w.Write([]byte(responseBody)) })) - headers := make(map[string][]string) - headers["Content-Type"] = []string{"application/x-www-form-urlencoded"} + headers := http.Header{} + headers.Add("Content-Type", "application/x-www-form-urlencoded") resp, err := ExchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) if err != nil { - t.Fatalf("ExchangeToken failed with error: %s", err) + t.Fatalf("ExchangeToken failed with error: %v", err) } if diff := cmp.Diff(expectedToken, *resp); diff != "" { @@ -90,10 +93,10 @@ func TestExchangeToken_Err(t *testing.T) { w.Write([]byte("what's wrong with this response?")) })) - headers := make(map[string][]string) - headers["Content-Type"] = []string{"application/x-www-form-urlencoded"} + headers := http.Header{} + headers.Add("Content-Type", "application/x-www-form-urlencoded") _, err := ExchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) if err == nil { t.Errorf("Expected handled error; instead got nil.") } -} \ No newline at end of file +}