oauth2, oauth2/google: add, use ReuseTokenSource

Token caching is now done whenever you make a Client, and
ReuseTokenSource is exported from the oauth2 package and used by the
Google TokenSources (Compute and App Engine).

Token.Expired is now Token.Valid, and works on nil receivers.

Some other wording cleanups in the process.

All tests pass. App Engine should pass, but is untested.

Change-Id: Ibe1d2599ac3ccfe9b399b1672f74bb24cfc8d311
Reviewed-on: https://go-review.googlesource.com/2195
Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
Brad Fitzpatrick 2014-12-30 13:25:01 -08:00
parent e5909d4679
commit a379e41d44
10 changed files with 97 additions and 68 deletions

View File

@ -50,7 +50,6 @@ func ExampleConfig() {
}
func ExampleJWTConfig() {
var initialToken *oauth2.Token // nil means no initial token
conf := &oauth2.JWTConfig{
Email: "xxx@developer.com",
// The contents of your RSA private key or your PEM file
@ -67,6 +66,6 @@ func ExampleJWTConfig() {
}
// Initiate an http.Client, the following GET request will be
// authorized and authenticated on the behalf of user@example.com.
client := conf.Client(oauth2.NoContext, initialToken)
client := conf.Client(oauth2.NoContext)
client.Get("...")
}

View File

@ -69,7 +69,7 @@ func ExampleJWTConfigFromJSON() {
// Initiate an http.Client. The following GET request will be
// authorized and authenticated on the behalf of
// your service account.
client := conf.Client(oauth2.NoContext, nil)
client := conf.Client(oauth2.NoContext)
client.Get("...")
}
@ -101,7 +101,7 @@ func Example_serviceAccount() {
}
// Initiate an http.Client, the following GET request will be
// authorized and authenticated on the behalf of user@example.com.
client := conf.Client(oauth2.NoContext, nil)
client := conf.Client(oauth2.NoContext)
client.Get("...")
}

View File

