google: fix error formatting, add ctx argument

Change-Id: If664a760005cc209c7398f2efce5ff324ebc4746
This commit is contained in:
Patrick Jones 2020-11-03 13:44:17 -08:00
parent 5688b03dbc
commit 01ea6373cb
2 changed files with 36 additions and 11 deletions

View File

@ -1,8 +1,15 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount package externalaccount
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"golang.org/x/oauth2"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
@ -10,9 +17,9 @@ import (
"strings" "strings"
) )
func ExchangeToken(endpoint string, request *STSTokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*STSTokenExchangeResponse, error) { func ExchangeToken(ctx context.Context, endpoint string, request *STSTokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*STSTokenExchangeResponse, error) {
client := &http.Client{} client := oauth2.NewClient(ctx, nil)
data := url.Values{} data := url.Values{}
data.Set("audience", request.Audience) data.Set("audience", request.Audience)
@ -21,32 +28,39 @@ func ExchangeToken(endpoint string, request *STSTokenExchangeRequest, authentica
data.Set("subject_token_type", request.SubjectTokenType) data.Set("subject_token_type", request.SubjectTokenType)
data.Set("subject_token", request.SubjectToken) data.Set("subject_token", request.SubjectToken)
data.Set("scope", strings.Join(request.Scope, " ")) data.Set("scope", strings.Join(request.Scope, " "))
opts, err := json.Marshal(options)
if err == nil {
data.Set("options", string(opts))
} else {
fmt.Errorf("oauth2/google: failed to marshal additional options: %v", err)
}
authentication.InjectAuthentication(data, headers) authentication.InjectAuthentication(data, headers)
req, err := http.NewRequest("POST", endpoint, strings.NewReader(data.Encode())) encodedData := data.Encode()
req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
if err != nil { if err != nil {
fmt.Errorf("oauth2/google: failed to properly build http request") fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
} }
for key, _ := range headers { for key, _ := range headers {
for _, val := range headers.Values(key) { for _, val := range headers.Values(key) {
req.Header.Add(key, val) req.Header.Add(key, val)
} }
} }
req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) req.Header.Add("Content-Length", strconv.Itoa(len(encodedData)))
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
fmt.Errorf("oauth2/google: invalid response from Secure Token Server: #{err}") fmt.Errorf("oauth2/google: invalid response from Secure Token Server: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
fmt.Errorf("oauth2/google: invalid body in Secute Token Server response: #{err}") fmt.Errorf("oauth2/google: invalid body in Secure Token Server response: %v", err)
} }
var stsResp STSTokenExchangeResponse var stsResp STSTokenExchangeResponse
err = json.Unmarshal(body, &stsResp) err = json.Unmarshal(body, &stsResp)
if err != nil { if err != nil {
fmt.Println(err) fmt.Errorf("oauth2/google: failed to unmarshal response body from STS server: %v", err)
} }
return &stsResp, nil return &stsResp, nil
} }

View File

@ -1,6 +1,11 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccount package externalaccount
import ( import (
"context"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"io/ioutil" "io/ioutil"
@ -68,7 +73,7 @@ func TestExchangeToken(t *testing.T) {
headers := make(map[string][]string) headers := make(map[string][]string)
headers["Content-Type"] = []string{"application/x-www-form-urlencoded"} headers["Content-Type"] = []string{"application/x-www-form-urlencoded"}
resp, err := ExchangeToken(ts.URL, &tokenRequest, auth, headers, nil) resp, err := ExchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil)
if err != nil { if err != nil {
t.Errorf("ExchangeToken failed with error: %s", err) t.Errorf("ExchangeToken failed with error: %s", err)
} }
@ -85,4 +90,10 @@ func TestExchangeToken_Err(t *testing.T) {
w.Write([]byte("what's wrong with this response?")) w.Write([]byte("what's wrong with this response?"))
})) }))
headers := make(map[string][]string)
headers["Content-Type"] = []string{"application/x-www-form-urlencoded"}
_, err := ExchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil)
if err == nil {
t.Errorf("Expected handled error; instead got nil.")
}
} }