diff --git a/client_appengine.go b/client_appengine.go index 10aaf91..4a554cb 100644 --- a/client_appengine.go +++ b/client_appengine.go @@ -12,11 +12,12 @@ import ( "net/http" "golang.org/x/net/context" + "golang.org/x/oauth2/internal" "google.golang.org/appengine/urlfetch" ) func init() { - registerContextClientFunc(contextClientAppEngine) + internal.RegisterContextClientFunc(contextClientAppEngine) } func contextClientAppEngine(ctx context.Context) (*http.Client, error) { diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go new file mode 100644 index 0000000..baebced --- /dev/null +++ b/clientcredentials/clientcredentials.go @@ -0,0 +1,112 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package clientcredentials implements the OAuth2.0 "client credentials" token flow, +// also known as the "two-legged OAuth 2.0". +// +// This should be used when the client is acting on its own behalf or when the client +// is the resource owner. It may also be used when requesting access to protected +// resources based on an authorization previously arranged with the authorization +// server. +// +// See http://tools.ietf.org/html/draft-ietf-oauth-v2-31#section-4.4 +package clientcredentials // import "golang.org/x/oauth2/clientcredentials" + +import ( + "net/http" + "net/url" + "strings" + + "golang.org/x/net/context" + "golang.org/x/oauth2" + "golang.org/x/oauth2/internal" +) + +// tokenFromInternal maps an *internal.Token struct into +// an *oauth2.Token struct. +func tokenFromInternal(t *internal.Token) *oauth2.Token { + if t == nil { + return nil + } + tk := &oauth2.Token{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + Expiry: t.Expiry, + } + return tk.WithExtra(t.Raw) +} + +// retrieveToken takes a *Config and uses that to retrieve an *internal.Token. +// This token is then mapped from *internal.Token into an *oauth2.Token which is +// returned along with an error. +func retrieveToken(ctx context.Context, c *Config, v url.Values) (*oauth2.Token, error) { + tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.TokenURL, v) + if err != nil { + return nil, err + } + return tokenFromInternal(tk), nil +} + +// Client Credentials Config describes a 2-legged OAuth2 flow, with both the +// client application information and the server's endpoint URLs. +type Config struct { + // ClientID is the application's ID. + ClientID string + + // ClientSecret is the application's secret. + ClientSecret string + + // TokenURL is the resource server's token endpoint + // URL. This is a constant specific to each server. + TokenURL string + + // Scope specifies optional requested permissions. + Scopes []string +} + +// Token uses client credentials to retreive a token. +// The HTTP client to use is derived from the context. +// If nil, http.DefaultClient is used. +func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { + return retrieveToken(ctx, c, url.Values{ + "grant_type": {"client_credentials"}, + "scope": internal.CondVal(strings.Join(c.Scopes, " ")), + }) +} + +// Client returns an HTTP client using the provided token. +// The token will auto-refresh as necessary. The underlying +// HTTP transport will be obtained using the provided context. +// The returned client and its Transport should not be modified. +func (c *Config) Client(ctx context.Context) *http.Client { + return oauth2.NewClient(ctx, c.TokenSource(ctx)) +} + +// TokenSource returns a TokenSource that returns t until t expires, +// automatically refreshing it as necessary using the provided context and the +// client ID and client secret. +// +// Most users will use Config.Client instead. +func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { + source := &tokenSource{ + ctx: ctx, + conf: c, + } + return oauth2.ReuseTokenSource(nil, source) +} + +type tokenSource struct { + ctx context.Context + conf *Config +} + +// Token refreshes the token by using a new client credentials request. +// tokens received this way do not include a refresh token +func (c *tokenSource) Token() (*oauth2.Token, error) { + return retrieveToken(c.ctx, c.conf, url.Values{ + "grant_type": {"client_credentials"}, + "scope": internal.CondVal(strings.Join(c.conf.Scopes, " ")), + }) +} diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go new file mode 100644 index 0000000..ab319e0 --- /dev/null +++ b/clientcredentials/clientcredentials_test.go @@ -0,0 +1,96 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package clientcredentials + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "golang.org/x/oauth2" +) + +func newConf(url string) *Config { + return &Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + Scopes: []string{"scope1", "scope2"}, + TokenURL: url + "/token", + } +} + +type mockTransport struct { + rt func(req *http.Request) (resp *http.Response, err error) +} + +func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return t.rt(req) +} + +func TestTokenRequest(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("authenticate client request URL = %q; want %q", r.URL, "/token") + } + headerAuth := r.Header.Get("Authorization") + if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { + t.Errorf("Unexpected authorization header, %v is found.", headerAuth) + } + if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { + t.Errorf("Content-Type header = %q; want %q", got, want) + } + body, err := ioutil.ReadAll(r.Body) + if err != nil { + r.Body.Close() + } + if err != nil { + t.Errorf("failed reading request body: %s.", err) + } + if string(body) != "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2" { + t.Errorf("payload = %q; want %q", string(body), "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2") + } + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) + })) + defer ts.Close() + conf := newConf(ts.URL) + tok, err := conf.Token(oauth2.NoContext) + if err != nil { + t.Error(err) + } + if !tok.Valid() { + t.Fatalf("token invalid. got: %#v", tok) + } + if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { + t.Errorf("Access token = %q; want %q", tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c") + } + if tok.TokenType != "bearer" { + t.Errorf("token type = %q; want %q", tok.TokenType, "bearer") + } +} + +func TestTokenRefreshRequest(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() == "/somethingelse" { + return + } + if r.URL.String() != "/token" { + t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + } + headerContentType := r.Header.Get("Content-Type") + if headerContentType != "application/x-www-form-urlencoded" { + t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + } + body, _ := ioutil.ReadAll(r.Body) + if string(body) != "client_id=CLIENT_ID&grant_type=client_credentials&scope=scope1+scope2" { + t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) + } + })) + defer ts.Close() + conf := newConf(ts.URL) + c := conf.Client(oauth2.NoContext) + c.Get(ts.URL + "/somethingelse") +} diff --git a/internal/oauth2.go b/internal/oauth2.go index 37571a1..dc8ebfc 100644 --- a/internal/oauth2.go +++ b/internal/oauth2.go @@ -67,3 +67,10 @@ func ParseINI(ini io.Reader) (map[string]map[string]string, error) { } return result, nil } + +func CondVal(v string) []string { + if v == "" { + return nil + } + return []string{v} +} diff --git a/internal/token.go b/internal/token.go new file mode 100644 index 0000000..8bb467e --- /dev/null +++ b/internal/token.go @@ -0,0 +1,210 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package internal contains support packages for oauth2 package. +package internal + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "mime" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "golang.org/x/net/context" +) + +// Token represents the crendentials used to authorize +// the requests to access protected resources on the OAuth 2.0 +// provider's backend. +// +// This type is a mirror of oauth2.Token and exists to break +// an otherwise-circular dependency. Other internal packages +// should convert this Token into an oauth2.Token before use. +type Token struct { + // AccessToken is the token that authorizes and authenticates + // the requests. + AccessToken string + + // TokenType is the type of token. + // The Type method returns either this or "Bearer", the default. + TokenType string + + // RefreshToken is a token that's used by the application + // (as opposed to the user) to refresh the access token + // if it expires. + RefreshToken string + + // Expiry is the optional expiration time of the access token. + // + // If zero, TokenSource implementations will reuse the same + // token forever and RefreshToken or equivalent + // mechanisms for that TokenSource will not be used. + Expiry time.Time + + // Raw optionally contains extra metadata from the server + // when updating a token. + Raw interface{} +} + +// tokenJSON is the struct representing the HTTP response from OAuth2 +// providers returning a token in JSON form. +type tokenJSON struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number + Expires expirationTime `json:"expires"` // broken Facebook spelling of expires_in +} + +func (e *tokenJSON) expiry() (t time.Time) { + if v := e.ExpiresIn; v != 0 { + return time.Now().Add(time.Duration(v) * time.Second) + } + if v := e.Expires; v != 0 { + return time.Now().Add(time.Duration(v) * time.Second) + } + return +} + +type expirationTime int32 + +func (e *expirationTime) UnmarshalJSON(b []byte) error { + var n json.Number + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + i, err := n.Int64() + if err != nil { + return err + } + *e = expirationTime(i) + return nil +} + +var brokenAuthHeaderProviders = []string{ + "https://accounts.google.com/", + "https://www.googleapis.com/", + "https://github.com/", + "https://api.instagram.com/", + "https://www.douban.com/", + "https://api.dropbox.com/", + "https://api.soundcloud.com/", + "https://www.linkedin.com/", + "https://api.twitch.tv/", + "https://oauth.vk.com/", + "https://api.odnoklassniki.ru/", + "https://connect.stripe.com/", + "https://api.pushbullet.com/", + "https://oauth.sandbox.trainingpeaks.com/", + "https://oauth.trainingpeaks.com/", + "https://www.strava.com/oauth/", +} + +// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL +// implements the OAuth2 spec correctly +// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. +// In summary: +// - Reddit only accepts client secret in the Authorization header +// - Dropbox accepts either it in URL param or Auth header, but not both. +// - Google only accepts URL param (not spec compliant?), not Auth header +// - Stripe only accepts client secret in Auth header with Bearer method, not Basic +func providerAuthHeaderWorks(tokenURL string) bool { + for _, s := range brokenAuthHeaderProviders { + if strings.HasPrefix(tokenURL, s) { + // Some sites fail to implement the OAuth2 spec fully. + return false + } + } + + // Assume the provider implements the spec properly + // otherwise. We can add more exceptions as they're + // discovered. We will _not_ be adding configurable hooks + // to this package to let users select server bugs. + return true +} + +func RetrieveToken(ctx context.Context, ClientID, ClientSecret, TokenURL string, v url.Values) (*Token, error) { + hc, err := ContextClient(ctx) + if err != nil { + return nil, err + } + v.Set("client_id", ClientID) + bustedAuth := !providerAuthHeaderWorks(TokenURL) + if bustedAuth && ClientSecret != "" { + v.Set("client_secret", ClientSecret) + } + req, err := http.NewRequest("POST", TokenURL, strings.NewReader(v.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if !bustedAuth { + req.SetBasicAuth(ClientID, ClientSecret) + } + r, err := hc.Do(req) + if err != nil { + return nil, err + } + defer r.Body.Close() + body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + if code := r.StatusCode; code < 200 || code > 299 { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) + } + + var token *Token + content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) + switch content { + case "application/x-www-form-urlencoded", "text/plain": + vals, err := url.ParseQuery(string(body)) + if err != nil { + return nil, err + } + token = &Token{ + AccessToken: vals.Get("access_token"), + TokenType: vals.Get("token_type"), + RefreshToken: vals.Get("refresh_token"), + Raw: vals, + } + e := vals.Get("expires_in") + if e == "" { + // TODO(jbd): Facebook's OAuth2 implementation is broken and + // returns expires_in field in expires. Remove the fallback to expires, + // when Facebook fixes their implementation. + e = vals.Get("expires") + } + expires, _ := strconv.Atoi(e) + if expires != 0 { + token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) + } + default: + var tj tokenJSON + if err = json.Unmarshal(body, &tj); err != nil { + return nil, err + } + token = &Token{ + AccessToken: tj.AccessToken, + TokenType: tj.TokenType, + RefreshToken: tj.RefreshToken, + Expiry: tj.expiry(), + Raw: make(map[string]interface{}), + } + json.Unmarshal(body, &token.Raw) // no error checks for optional fields + } + // Don't overwrite `RefreshToken` with an empty value + // if this was a token refreshing request. + if token.RefreshToken == "" { + token.RefreshToken = v.Get("refresh_token") + } + return token, nil +} diff --git a/internal/token_test.go b/internal/token_test.go new file mode 100644 index 0000000..864f6fa --- /dev/null +++ b/internal/token_test.go @@ -0,0 +1,28 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package internal contains support packages for oauth2 package. +package internal + +import ( + "fmt" + "testing" +) + +func Test_providerAuthHeaderWorks(t *testing.T) { + for _, p := range brokenAuthHeaderProviders { + if providerAuthHeaderWorks(p) { + t.Errorf("URL: %s not found in list", p) + } + p := fmt.Sprintf("%ssomesuffix", p) + if providerAuthHeaderWorks(p) { + t.Errorf("URL: %s not found in list", p) + } + } + p := "https://api.not-in-the-list-example.com/" + if !providerAuthHeaderWorks(p) { + t.Errorf("URL: %s found in list", p) + } + +} diff --git a/internal/transport.go b/internal/transport.go new file mode 100644 index 0000000..521e7b4 --- /dev/null +++ b/internal/transport.go @@ -0,0 +1,67 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package internal contains support packages for oauth2 package. +package internal + +import ( + "net/http" + + "golang.org/x/net/context" +) + +// HTTPClient is the context key to use with golang.org/x/net/context's +// WithValue function to associate an *http.Client value with a context. +var HTTPClient ContextKey + +// ContextKey is just an empty struct. It exists so HTTPClient can be +// an immutable public variable with a unique type. It's immutable +// because nobody else can create a ContextKey, being unexported. +type ContextKey struct{} + +// ContextClientFunc is a func which tries to return an *http.Client +// given a Context value. If it returns an error, the search stops +// with that error. If it returns (nil, nil), the search continues +// down the list of registered funcs. +type ContextClientFunc func(context.Context) (*http.Client, error) + +var contextClientFuncs []ContextClientFunc + +func RegisterContextClientFunc(fn ContextClientFunc) { + contextClientFuncs = append(contextClientFuncs, fn) +} + +func ContextClient(ctx context.Context) (*http.Client, error) { + for _, fn := range contextClientFuncs { + c, err := fn(ctx) + if err != nil { + return nil, err + } + if c != nil { + return c, nil + } + } + if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { + return hc, nil + } + return http.DefaultClient, nil +} + +func ContextTransport(ctx context.Context) http.RoundTripper { + hc, err := ContextClient(ctx) + // This is a rare error case (somebody using nil on App Engine). + if err != nil { + return ErrorTransport{err} + } + return hc.Transport +} + +// ErrorTransport returns the specified error on RoundTrip. +// This RoundTripper should be used in rare error cases where +// error handling can be postponed to response handling time. +type ErrorTransport struct{ Err error } + +func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) { + return nil, t.Err +} diff --git a/oauth2.go b/oauth2.go index 98b3c2c..b879351 100644 --- a/oauth2.go +++ b/oauth2.go @@ -9,20 +9,14 @@ package oauth2 // import "golang.org/x/oauth2" import ( "bytes" - "encoding/json" "errors" - "fmt" - "io" - "io/ioutil" - "mime" "net/http" "net/url" - "strconv" "strings" "sync" - "time" "golang.org/x/net/context" + "golang.org/x/oauth2/internal" ) // NoContext is the default context you should supply if not using @@ -119,9 +113,9 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { v := url.Values{ "response_type": {"code"}, "client_id": {c.ClientID}, - "redirect_uri": condVal(c.RedirectURL), - "scope": condVal(strings.Join(c.Scopes, " ")), - "state": condVal(state), + "redirect_uri": internal.CondVal(c.RedirectURL), + "scope": internal.CondVal(strings.Join(c.Scopes, " ")), + "state": internal.CondVal(state), } for _, opt := range opts { opt.setValue(v) @@ -151,7 +145,7 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor "grant_type": {"password"}, "username": {username}, "password": {password}, - "scope": condVal(strings.Join(c.Scopes, " ")), + "scope": internal.CondVal(strings.Join(c.Scopes, " ")), }) } @@ -169,51 +163,11 @@ func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) { return retrieveToken(ctx, c, url.Values{ "grant_type": {"authorization_code"}, "code": {code}, - "redirect_uri": condVal(c.RedirectURL), - "scope": condVal(strings.Join(c.Scopes, " ")), + "redirect_uri": internal.CondVal(c.RedirectURL), + "scope": internal.CondVal(strings.Join(c.Scopes, " ")), }) } -// contextClientFunc is a func which tries to return an *http.Client -// given a Context value. If it returns an error, the search stops -// with that error. If it returns (nil, nil), the search continues -// down the list of registered funcs. -type contextClientFunc func(context.Context) (*http.Client, error) - -var contextClientFuncs []contextClientFunc - -func registerContextClientFunc(fn contextClientFunc) { - contextClientFuncs = append(contextClientFuncs, fn) -} - -func contextClient(ctx context.Context) (*http.Client, error) { - for _, fn := range contextClientFuncs { - c, err := fn(ctx) - if err != nil { - return nil, err - } - if c != nil { - return c, nil - } - } - if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { - return hc, nil - } - return http.DefaultClient, nil -} - -func contextTransport(ctx context.Context) http.RoundTripper { - hc, err := contextClient(ctx) - if err != nil { - // This is a rare error case (somebody using nil on App Engine), - // so I'd rather not everybody do an error check on this Client - // method. They can get the error that they're doing it wrong - // later, at client.Get/PostForm time. - return errorTransport{err} - } - return hc.Transport -} - // Client returns an HTTP client using the provided token. // The token will auto-refresh as necessary. The underlying // HTTP transport will be obtained using the provided context. @@ -299,177 +253,9 @@ func (s *reuseTokenSource) Token() (*Token, error) { return t, nil } -func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { - hc, err := contextClient(ctx) - if err != nil { - return nil, err - } - v.Set("client_id", c.ClientID) - bustedAuth := !providerAuthHeaderWorks(c.Endpoint.TokenURL) - if bustedAuth && c.ClientSecret != "" { - v.Set("client_secret", c.ClientSecret) - } - req, err := http.NewRequest("POST", c.Endpoint.TokenURL, strings.NewReader(v.Encode())) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if !bustedAuth { - req.SetBasicAuth(c.ClientID, c.ClientSecret) - } - r, err := hc.Do(req) - if err != nil { - return nil, err - } - defer r.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - if code := r.StatusCode; code < 200 || code > 299 { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) - } - - var token *Token - content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) - switch content { - case "application/x-www-form-urlencoded", "text/plain": - vals, err := url.ParseQuery(string(body)) - if err != nil { - return nil, err - } - token = &Token{ - AccessToken: vals.Get("access_token"), - TokenType: vals.Get("token_type"), - RefreshToken: vals.Get("refresh_token"), - raw: vals, - } - e := vals.Get("expires_in") - if e == "" { - // TODO(jbd): Facebook's OAuth2 implementation is broken and - // returns expires_in field in expires. Remove the fallback to expires, - // when Facebook fixes their implementation. - e = vals.Get("expires") - } - expires, _ := strconv.Atoi(e) - if expires != 0 { - token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) - } - default: - var tj tokenJSON - if err = json.Unmarshal(body, &tj); err != nil { - return nil, err - } - token = &Token{ - AccessToken: tj.AccessToken, - TokenType: tj.TokenType, - RefreshToken: tj.RefreshToken, - Expiry: tj.expiry(), - raw: make(map[string]interface{}), - } - json.Unmarshal(body, &token.raw) // no error checks for optional fields - } - // Don't overwrite `RefreshToken` with an empty value - // if this was a token refreshing request. - if token.RefreshToken == "" { - token.RefreshToken = v.Get("refresh_token") - } - return token, nil -} - -// tokenJSON is the struct representing the HTTP response from OAuth2 -// providers returning a token in JSON form. -type tokenJSON struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - RefreshToken string `json:"refresh_token"` - ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number - Expires expirationTime `json:"expires"` // broken Facebook spelling of expires_in -} - -func (e *tokenJSON) expiry() (t time.Time) { - if v := e.ExpiresIn; v != 0 { - return time.Now().Add(time.Duration(v) * time.Second) - } - if v := e.Expires; v != 0 { - return time.Now().Add(time.Duration(v) * time.Second) - } - return -} - -type expirationTime int32 - -func (e *expirationTime) UnmarshalJSON(b []byte) error { - var n json.Number - err := json.Unmarshal(b, &n) - if err != nil { - return err - } - i, err := n.Int64() - if err != nil { - return err - } - *e = expirationTime(i) - return nil -} - -func condVal(v string) []string { - if v == "" { - return nil - } - return []string{v} -} - -var brokenAuthHeaderProviders = []string{ - "https://accounts.google.com/", - "https://www.googleapis.com/", - "https://github.com/", - "https://api.instagram.com/", - "https://www.douban.com/", - "https://api.dropbox.com/", - "https://api.soundcloud.com/", - "https://www.linkedin.com/", - "https://api.twitch.tv/", - "https://oauth.vk.com/", - "https://api.odnoklassniki.ru/", - "https://connect.stripe.com/", - "https://api.pushbullet.com/", - "https://oauth.sandbox.trainingpeaks.com/", - "https://oauth.trainingpeaks.com/", - "https://www.strava.com/oauth/", -} - -// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL -// implements the OAuth2 spec correctly -// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. -// In summary: -// - Reddit only accepts client secret in the Authorization header -// - Dropbox accepts either it in URL param or Auth header, but not both. -// - Google only accepts URL param (not spec compliant?), not Auth header -// - Stripe only accepts client secret in Auth header with Bearer method, not Basic -func providerAuthHeaderWorks(tokenURL string) bool { - for _, s := range brokenAuthHeaderProviders { - if strings.HasPrefix(tokenURL, s) { - // Some sites fail to implement the OAuth2 spec fully. - return false - } - } - - // Assume the provider implements the spec properly - // otherwise. We can add more exceptions as they're - // discovered. We will _not_ be adding configurable hooks - // to this package to let users select server bugs. - return true -} - // HTTPClient is the context key to use with golang.org/x/net/context's // WithValue function to associate an *http.Client value with a context. -var HTTPClient contextKey - -// contextKey is just an empty struct. It exists so HTTPClient can be -// an immutable public variable with a unique type. It's immutable -// because nobody else can create a contextKey, being unexported. -type contextKey struct{} +var HTTPClient internal.ContextKey // NewClient creates an *http.Client from a Context and TokenSource. // The returned client is not valid beyond the lifetime of the context. @@ -479,15 +265,15 @@ type contextKey struct{} // packages. func NewClient(ctx context.Context, src TokenSource) *http.Client { if src == nil { - c, err := contextClient(ctx) + c, err := internal.ContextClient(ctx) if err != nil { - return &http.Client{Transport: errorTransport{err}} + return &http.Client{Transport: internal.ErrorTransport{err}} } return c } return &http.Client{ Transport: &Transport{ - Base: contextTransport(ctx), + Base: internal.ContextTransport(ctx), Source: ReuseTokenSource(nil, src), }, } diff --git a/oauth2_test.go b/oauth2_test.go index 7b7e3c1..2f7d731 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -420,20 +420,3 @@ func TestConfigClientWithToken(t *testing.T) { t.Error(err) } } - -func Test_providerAuthHeaderWorks(t *testing.T) { - for _, p := range brokenAuthHeaderProviders { - if providerAuthHeaderWorks(p) { - t.Errorf("URL: %s not found in list", p) - } - p := fmt.Sprintf("%ssomesuffix", p) - if providerAuthHeaderWorks(p) { - t.Errorf("URL: %s not found in list", p) - } - } - p := "https://api.not-in-the-list-example.com/" - if !providerAuthHeaderWorks(p) { - t.Errorf("URL: %s found in list", p) - } - -} diff --git a/token.go b/token.go index 9852ceb..252cfc7 100644 --- a/token.go +++ b/token.go @@ -8,6 +8,9 @@ import ( "net/http" "net/url" "time" + + "golang.org/x/net/context" + "golang.org/x/oauth2/internal" ) // expiryDelta determines how earlier a token should be considered @@ -102,3 +105,29 @@ func (t *Token) expired() bool { func (t *Token) Valid() bool { return t != nil && t.AccessToken != "" && !t.expired() } + +// tokenFromInternal maps an *internal.Token struct into +// a *Token struct. +func tokenFromInternal(t *internal.Token) *Token { + if t == nil { + return nil + } + return &Token{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + Expiry: t.Expiry, + raw: t.Raw, + } +} + +// retrieveToken takes a *Config and uses that to retrieve an *internal.Token. +// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along +// with an error.. +func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { + tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) + if err != nil { + return nil, err + } + return tokenFromInternal(tk), nil +} diff --git a/transport.go b/transport.go index 10339a0..90db088 100644 --- a/transport.go +++ b/transport.go @@ -130,9 +130,3 @@ func (r *onEOFReader) runFunc() { r.fn = nil } } - -type errorTransport struct{ err error } - -func (t errorTransport) RoundTrip(*http.Request) (*http.Response, error) { - return nil, t.err -}