diff --git a/google/internal/externalaccount/sts_exchange.go b/google/internal/externalaccount/sts_exchange.go index 2e20682..ccff3ad 100644 --- a/google/internal/externalaccount/sts_exchange.go +++ b/google/internal/externalaccount/sts_exchange.go @@ -1,8 +1,15 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package externalaccount import ( + "context" "encoding/json" "fmt" + "golang.org/x/oauth2" + "io" "io/ioutil" "net/http" "net/url" @@ -10,9 +17,9 @@ import ( "strings" ) -func ExchangeToken(endpoint string, request *STSTokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*STSTokenExchangeResponse, error) { +func ExchangeToken(ctx context.Context, endpoint string, request *STSTokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*STSTokenExchangeResponse, error) { - client := &http.Client{} + client := oauth2.NewClient(ctx, nil) data := url.Values{} data.Set("audience", request.Audience) @@ -21,32 +28,39 @@ func ExchangeToken(endpoint string, request *STSTokenExchangeRequest, authentica data.Set("subject_token_type", request.SubjectTokenType) data.Set("subject_token", request.SubjectToken) data.Set("scope", strings.Join(request.Scope, " ")) + opts, err := json.Marshal(options) + if err == nil { + data.Set("options", string(opts)) + } else { + fmt.Errorf("oauth2/google: failed to marshal additional options: %v", err) + } authentication.InjectAuthentication(data, headers) - req, err := http.NewRequest("POST", endpoint, strings.NewReader(data.Encode())) + encodedData := data.Encode() + req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData)) if err != nil { - fmt.Errorf("oauth2/google: failed to properly build http request") + fmt.Errorf("oauth2/google: failed to properly build http request: %v", err) } for key, _ := range headers { for _, val := range headers.Values(key) { req.Header.Add(key, val) } } - req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) + req.Header.Add("Content-Length", strconv.Itoa(len(encodedData))) resp, err := client.Do(req) if err != nil { - fmt.Errorf("oauth2/google: invalid response from Secure Token Server: #{err}") + fmt.Errorf("oauth2/google: invalid response from Secure Token Server: %v", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { - fmt.Errorf("oauth2/google: invalid body in Secute Token Server response: #{err}") + fmt.Errorf("oauth2/google: invalid body in Secure Token Server response: %v", err) } var stsResp STSTokenExchangeResponse err = json.Unmarshal(body, &stsResp) if err != nil { - fmt.Println(err) + fmt.Errorf("oauth2/google: failed to unmarshal response body from STS server: %v", err) } return &stsResp, nil } diff --git a/google/internal/externalaccount/sts_exchange_test.go b/google/internal/externalaccount/sts_exchange_test.go index 451b86a..ca409dc 100644 --- a/google/internal/externalaccount/sts_exchange_test.go +++ b/google/internal/externalaccount/sts_exchange_test.go @@ -1,6 +1,11 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package externalaccount import ( + "context" "github.com/google/go-cmp/cmp" "golang.org/x/oauth2" "io/ioutil" @@ -68,7 +73,7 @@ func TestExchangeToken(t *testing.T) { headers := make(map[string][]string) headers["Content-Type"] = []string{"application/x-www-form-urlencoded"} - resp, err := ExchangeToken(ts.URL, &tokenRequest, auth, headers, nil) + resp, err := ExchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) if err != nil { t.Errorf("ExchangeToken failed with error: %s", err) } @@ -84,5 +89,11 @@ func TestExchangeToken_Err(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.Write([]byte("what's wrong with this response?")) })) - + + headers := make(map[string][]string) + headers["Content-Type"] = []string{"application/x-www-form-urlencoded"} + _, err := ExchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) + if err == nil { + t.Errorf("Expected handled error; instead got nil.") + } } \ No newline at end of file