@ -15,7 +15,6 @@ package google // import "golang.org/x/oauth2/google"
import (
"encoding/json"
"fmt"
"net"
"net/http"
@ -24,6 +23,9 @@ import (
"golang.org/x/oauth2"
)
// TODO(bradfitz,jbd): import "google.golang.org/cloud/compute/metadata" instead of
// the metaClient and metadata.google.internal stuff below.
// Endpoint is Google's OAuth 2.0 endpoint.
var Endpoint = oauth2.Endpoint{
AuthURL: "https://accounts.google.com/o/oauth2/auth",
@ -66,7 +68,7 @@ type metaTokenRespBody struct {
// Further information about retrieving access tokens from the GCE metadata
// server can be found at https://cloud.google.com/compute/docs/authentication.
func ComputeTokenSource(account string) oauth2.TokenSource {
return &computeSource{account: account}
return oauth2.ReuseTokenSource(nil, &computeSource{account: account})
}
type computeSource struct {

View File

@ -29,13 +29,16 @@ type tokenLock struct {
}
type appEngineTokenSource struct {
ctx oauth2.Context
scopes []string
key string // guarded by package-level mutex, aeTokensMu
ctx oauth2.Context
// fetcherFunc makes the actual RPC to fetch a new access token with an expiry time.
// Provider of this function is responsible to assert that the given context is valid.
fetcherFunc func(ctx oauth2.Context, scope ...string) (string, time.Time, error)
// fetcherFunc makes the actual RPC to fetch a new access
// token with an expiry time. Provider of this function is
// responsible to assert that the given context is valid.
fetcherFunc func(ctx oauth2.Context, scope ...string) (accessToken string, expiry time.Time, err error)
// scopes and key are guarded by the package-level mutex aeTokensMu
scopes []string
key string
}
func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) {
@ -53,7 +56,7 @@ func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) {
tok.mu.Lock()
defer tok.mu.Unlock()
if tok.t != nil && !tok.t.Expired() {
if tok.t.Valid() {
return tok.t, nil
}
access, exp, err := ts.fetcherFunc(ts.ctx, ts.scopes...)

22
jwt.go
View File

@ -52,33 +52,21 @@ type JWTConfig struct {
// TokenSource returns a JWT TokenSource using the configuration
// in c and the HTTP client from the provided context.
//
// The returned TokenSource only does JWT requests when necessary but
// otherwise returns the same token repeatedly until it expires.
//
// The provided initialToken may be nil, in which case the first
// call to TokenSource will do a new JWT request.
func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource {
return &newWhenNeededSource{
t: initialToken,
new: jwtSource{ctx, c},
}
func (c *JWTConfig) TokenSource(ctx Context) TokenSource {
return ReuseTokenSource(nil, jwtSource{ctx, c})
}
// Client returns an HTTP client wrapping the context's
// HTTP transport and adding Authorization headers with tokens
// obtained from c.
//
// The provided initialToken may be nil, in which case the first
// call to TokenSource will do a new JWT request.
//
// The returned client and its Transport should not be modified.
func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client {
return NewClient(ctx, c.TokenSource(ctx, initialToken))
func (c *JWTConfig) Client(ctx Context) *http.Client {
return NewClient(ctx, c.TokenSource(ctx))
}
// jwtSource is a source that always does a signed JWT request for a token.
// It should typically be wrapped with a newWhenNeededSource.
// It should typically be wrapped with a reuseTokenSource.
type jwtSource struct {
ctx Context
conf *JWTConfig

View File

@ -55,12 +55,12 @@ func TestJWTFetch_JSONResponse(t *testing.T) {
PrivateKey: dummyPrivateKey,
TokenURL: ts.URL,
}
tok, err := conf.TokenSource(NoContext, nil).Token()
tok, err := conf.TokenSource(NoContext).Token()
if err != nil {
t.Fatal(err)
}
if tok.Expired() {
t.Errorf("Token shouldn't be expired")
if !tok.Valid() {
t.Errorf("Token invalid")
}
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v", tok.AccessToken)
@ -89,19 +89,25 @@ func TestJWTFetch_BadResponse(t *testing.T) {
PrivateKey: dummyPrivateKey,
TokenURL: ts.URL,
}
tok, err := conf.TokenSource(NoContext, nil).Token()
tok, err := conf.TokenSource(NoContext).Token()
if err != nil {
t.Fatal(err)
}
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
if tok == nil {
t.Fatalf("token is nil")
}
if tok.TokenType != "bearer" {
t.Errorf("Unexpected token type, %#v.", tok.TokenType)
if tok.Valid() {
t.Errorf("token is valid. want invalid.")
}
if tok.AccessToken != "" {
t.Errorf("Unexpected non-empty access token %q.", tok.AccessToken)
}
if want := "bearer"; tok.TokenType != want {
t.Errorf("TokenType = %q; want %q", tok.TokenType, want)
}
scope := tok.Extra("scope")
if scope != "user" {
t.Errorf("Unexpected value for scope: %v", scope)
if want := "user"; scope != want {
t.Errorf("token scope = %q; want %q", scope, want)
}
}
@ -116,7 +122,7 @@ func TestJWTFetch_BadResponseType(t *testing.T) {
PrivateKey: dummyPrivateKey,
TokenURL: ts.URL,
}
tok, err := conf.TokenSource(NoContext, nil).Token()
tok, err := conf.TokenSource(NoContext).Token()
if err == nil {
t.Error("got a token; expected error")
if tok.AccessToken != "" {

View File

@ -26,7 +26,6 @@ import (
)
// Context can be an golang.org/x/net.Context, or an App Engine Context.
// In the future these will be unified.
// If you don't care and aren't running on App Engine, you may use NoContext.
type Context interface{}
@ -36,7 +35,7 @@ type Context interface{}
var NoContext Context = nil
// Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's URLs.
// client application information and the server's endpoint URLs.
type Config struct {
// ClientID is the application's ID.
ClientID string
@ -45,9 +44,9 @@ type Config struct {
ClientSecret string
// Endpoint contains the resource server's token endpoint
// URLs. These are supplied by the server and are often
// available via site-specific packages (for example,
// google.Endpoint or github.Endpoint)
// URLs. These are constants specific to each server and are
// often available via site-specific packages, such as
// google.Endpoint or github.Endpoint.
Endpoint Endpoint
// RedirectURL is the URL to redirect users going through
@ -61,6 +60,7 @@ type Config struct {
// A TokenSource is anything that can return a token.
type TokenSource interface {
// Token returns a token or an error.
// Token must be safe for concurrent use by multiple goroutines.
Token() (*Token, error)
}
@ -208,7 +208,7 @@ func (c *Config) Client(ctx Context, t *Token) *http.Client {
//
// Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx Context, t *Token) TokenSource {
nwn := &newWhenNeededSource{t: t}
nwn := &reuseTokenSource{t: t}
nwn.new = tokenRefresher{
ctx: ctx,
conf: c,
@ -239,13 +239,13 @@ func (tf tokenRefresher) Token() (*Token, error) {
})
}
// newWhenNeededSource is a TokenSource that holds a single token in memory
// reuseTokenSource is a TokenSource that holds a single token in memory
// and validates its expiry before each call to retrieve it with
// Token. If it's expired, it will be auto-refreshed using the
// new TokenSource.
//
// The first call to TokenRefresher must be SetToken.
type newWhenNeededSource struct {
type reuseTokenSource struct {
new TokenSource // called when t is expired.
mu sync.Mutex // guards t
@ -255,10 +255,10 @@ type newWhenNeededSource struct {
// Token returns the current token if it's still valid, else will
// refresh the current token (using r.Context for HTTP client
// information) and return the new one.
func (s *newWhenNeededSource) Token() (*Token, error) {
func (s *reuseTokenSource) Token() (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.t != nil && !s.t.Expired() {
if s.t.Valid() {
return s.t, nil
}
t, err := s.new.Token()
@ -410,12 +410,41 @@ var HTTPClient contextKey
type contextKey struct{}
// NewClient creates an *http.Client from a Context and TokenSource.
// The client's lifetime does not extend beyond the lifetime of the context.
// The returned client is not valid beyond the lifetime of the context.
func NewClient(ctx Context, src TokenSource) *http.Client {
return &http.Client{
Transport: &Transport{
Base: contextTransport(ctx),
Source: src,
Source: ReuseTokenSource(nil, src),
},
}
}
// ReuseTokenSource returns a TokenSource which repeatedly returns the
// same token as long as it's valid, starting with t.
// When its cached token is invalid, a new token is obtained from src.
//
// ReuseTokenSource is typically used to reuse tokens from a cache
// (such as a file on disk) between runs of a program, rather than
// obtaining new tokens unnecessarily.
//
// The initial token t may be nil, in which case the TokenSource is
// wrapped in a caching version if it isn't one already. This also
// means it's always safe to wrap ReuseTokenSource around any other
// TokenSource without adverse effects.
func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
// Don't wrap a reuseTokenSource in itself. That would work,
// but cause an unnecessary number of mutex operations.
// Just build the equivalent one.
if rt, ok := src.(*reuseTokenSource); ok {
if t == nil {
// Just use it directly.
return rt
}
src = rt.new
}
return &reuseTokenSource{
t: t,
new: src,
}
}

View File

@ -99,8 +99,8 @@ func TestExchangeRequest(t *testing.T) {
if err != nil {
t.Error(err)
}
if tok.Expired() {
t.Errorf("Token shouldn't be expired.")
if !tok.Valid() {
t.Fatalf("Token invalid. Got: %#v", tok)
}
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
@ -143,8 +143,8 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
if err != nil {
t.Error(err)
}
if tok.Expired() {
t.Errorf("Token shouldn't be expired.")
if !tok.Valid() {
t.Fatalf("Token invalid. Got: %#v", tok)
}
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)

View File

@ -74,14 +74,16 @@ func (t *Token) Extra(key string) string {
return ""
}
// Expired returns true if there is no access token or the
// access token is expired.
func (t *Token) Expired() bool {
if t.AccessToken == "" {
return true
}
// expired reports whether the token is expired.
// t must be non-nil.
func (t *Token) expired() bool {
if t.Expiry.IsZero() {
return false
}
return t.Expiry.Before(time.Now())
}
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
func (t *Token) Valid() bool {
return t != nil && t.AccessToken != "" && !t.expired()
}

View File

@ -32,10 +32,10 @@ func TestTransportTokenSource(t *testing.T) {
client.Get(server.URL)
}
func TestExpiredWithNoAccessToken(t *testing.T) {
func TestTokenValidNoAccessToken(t *testing.T) {
token := &Token{}
if !token.Expired() {
t.Errorf("Token should be expired if no access token is provided")
if token.Valid() {
t.Errorf("Token should not be valid with no access token")
}
}
@ -43,8 +43,8 @@ func TestExpiredWithExpiry(t *testing.T) {
token := &Token{
Expiry: time.Now().Add(-5 * time.Hour),
}
if !token.Expired() {
t.Errorf("Token should be expired if no access token is provided")
if token.Valid() {
t.Errorf("Token should not be valid if it expired in the past")
}
}