forked from Mirrors/oauth2
oauth2: auto-detect auth style by default, add Endpoint.AuthStyle
Instead of maintaining a global map of which OAuth2 servers do which auth style and/or requiring the user to tell us, just try both ways and remember which way worked. But if users want to tell us in the Endpoint, this CL also add Endpoint.AuthStyle. Fixes golang/oauth2#111 Fixes golang/oauth2#365 Fixes golang/oauth2#362 Fixes golang/oauth2#357 Fixes golang/oauth2#353 Fixes golang/oauth2#345 Fixes golang/oauth2#326 Fixes golang/oauth2#352 Fixes golang/oauth2#268 Fixes https://go-review.googlesource.com/c/oauth2/+/58510 (... and surely many more ...) Change-Id: I7b4d98ba1900ee2d3e11e629316b0bf867f7d237 Reviewed-on: https://go-review.googlesource.com/c/157820 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Ross Light <light@google.com>
This commit is contained in:
parent
99b60b757e
commit
80673b4a4b
|
@ -42,6 +42,11 @@ type Config struct {
|
||||||
|
|
||||||
// EndpointParams specifies additional parameters for requests to the token endpoint.
|
// EndpointParams specifies additional parameters for requests to the token endpoint.
|
||||||
EndpointParams url.Values
|
EndpointParams url.Values
|
||||||
|
|
||||||
|
// AuthStyle optionally specifies how the endpoint wants the
|
||||||
|
// client ID & client secret sent. The zero value means to
|
||||||
|
// auto-detect.
|
||||||
|
AuthStyle oauth2.AuthStyle
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token uses client credentials to retrieve a token.
|
// Token uses client credentials to retrieve a token.
|
||||||
|
@ -97,7 +102,8 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
|
||||||
}
|
}
|
||||||
v[k] = p
|
v[k] = p
|
||||||
}
|
}
|
||||||
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v)
|
|
||||||
|
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if rErr, ok := err.(*internal.RetrieveError); ok {
|
if rErr, ok := err.(*internal.RetrieveError); ok {
|
||||||
return nil, (*oauth2.RetrieveError)(rErr)
|
return nil, (*oauth2.RetrieveError)(rErr)
|
||||||
|
|
|
@ -6,11 +6,14 @@ package clientcredentials
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newConf(serverURL string) *Config {
|
func newConf(serverURL string) *Config {
|
||||||
|
@ -111,21 +114,25 @@ func TestTokenRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTokenRefreshRequest(t *testing.T) {
|
func TestTokenRefreshRequest(t *testing.T) {
|
||||||
|
internal.ResetAuthCache()
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.String() == "/somethingelse" {
|
if r.URL.String() == "/somethingelse" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.URL.String() != "/token" {
|
if r.URL.String() != "/token" {
|
||||||
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
|
t.Errorf("Unexpected token refresh request URL: %q", r.URL)
|
||||||
}
|
}
|
||||||
headerContentType := r.Header.Get("Content-Type")
|
headerContentType := r.Header.Get("Content-Type")
|
||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
|
||||||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
t.Errorf("Content-Type = %q; want %q", got, want)
|
||||||
}
|
}
|
||||||
body, _ := ioutil.ReadAll(r.Body)
|
body, _ := ioutil.ReadAll(r.Body)
|
||||||
if string(body) != "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" {
|
const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2"
|
||||||
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
|
if string(body) != want {
|
||||||
|
t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want)
|
||||||
}
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf := newConf(ts.URL)
|
conf := newConf(ts.URL)
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
var Endpoint = oauth2.Endpoint{
|
var Endpoint = oauth2.Endpoint{
|
||||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
TokenURL: "https://accounts.google.com/o/oauth2/token",
|
TokenURL: "https://accounts.google.com/o/oauth2/token",
|
||||||
|
AuthStyle: oauth2.AuthStyleInParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow.
|
// JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow.
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context/ctxhttp"
|
"golang.org/x/net/context/ctxhttp"
|
||||||
|
@ -90,102 +91,71 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var brokenAuthHeaderProviders = []string{
|
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
|
||||||
"https://accounts.google.com/",
|
//
|
||||||
"https://api.codeswholesale.com/oauth/token",
|
// Deprecated: this function no longer does anything. Caller code that
|
||||||
"https://api.dropbox.com/",
|
// wants to avoid potential extra HTTP requests made during
|
||||||
"https://api.dropboxapi.com/",
|
// auto-probing of the provider's auth style should set
|
||||||
"https://api.instagram.com/",
|
// Endpoint.AuthStyle.
|
||||||
"https://api.netatmo.net/",
|
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
|
||||||
"https://api.odnoklassniki.ru/",
|
|
||||||
"https://api.pushbullet.com/",
|
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
|
||||||
"https://api.soundcloud.com/",
|
type AuthStyle int
|
||||||
"https://api.twitch.tv/",
|
|
||||||
"https://id.twitch.tv/",
|
const (
|
||||||
"https://app.box.com/",
|
AuthStyleUnknown AuthStyle = 0
|
||||||
"https://api.box.com/",
|
AuthStyleInParams AuthStyle = 1
|
||||||
"https://connect.stripe.com/",
|
AuthStyleInHeader AuthStyle = 2
|
||||||
"https://login.mailchimp.com/",
|
)
|
||||||
"https://login.microsoftonline.com/",
|
|
||||||
"https://login.salesforce.com/",
|
// authStyleCache is the set of tokenURLs we've successfully used via
|
||||||
"https://login.windows.net",
|
// RetrieveToken and which style auth we ended up using.
|
||||||
"https://login.live.com/",
|
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
||||||
"https://login.live-int.com/",
|
// the set of OAuth2 servers a program contacts over time is fixed and
|
||||||
"https://oauth.sandbox.trainingpeaks.com/",
|
// small.
|
||||||
"https://oauth.trainingpeaks.com/",
|
var authStyleCache struct {
|
||||||
"https://oauth.vk.com/",
|
sync.Mutex
|
||||||
"https://openapi.baidu.com/",
|
m map[string]AuthStyle // keyed by tokenURL
|
||||||
"https://slack.com/",
|
|
||||||
"https://test-sandbox.auth.corp.google.com",
|
|
||||||
"https://test.salesforce.com/",
|
|
||||||
"https://user.gini.net/",
|
|
||||||
"https://www.douban.com/",
|
|
||||||
"https://www.googleapis.com/",
|
|
||||||
"https://www.linkedin.com/",
|
|
||||||
"https://www.strava.com/oauth/",
|
|
||||||
"https://www.wunderlist.com/oauth/",
|
|
||||||
"https://api.patreon.com/",
|
|
||||||
"https://sandbox.codeswholesale.com/oauth/token",
|
|
||||||
"https://api.sipgate.com/v1/authorization/oauth",
|
|
||||||
"https://api.medium.com/v1/tokens",
|
|
||||||
"https://log.finalsurge.com/oauth/token",
|
|
||||||
"https://multisport.todaysplan.com.au/rest/oauth/access_token",
|
|
||||||
"https://whats.todaysplan.com.au/rest/oauth/access_token",
|
|
||||||
"https://stackoverflow.com/oauth/access_token",
|
|
||||||
"https://account.health.nokia.com",
|
|
||||||
"https://accounts.zoho.com",
|
|
||||||
"https://gitter.im/login/oauth/token",
|
|
||||||
"https://openid-connect.onelogin.com/oidc",
|
|
||||||
"https://api.dailymotion.com/oauth/token",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
|
// ResetAuthCache resets the global authentication style cache used
|
||||||
var brokenAuthHeaderDomains = []string{
|
// for AuthStyleUnknown token requests.
|
||||||
".auth0.com",
|
func ResetAuthCache() {
|
||||||
".force.com",
|
authStyleCache.Lock()
|
||||||
".myshopify.com",
|
defer authStyleCache.Unlock()
|
||||||
".okta.com",
|
authStyleCache.m = nil
|
||||||
".oktapreview.com",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
|
// lookupAuthStyle reports which auth style we last used with tokenURL
|
||||||
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL)
|
// when calling RetrieveToken and whether we have ever done so.
|
||||||
|
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
|
||||||
|
authStyleCache.Lock()
|
||||||
|
defer authStyleCache.Unlock()
|
||||||
|
style, ok = authStyleCache.m[tokenURL]
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
|
// setAuthStyle adds an entry to authStyleCache, documented above.
|
||||||
// implements the OAuth2 spec correctly
|
func setAuthStyle(tokenURL string, v AuthStyle) {
|
||||||
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
authStyleCache.Lock()
|
||||||
// In summary:
|
defer authStyleCache.Unlock()
|
||||||
// - Reddit only accepts client secret in the Authorization header
|
if authStyleCache.m == nil {
|
||||||
// - Dropbox accepts either it in URL param or Auth header, but not both.
|
authStyleCache.m = make(map[string]AuthStyle)
|
||||||
// - Google only accepts URL param (not spec compliant?), not Auth header
|
|
||||||
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
|
|
||||||
func providerAuthHeaderWorks(tokenURL string) bool {
|
|
||||||
for _, s := range brokenAuthHeaderProviders {
|
|
||||||
if strings.HasPrefix(tokenURL, s) {
|
|
||||||
// Some sites fail to implement the OAuth2 spec fully.
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
authStyleCache.m[tokenURL] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
if u, err := url.Parse(tokenURL); err == nil {
|
// newTokenRequest returns a new *http.Request to retrieve a new token
|
||||||
for _, s := range brokenAuthHeaderDomains {
|
// from tokenURL using the provided clientID, clientSecret, and POST
|
||||||
if strings.HasSuffix(u.Host, s) {
|
// body parameters.
|
||||||
return false
|
//
|
||||||
}
|
// inParams is whether the clientID & clientSecret should be encoded
|
||||||
}
|
// as the POST body. An 'inParams' value of true means to send it in
|
||||||
}
|
// the POST body (along with any values in v); false means to send it
|
||||||
|
// in the Authorization header.
|
||||||
// Assume the provider implements the spec properly
|
func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) {
|
||||||
// otherwise. We can add more exceptions as they're
|
if authStyle == AuthStyleInParams {
|
||||||
// discovered. We will _not_ be adding configurable hooks
|
v = cloneURLValues(v)
|
||||||
// to this package to let users select server bugs.
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
|
|
||||||
bustedAuth := !providerAuthHeaderWorks(tokenURL)
|
|
||||||
if bustedAuth {
|
|
||||||
if clientID != "" {
|
if clientID != "" {
|
||||||
v.Set("client_id", clientID)
|
v.Set("client_id", clientID)
|
||||||
}
|
}
|
||||||
|
@ -198,15 +168,70 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
if !bustedAuth {
|
if authStyle == AuthStyleInHeader {
|
||||||
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
|
req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
|
||||||
}
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneURLValues(v url.Values) url.Values {
|
||||||
|
v2 := make(url.Values, len(v))
|
||||||
|
for k, vv := range v {
|
||||||
|
v2[k] = append([]string(nil), vv...)
|
||||||
|
}
|
||||||
|
return v2
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
|
||||||
|
needsAuthStyleProbe := authStyle == 0
|
||||||
|
if needsAuthStyleProbe {
|
||||||
|
if style, ok := lookupAuthStyle(tokenURL); ok {
|
||||||
|
authStyle = style
|
||||||
|
needsAuthStyleProbe = false
|
||||||
|
} else {
|
||||||
|
authStyle = AuthStyleInHeader // the first way we'll try
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
token, err := doTokenRoundTrip(ctx, req)
|
||||||
|
if err != nil && needsAuthStyleProbe {
|
||||||
|
// If we get an error, assume the server wants the
|
||||||
|
// clientID & clientSecret in a different form.
|
||||||
|
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
||||||
|
// In summary:
|
||||||
|
// - Reddit only accepts client secret in the Authorization header
|
||||||
|
// - Dropbox accepts either it in URL param or Auth header, but not both.
|
||||||
|
// - Google only accepts URL param (not spec compliant?), not Auth header
|
||||||
|
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
|
||||||
|
//
|
||||||
|
// We used to maintain a big table in this code of all the sites and which way
|
||||||
|
// they went, but maintaining it didn't scale & got annoying.
|
||||||
|
// So just try both ways.
|
||||||
|
authStyle = AuthStyleInParams // the second way we'll try
|
||||||
|
req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
|
||||||
|
token, err = doTokenRoundTrip(ctx, req)
|
||||||
|
}
|
||||||
|
if needsAuthStyleProbe && err == nil {
|
||||||
|
setAuthStyle(tokenURL, authStyle)
|
||||||
|
}
|
||||||
|
// Don't overwrite `RefreshToken` with an empty value
|
||||||
|
// if this was a token refreshing request.
|
||||||
|
if token != nil && token.RefreshToken == "" {
|
||||||
|
token.RefreshToken = v.Get("refresh_token")
|
||||||
|
}
|
||||||
|
return token, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
|
||||||
r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
|
r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Body.Close()
|
|
||||||
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
|
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||||
|
r.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -256,13 +281,8 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
||||||
}
|
}
|
||||||
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
|
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
|
||||||
}
|
}
|
||||||
// Don't overwrite `RefreshToken` with an empty value
|
|
||||||
// if this was a token refreshing request.
|
|
||||||
if token.RefreshToken == "" {
|
|
||||||
token.RefreshToken = v.Get("refresh_token")
|
|
||||||
}
|
|
||||||
if token.AccessToken == "" {
|
if token.AccessToken == "" {
|
||||||
return token, errors.New("oauth2: server response missing access_token")
|
return nil, errors.New("oauth2: server response missing access_token")
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -14,17 +13,9 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRegisterBrokenAuthHeaderProvider(t *testing.T) {
|
func TestRetrieveToken_InParams(t *testing.T) {
|
||||||
RegisterBrokenAuthHeaderProvider("https://aaa.com/")
|
ResetAuthCache()
|
||||||
tokenURL := "https://aaa.com/token"
|
|
||||||
if providerAuthHeaderWorks(tokenURL) {
|
|
||||||
t.Errorf("got %q as unbroken; want broken", tokenURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRetrieveTokenBustedNoSecret(t *testing.T) {
|
|
||||||
const clientID = "client-id"
|
const clientID = "client-id"
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if got, want := r.FormValue("client_id"), clientID; got != want {
|
if got, want := r.FormValue("client_id"), clientID; got != want {
|
||||||
t.Errorf("client_id = %q; want %q", got, want)
|
t.Errorf("client_id = %q; want %q", got, want)
|
||||||
|
@ -36,52 +27,14 @@ func TestRetrieveTokenBustedNoSecret(t *testing.T) {
|
||||||
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
|
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams)
|
||||||
RegisterBrokenAuthHeaderProvider(ts.URL)
|
|
||||||
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("RetrieveToken = %v; want no error", err)
|
t.Errorf("RetrieveToken = %v; want no error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_providerAuthHeaderWorks(t *testing.T) {
|
|
||||||
for _, p := range brokenAuthHeaderProviders {
|
|
||||||
if providerAuthHeaderWorks(p) {
|
|
||||||
t.Errorf("got %q as unbroken; want broken", p)
|
|
||||||
}
|
|
||||||
p := fmt.Sprintf("%ssomesuffix", p)
|
|
||||||
if providerAuthHeaderWorks(p) {
|
|
||||||
t.Errorf("got %q as unbroken; want broken", p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p := "https://api.not-in-the-list-example.com/"
|
|
||||||
if !providerAuthHeaderWorks(p) {
|
|
||||||
t.Errorf("got %q as unbroken; want broken", p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderAuthHeaderWorksDomain(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
tokenURL string
|
|
||||||
wantWorks bool
|
|
||||||
}{
|
|
||||||
{"https://dev-12345.okta.com/token-url", false},
|
|
||||||
{"https://dev-12345.oktapreview.com/token-url", false},
|
|
||||||
{"https://dev-12345.okta.org/token-url", true},
|
|
||||||
{"https://foo.bar.force.com/token-url", false},
|
|
||||||
{"https://foo.force.com/token-url", false},
|
|
||||||
{"https://force.com/token-url", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
got := providerAuthHeaderWorks(test.tokenURL)
|
|
||||||
if got != test.wantWorks {
|
|
||||||
t.Errorf("providerAuthHeaderWorks(%q) = %v; want %v", test.tokenURL, got, test.wantWorks)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRetrieveTokenWithContexts(t *testing.T) {
|
func TestRetrieveTokenWithContexts(t *testing.T) {
|
||||||
|
ResetAuthCache()
|
||||||
const clientID = "client-id"
|
const clientID = "client-id"
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -90,7 +43,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{})
|
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
|
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
|
||||||
}
|
}
|
||||||
|
@ -103,7 +56,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel()
|
cancel()
|
||||||
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{})
|
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown)
|
||||||
close(retrieved)
|
close(retrieved)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")
|
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")
|
||||||
|
|
|
@ -13,4 +13,5 @@ import (
|
||||||
var Endpoint = oauth2.Endpoint{
|
var Endpoint = oauth2.Endpoint{
|
||||||
AuthURL: "https://www.linkedin.com/oauth/v2/authorization",
|
AuthURL: "https://www.linkedin.com/oauth/v2/authorization",
|
||||||
TokenURL: "https://www.linkedin.com/oauth/v2/accessToken",
|
TokenURL: "https://www.linkedin.com/oauth/v2/accessToken",
|
||||||
|
AuthStyle: oauth2.AuthStyleInParams,
|
||||||
}
|
}
|
||||||
|
|
45
oauth2.go
45
oauth2.go
|
@ -26,17 +26,13 @@ import (
|
||||||
// Deprecated: Use context.Background() or context.TODO() instead.
|
// Deprecated: Use context.Background() or context.TODO() instead.
|
||||||
var NoContext = context.TODO()
|
var NoContext = context.TODO()
|
||||||
|
|
||||||
// RegisterBrokenAuthHeaderProvider registers an OAuth2 server
|
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
|
||||||
// identified by the tokenURL prefix as an OAuth2 implementation
|
//
|
||||||
// which doesn't support the HTTP Basic authentication
|
// Deprecated: this function no longer does anything. Caller code that
|
||||||
// scheme to authenticate with the authorization server.
|
// wants to avoid potential extra HTTP requests made during
|
||||||
// Once a server is registered, credentials (client_id and client_secret)
|
// auto-probing of the provider's auth style should set
|
||||||
// will be passed as parameters in the request body rather than being present
|
// Endpoint.AuthStyle.
|
||||||
// in the Authorization header.
|
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
|
||||||
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
|
||||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
|
|
||||||
internal.RegisterBrokenAuthHeaderProvider(tokenURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config describes a typical 3-legged OAuth2 flow, with both the
|
// Config describes a typical 3-legged OAuth2 flow, with both the
|
||||||
// client application information and the server's endpoint URLs.
|
// client application information and the server's endpoint URLs.
|
||||||
|
@ -71,13 +67,38 @@ type TokenSource interface {
|
||||||
Token() (*Token, error)
|
Token() (*Token, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Endpoint contains the OAuth 2.0 provider's authorization and token
|
// Endpoint represents an OAuth 2.0 provider's authorization and token
|
||||||
// endpoint URLs.
|
// endpoint URLs.
|
||||||
type Endpoint struct {
|
type Endpoint struct {
|
||||||
AuthURL string
|
AuthURL string
|
||||||
TokenURL string
|
TokenURL string
|
||||||
|
|
||||||
|
// AuthStyle optionally specifies how the endpoint wants the
|
||||||
|
// client ID & client secret sent. The zero value means to
|
||||||
|
// auto-detect.
|
||||||
|
AuthStyle AuthStyle
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthStyle represents how requests for tokens are authenticated
|
||||||
|
// to the server.
|
||||||
|
type AuthStyle int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// AuthStyleAutoDetect means to auto-detect which authentication
|
||||||
|
// style the provider wants by trying both ways and caching
|
||||||
|
// the successful way for the future.
|
||||||
|
AuthStyleAutoDetect AuthStyle = 0
|
||||||
|
|
||||||
|
// AuthStyleInParams sends the "client_id" and "client_secret"
|
||||||
|
// in the POST body as application/x-www-form-urlencoded parameters.
|
||||||
|
AuthStyleInParams AuthStyle = 1
|
||||||
|
|
||||||
|
// AuthStyleInHeader sends the client_id and client_password
|
||||||
|
// using HTTP Basic Authorization. This is an optional style
|
||||||
|
// described in the OAuth2 RFC 6749 section 2.3.1.
|
||||||
|
AuthStyleInHeader AuthStyle = 2
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// AccessTypeOnline and AccessTypeOffline are options passed
|
// AccessTypeOnline and AccessTypeOffline are options passed
|
||||||
// to the Options.AuthCodeURL method. They modify the
|
// to the Options.AuthCodeURL method. They modify the
|
||||||
|
|
|
@ -8,12 +8,15 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockTransport struct {
|
type mockTransport struct {
|
||||||
|
@ -93,22 +96,22 @@ func TestURLUnsafeClientConfig(t *testing.T) {
|
||||||
func TestExchangeRequest(t *testing.T) {
|
func TestExchangeRequest(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.String() != "/token" {
|
if r.URL.String() != "/token" {
|
||||||
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
|
t.Errorf("Unexpected exchange request URL %q", r.URL)
|
||||||
}
|
}
|
||||||
headerAuth := r.Header.Get("Authorization")
|
headerAuth := r.Header.Get("Authorization")
|
||||||
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
|
if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want {
|
||||||
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
|
t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want)
|
||||||
}
|
}
|
||||||
headerContentType := r.Header.Get("Content-Type")
|
headerContentType := r.Header.Get("Content-Type")
|
||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
t.Errorf("Unexpected Content-Type header %q", headerContentType)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed reading request body: %s.", err)
|
t.Errorf("Failed reading request body: %s.", err)
|
||||||
}
|
}
|
||||||
if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
|
if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
|
||||||
t.Errorf("Unexpected exchange payload, %v is found.", string(body))
|
t.Errorf("Unexpected exchange payload; got %q", body)
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
|
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
|
||||||
|
@ -343,11 +346,12 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
||||||
|
internal.ResetAuthCache()
|
||||||
tr := &mockTransport{
|
tr := &mockTransport{
|
||||||
rt: func(r *http.Request) (w *http.Response, err error) {
|
rt: func(r *http.Request) (w *http.Response, err error) {
|
||||||
headerAuth := r.Header.Get("Authorization")
|
headerAuth := r.Header.Get("Authorization")
|
||||||
if headerAuth != "" {
|
if headerAuth != "" {
|
||||||
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
|
t.Errorf("Unexpected authorization header %q", headerAuth)
|
||||||
}
|
}
|
||||||
return nil, errors.New("no response")
|
return nil, errors.New("no response")
|
||||||
},
|
},
|
||||||
|
@ -358,6 +362,7 @@ func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
||||||
Endpoint: Endpoint{
|
Endpoint: Endpoint{
|
||||||
AuthURL: "https://accounts.google.com/auth",
|
AuthURL: "https://accounts.google.com/auth",
|
||||||
TokenURL: "https://accounts.google.com/token",
|
TokenURL: "https://accounts.google.com/token",
|
||||||
|
AuthStyle: AuthStyleInParams,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -413,21 +418,24 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTokenRefreshRequest(t *testing.T) {
|
func TestTokenRefreshRequest(t *testing.T) {
|
||||||
|
internal.ResetAuthCache()
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.String() == "/somethingelse" {
|
if r.URL.String() == "/somethingelse" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.URL.String() != "/token" {
|
if r.URL.String() != "/token" {
|
||||||
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
|
t.Errorf("Unexpected token refresh request URL %q", r.URL)
|
||||||
}
|
}
|
||||||
headerContentType := r.Header.Get("Content-Type")
|
headerContentType := r.Header.Get("Content-Type")
|
||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
t.Errorf("Unexpected Content-Type header %q", headerContentType)
|
||||||
}
|
}
|
||||||
body, _ := ioutil.ReadAll(r.Body)
|
body, _ := ioutil.ReadAll(r.Body)
|
||||||
if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
||||||
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
|
t.Errorf("Unexpected refresh token payload %q", body)
|
||||||
}
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf := newConf(ts.URL)
|
conf := newConf(ts.URL)
|
||||||
|
@ -478,7 +486,7 @@ func TestTokenRetrieveError(t *testing.T) {
|
||||||
}
|
}
|
||||||
_, ok := err.(*RetrieveError)
|
_, ok := err.(*RetrieveError)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("got %T error, expected *RetrieveError", err)
|
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
|
||||||
}
|
}
|
||||||
// Test error string for backwards compatibility
|
// Test error string for backwards compatibility
|
||||||
expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
|
expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
|
||||||
|
|
2
token.go
2
token.go
|
@ -154,7 +154,7 @@ func tokenFromInternal(t *internal.Token) *Token {
|
||||||
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
||||||
// with an error..
|
// with an error..
|
||||||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
||||||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v)
|
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if rErr, ok := err.(*internal.RetrieveError); ok {
|
if rErr, ok := err.(*internal.RetrieveError); ok {
|
||||||
return nil, (*RetrieveError)(rErr)
|
return nil, (*RetrieveError)(rErr)
|
||||||
|
|
Loading…
Reference in New Issue