oauth2: remove oauth2.Context type, simplify App Engine token code

You can now use the "google.golang.org/appengine" packages on both
Managed VMs and App Engine Classic(TM). The newer packages use the
context.Context instead of appengine.Context, so we no longer need the
oauth2.Context type.

Some clients will require code changes, replacing oauth2.Context or
appengine.Context with context.Context (imported from
the repository "golang.org/x/net/context").

Users of classic App Engine must switch to using the new
"google.golang.org/appengine" packages in order to use the oauth2
package.

Fixes #89

Change-Id: Ibaff3117117f9f7c5d1b3048a6e4086f62c18c3b
Reviewed-on: https://go-review.googlesource.com/6075
Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
Andrew Gerrand 2015-02-26 16:53:51 +11:00
parent a0fac97f6e
commit 96e89befdc
9 changed files with 104 additions and 193 deletions

View File

@ -2,38 +2,23 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build appengine,!appenginevm // +build appengine appenginevm
// App Engine hooks. // App Engine hooks.
package oauth2 package oauth2
import ( import (
"log"
"net/http" "net/http"
"sync"
"appengine" "golang.org/x/net/context"
"appengine/urlfetch" "google.golang.org/appengine/urlfetch"
) )
var warnOnce sync.Once
func init() { func init() {
registerContextClientFunc(contextClientAppEngine) registerContextClientFunc(contextClientAppEngine)
} }
func contextClientAppEngine(ctx Context) (*http.Client, error) { func contextClientAppEngine(ctx context.Context) (*http.Client, error) {
if actx, ok := ctx.(appengine.Context); ok { return urlfetch.Client(ctx), nil
return urlfetch.Client(actx), nil
}
// The user did it wrong. We'll log once (and hope they see it
// in dev_appserver), but stil return (nil, nil) in case some
// other contextClientFunc hook finds a way to proceed.
warnOnce.Do(gaeDoingItWrongHelp)
return nil, nil
}
func gaeDoingItWrongHelp() {
log.Printf("WARNING: you attempted to use the oauth2 package without passing a valid appengine.Context or *http.Request as the oauth2.Context. App Engine requires that all service RPCs (including urlfetch) be associated with an *http.Request/appengine.Context.")
} }

View File

