diff --git a/internal/token.go b/internal/token.go index efe7af5..766cfb6 100644 --- a/internal/token.go +++ b/internal/token.go @@ -18,6 +18,7 @@ import ( "time" "golang.org/x/net/context" + "golang.org/x/net/context/ctxhttp" ) // Token represents the crendentials used to authorize @@ -189,7 +190,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, if !bustedAuth { req.SetBasicAuth(clientID, clientSecret) } - r, err := hc.Do(req) + r, err := ctxhttp.Do(ctx, hc, req) if err != nil { return nil, err } diff --git a/internal/token_test.go b/internal/token_test.go index 882de11..c19a462 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -79,3 +79,27 @@ func TestProviderAuthHeaderWorksDomain(t *testing.T) { } } } + +func TestRetrieveTokenWithContexts(t *testing.T) { + const clientID = "client-id" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer ts.Close() + + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}) + if err != nil { + t.Errorf("RetrieveToken (with background context) = %v; want no error", err) + } + + ctx, cancelfunc := context.WithCancel(context.Background()) + + cancellingts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cancelfunc() + })) + defer cancellingts.Close() + + _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}) + if err == nil { + t.Errorf("RetrieveToken (with cancelled context) = nil; want error") + } +}