oauth2, oauth2/google: add, use ReuseTokenSource

Token caching is now done whenever you make a Client, and
ReuseTokenSource is exported from the oauth2 package and used by the
Google TokenSources (Compute and App Engine).

Token.Expired is now Token.Valid, and works on nil receivers.

Some other wording cleanups in the process.

All tests pass. App Engine should pass, but is untested.

Change-Id: Ibe1d2599ac3ccfe9b399b1672f74bb24cfc8d311
Reviewed-on: https://go-review.googlesource.com/2195
Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
Brad Fitzpatrick 2014-12-30 13:25:01 -08:00
parent e5909d4679
commit a379e41d44
10 changed files with 97 additions and 68 deletions

View File

@ -50,7 +50,6 @@ func ExampleConfig() {
} }
func ExampleJWTConfig() { func ExampleJWTConfig() {
var initialToken *oauth2.Token // nil means no initial token
conf := &oauth2.JWTConfig{ conf := &oauth2.JWTConfig{
Email: "xxx@developer.com", Email: "xxx@developer.com",
// The contents of your RSA private key or your PEM file // The contents of your RSA private key or your PEM file
@ -67,6 +66,6 @@ func ExampleJWTConfig() {
} }
// Initiate an http.Client, the following GET request will be // Initiate an http.Client, the following GET request will be
// authorized and authenticated on the behalf of user@example.com. // authorized and authenticated on the behalf of user@example.com.
client := conf.Client(oauth2.NoContext, initialToken) client := conf.Client(oauth2.NoContext)
client.Get("...") client.Get("...")
} }

View File

@ -69,7 +69,7 @@ func ExampleJWTConfigFromJSON() {
// Initiate an http.Client. The following GET request will be // Initiate an http.Client. The following GET request will be
// authorized and authenticated on the behalf of // authorized and authenticated on the behalf of
// your service account. // your service account.
client := conf.Client(oauth2.NoContext, nil) client := conf.Client(oauth2.NoContext)
client.Get("...") client.Get("...")
} }
@ -101,7 +101,7 @@ func Example_serviceAccount() {
} }
// Initiate an http.Client, the following GET request will be // Initiate an http.Client, the following GET request will be
// authorized and authenticated on the behalf of user@example.com. // authorized and authenticated on the behalf of user@example.com.
client := conf.Client(oauth2.NoContext, nil) client := conf.Client(oauth2.NoContext)
client.Get("...") client.Get("...")
} }

View File

