diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 53a96b6..4afb631 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -92,6 +92,9 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { } tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v) if err != nil { + if rErr, ok := err.(*internal.RetrieveError); ok { + return nil, (*oauth2.RetrieveError)(rErr) + } return nil, err } t := &oauth2.Token{ diff --git a/internal/token.go b/internal/token.go index 600dbe6..8a10204 100644 --- a/internal/token.go +++ b/internal/token.go @@ -200,7 +200,10 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } if code := r.StatusCode; code < 200 || code > 299 { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) + return nil, &RetrieveError{ + Response: r, + Body: body, + } } var token *Token @@ -249,3 +252,12 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, } return token, nil } + +type RetrieveError struct { + Response *http.Response + Body []byte +} + +func (r *RetrieveError) Error() string { + return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) +} diff --git a/oauth2_test.go b/oauth2_test.go index 09293ed..14f041d 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -419,6 +419,32 @@ func TestFetchWithNoRefreshToken(t *testing.T) { } } +func TestTokenRetrieveError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + } + w.Header().Set("Content-type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": "invalid_grant"}`)) + })) + defer ts.Close() + conf := newConf(ts.URL) + _, err := conf.Exchange(context.Background(), "exchange-code") + if err == nil { + t.Fatalf("got no error, expected one") + } + _, ok := err.(*RetrieveError) + if !ok { + t.Fatalf("got %T error, expected *RetrieveError", err) + } + // Test error string for backwards compatibility + expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`) + if errStr := err.Error(); errStr != expected { + t.Fatalf("got %#v, expected %#v", errStr, expected) + } +} + func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/token.go b/token.go index d8534de..34db8cd 100644 --- a/token.go +++ b/token.go @@ -5,6 +5,7 @@ package oauth2 import ( + "fmt" "net/http" "net/url" "strconv" @@ -152,7 +153,23 @@ func tokenFromInternal(t *internal.Token) *Token { func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) if err != nil { + if rErr, ok := err.(*internal.RetrieveError); ok { + return nil, (*RetrieveError)(rErr) + } return nil, err } return tokenFromInternal(tk), nil } + +// RetrieveError is the error returned when the token endpoint returns a +// non-2XX HTTP status code. +type RetrieveError struct { + Response *http.Response + // Body is the body that was consumed by reading Response.Body. + // It may be truncated. + Body []byte +} + +func (r *RetrieveError) Error() string { + return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) +}