diff --git a/internal/token.go b/internal/token.go index b4723fc..58901bd 100644 --- a/internal/token.go +++ b/internal/token.go @@ -55,12 +55,18 @@ type Token struct { } // tokenJSON is the struct representing the HTTP response from OAuth2 -// providers returning a token in JSON form. +// providers returning a token or error in JSON form. +// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 type tokenJSON struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` RefreshToken string `json:"refresh_token"` ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number + // error fields + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` } func (e *tokenJSON) expiry() (t time.Time) { @@ -236,21 +242,29 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } - if code := r.StatusCode; code < 200 || code > 299 { - return nil, &RetrieveError{ - Response: r, - Body: body, - } + + failureStatus := r.StatusCode < 200 || r.StatusCode > 299 + retrieveError := &RetrieveError{ + Response: r, + Body: body, + // attempt to populate error detail below } var token *Token content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) switch content { case "application/x-www-form-urlencoded", "text/plain": + // some endpoints return a query string vals, err := url.ParseQuery(string(body)) if err != nil { - return nil, err + if failureStatus { + return nil, retrieveError + } + return nil, fmt.Errorf("oauth2: cannot parse response: %v", err) } + retrieveError.ErrorCode = vals.Get("error") + retrieveError.ErrorDescription = vals.Get("error_description") + retrieveError.ErrorURI = vals.Get("error_uri") token = &Token{ AccessToken: vals.Get("access_token"), TokenType: vals.Get("token_type"), @@ -265,8 +279,14 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { default: var tj tokenJSON if err = json.Unmarshal(body, &tj); err != nil { - return nil, err + if failureStatus { + return nil, retrieveError + } + return nil, fmt.Errorf("oauth2: cannot parse json: %v", err) } + retrieveError.ErrorCode = tj.ErrorCode + retrieveError.ErrorDescription = tj.ErrorDescription + retrieveError.ErrorURI = tj.ErrorURI token = &Token{ AccessToken: tj.AccessToken, TokenType: tj.TokenType, @@ -276,17 +296,37 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { } json.Unmarshal(body, &token.Raw) // no error checks for optional fields } + // according to spec, servers should respond status 400 in error case + // https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + // but some unorthodox servers respond 200 in error case + if failureStatus || retrieveError.ErrorCode != "" { + return nil, retrieveError + } if token.AccessToken == "" { return nil, errors.New("oauth2: server response missing access_token") } return token, nil } +// mirrors oauth2.RetrieveError type RetrieveError struct { - Response *http.Response - Body []byte + Response *http.Response + Body []byte + ErrorCode string + ErrorDescription string + ErrorURI string } func (r *RetrieveError) Error() string { + if r.ErrorCode != "" { + s := fmt.Sprintf("oauth2: %q", r.ErrorCode) + if r.ErrorDescription != "" { + s += fmt.Sprintf(" %q", r.ErrorDescription) + } + if r.ErrorURI != "" { + s += fmt.Sprintf(" %q", r.ErrorURI) + } + return s + } return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) } diff --git a/oauth2_test.go b/oauth2_test.go index b7975e1..a351776 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -484,6 +484,7 @@ func TestTokenRetrieveError(t *testing.T) { t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) } w.Header().Set("Content-type", "application/json") + // "The authorization server responds with an HTTP 400 (Bad Request)" https://www.rfc-editor.org/rfc/rfc6749#section-5.2 w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"error": "invalid_grant"}`)) })) @@ -493,15 +494,47 @@ func TestTokenRetrieveError(t *testing.T) { if err == nil { t.Fatalf("got no error, expected one") } - _, ok := err.(*RetrieveError) + re, ok := err.(*RetrieveError) if !ok { t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) } - // Test error string for backwards compatibility - expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`) + expected := `oauth2: "invalid_grant"` if errStr := err.Error(); errStr != expected { t.Fatalf("got %#v, expected %#v", errStr, expected) } + expected = "invalid_grant" + if re.ErrorCode != expected { + t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected) + } +} + +// TestTokenRetrieveError200 tests handling of unorthodox server that returns 200 in error case +func TestTokenRetrieveError200(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + } + w.Header().Set("Content-type", "application/json") + w.Write([]byte(`{"error": "invalid_grant"}`)) + })) + defer ts.Close() + conf := newConf(ts.URL) + _, err := conf.Exchange(context.Background(), "exchange-code") + if err == nil { + t.Fatalf("got no error, expected one") + } + re, ok := err.(*RetrieveError) + if !ok { + t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) + } + expected := `oauth2: "invalid_grant"` + if errStr := err.Error(); errStr != expected { + t.Fatalf("got %#v, expected %#v", errStr, expected) + } + expected = "invalid_grant" + if re.ErrorCode != expected { + t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected) + } } func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { diff --git a/token.go b/token.go index 7c64006..5ffce97 100644 --- a/token.go +++ b/token.go @@ -175,14 +175,31 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) } // RetrieveError is the error returned when the token endpoint returns a -// non-2XX HTTP status code. +// non-2XX HTTP status code or populates RFC 6749's 'error' parameter. +// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 type RetrieveError struct { Response *http.Response // Body is the body that was consumed by reading Response.Body. // It may be truncated. Body []byte + // ErrorCode is RFC 6749's 'error' parameter. + ErrorCode string + // ErrorDescription is RFC 6749's 'error_description' parameter. + ErrorDescription string + // ErrorURI is RFC 6749's 'error_uri' parameter. + ErrorURI string } func (r *RetrieveError) Error() string { + if r.ErrorCode != "" { + s := fmt.Sprintf("oauth2: %q", r.ErrorCode) + if r.ErrorDescription != "" { + s += fmt.Sprintf(" %q", r.ErrorDescription) + } + if r.ErrorURI != "" { + s += fmt.Sprintf(" %q", r.ErrorURI) + } + return s + } return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) }