@ -15,7 +15,6 @@ package google // import "golang.org/x/oauth2/google"
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -24,6 +23,9 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// TODO(bradfitz,jbd): import "google.golang.org/cloud/compute/metadata" instead of
// the metaClient and metadata.google.internal stuff below.
// Endpoint is Google's OAuth 2.0 endpoint. // Endpoint is Google's OAuth 2.0 endpoint.
var Endpoint = oauth2.Endpoint{ var Endpoint = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth", AuthURL: "https://accounts.google.com/o/oauth2/auth",
@ -66,7 +68,7 @@ type metaTokenRespBody struct {
// Further information about retrieving access tokens from the GCE metadata // Further information about retrieving access tokens from the GCE metadata
// server can be found at https://cloud.google.com/compute/docs/authentication. // server can be found at https://cloud.google.com/compute/docs/authentication.
func ComputeTokenSource(account string) oauth2.TokenSource { func ComputeTokenSource(account string) oauth2.TokenSource {
return &computeSource{account: account} return oauth2.ReuseTokenSource(nil, &computeSource{account: account})
} }
type computeSource struct { type computeSource struct {

View File

@ -29,13 +29,16 @@ type tokenLock struct {
} }
type appEngineTokenSource struct { type appEngineTokenSource struct {
ctx oauth2.Context ctx oauth2.Context
scopes []string
key string // guarded by package-level mutex, aeTokensMu
// fetcherFunc makes the actual RPC to fetch a new access token with an expiry time. // fetcherFunc makes the actual RPC to fetch a new access
// Provider of this function is responsible to assert that the given context is valid. // token with an expiry time. Provider of this function is
fetcherFunc func(ctx oauth2.Context, scope ...string) (string, time.Time, error) // responsible to assert that the given context is valid.
fetcherFunc func(ctx oauth2.Context, scope ...string) (accessToken string, expiry time.Time, err error)
// scopes and key are guarded by the package-level mutex aeTokensMu
scopes []string
key string
} }
func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) { func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) {
@ -53,7 +56,7 @@ func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) {
tok.mu.Lock() tok.mu.Lock()
defer tok.mu.Unlock() defer tok.mu.Unlock()
if tok.t != nil && !tok.t.Expired() { if tok.t.Valid() {
return tok.t, nil return tok.t, nil
} }
access, exp, err := ts.fetcherFunc(ts.ctx, ts.scopes...) access, exp, err := ts.fetcherFunc(ts.ctx, ts.scopes...)

22
jwt.go
View File

@ -52,33 +52,21 @@ type JWTConfig struct {
// TokenSource returns a JWT TokenSource using the configuration // TokenSource returns a JWT TokenSource using the configuration
// in c and the HTTP client from the provided context. // in c and the HTTP client from the provided context.
// func (c *JWTConfig) TokenSource(ctx Context) TokenSource {
// The returned TokenSource only does JWT requests when necessary but return ReuseTokenSource(nil, jwtSource{ctx, c})
// otherwise returns the same token repeatedly until it expires.
//
// The provided initialToken may be nil, in which case the first
// call to TokenSource will do a new JWT request.
func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource {
return &newWhenNeededSource{
t: initialToken,
new: jwtSource{ctx, c},
}
} }
// Client returns an HTTP client wrapping the context's // Client returns an HTTP client wrapping the context's
// HTTP transport and adding Authorization headers with tokens // HTTP transport and adding Authorization headers with tokens
// obtained from c. // obtained from c.
// //
// The provided initialToken may be nil, in which case the first
// call to TokenSource will do a new JWT request.
//
// The returned client and its Transport should not be modified. // The returned client and its Transport should not be modified.
func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client { func (c *JWTConfig) Client(ctx Context) *http.Client {
return NewClient(ctx, c.TokenSource(ctx, initialToken)) return NewClient(ctx, c.TokenSource(ctx))
} }
// jwtSource is a source that always does a signed JWT request for a token. // jwtSource is a source that always does a signed JWT request for a token.
// It should typically be wrapped with a newWhenNeededSource. // It should typically be wrapped with a reuseTokenSource.
type jwtSource struct { type jwtSource struct {
ctx Context ctx Context
conf *JWTConfig conf *JWTConfig

View File

@ -55,12 +55,12 @@ func TestJWTFetch_JSONResponse(t *testing.T) {
PrivateKey: dummyPrivateKey, PrivateKey: dummyPrivateKey,
TokenURL: ts.URL, TokenURL: ts.URL,
} }
tok, err := conf.TokenSource(NoContext, nil).Token() tok, err := conf.TokenSource(NoContext).Token()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tok.Expired() { if !tok.Valid() {
t.Errorf("Token shouldn't be expired") t.Errorf("Token invalid")
} }
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v", tok.AccessToken) t.Errorf("Unexpected access token, %#v", tok.AccessToken)
@ -89,19 +89,25 @@ func TestJWTFetch_BadResponse(t *testing.T) {
PrivateKey: dummyPrivateKey, PrivateKey: dummyPrivateKey,
TokenURL: ts.URL, TokenURL: ts.URL,
} }
tok, err := conf.TokenSource(NoContext, nil).Token() tok, err := conf.TokenSource(NoContext).Token()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tok.AccessToken != "" { if tok == nil {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) t.Fatalf("token is nil")
} }
if tok.TokenType != "bearer" { if tok.Valid() {
t.Errorf("Unexpected token type, %#v.", tok.TokenType) t.Errorf("token is valid. want invalid.")
}
if tok.AccessToken != "" {
t.Errorf("Unexpected non-empty access token %q.", tok.AccessToken)
}
if want := "bearer"; tok.TokenType != want {
t.Errorf("TokenType = %q; want %q", tok.TokenType, want)
} }
scope := tok.Extra("scope") scope := tok.Extra("scope")
if scope != "user" { if want := "user"; scope != want {
t.Errorf("Unexpected value for scope: %v", scope) t.Errorf("token scope = %q; want %q", scope, want)
} }
} }
@ -116,7 +122,7 @@ func TestJWTFetch_BadResponseType(t *testing.T) {
PrivateKey: dummyPrivateKey, PrivateKey: dummyPrivateKey,
TokenURL: ts.URL, TokenURL: ts.URL,
} }
tok, err := conf.TokenSource(NoContext, nil).Token() tok, err := conf.TokenSource(NoContext).Token()
if err == nil { if err == nil {
t.Error("got a token; expected error") t.Error("got a token; expected error")
if tok.AccessToken != "" { if tok.AccessToken != "" {

View File

@ -26,7 +26,6 @@ import (
) )
// Context can be an golang.org/x/net.Context, or an App Engine Context. // Context can be an golang.org/x/net.Context, or an App Engine Context.
// In the future these will be unified.
// If you don't care and aren't running on App Engine, you may use NoContext. // If you don't care and aren't running on App Engine, you may use NoContext.
type Context interface{} type Context interface{}
@ -36,7 +35,7 @@ type Context interface{}
var NoContext Context = nil var NoContext Context = nil
// Config describes a typical 3-legged OAuth2 flow, with both the // Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's URLs. // client application information and the server's endpoint URLs.
type Config struct { type Config struct {
// ClientID is the application's ID. // ClientID is the application's ID.
ClientID string ClientID string
@ -45,9 +44,9 @@ type Config struct {
ClientSecret string ClientSecret string
// Endpoint contains the resource server's token endpoint // Endpoint contains the resource server's token endpoint
// URLs. These are supplied by the server and are often // URLs. These are constants specific to each server and are
// available via site-specific packages (for example, // often available via site-specific packages, such as
// google.Endpoint or github.Endpoint) // google.Endpoint or github.Endpoint.
Endpoint Endpoint Endpoint Endpoint
// RedirectURL is the URL to redirect users going through // RedirectURL is the URL to redirect users going through
@ -61,6 +60,7 @@ type Config struct {
// A TokenSource is anything that can return a token. // A TokenSource is anything that can return a token.
type TokenSource interface { type TokenSource interface {
// Token returns a token or an error. // Token returns a token or an error.
// Token must be safe for concurrent use by multiple goroutines.
Token() (*Token, error) Token() (*Token, error)
} }
@ -208,7 +208,7 @@ func (c *Config) Client(ctx Context, t *Token) *http.Client {
// //
// Most users will use Config.Client instead. // Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx Context, t *Token) TokenSource { func (c *Config) TokenSource(ctx Context, t *Token) TokenSource {
nwn := &newWhenNeededSource{t: t} nwn := &reuseTokenSource{t: t}
nwn.new = tokenRefresher{ nwn.new = tokenRefresher{
ctx: ctx, ctx: ctx,
conf: c, conf: c,
@ -239,13 +239,13 @@ func (tf tokenRefresher) Token() (*Token, error) {
}) })
} }
// newWhenNeededSource is a TokenSource that holds a single token in memory // reuseTokenSource is a TokenSource that holds a single token in memory
// and validates its expiry before each call to retrieve it with // and validates its expiry before each call to retrieve it with
// Token. If it's expired, it will be auto-refreshed using the // Token. If it's expired, it will be auto-refreshed using the
// new TokenSource. // new TokenSource.
// //
// The first call to TokenRefresher must be SetToken. // The first call to TokenRefresher must be SetToken.
type newWhenNeededSource struct { type reuseTokenSource struct {
new TokenSource // called when t is expired. new TokenSource // called when t is expired.
mu sync.Mutex // guards t mu sync.Mutex // guards t
@ -255,10 +255,10 @@ type newWhenNeededSource struct {
// Token returns the current token if it's still valid, else will // Token returns the current token if it's still valid, else will
// refresh the current token (using r.Context for HTTP client // refresh the current token (using r.Context for HTTP client
// information) and return the new one. // information) and return the new one.
func (s *newWhenNeededSource) Token() (*Token, error) { func (s *reuseTokenSource) Token() (*Token, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.t != nil && !s.t.Expired() { if s.t.Valid() {
return s.t, nil return s.t, nil
} }
t, err := s.new.Token() t, err := s.new.Token()
@ -410,12 +410,41 @@ var HTTPClient contextKey
type contextKey struct{} type contextKey struct{}
// NewClient creates an *http.Client from a Context and TokenSource. // NewClient creates an *http.Client from a Context and TokenSource.
// The client's lifetime does not extend beyond the lifetime of the context. // The returned client is not valid beyond the lifetime of the context.
func NewClient(ctx Context, src TokenSource) *http.Client { func NewClient(ctx Context, src TokenSource) *http.Client {
return &http.Client{ return &http.Client{
Transport: &Transport{ Transport: &Transport{
Base: contextTransport(ctx), Base: contextTransport(ctx),
Source: src, Source: ReuseTokenSource(nil, src),
}, },
} }
} }
// ReuseTokenSource returns a TokenSource which repeatedly returns the
// same token as long as it's valid, starting with t.
// When its cached token is invalid, a new token is obtained from src.
//
// ReuseTokenSource is typically used to reuse tokens from a cache
// (such as a file on disk) between runs of a program, rather than
// obtaining new tokens unnecessarily.
//
// The initial token t may be nil, in which case the TokenSource is
// wrapped in a caching version if it isn't one already. This also
// means it's always safe to wrap ReuseTokenSource around any other
// TokenSource without adverse effects.
func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
// Don't wrap a reuseTokenSource in itself. That would work,
// but cause an unnecessary number of mutex operations.
// Just build the equivalent one.
if rt, ok := src.(*reuseTokenSource); ok {
if t == nil {
// Just use it directly.
return rt
}
src = rt.new
}
return &reuseTokenSource{
t: t,
new: src,
}
}

View File

@ -99,8 +99,8 @@ func TestExchangeRequest(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if tok.Expired() { if !tok.Valid() {
t.Errorf("Token shouldn't be expired.") t.Fatalf("Token invalid. Got: %#v", tok)
} }
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
@ -143,8 +143,8 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if tok.Expired() { if !tok.Valid() {
t.Errorf("Token shouldn't be expired.") t.Fatalf("Token invalid. Got: %#v", tok)
} }
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) t.Errorf("Unexpected access token, %#v.", tok.AccessToken)

View File

@ -74,14 +74,16 @@ func (t *Token) Extra(key string) string {
return "" return ""
} }
// Expired returns true if there is no access token or the // expired reports whether the token is expired.
// access token is expired. // t must be non-nil.
func (t *Token) Expired() bool { func (t *Token) expired() bool {
if t.AccessToken == "" {
return true
}
if t.Expiry.IsZero() { if t.Expiry.IsZero() {
return false return false
} }
return t.Expiry.Before(time.Now()) return t.Expiry.Before(time.Now())
} }
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
func (t *Token) Valid() bool {
return t != nil && t.AccessToken != "" && !t.expired()
}

View File

@ -32,10 +32,10 @@ func TestTransportTokenSource(t *testing.T) {
client.Get(server.URL) client.Get(server.URL)
} }
func TestExpiredWithNoAccessToken(t *testing.T) { func TestTokenValidNoAccessToken(t *testing.T) {
token := &Token{} token := &Token{}
if !token.Expired() { if token.Valid() {
t.Errorf("Token should be expired if no access token is provided") t.Errorf("Token should not be valid with no access token")
} }
} }
@ -43,8 +43,8 @@ func TestExpiredWithExpiry(t *testing.T) {
token := &Token{ token := &Token{
Expiry: time.Now().Add(-5 * time.Hour), Expiry: time.Now().Add(-5 * time.Hour),
} }
if !token.Expired() { if token.Valid() {
t.Errorf("Token should be expired if no access token is provided") t.Errorf("Token should not be valid if it expired in the past")
} }
} }