diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 0812964..7a0b9ed 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -42,6 +42,11 @@ type Config struct { // EndpointParams specifies additional parameters for requests to the token endpoint. EndpointParams url.Values + + // AuthStyle optionally specifies how the endpoint wants the + // client ID & client secret sent. The zero value means to + // auto-detect. + AuthStyle oauth2.AuthStyle } // Token uses client credentials to retrieve a token. @@ -97,7 +102,8 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { } v[k] = p } - tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v) + + tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle)) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { return nil, (*oauth2.RetrieveError)(rErr) diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 35cbd54..02a1c89 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -6,11 +6,14 @@ package clientcredentials import ( "context" + "io" "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" + + "golang.org/x/oauth2/internal" ) func newConf(serverURL string) *Config { @@ -111,21 +114,25 @@ func TestTokenRequest(t *testing.T) { } func TestTokenRefreshRequest(t *testing.T) { + internal.ResetAuthCache() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/somethingelse" { return } if r.URL.String() != "/token" { - t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + t.Errorf("Unexpected token refresh request URL: %q", r.URL) } headerContentType := r.Header.Get("Content-Type") - if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want { + t.Errorf("Content-Type = %q; want %q", got, want) } body, _ := ioutil.ReadAll(r.Body) - if string(body) != "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" { - t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) + const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" + if string(body) != want { + t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want) } + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`) })) defer ts.Close() conf := newConf(ts.URL) diff --git a/google/google.go b/google/google.go index ca7d208..4b0b547 100644 --- a/google/google.go +++ b/google/google.go @@ -19,8 +19,9 @@ import ( // Endpoint is Google's OAuth 2.0 endpoint. var Endpoint = oauth2.Endpoint{ - AuthURL: "https://accounts.google.com/o/oauth2/auth", - TokenURL: "https://accounts.google.com/o/oauth2/token", + AuthURL: "https://accounts.google.com/o/oauth2/auth", + TokenURL: "https://accounts.google.com/o/oauth2/token", + AuthStyle: oauth2.AuthStyleInParams, } // JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow. diff --git a/internal/token.go b/internal/token.go index a831b77..0f75a18 100644 --- a/internal/token.go +++ b/internal/token.go @@ -16,6 +16,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "golang.org/x/net/context/ctxhttp" @@ -90,102 +91,71 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { return nil } -var brokenAuthHeaderProviders = []string{ - "https://accounts.google.com/", - "https://api.codeswholesale.com/oauth/token", - "https://api.dropbox.com/", - "https://api.dropboxapi.com/", - "https://api.instagram.com/", - "https://api.netatmo.net/", - "https://api.odnoklassniki.ru/", - "https://api.pushbullet.com/", - "https://api.soundcloud.com/", - "https://api.twitch.tv/", - "https://id.twitch.tv/", - "https://app.box.com/", - "https://api.box.com/", - "https://connect.stripe.com/", - "https://login.mailchimp.com/", - "https://login.microsoftonline.com/", - "https://login.salesforce.com/", - "https://login.windows.net", - "https://login.live.com/", - "https://login.live-int.com/", - "https://oauth.sandbox.trainingpeaks.com/", - "https://oauth.trainingpeaks.com/", - "https://oauth.vk.com/", - "https://openapi.baidu.com/", - "https://slack.com/", - "https://test-sandbox.auth.corp.google.com", - "https://test.salesforce.com/", - "https://user.gini.net/", - "https://www.douban.com/", - "https://www.googleapis.com/", - "https://www.linkedin.com/", - "https://www.strava.com/oauth/", - "https://www.wunderlist.com/oauth/", - "https://api.patreon.com/", - "https://sandbox.codeswholesale.com/oauth/token", - "https://api.sipgate.com/v1/authorization/oauth", - "https://api.medium.com/v1/tokens", - "https://log.finalsurge.com/oauth/token", - "https://multisport.todaysplan.com.au/rest/oauth/access_token", - "https://whats.todaysplan.com.au/rest/oauth/access_token", - "https://stackoverflow.com/oauth/access_token", - "https://account.health.nokia.com", - "https://accounts.zoho.com", - "https://gitter.im/login/oauth/token", - "https://openid-connect.onelogin.com/oidc", - "https://api.dailymotion.com/oauth/token", +// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. +// +// Deprecated: this function no longer does anything. Caller code that +// wants to avoid potential extra HTTP requests made during +// auto-probing of the provider's auth style should set +// Endpoint.AuthStyle. +func RegisterBrokenAuthHeaderProvider(tokenURL string) {} + +// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. +type AuthStyle int + +const ( + AuthStyleUnknown AuthStyle = 0 + AuthStyleInParams AuthStyle = 1 + AuthStyleInHeader AuthStyle = 2 +) + +// authStyleCache is the set of tokenURLs we've successfully used via +// RetrieveToken and which style auth we ended up using. +// It's called a cache, but it doesn't (yet?) shrink. It's expected that +// the set of OAuth2 servers a program contacts over time is fixed and +// small. +var authStyleCache struct { + sync.Mutex + m map[string]AuthStyle // keyed by tokenURL } -// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints. -var brokenAuthHeaderDomains = []string{ - ".auth0.com", - ".force.com", - ".myshopify.com", - ".okta.com", - ".oktapreview.com", +// ResetAuthCache resets the global authentication style cache used +// for AuthStyleUnknown token requests. +func ResetAuthCache() { + authStyleCache.Lock() + defer authStyleCache.Unlock() + authStyleCache.m = nil } -func RegisterBrokenAuthHeaderProvider(tokenURL string) { - brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL) +// lookupAuthStyle reports which auth style we last used with tokenURL +// when calling RetrieveToken and whether we have ever done so. +func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { + authStyleCache.Lock() + defer authStyleCache.Unlock() + style, ok = authStyleCache.m[tokenURL] + return } -// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL -// implements the OAuth2 spec correctly -// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. -// In summary: -// - Reddit only accepts client secret in the Authorization header -// - Dropbox accepts either it in URL param or Auth header, but not both. -// - Google only accepts URL param (not spec compliant?), not Auth header -// - Stripe only accepts client secret in Auth header with Bearer method, not Basic -func providerAuthHeaderWorks(tokenURL string) bool { - for _, s := range brokenAuthHeaderProviders { - if strings.HasPrefix(tokenURL, s) { - // Some sites fail to implement the OAuth2 spec fully. - return false - } +// setAuthStyle adds an entry to authStyleCache, documented above. +func setAuthStyle(tokenURL string, v AuthStyle) { + authStyleCache.Lock() + defer authStyleCache.Unlock() + if authStyleCache.m == nil { + authStyleCache.m = make(map[string]AuthStyle) } - - if u, err := url.Parse(tokenURL); err == nil { - for _, s := range brokenAuthHeaderDomains { - if strings.HasSuffix(u.Host, s) { - return false - } - } - } - - // Assume the provider implements the spec properly - // otherwise. We can add more exceptions as they're - // discovered. We will _not_ be adding configurable hooks - // to this package to let users select server bugs. - return true + authStyleCache.m[tokenURL] = v } -func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { - bustedAuth := !providerAuthHeaderWorks(tokenURL) - if bustedAuth { +// newTokenRequest returns a new *http.Request to retrieve a new token +// from tokenURL using the provided clientID, clientSecret, and POST +// body parameters. +// +// inParams is whether the clientID & clientSecret should be encoded +// as the POST body. An 'inParams' value of true means to send it in +// the POST body (along with any values in v); false means to send it +// in the Authorization header. +func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) { + if authStyle == AuthStyleInParams { + v = cloneURLValues(v) if clientID != "" { v.Set("client_id", clientID) } @@ -198,15 +168,70 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if !bustedAuth { + if authStyle == AuthStyleInHeader { req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) } + return req, nil +} + +func cloneURLValues(v url.Values) url.Values { + v2 := make(url.Values, len(v)) + for k, vv := range v { + v2[k] = append([]string(nil), vv...) + } + return v2 +} + +func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { + needsAuthStyleProbe := authStyle == 0 + if needsAuthStyleProbe { + if style, ok := lookupAuthStyle(tokenURL); ok { + authStyle = style + needsAuthStyleProbe = false + } else { + authStyle = AuthStyleInHeader // the first way we'll try + } + } + req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) + if err != nil { + return nil, err + } + token, err := doTokenRoundTrip(ctx, req) + if err != nil && needsAuthStyleProbe { + // If we get an error, assume the server wants the + // clientID & clientSecret in a different form. + // See https://code.google.com/p/goauth2/issues/detail?id=31 for background. + // In summary: + // - Reddit only accepts client secret in the Authorization header + // - Dropbox accepts either it in URL param or Auth header, but not both. + // - Google only accepts URL param (not spec compliant?), not Auth header + // - Stripe only accepts client secret in Auth header with Bearer method, not Basic + // + // We used to maintain a big table in this code of all the sites and which way + // they went, but maintaining it didn't scale & got annoying. + // So just try both ways. + authStyle = AuthStyleInParams // the second way we'll try + req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) + token, err = doTokenRoundTrip(ctx, req) + } + if needsAuthStyleProbe && err == nil { + setAuthStyle(tokenURL, authStyle) + } + // Don't overwrite `RefreshToken` with an empty value + // if this was a token refreshing request. + if token != nil && token.RefreshToken == "" { + token.RefreshToken = v.Get("refresh_token") + } + return token, err +} + +func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { r, err := ctxhttp.Do(ctx, ContextClient(ctx), req) if err != nil { return nil, err } - defer r.Body.Close() body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) + r.Body.Close() if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } @@ -256,13 +281,8 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, } json.Unmarshal(body, &token.Raw) // no error checks for optional fields } - // Don't overwrite `RefreshToken` with an empty value - // if this was a token refreshing request. - if token.RefreshToken == "" { - token.RefreshToken = v.Get("refresh_token") - } if token.AccessToken == "" { - return token, errors.New("oauth2: server response missing access_token") + return nil, errors.New("oauth2: server response missing access_token") } return token, nil } diff --git a/internal/token_test.go b/internal/token_test.go index d1da8bb..d8373c2 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -6,7 +6,6 @@ package internal import ( "context" - "fmt" "io" "net/http" "net/http/httptest" @@ -14,17 +13,9 @@ import ( "testing" ) -func TestRegisterBrokenAuthHeaderProvider(t *testing.T) { - RegisterBrokenAuthHeaderProvider("https://aaa.com/") - tokenURL := "https://aaa.com/token" - if providerAuthHeaderWorks(tokenURL) { - t.Errorf("got %q as unbroken; want broken", tokenURL) - } -} - -func TestRetrieveTokenBustedNoSecret(t *testing.T) { +func TestRetrieveToken_InParams(t *testing.T) { + ResetAuthCache() const clientID = "client-id" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got, want := r.FormValue("client_id"), clientID; got != want { t.Errorf("client_id = %q; want %q", got, want) @@ -36,52 +27,14 @@ func TestRetrieveTokenBustedNoSecret(t *testing.T) { io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) })) defer ts.Close() - - RegisterBrokenAuthHeaderProvider(ts.URL) - _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}) + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams) if err != nil { t.Errorf("RetrieveToken = %v; want no error", err) } } -func Test_providerAuthHeaderWorks(t *testing.T) { - for _, p := range brokenAuthHeaderProviders { - if providerAuthHeaderWorks(p) { - t.Errorf("got %q as unbroken; want broken", p) - } - p := fmt.Sprintf("%ssomesuffix", p) - if providerAuthHeaderWorks(p) { - t.Errorf("got %q as unbroken; want broken", p) - } - } - p := "https://api.not-in-the-list-example.com/" - if !providerAuthHeaderWorks(p) { - t.Errorf("got %q as unbroken; want broken", p) - } -} - -func TestProviderAuthHeaderWorksDomain(t *testing.T) { - tests := []struct { - tokenURL string - wantWorks bool - }{ - {"https://dev-12345.okta.com/token-url", false}, - {"https://dev-12345.oktapreview.com/token-url", false}, - {"https://dev-12345.okta.org/token-url", true}, - {"https://foo.bar.force.com/token-url", false}, - {"https://foo.force.com/token-url", false}, - {"https://force.com/token-url", true}, - } - - for _, test := range tests { - got := providerAuthHeaderWorks(test.tokenURL) - if got != test.wantWorks { - t.Errorf("providerAuthHeaderWorks(%q) = %v; want %v", test.tokenURL, got, test.wantWorks) - } - } -} - func TestRetrieveTokenWithContexts(t *testing.T) { + ResetAuthCache() const clientID = "client-id" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -90,7 +43,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { })) defer ts.Close() - _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}) + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown) if err != nil { t.Errorf("RetrieveToken (with background context) = %v; want no error", err) } @@ -103,7 +56,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}) + _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown) close(retrieved) if err == nil { t.Errorf("RetrieveToken (with cancelled context) = nil; want error") diff --git a/linkedin/linkedin.go b/linkedin/linkedin.go index 62d1de8..d397277 100644 --- a/linkedin/linkedin.go +++ b/linkedin/linkedin.go @@ -11,6 +11,7 @@ import ( // Endpoint is LinkedIn's OAuth 2.0 endpoint. var Endpoint = oauth2.Endpoint{ - AuthURL: "https://www.linkedin.com/oauth/v2/authorization", - TokenURL: "https://www.linkedin.com/oauth/v2/accessToken", + AuthURL: "https://www.linkedin.com/oauth/v2/authorization", + TokenURL: "https://www.linkedin.com/oauth/v2/accessToken", + AuthStyle: oauth2.AuthStyleInParams, } diff --git a/oauth2.go b/oauth2.go index 3de6331..ec6ee00 100644 --- a/oauth2.go +++ b/oauth2.go @@ -26,17 +26,13 @@ import ( // Deprecated: Use context.Background() or context.TODO() instead. var NoContext = context.TODO() -// RegisterBrokenAuthHeaderProvider registers an OAuth2 server -// identified by the tokenURL prefix as an OAuth2 implementation -// which doesn't support the HTTP Basic authentication -// scheme to authenticate with the authorization server. -// Once a server is registered, credentials (client_id and client_secret) -// will be passed as parameters in the request body rather than being present -// in the Authorization header. -// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. -func RegisterBrokenAuthHeaderProvider(tokenURL string) { - internal.RegisterBrokenAuthHeaderProvider(tokenURL) -} +// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. +// +// Deprecated: this function no longer does anything. Caller code that +// wants to avoid potential extra HTTP requests made during +// auto-probing of the provider's auth style should set +// Endpoint.AuthStyle. +func RegisterBrokenAuthHeaderProvider(tokenURL string) {} // Config describes a typical 3-legged OAuth2 flow, with both the // client application information and the server's endpoint URLs. @@ -71,13 +67,38 @@ type TokenSource interface { Token() (*Token, error) } -// Endpoint contains the OAuth 2.0 provider's authorization and token +// Endpoint represents an OAuth 2.0 provider's authorization and token // endpoint URLs. type Endpoint struct { AuthURL string TokenURL string + + // AuthStyle optionally specifies how the endpoint wants the + // client ID & client secret sent. The zero value means to + // auto-detect. + AuthStyle AuthStyle } +// AuthStyle represents how requests for tokens are authenticated +// to the server. +type AuthStyle int + +const ( + // AuthStyleAutoDetect means to auto-detect which authentication + // style the provider wants by trying both ways and caching + // the successful way for the future. + AuthStyleAutoDetect AuthStyle = 0 + + // AuthStyleInParams sends the "client_id" and "client_secret" + // in the POST body as application/x-www-form-urlencoded parameters. + AuthStyleInParams AuthStyle = 1 + + // AuthStyleInHeader sends the client_id and client_password + // using HTTP Basic Authorization. This is an optional style + // described in the OAuth2 RFC 6749 section 2.3.1. + AuthStyleInHeader AuthStyle = 2 +) + var ( // AccessTypeOnline and AccessTypeOffline are options passed // to the Options.AuthCodeURL method. They modify the diff --git a/oauth2_test.go b/oauth2_test.go index 19aaf6b..a059b8b 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -8,12 +8,15 @@ import ( "context" "errors" "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" "time" + + "golang.org/x/oauth2/internal" ) type mockTransport struct { @@ -93,22 +96,22 @@ func TestURLUnsafeClientConfig(t *testing.T) { func TestExchangeRequest(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() != "/token" { - t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) + t.Errorf("Unexpected exchange request URL %q", r.URL) } headerAuth := r.Header.Get("Authorization") - if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { - t.Errorf("Unexpected authorization header, %v is found.", headerAuth) + if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want { + t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want) } headerContentType := r.Header.Get("Content-Type") if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + t.Errorf("Unexpected Content-Type header %q", headerContentType) } body, err := ioutil.ReadAll(r.Body) if err != nil { t.Errorf("Failed reading request body: %s.", err) } if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { - t.Errorf("Unexpected exchange payload, %v is found.", string(body)) + t.Errorf("Unexpected exchange payload; got %q", body) } w.Header().Set("Content-Type", "application/x-www-form-urlencoded") w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) @@ -343,11 +346,12 @@ func TestExchangeRequest_BadResponseType(t *testing.T) { } func TestExchangeRequest_NonBasicAuth(t *testing.T) { + internal.ResetAuthCache() tr := &mockTransport{ rt: func(r *http.Request) (w *http.Response, err error) { headerAuth := r.Header.Get("Authorization") if headerAuth != "" { - t.Errorf("Unexpected authorization header, %v is found.", headerAuth) + t.Errorf("Unexpected authorization header %q", headerAuth) } return nil, errors.New("no response") }, @@ -356,8 +360,9 @@ func TestExchangeRequest_NonBasicAuth(t *testing.T) { conf := &Config{ ClientID: "CLIENT_ID", Endpoint: Endpoint{ - AuthURL: "https://accounts.google.com/auth", - TokenURL: "https://accounts.google.com/token", + AuthURL: "https://accounts.google.com/auth", + TokenURL: "https://accounts.google.com/token", + AuthStyle: AuthStyleInParams, }, } @@ -413,21 +418,24 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { } func TestTokenRefreshRequest(t *testing.T) { + internal.ResetAuthCache() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/somethingelse" { return } if r.URL.String() != "/token" { - t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + t.Errorf("Unexpected token refresh request URL %q", r.URL) } headerContentType := r.Header.Get("Content-Type") if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + t.Errorf("Unexpected Content-Type header %q", headerContentType) } body, _ := ioutil.ReadAll(r.Body) if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { - t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) + t.Errorf("Unexpected refresh token payload %q", body) } + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`) })) defer ts.Close() conf := newConf(ts.URL) @@ -478,7 +486,7 @@ func TestTokenRetrieveError(t *testing.T) { } _, ok := err.(*RetrieveError) if !ok { - t.Fatalf("got %T error, expected *RetrieveError", err) + t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) } // Test error string for backwards compatibility expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`) diff --git a/token.go b/token.go index ee4be54..8227203 100644 --- a/token.go +++ b/token.go @@ -154,7 +154,7 @@ func tokenFromInternal(t *internal.Token) *Token { // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along // with an error.. func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { - tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) + tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle)) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { return nil, (*RetrieveError)(rErr)