@ -2,36 +2,82 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build appengine
package google package google
import ( import (
"sort"
"strings"
"sync"
"time" "time"
"appengine" "golang.org/x/net/context"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// Set at init time by appengine_hook.go. If nil, we're not on App Engine.
var appengineTokenFunc func(c context.Context, scopes ...string) (token string, expiry time.Time, err error)
// AppEngineTokenSource returns a token source that fetches tokens // AppEngineTokenSource returns a token source that fetches tokens
// issued to the current App Engine application's service account. // issued to the current App Engine application's service account.
// If you are implementing a 3-legged OAuth 2.0 flow on App Engine // If you are implementing a 3-legged OAuth 2.0 flow on App Engine
// that involves user accounts, see oauth2.Config instead. // that involves user accounts, see oauth2.Config instead.
// //
// The provided context must have come from appengine.NewContext. // The provided context must have come from appengine.NewContext.
func AppEngineTokenSource(ctx oauth2.Context, scope ...string) oauth2.TokenSource { func AppEngineTokenSource(ctx context.Context, scope ...string) oauth2.TokenSource {
if appengineTokenFunc == nil {
panic("google: AppEngineTokenSource can only be used on App Engine.")
}
scopes := append([]string{}, scope...)
sort.Strings(scopes)
return &appEngineTokenSource{ return &appEngineTokenSource{
ctx: ctx, ctx: ctx,
scopes: scope, scopes: scopes,
fetcherFunc: aeFetcherFunc, key: strings.Join(scopes, " "),
} }
} }
var aeFetcherFunc = func(ctx oauth2.Context, scope ...string) (string, time.Time, error) { // aeTokens helps the fetched tokens to be reused until their expiration.
c, ok := ctx.(appengine.Context) var (
if !ok { aeTokensMu sync.Mutex
return "", time.Time{}, errInvalidContext aeTokens = make(map[string]*tokenLock) // key is space-separated scopes
} )
return appengine.AccessToken(c, scope...)
type tokenLock struct {
mu sync.Mutex // guards t; held while fetching or updating t
t *oauth2.Token
}
type appEngineTokenSource struct {
ctx context.Context
scopes []string
key string // to aeTokens map; space-separated scopes
}
func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) {
if appengineTokenFunc == nil {
panic("google: AppEngineTokenSource can only be used on App Engine.")
}
aeTokensMu.Lock()
tok, ok := aeTokens[ts.key]
if !ok {
tok = &tokenLock{}
aeTokens[ts.key] = tok
}
aeTokensMu.Unlock()
tok.mu.Lock()
defer tok.mu.Unlock()
if tok.t.Valid() {
return tok.t, nil
}
access, exp, err := appengineTokenFunc(ts.ctx, ts.scopes...)
if err != nil {
return nil, err
}
tok.t = &oauth2.Token{
AccessToken: access,
Expiry: exp,
}
return tok.t, nil
} }

13
google/appengine_hook.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright 2015 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine appenginevm
package google
import "google.golang.org/appengine"
func init() {
appengineTokenFunc = appengine.AccessToken
}

View File

@ -1,19 +0,0 @@
// Copyright 2015 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine,!appenginevm
package google
import "golang.org/x/oauth2"
// AppEngineTokenSource returns a token source that fetches tokens
// issued to the current App Engine application's service account.
// If you are implementing a 3-legged OAuth 2.0 flow on App Engine
// that involves user accounts, see oauth2.Config instead.
//
// You are required to provide a valid appengine.Context as context.
func AppEngineTokenSource(ctx oauth2.Context, scope ...string) oauth2.TokenSource {
panic("You should only use an AppEngineTokenSource in an App Engine application.")
}

View File

@ -1,37 +0,0 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appenginevm
package google
import (
"time"
"golang.org/x/net/context"
"golang.org/x/oauth2"
"google.golang.org/appengine"
)
// AppEngineTokenSource returns a token source that fetches tokens
// issued to the current App Engine application's service account.
// If you are implementing a 3-legged OAuth 2.0 flow on App Engine
// that involves user accounts, see oauth2.Config instead.
//
// The provided context must have come from appengine.NewContext.
func AppEngineTokenSource(ctx oauth2.Context, scope ...string) oauth2.TokenSource {
return &appEngineTokenSource{
ctx: ctx,
scopes: scope,
fetcherFunc: aeVMFetcherFunc,
}
}
var aeVMFetcherFunc = func(ctx oauth2.Context, scope ...string) (string, time.Time, error) {
c, ok := ctx.(context.Context)
if !ok {
return "", time.Time{}, errInvalidContext
}
return appengine.AccessToken(c, scope...)
}

View File

@ -16,6 +16,7 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/net/context"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/internal" "golang.org/x/oauth2/internal"
) )
@ -123,7 +124,7 @@ func NewSDKConfig(account string) (*SDKConfig, error) {
// underlying http.RoundTripper will be obtained using the provided // underlying http.RoundTripper will be obtained using the provided
// context. The returned client and its Transport should not be // context. The returned client and its Transport should not be
// modified. // modified.
func (c *SDKConfig) Client(ctx oauth2.Context) *http.Client { func (c *SDKConfig) Client(ctx context.Context) *http.Client {
return &http.Client{ return &http.Client{
Transport: &oauth2.Transport{ Transport: &oauth2.Transport{
Source: c.TokenSource(ctx), Source: c.TokenSource(ctx),
@ -136,7 +137,7 @@ func (c *SDKConfig) Client(ctx oauth2.Context) *http.Client {
// It will returns the current access token stored in the credentials, // It will returns the current access token stored in the credentials,
// and refresh it when it expires, but it won't update the credentials // and refresh it when it expires, but it won't update the credentials
// with the new access token. // with the new access token.
func (c *SDKConfig) TokenSource(ctx oauth2.Context) oauth2.TokenSource { func (c *SDKConfig) TokenSource(ctx context.Context) oauth2.TokenSource {
return c.conf.TokenSource(ctx, c.initialToken) return c.conf.TokenSource(ctx, c.initialToken)
} }

View File

@ -1,71 +0,0 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package google
import (
"errors"
"sort"
"strings"
"sync"
"time"
"golang.org/x/oauth2"
)
var (
aeTokensMu sync.Mutex // guards aeTokens and appEngineTokenSource.key
// aeTokens helps the fetched tokens to be reused until their expiration.
aeTokens = make(map[string]*tokenLock) // key is '\0'-separated scopes
)
var errInvalidContext = errors.New("oauth2: context must come from appengine.NewContext")
type tokenLock struct {
mu sync.Mutex // guards t; held while updating t
t *oauth2.Token
}
type appEngineTokenSource struct {
ctx oauth2.Context
// fetcherFunc makes the actual RPC to fetch a new access
// token with an expiry time. Provider of this function is
// responsible to assert that the given context is valid.
fetcherFunc func(ctx oauth2.Context, scope ...string) (accessToken string, expiry time.Time, err error)
// scopes and key are guarded by the package-level mutex aeTokensMu
scopes []string
key string
}
func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) {
aeTokensMu.Lock()
if ts.key == "" {
sort.Sort(sort.StringSlice(ts.scopes))
ts.key = strings.Join(ts.scopes, string(0))
}
tok, ok := aeTokens[ts.key]
if !ok {
tok = &tokenLock{}
aeTokens[ts.key] = tok
}
aeTokensMu.Unlock()
tok.mu.Lock()
defer tok.mu.Unlock()
if tok.t.Valid() {
return tok.t, nil
}
access, exp, err := ts.fetcherFunc(ts.ctx, ts.scopes...)
if err != nil {
return nil, err
}
tok.t = &oauth2.Token{
AccessToken: access,
Expiry: exp,
}
return tok.t, nil
}

View File

@ -18,6 +18,7 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/net/context"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/internal" "golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws" "golang.org/x/oauth2/jws"
@ -57,7 +58,7 @@ type Config struct {
// TokenSource returns a JWT TokenSource using the configuration // TokenSource returns a JWT TokenSource using the configuration
// in c and the HTTP client from the provided context. // in c and the HTTP client from the provided context.
func (c *Config) TokenSource(ctx oauth2.Context) oauth2.TokenSource { func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
return oauth2.ReuseTokenSource(nil, jwtSource{ctx, c}) return oauth2.ReuseTokenSource(nil, jwtSource{ctx, c})
} }
@ -66,14 +67,14 @@ func (c *Config) TokenSource(ctx oauth2.Context) oauth2.TokenSource {
// obtained from c. // obtained from c.
// //
// The returned client and its Transport should not be modified. // The returned client and its Transport should not be modified.
func (c *Config) Client(ctx oauth2.Context) *http.Client { func (c *Config) Client(ctx context.Context) *http.Client {
return oauth2.NewClient(ctx, c.TokenSource(ctx)) return oauth2.NewClient(ctx, c.TokenSource(ctx))
} }
// jwtSource is a source that always does a signed JWT request for a token. // jwtSource is a source that always does a signed JWT request for a token.
// It should typically be wrapped with a reuseTokenSource. // It should typically be wrapped with a reuseTokenSource.
type jwtSource struct { type jwtSource struct {
ctx oauth2.Context ctx context.Context
conf *Config conf *Config
} }

View File

@ -25,14 +25,9 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
// Context can be an golang.org/x/net/context.Context, or an App Engine Context. // NoContext is the default context you should supply if not using
// If you don't care and aren't running on App Engine, you may use NoContext. // your own context.Context (see https://golang.org/x/net/context).
type Context interface{} var NoContext = context.TODO()
// NoContext is the default context. If you're not running this code
// on App Engine or not using golang.org/x/net/context.Context to provide a custom
// HTTP client, you should use NoContext.
var NoContext Context = nil
// Config describes a typical 3-legged OAuth2 flow, with both the // Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's endpoint URLs. // client application information and the server's endpoint URLs.
@ -143,9 +138,9 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
// and when other authorization grant types are not available." // and when other authorization grant types are not available."
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info. // See https://tools.ietf.org/html/rfc6749#section-4.3 for more info.
// //
// The HTTP client to use is derived from the context. If nil, // The HTTP client to use is derived from the context.
// http.DefaultClient is used. See the Context type's documentation. // If nil, http.DefaultClient is used.
func (c *Config) PasswordCredentialsToken(ctx Context, username, password string) (*Token, error) { func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
return retrieveToken(ctx, c, url.Values{ return retrieveToken(ctx, c, url.Values{
"grant_type": {"password"}, "grant_type": {"password"},
"username": {username}, "username": {username},
@ -159,12 +154,12 @@ func (c *Config) PasswordCredentialsToken(ctx Context, username, password string
// It is used after a resource provider redirects the user back // It is used after a resource provider redirects the user back
// to the Redirect URI (the URL obtained from AuthCodeURL). // to the Redirect URI (the URL obtained from AuthCodeURL).
// //
// The HTTP client to use is derived from the context. If nil, // The HTTP client to use is derived from the context.
// http.DefaultClient is used. See the Context type's documentation. // If nil, http.DefaultClient is used.
// //
// The code will be in the *http.Request.FormValue("code"). Before // The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state"). // calling Exchange, be sure to validate FormValue("state").
func (c *Config) Exchange(ctx Context, code string) (*Token, error) { func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) {
return retrieveToken(ctx, c, url.Values{ return retrieveToken(ctx, c, url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
"code": {code}, "code": {code},
@ -177,7 +172,7 @@ func (c *Config) Exchange(ctx Context, code string) (*Token, error) {
// given a Context value. If it returns an error, the search stops // given a Context value. If it returns an error, the search stops
// with that error. If it returns (nil, nil), the search continues // with that error. If it returns (nil, nil), the search continues
// down the list of registered funcs. // down the list of registered funcs.
type contextClientFunc func(Context) (*http.Client, error) type contextClientFunc func(context.Context) (*http.Client, error)
var contextClientFuncs []contextClientFunc var contextClientFuncs []contextClientFunc
@ -185,7 +180,7 @@ func registerContextClientFunc(fn contextClientFunc) {
contextClientFuncs = append(contextClientFuncs, fn) contextClientFuncs = append(contextClientFuncs, fn)
} }
func contextClient(ctx Context) (*http.Client, error) { func contextClient(ctx context.Context) (*http.Client, error) {
for _, fn := range contextClientFuncs { for _, fn := range contextClientFuncs {
c, err := fn(ctx) c, err := fn(ctx)
if err != nil { if err != nil {
@ -195,15 +190,13 @@ func contextClient(ctx Context) (*http.Client, error) {
return c, nil return c, nil
} }
} }
if xc, ok := ctx.(context.Context); ok { if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
if hc, ok := xc.Value(HTTPClient).(*http.Client); ok {
return hc, nil return hc, nil
} }
}
return http.DefaultClient, nil return http.DefaultClient, nil
} }
func contextTransport(ctx Context) http.RoundTripper { func contextTransport(ctx context.Context) http.RoundTripper {
hc, err := contextClient(ctx) hc, err := contextClient(ctx)
if err != nil { if err != nil {
// This is a rare error case (somebody using nil on App Engine), // This is a rare error case (somebody using nil on App Engine),
@ -219,16 +212,15 @@ func contextTransport(ctx Context) http.RoundTripper {
// The token will auto-refresh as necessary. The underlying // The token will auto-refresh as necessary. The underlying
// HTTP transport will be obtained using the provided context. // HTTP transport will be obtained using the provided context.
// The returned client and its Transport should not be modified. // The returned client and its Transport should not be modified.
func (c *Config) Client(ctx Context, t *Token) *http.Client { func (c *Config) Client(ctx context.Context, t *Token) *http.Client {
return NewClient(ctx, c.TokenSource(ctx, t)) return NewClient(ctx, c.TokenSource(ctx, t))
} }
// TokenSource returns a TokenSource that returns t until t expires, // TokenSource returns a TokenSource that returns t until t expires,
// automatically refreshing it as necessary using the provided context. // automatically refreshing it as necessary using the provided context.
// See the the Context documentation.
// //
// Most users will use Config.Client instead. // Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx Context, t *Token) TokenSource { func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
tkr := &tokenRefresher{ tkr := &tokenRefresher{
ctx: ctx, ctx: ctx,
conf: c, conf: c,
@ -245,7 +237,7 @@ func (c *Config) TokenSource(ctx Context, t *Token) TokenSource {
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" // tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
// HTTP requests to renew a token using a RefreshToken. // HTTP requests to renew a token using a RefreshToken.
type tokenRefresher struct { type tokenRefresher struct {
ctx Context // used to get HTTP requests ctx context.Context // used to get HTTP requests
conf *Config conf *Config
refreshToken string refreshToken string
} }
@ -301,7 +293,7 @@ func (s *reuseTokenSource) Token() (*Token, error) {
return t, nil return t, nil
} }
func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) { func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
hc, err := contextClient(ctx) hc, err := contextClient(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -451,7 +443,7 @@ type contextKey struct{}
// As a special case, if src is nil, a non-OAuth2 client is returned // As a special case, if src is nil, a non-OAuth2 client is returned
// using the provided context. This exists to support related OAuth2 // using the provided context. This exists to support related OAuth2
// packages. // packages.
func NewClient(ctx Context, src TokenSource) *http.Client { func NewClient(ctx context.Context, src TokenSource) *http.Client {
if src == nil { if src == nil {
c, err := contextClient(ctx) c, err := contextClient(ctx)
if err != nil { if err != nil {