forked from Mirrors/oauth2
Introduce an option function type
- Reduce the duplicate code by merging the flows and determining the flow type by looking at the provided options. - Options as a function type allows us to validate an individual an option in its scope and makes it easier to compose the built-in options with the third-party ones.
This commit is contained in:
parent
49f4824137
commit
0cf6f9b144
|
@ -1,3 +1,7 @@
|
||||||
|
// Copyright 2014 The oauth2 Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package oauth2_test
|
package oauth2_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -13,33 +17,34 @@ import (
|
||||||
// Related to https://codereview.appspot.com/107320046
|
// Related to https://codereview.appspot.com/107320046
|
||||||
func TestA(t *testing.T) {}
|
func TestA(t *testing.T) {}
|
||||||
|
|
||||||
func Example_config() {
|
func Example_regular() {
|
||||||
conf, err := oauth2.NewConfig(&oauth2.Options{
|
f, err := oauth2.New(
|
||||||
ClientID: "YOUR_CLIENT_ID",
|
oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"),
|
||||||
ClientSecret: "YOUR_CLIENT_SECRET",
|
oauth2.RedirectURL("YOUR_REDIRECT_URL"),
|
||||||
RedirectURL: "YOUR_REDIRECT_URL",
|
oauth2.Scope("SCOPE1", "SCOPE2"),
|
||||||
Scopes: []string{"SCOPE1", "SCOPE2"},
|
oauth2.Endpoint(
|
||||||
},
|
|
||||||
"https://provider.com/o/oauth2/auth",
|
"https://provider.com/o/oauth2/auth",
|
||||||
"https://provider.com/o/oauth2/token")
|
"https://provider.com/o/oauth2/token",
|
||||||
|
),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect user to consent page to ask for permission
|
// Redirect user to consent page to ask for permission
|
||||||
// for the scopes specified above.
|
// for the scopes specified above.
|
||||||
url := conf.AuthCodeURL("state", "online", "auto")
|
url := f.AuthCodeURL("state", "online", "auto")
|
||||||
fmt.Printf("Visit the URL for the auth dialog: %v", url)
|
fmt.Printf("Visit the URL for the auth dialog: %v", url)
|
||||||
|
|
||||||
// Use the authorization code that is pushed to the redirect URL.
|
// Use the authorization code that is pushed to the redirect URL.
|
||||||
// NewTransportWithCode will do the handshake to retrieve
|
// NewTransportWithCode will do the handshake to retrieve
|
||||||
// an access token and initiate a Transport that is
|
// an access token and initiate a Transport that is
|
||||||
// authorized and authenticated by the retrieved token.
|
// authorized and authenticated by the retrieved token.
|
||||||
var authorizationCode string
|
var code string
|
||||||
if _, err = fmt.Scan(&authorizationCode); err != nil {
|
if _, err = fmt.Scan(&code); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
t, err := conf.NewTransportWithCode(authorizationCode)
|
t, err := f.NewTransportFromCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -50,9 +55,8 @@ func Example_config() {
|
||||||
client.Get("...")
|
client.Get("...")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Example_jWTConfig() {
|
func Example_jWT() {
|
||||||
conf, err := oauth2.NewJWTConfig(&oauth2.JWTOptions{
|
f, err := oauth2.New(
|
||||||
Email: "xxx@developer.gserviceaccount.com",
|
|
||||||
// The contents of your RSA private key or your PEM file
|
// The contents of your RSA private key or your PEM file
|
||||||
// that contains a private key.
|
// that contains a private key.
|
||||||
// If you have a p12 file instead, you
|
// If you have a p12 file instead, you
|
||||||
|
@ -61,23 +65,23 @@ func Example_jWTConfig() {
|
||||||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||||
//
|
//
|
||||||
// It only supports PEM containers with no passphrase.
|
// It only supports PEM containers with no passphrase.
|
||||||
PrivateKey: []byte("PRIVATE KEY CONTENTS"),
|
oauth2.JWTClient(
|
||||||
Scopes: []string{"SCOPE1", "SCOPE2"},
|
"xxx@developer.gserviceaccount.com",
|
||||||
},
|
[]byte("-----BEGIN RSA PRIVATE KEY-----...")),
|
||||||
"https://provider.com/o/oauth2/token")
|
oauth2.Scope("SCOPE1", "SCOPE2"),
|
||||||
|
oauth2.JWTEndpoint("https://provider.com/o/oauth2/token"),
|
||||||
|
// If you would like to impersonate a user, you can
|
||||||
|
// create a transport with a subject. The following GET
|
||||||
|
// request will be made on the behalf of user@example.com.
|
||||||
|
// Subject is optional.
|
||||||
|
oauth2.Subject("user@example.com"),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initiate an http.Client, the following GET request will be
|
// Initiate an http.Client, the following GET request will be
|
||||||
// authorized and authenticated on the behalf of
|
// authorized and authenticated on the behalf of user@example.com.
|
||||||
// xxx@developer.gserviceaccount.com.
|
client := http.Client{Transport: f.NewTransport()}
|
||||||
client := http.Client{Transport: conf.NewTransport()}
|
|
||||||
client.Get("...")
|
|
||||||
|
|
||||||
// If you would like to impersonate a user, you can
|
|
||||||
// create a transport with a subject. The following GET
|
|
||||||
// request will be made on the behalf of user@example.com.
|
|
||||||
client = http.Client{Transport: conf.NewTransportWithUser("user@example.com")}
|
|
||||||
client.Get("...")
|
client.Get("...")
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
package google
|
package google
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -19,6 +18,82 @@ import (
|
||||||
"appengine/urlfetch"
|
"appengine/urlfetch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// memcacheGob enables mocking of the memcache.Gob calls for unit testing.
|
||||||
|
memcacheGob memcacher = &aeMemcache{}
|
||||||
|
|
||||||
|
// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing.
|
||||||
|
accessTokenFunc = appengine.AccessToken
|
||||||
|
|
||||||
|
// mu protects multiple threads from attempting to fetch a token at the same time.
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls.
|
||||||
|
tokens map[string]*oauth2.Token
|
||||||
|
)
|
||||||
|
|
||||||
|
// safetyMargin is used to avoid clock-skew problems.
|
||||||
|
// 5 minutes is conservative because tokens are valid for 60 minutes.
|
||||||
|
const safetyMargin = 5 * time.Minute
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tokens = make(map[string]*oauth2.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppEngineContext requires an App Engine request context.
|
||||||
|
func AppEngineContext(ctx appengine.Context) oauth2.Option {
|
||||||
|
return func(opts *oauth2.Options) error {
|
||||||
|
opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts)
|
||||||
|
opts.Transport = &urlfetch.Transport{Context: ctx}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchToken fetches a new access token for the provided scopes.
|
||||||
|
// Tokens are cached locally and also with Memcache so that the app can scale
|
||||||
|
// without hitting quota limits by calling appengine.AccessToken too frequently.
|
||||||
|
func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) {
|
||||||
|
return func(existing *oauth2.Token) (*oauth2.Token, error) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
key := ":" + strings.Join(opts.Scopes, "_")
|
||||||
|
now := time.Now().Add(safetyMargin)
|
||||||
|
if t, ok := tokens[key]; ok && !t.Expiry.Before(now) {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
delete(tokens, key)
|
||||||
|
|
||||||
|
// Attempt to get token from Memcache
|
||||||
|
tok := new(oauth2.Token)
|
||||||
|
_, err := memcacheGob.Get(ctx, key, tok)
|
||||||
|
if err == nil && !tok.Expiry.Before(now) {
|
||||||
|
tokens[key] = tok // Save token locally
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, expiry, err := accessTokenFunc(ctx, opts.Scopes...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := &oauth2.Token{
|
||||||
|
AccessToken: token,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
tokens[key] = t
|
||||||
|
// Also back up token in Memcache
|
||||||
|
if err = memcacheGob.Set(ctx, &memcache.Item{
|
||||||
|
Key: key,
|
||||||
|
Value: []byte{},
|
||||||
|
Object: *t,
|
||||||
|
Expiration: expiry.Sub(now),
|
||||||
|
}); err != nil {
|
||||||
|
ctx.Errorf("unexpected memcache.Set error: %v", err)
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// aeMemcache wraps the needed Memcache functionality to make it easy to mock
|
// aeMemcache wraps the needed Memcache functionality to make it easy to mock
|
||||||
type aeMemcache struct{}
|
type aeMemcache struct{}
|
||||||
|
|
||||||
|
@ -34,99 +109,3 @@ type memcacher interface {
|
||||||
Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error)
|
Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error)
|
||||||
Set(c appengine.Context, item *memcache.Item) error
|
Set(c appengine.Context, item *memcache.Item) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// memcacheGob enables mocking of the memcache.Gob calls for unit testing.
|
|
||||||
var memcacheGob memcacher = &aeMemcache{}
|
|
||||||
|
|
||||||
// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing.
|
|
||||||
var accessTokenFunc = appengine.AccessToken
|
|
||||||
|
|
||||||
// safetyMargin is used to avoid clock-skew problems.
|
|
||||||
// 5 minutes is conservative because tokens are valid for 60 minutes.
|
|
||||||
const safetyMargin = 5 * time.Minute
|
|
||||||
|
|
||||||
// mu protects multiple threads from attempting to fetch a token at the same time.
|
|
||||||
var mu sync.Mutex
|
|
||||||
|
|
||||||
// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls.
|
|
||||||
var tokens map[string]*oauth2.Token
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
tokens = make(map[string]*oauth2.Token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AppEngineConfig represents a configuration for an
|
|
||||||
// App Engine application's Google service account.
|
|
||||||
type AppEngineConfig struct {
|
|
||||||
// Transport is the http.RoundTripper to be used
|
|
||||||
// to construct new oauth2.Transport instances from
|
|
||||||
// this configuration.
|
|
||||||
Transport http.RoundTripper
|
|
||||||
|
|
||||||
context appengine.Context
|
|
||||||
scopes []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAppEngineConfig creates a new AppEngineConfig for the
|
|
||||||
// provided auth scopes.
|
|
||||||
func NewAppEngineConfig(context appengine.Context, scopes ...string) *AppEngineConfig {
|
|
||||||
return &AppEngineConfig{
|
|
||||||
context: context,
|
|
||||||
scopes: scopes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTransport returns a transport that authorizes
|
|
||||||
// the requests with the application's service account.
|
|
||||||
func (c *AppEngineConfig) NewTransport() *oauth2.Transport {
|
|
||||||
return oauth2.NewTransport(c.transport(), c, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchToken fetches a new access token for the provided scopes.
|
|
||||||
// Tokens are cached locally and also with Memcache so that the app can scale
|
|
||||||
// without hitting quota limits by calling appengine.AccessToken too frequently.
|
|
||||||
func (c *AppEngineConfig) FetchToken(existing *oauth2.Token) (*oauth2.Token, error) {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
key := ":" + strings.Join(c.scopes, "_")
|
|
||||||
now := time.Now().Add(safetyMargin)
|
|
||||||
if t, ok := tokens[key]; ok && !t.Expiry.Before(now) {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
delete(tokens, key)
|
|
||||||
|
|
||||||
// Attempt to get token from Memcache
|
|
||||||
tok := new(oauth2.Token)
|
|
||||||
_, err := memcacheGob.Get(c.context, key, tok)
|
|
||||||
if err == nil && !tok.Expiry.Before(now) {
|
|
||||||
tokens[key] = tok // Save token locally
|
|
||||||
return tok, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
token, expiry, err := accessTokenFunc(c.context, c.scopes...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t := &oauth2.Token{
|
|
||||||
AccessToken: token,
|
|
||||||
Expiry: expiry,
|
|
||||||
}
|
|
||||||
tokens[key] = t
|
|
||||||
// Also back up token in Memcache
|
|
||||||
if err = memcacheGob.Set(c.context, &memcache.Item{
|
|
||||||
Key: key,
|
|
||||||
Value: []byte{},
|
|
||||||
Object: *t,
|
|
||||||
Expiration: expiry.Sub(now),
|
|
||||||
}); err != nil {
|
|
||||||
c.context.Errorf("unexpected memcache.Set error: %v", err)
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *AppEngineConfig) transport() http.RoundTripper {
|
|
||||||
if c.Transport != nil {
|
|
||||||
return c.Transport
|
|
||||||
}
|
|
||||||
return &urlfetch.Transport{Context: c.context}
|
|
||||||
}
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ package google
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -72,11 +73,16 @@ func TestFetchTokenLocalCacheMiss(t *testing.T) {
|
||||||
memcacheGob = m
|
memcacheGob = m
|
||||||
accessTokenCount = 0
|
accessTokenCount = 0
|
||||||
delete(tokens, testScopeKey) // clear local cache
|
delete(tokens, testScopeKey) // clear local cache
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
f, err := oauth2.New(
|
||||||
_, err := config.FetchToken(nil)
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
tr := f.NewTransport()
|
||||||
|
c := http.Client{Transport: tr}
|
||||||
|
c.Get("server")
|
||||||
if w := 1; m.getCount != w {
|
if w := 1; m.getCount != w {
|
||||||
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
||||||
}
|
}
|
||||||
|
@ -102,8 +108,16 @@ func TestFetchTokenLocalCacheHit(t *testing.T) {
|
||||||
AccessToken: "mytoken",
|
AccessToken: "mytoken",
|
||||||
Expiry: time.Now().Add(1 * time.Hour),
|
Expiry: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
f, err := oauth2.New(
|
||||||
_, err := config.FetchToken(nil)
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
tr := f.NewTransport()
|
||||||
|
c := http.Client{Transport: tr}
|
||||||
|
c.Get("server")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Errorf("unable to FetchToken: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -139,11 +153,16 @@ func TestFetchTokenMemcacheHit(t *testing.T) {
|
||||||
Expiration: 1 * time.Hour,
|
Expiration: 1 * time.Hour,
|
||||||
})
|
})
|
||||||
m.setCount = 0
|
m.setCount = 0
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
|
||||||
_, err := config.FetchToken(nil)
|
f, err := oauth2.New(
|
||||||
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
c := http.Client{Transport: f.NewTransport()}
|
||||||
|
c.Get("server")
|
||||||
if w := 1; m.getCount != w {
|
if w := 1; m.getCount != w {
|
||||||
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
||||||
}
|
}
|
||||||
|
@ -180,11 +199,15 @@ func TestFetchTokenLocalCacheExpired(t *testing.T) {
|
||||||
Expiration: 1 * time.Hour,
|
Expiration: 1 * time.Hour,
|
||||||
})
|
})
|
||||||
m.setCount = 0
|
m.setCount = 0
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
f, err := oauth2.New(
|
||||||
_, err := config.FetchToken(nil)
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
c := http.Client{Transport: f.NewTransport()}
|
||||||
|
c.Get("server")
|
||||||
if w := 1; m.getCount != w {
|
if w := 1; m.getCount != w {
|
||||||
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
||||||
}
|
}
|
||||||
|
@ -217,11 +240,15 @@ func TestFetchTokenMemcacheExpired(t *testing.T) {
|
||||||
Expiration: -1 * time.Hour,
|
Expiration: -1 * time.Hour,
|
||||||
})
|
})
|
||||||
m.setCount = 0
|
m.setCount = 0
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
f, err := oauth2.New(
|
||||||
_, err := config.FetchToken(nil)
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
c := http.Client{Transport: f.NewTransport()}
|
||||||
|
c.Get("server")
|
||||||
if w := 1; m.getCount != w {
|
if w := 1; m.getCount != w {
|
||||||
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
package google
|
package google
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -17,6 +16,81 @@ import (
|
||||||
"google.golang.org/appengine/memcache"
|
"google.golang.org/appengine/memcache"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// memcacheGob enables mocking of the memcache.Gob calls for unit testing.
|
||||||
|
memcacheGob memcacher = &aeMemcache{}
|
||||||
|
|
||||||
|
// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing.
|
||||||
|
accessTokenFunc = appengine.AccessToken
|
||||||
|
|
||||||
|
// mu protects multiple threads from attempting to fetch a token at the same time.
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls.
|
||||||
|
tokens map[string]*oauth2.Token
|
||||||
|
)
|
||||||
|
|
||||||
|
// safetyMargin is used to avoid clock-skew problems.
|
||||||
|
// 5 minutes is conservative because tokens are valid for 60 minutes.
|
||||||
|
const safetyMargin = 5 * time.Minute
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
tokens = make(map[string]*oauth2.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppEngineContext requires an App Engine request context.
|
||||||
|
func AppEngineContext(ctx appengine.Context) oauth2.Option {
|
||||||
|
return func(opts *oauth2.Options) error {
|
||||||
|
opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchToken fetches a new access token for the provided scopes.
|
||||||
|
// Tokens are cached locally and also with Memcache so that the app can scale
|
||||||
|
// without hitting quota limits by calling appengine.AccessToken too frequently.
|
||||||
|
func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) {
|
||||||
|
return func(existing *oauth2.Token) (*oauth2.Token, error) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
key := ":" + strings.Join(opts.Scopes, "_")
|
||||||
|
now := time.Now().Add(safetyMargin)
|
||||||
|
if t, ok := tokens[key]; ok && !t.Expiry.Before(now) {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
delete(tokens, key)
|
||||||
|
|
||||||
|
// Attempt to get token from Memcache
|
||||||
|
tok := new(oauth2.Token)
|
||||||
|
_, err := memcacheGob.Get(ctx, key, tok)
|
||||||
|
if err == nil && !tok.Expiry.Before(now) {
|
||||||
|
tokens[key] = tok // Save token locally
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, expiry, err := accessTokenFunc(ctx, opts.Scopes...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t := &oauth2.Token{
|
||||||
|
AccessToken: token,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
tokens[key] = t
|
||||||
|
// Also back up token in Memcache
|
||||||
|
if err = memcacheGob.Set(ctx, &memcache.Item{
|
||||||
|
Key: key,
|
||||||
|
Value: []byte{},
|
||||||
|
Object: *t,
|
||||||
|
Expiration: expiry.Sub(now),
|
||||||
|
}); err != nil {
|
||||||
|
ctx.Errorf("unexpected memcache.Set error: %v", err)
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// aeMemcache wraps the needed Memcache functionality to make it easy to mock
|
// aeMemcache wraps the needed Memcache functionality to make it easy to mock
|
||||||
type aeMemcache struct{}
|
type aeMemcache struct{}
|
||||||
|
|
||||||
|
@ -32,99 +106,3 @@ type memcacher interface {
|
||||||
Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error)
|
Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error)
|
||||||
Set(c appengine.Context, item *memcache.Item) error
|
Set(c appengine.Context, item *memcache.Item) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// memcacheGob enables mocking of the memcache.Gob calls for unit testing.
|
|
||||||
var memcacheGob memcacher = &aeMemcache{}
|
|
||||||
|
|
||||||
// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing.
|
|
||||||
var accessTokenFunc = appengine.AccessToken
|
|
||||||
|
|
||||||
// safetyMargin is used to avoid clock-skew problems.
|
|
||||||
// 5 minutes is conservative because tokens are valid for 60 minutes.
|
|
||||||
const safetyMargin = 5 * time.Minute
|
|
||||||
|
|
||||||
// mu protects multiple threads from attempting to fetch a token at the same time.
|
|
||||||
var mu sync.Mutex
|
|
||||||
|
|
||||||
// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls.
|
|
||||||
var tokens map[string]*oauth2.Token
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
tokens = make(map[string]*oauth2.Token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AppEngineConfig represents a configuration for an
|
|
||||||
// App Engine application's Google service account.
|
|
||||||
type AppEngineConfig struct {
|
|
||||||
// Transport is the http.RoundTripper to be used
|
|
||||||
// to construct new oauth2.Transport instances from
|
|
||||||
// this configuration.
|
|
||||||
Transport http.RoundTripper
|
|
||||||
|
|
||||||
context appengine.Context
|
|
||||||
scopes []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAppEngineConfig creates a new AppEngineConfig for the
|
|
||||||
// provided auth scopes.
|
|
||||||
func NewAppEngineConfig(context appengine.Context, scopes ...string) *AppEngineConfig {
|
|
||||||
return &AppEngineConfig{
|
|
||||||
context: context,
|
|
||||||
scopes: scopes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTransport returns a transport that authorizes
|
|
||||||
// the requests with the application's service account.
|
|
||||||
func (c *AppEngineConfig) NewTransport() *oauth2.Transport {
|
|
||||||
return oauth2.NewTransport(c.transport(), c, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchToken fetches a new access token for the provided scopes.
|
|
||||||
// Tokens are cached locally and also with Memcache so that the app can scale
|
|
||||||
// without hitting quota limits by calling appengine.AccessToken too frequently.
|
|
||||||
func (c *AppEngineConfig) FetchToken(existing *oauth2.Token) (*oauth2.Token, error) {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
key := ":" + strings.Join(c.scopes, "_")
|
|
||||||
now := time.Now().Add(safetyMargin)
|
|
||||||
if t, ok := tokens[key]; ok && !t.Expiry.Before(now) {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
delete(tokens, key)
|
|
||||||
|
|
||||||
// Attempt to get token from Memcache
|
|
||||||
tok := new(oauth2.Token)
|
|
||||||
_, err := memcacheGob.Get(c.context, key, tok)
|
|
||||||
if err == nil && !tok.Expiry.Before(now) {
|
|
||||||
tokens[key] = tok // Save token locally
|
|
||||||
return tok, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
token, expiry, err := accessTokenFunc(c.context, c.scopes...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t := &oauth2.Token{
|
|
||||||
AccessToken: token,
|
|
||||||
Expiry: expiry,
|
|
||||||
}
|
|
||||||
tokens[key] = t
|
|
||||||
// Also back up token in Memcache
|
|
||||||
if err = memcacheGob.Set(c.context, &memcache.Item{
|
|
||||||
Key: key,
|
|
||||||
Value: []byte{},
|
|
||||||
Object: *t,
|
|
||||||
Expiration: expiry.Sub(now),
|
|
||||||
}); err != nil {
|
|
||||||
c.context.Errorf("unexpected memcache.Set error: %v", err)
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *AppEngineConfig) transport() http.RoundTripper {
|
|
||||||
if c.Transport != nil {
|
|
||||||
return c.Transport
|
|
||||||
}
|
|
||||||
return http.DefaultTransport
|
|
||||||
}
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ package google
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -71,11 +72,16 @@ func TestFetchTokenLocalCacheMiss(t *testing.T) {
|
||||||
memcacheGob = m
|
memcacheGob = m
|
||||||
accessTokenCount = 0
|
accessTokenCount = 0
|
||||||
delete(tokens, testScopeKey) // clear local cache
|
delete(tokens, testScopeKey) // clear local cache
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
f, err := oauth2.New(
|
||||||
_, err := config.FetchToken(nil)
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
tr := f.NewTransport()
|
||||||
|
c := http.Client{Transport: tr}
|
||||||
|
c.Get("server")
|
||||||
if w := 1; m.getCount != w {
|
if w := 1; m.getCount != w {
|
||||||
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
||||||
}
|
}
|
||||||
|
@ -101,8 +107,16 @@ func TestFetchTokenLocalCacheHit(t *testing.T) {
|
||||||
AccessToken: "mytoken",
|
AccessToken: "mytoken",
|
||||||
Expiry: time.Now().Add(1 * time.Hour),
|
Expiry: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
f, err := oauth2.New(
|
||||||
_, err := config.FetchToken(nil)
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
tr := f.NewTransport()
|
||||||
|
c := http.Client{Transport: tr}
|
||||||
|
c.Get("server")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Errorf("unable to FetchToken: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -138,11 +152,16 @@ func TestFetchTokenMemcacheHit(t *testing.T) {
|
||||||
Expiration: 1 * time.Hour,
|
Expiration: 1 * time.Hour,
|
||||||
})
|
})
|
||||||
m.setCount = 0
|
m.setCount = 0
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
|
||||||
_, err := config.FetchToken(nil)
|
f, err := oauth2.New(
|
||||||
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
c := http.Client{Transport: f.NewTransport()}
|
||||||
|
c.Get("server")
|
||||||
if w := 1; m.getCount != w {
|
if w := 1; m.getCount != w {
|
||||||
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
||||||
}
|
}
|
||||||
|
@ -179,11 +198,15 @@ func TestFetchTokenLocalCacheExpired(t *testing.T) {
|
||||||
Expiration: 1 * time.Hour,
|
Expiration: 1 * time.Hour,
|
||||||
})
|
})
|
||||||
m.setCount = 0
|
m.setCount = 0
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
f, err := oauth2.New(
|
||||||
_, err := config.FetchToken(nil)
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
c := http.Client{Transport: f.NewTransport()}
|
||||||
|
c.Get("server")
|
||||||
if w := 1; m.getCount != w {
|
if w := 1; m.getCount != w {
|
||||||
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
||||||
}
|
}
|
||||||
|
@ -216,11 +239,15 @@ func TestFetchTokenMemcacheExpired(t *testing.T) {
|
||||||
Expiration: -1 * time.Hour,
|
Expiration: -1 * time.Hour,
|
||||||
})
|
})
|
||||||
m.setCount = 0
|
m.setCount = 0
|
||||||
config := NewAppEngineConfig(nil, testScope)
|
f, err := oauth2.New(
|
||||||
_, err := config.FetchToken(nil)
|
AppEngineContext(nil),
|
||||||
|
oauth2.Scope(testScope),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to FetchToken: %v", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
c := http.Client{Transport: f.NewTransport()}
|
||||||
|
c.Get("server")
|
||||||
if w := 1; m.getCount != w {
|
if w := 1; m.getCount != w {
|
||||||
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,25 +24,25 @@ func TestA(t *testing.T) {}
|
||||||
func Example_webServer() {
|
func Example_webServer() {
|
||||||
// Your credentials should be obtained from the Google
|
// Your credentials should be obtained from the Google
|
||||||
// Developer Console (https://console.developers.google.com).
|
// Developer Console (https://console.developers.google.com).
|
||||||
config, err := google.NewConfig(&oauth2.Options{
|
f, err := oauth2.New(
|
||||||
ClientID: "YOUR_CLIENT_ID",
|
oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"),
|
||||||
ClientSecret: "YOUR_CLIENT_SECRET",
|
oauth2.RedirectURL("YOUR_REDIRECT_URL"),
|
||||||
RedirectURL: "YOUR_REDIRECT_URL",
|
oauth2.Scope(
|
||||||
Scopes: []string{
|
|
||||||
"https://www.googleapis.com/auth/bigquery",
|
"https://www.googleapis.com/auth/bigquery",
|
||||||
"https://www.googleapis.com/auth/blogger"},
|
"https://www.googleapis.com/auth/blogger",
|
||||||
})
|
),
|
||||||
|
google.Endpoint(),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect user to Google's consent page to ask for permission
|
// Redirect user to Google's consent page to ask for permission
|
||||||
// for the scopes specified above.
|
// for the scopes specified above.
|
||||||
url := config.AuthCodeURL("state", "online", "auto")
|
url := f.AuthCodeURL("state", "online", "auto")
|
||||||
fmt.Printf("Visit the URL for the auth dialog: %v", url)
|
fmt.Printf("Visit the URL for the auth dialog: %v", url)
|
||||||
|
|
||||||
// Handle the exchange code to initiate a transport
|
// Handle the exchange code to initiate a transport
|
||||||
t, err := config.NewTransportWithCode("exchange-code")
|
t, err := f.NewTransportFromCode("exchange-code")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -58,9 +58,12 @@ func Example_serviceAccountsJSON() {
|
||||||
// To create a service account client, click "Create new Client ID",
|
// To create a service account client, click "Create new Client ID",
|
||||||
// select "Service Account", and click "Create Client ID". A JSON
|
// select "Service Account", and click "Create Client ID". A JSON
|
||||||
// key file will then be downloaded to your computer.
|
// key file will then be downloaded to your computer.
|
||||||
config, err := google.NewServiceAccountJSONConfig(
|
f, err := oauth2.New(
|
||||||
"/path/to/your-project-key.json",
|
google.ServiceAccountJSONKey("/path/to/your-project-key.json"),
|
||||||
|
oauth2.Scope(
|
||||||
"https://www.googleapis.com/auth/bigquery",
|
"https://www.googleapis.com/auth/bigquery",
|
||||||
|
"https://www.googleapis.com/auth/blogger",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -68,63 +71,74 @@ func Example_serviceAccountsJSON() {
|
||||||
// Initiate an http.Client. The following GET request will be
|
// Initiate an http.Client. The following GET request will be
|
||||||
// authorized and authenticated on the behalf of
|
// authorized and authenticated on the behalf of
|
||||||
// your service account.
|
// your service account.
|
||||||
client := http.Client{Transport: config.NewTransport()}
|
client := http.Client{Transport: f.NewTransport()}
|
||||||
client.Get("...")
|
|
||||||
|
|
||||||
// If you would like to impersonate a user, you can
|
|
||||||
// create a transport with a subject. The following GET
|
|
||||||
// request will be made on the behalf of user@example.com.
|
|
||||||
client = http.Client{Transport: config.NewTransportWithUser("user@example.com")}
|
|
||||||
client.Get("...")
|
client.Get("...")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Example_serviceAccounts() {
|
func Example_serviceAccounts() {
|
||||||
// Your credentials should be obtained from the Google
|
// Your credentials should be obtained from the Google
|
||||||
// Developer Console (https://console.developers.google.com).
|
// Developer Console (https://console.developers.google.com).
|
||||||
config, err := google.NewServiceAccountConfig(&oauth2.JWTOptions{
|
f, err := oauth2.New(
|
||||||
Email: "xxx@developer.gserviceaccount.com",
|
|
||||||
// The contents of your RSA private key or your PEM file
|
// The contents of your RSA private key or your PEM file
|
||||||
// that contains a private key.
|
// that contains a private key.
|
||||||
// If you have a p12 file instead, you
|
// If you have a p12 file instead, you
|
||||||
// can use `openssl` to export the private key into a PEM file.
|
// can use `openssl` to export the private key into a pem file.
|
||||||
//
|
//
|
||||||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||||
//
|
//
|
||||||
// Supports only PEM containers without a passphrase.
|
// It only supports PEM containers with no passphrase.
|
||||||
PrivateKey: []byte("PRIVATE KEY CONTENTS"),
|
oauth2.JWTClient(
|
||||||
Scopes: []string{
|
"xxx@developer.gserviceaccount.com",
|
||||||
|
[]byte("-----BEGIN RSA PRIVATE KEY-----...")),
|
||||||
|
oauth2.Scope(
|
||||||
"https://www.googleapis.com/auth/bigquery",
|
"https://www.googleapis.com/auth/bigquery",
|
||||||
},
|
"https://www.googleapis.com/auth/blogger",
|
||||||
})
|
),
|
||||||
|
google.JWTEndpoint(),
|
||||||
|
// If you would like to impersonate a user, you can
|
||||||
|
// create a transport with a subject. The following GET
|
||||||
|
// request will be made on the behalf of user@example.com.
|
||||||
|
// Subject is optional.
|
||||||
|
oauth2.Subject("user@example.com"),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initiate an http.Client, the following GET request will be
|
// Initiate an http.Client, the following GET request will be
|
||||||
// authorized and authenticated on the behalf of
|
// authorized and authenticated on the behalf of user@example.com.
|
||||||
// xxx@developer.gserviceaccount.com.
|
client := http.Client{Transport: f.NewTransport()}
|
||||||
client := http.Client{Transport: config.NewTransport()}
|
|
||||||
client.Get("...")
|
|
||||||
|
|
||||||
// If you would like to impersonate a user, you can
|
|
||||||
// create a transport with a subject. The following GET
|
|
||||||
// request will be made on the behalf of user@example.com.
|
|
||||||
client = http.Client{Transport: config.NewTransportWithUser("user@example.com")}
|
|
||||||
client.Get("...")
|
client.Get("...")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Example_appEngine() {
|
func Example_appEngineVMs() {
|
||||||
c := appengine.NewContext(nil)
|
ctx := appengine.NewContext(nil)
|
||||||
config := google.NewAppEngineConfig(c, "https://www.googleapis.com/auth/bigquery")
|
f, err := oauth2.New(
|
||||||
|
google.AppEngineContext(ctx),
|
||||||
|
oauth2.Scope(
|
||||||
|
"https://www.googleapis.com/auth/bigquery",
|
||||||
|
"https://www.googleapis.com/auth/blogger",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
// The following client will be authorized by the App Engine
|
// The following client will be authorized by the App Engine
|
||||||
// app's service account for the provided scopes.
|
// app's service account for the provided scopes.
|
||||||
client := http.Client{Transport: config.NewTransport()}
|
client := http.Client{Transport: f.NewTransport()}
|
||||||
client.Get("...")
|
client.Get("...")
|
||||||
}
|
}
|
||||||
|
|
||||||
func Example_computeEngine() {
|
func Example_computeEngine() {
|
||||||
// If no other account is specified, "default" is used.
|
f, err := oauth2.New(
|
||||||
config := google.NewComputeEngineConfig("")
|
// Query Google Compute Engine's metadata server to retrieve
|
||||||
client := http.Client{Transport: config.NewTransport()}
|
// an access token for the provided account.
|
||||||
|
// If no account is specified, "default" is used.
|
||||||
|
google.ComputeEngineAccount(""),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
client := http.Client{Transport: f.NewTransport()}
|
||||||
client.Get("...")
|
client.Get("...")
|
||||||
}
|
}
|
||||||
|
|
133
google/google.go
133
google/google.go
|
@ -15,18 +15,20 @@ package google
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/oauth2"
|
"github.com/golang/oauth2"
|
||||||
|
"github.com/golang/oauth2/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
// Google endpoints.
|
uriGoogleAuth, _ = url.Parse("https://accounts.google.com/o/oauth2/auth")
|
||||||
uriGoogleAuth = "https://accounts.google.com/o/oauth2/auth"
|
uriGoogleToken, _ = url.Parse("https://accounts.google.com/o/oauth2/token")
|
||||||
uriGoogleToken = "https://accounts.google.com/o/oauth2/token"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type metaTokenRespBody struct {
|
type metaTokenRespBody struct {
|
||||||
|
@ -35,84 +37,79 @@ type metaTokenRespBody struct {
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ComputeEngineConfig represents a OAuth 2.0 consumer client
|
// JWTEndpoint adds the endpoints required to complete the 2-legged service account flow.
|
||||||
// running on Google Compute Engine.
|
func JWTEndpoint() oauth2.Option {
|
||||||
type ComputeEngineConfig struct {
|
return func(opts *oauth2.Options) error {
|
||||||
// Client is the HTTP client to be used to retrieve
|
opts.AUD = uriGoogleToken
|
||||||
// tokens from the OAuth 2.0 provider.
|
return nil
|
||||||
Client *http.Client
|
}
|
||||||
|
|
||||||
// Transport is the round tripper to be used
|
|
||||||
// to construct new oauth2.Transport instances from
|
|
||||||
// this configuration.
|
|
||||||
Transport http.RoundTripper
|
|
||||||
|
|
||||||
account string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConfig creates a new OAuth2 config that uses Google
|
// Endpoint adds the endpoints required to do the 3-legged Web server flow.
|
||||||
// endpoints.
|
func Endpoint() oauth2.Option {
|
||||||
func NewConfig(opts *oauth2.Options) (*oauth2.Config, error) {
|
return func(opts *oauth2.Options) error {
|
||||||
return oauth2.NewConfig(opts, uriGoogleAuth, uriGoogleToken)
|
opts.AuthURL = uriGoogleAuth
|
||||||
|
opts.TokenURL = uriGoogleToken
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServiceAccountConfig creates a new JWT config that can
|
// ComputeEngineAccount uses the specified account to retrieve an access
|
||||||
// fetch Bearer JWT tokens from Google endpoints.
|
// token from the Google Compute Engine's metadata server. If no user is
|
||||||
func NewServiceAccountConfig(opts *oauth2.JWTOptions) (*oauth2.JWTConfig, error) {
|
// provided, "default" is being used.
|
||||||
return oauth2.NewJWTConfig(opts, uriGoogleToken)
|
func ComputeEngineAccount(account string) oauth2.Option {
|
||||||
|
return func(opts *oauth2.Options) error {
|
||||||
|
if account == "" {
|
||||||
|
account = "default"
|
||||||
|
}
|
||||||
|
opts.TokenFetcherFunc = makeComputeFetcher(opts, account)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServiceAccountJSONConfig creates a new JWT config from a
|
// ServiceAccountJSONKey uses the provided Google Developers
|
||||||
// JSON key file downloaded from the Google Developers Console.
|
// JSON key file to authorize the user. See the "Credentials" page under
|
||||||
// See the "Credentials" page under "APIs & Auth" for your project
|
// "APIs & Auth" for your project at https://console.developers.google.com
|
||||||
// at https://console.developers.google.com.
|
// to download a JSON key file.
|
||||||
func NewServiceAccountJSONConfig(filename string, scopes ...string) (*oauth2.JWTConfig, error) {
|
func ServiceAccountJSONKey(filename string) oauth2.Option {
|
||||||
|
return func(opts *oauth2.Options) error {
|
||||||
b, err := ioutil.ReadFile(filename)
|
b, err := ioutil.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
var key struct {
|
var key struct {
|
||||||
Email string `json:"client_email"`
|
Email string `json:"client_email"`
|
||||||
PrivateKey string `json:"private_key"`
|
PrivateKey string `json:"private_key"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(b, &key); err != nil {
|
if err := json.Unmarshal(b, &key); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
opts := &oauth2.JWTOptions{
|
pk, err := internal.ParseKey([]byte(key.PrivateKey))
|
||||||
Email: key.Email,
|
if err != nil {
|
||||||
PrivateKey: []byte(key.PrivateKey),
|
return err
|
||||||
Scopes: scopes,
|
}
|
||||||
|
opts.Email = key.Email
|
||||||
|
opts.PrivateKey = pk
|
||||||
|
opts.AUD = uriGoogleToken
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return NewServiceAccountConfig(opts)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewComputeEngineConfig creates a new config that can fetch tokens
|
func makeComputeFetcher(opts *oauth2.Options, account string) func(*oauth2.Token) (*oauth2.Token, error) {
|
||||||
// from Google Compute Engine instance's metaserver. If no account is
|
return func(t *oauth2.Token) (*oauth2.Token, error) {
|
||||||
// provided, default is used.
|
|
||||||
func NewComputeEngineConfig(account string) *ComputeEngineConfig {
|
|
||||||
return &ComputeEngineConfig{account: account}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTransport creates an authorized transport.
|
|
||||||
func (c *ComputeEngineConfig) NewTransport() *oauth2.Transport {
|
|
||||||
return oauth2.NewTransport(c.transport(), c, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchToken retrieves a new access token via metadata server.
|
|
||||||
func (c *ComputeEngineConfig) FetchToken(existing *oauth2.Token) (token *oauth2.Token, err error) {
|
|
||||||
account := "default"
|
|
||||||
if c.account != "" {
|
|
||||||
account = c.account
|
|
||||||
}
|
|
||||||
u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + account + "/token"
|
u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + account + "/token"
|
||||||
req, err := http.NewRequest("GET", u, nil)
|
req, err := http.NewRequest("GET", u, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Add("X-Google-Metadata-Request", "True")
|
req.Header.Add("X-Google-Metadata-Request", "True")
|
||||||
resp, err := c.client().Do(req)
|
c := &http.Client{}
|
||||||
|
if opts.Client != nil {
|
||||||
|
c = opts.Client
|
||||||
|
}
|
||||||
|
resp, err := c.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||||
|
@ -121,26 +118,12 @@ func (c *ComputeEngineConfig) FetchToken(existing *oauth2.Token) (token *oauth2.
|
||||||
var tokenResp metaTokenRespBody
|
var tokenResp metaTokenRespBody
|
||||||
err = json.NewDecoder(resp.Body).Decode(&tokenResp)
|
err = json.NewDecoder(resp.Body).Decode(&tokenResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
token = &oauth2.Token{
|
return &oauth2.Token{
|
||||||
AccessToken: tokenResp.AccessToken,
|
AccessToken: tokenResp.AccessToken,
|
||||||
TokenType: tokenResp.TokenType,
|
TokenType: tokenResp.TokenType,
|
||||||
Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second),
|
Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ComputeEngineConfig) transport() http.RoundTripper {
|
|
||||||
if c.Transport != nil {
|
|
||||||
return c.Transport
|
|
||||||
}
|
|
||||||
return http.DefaultTransport
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ComputeEngineConfig) client() *http.Client {
|
|
||||||
if c.Client != nil {
|
|
||||||
return c.Client
|
|
||||||
}
|
|
||||||
return http.DefaultClient
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
// Copyright 2014 The oauth2 Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseKey converts the binary contents of a private key file
|
||||||
|
// to an *rsa.PrivateKey. It detects whether the private key is in a
|
||||||
|
// PEM container or not. If so, it extracts the the private key
|
||||||
|
// from PEM container before conversion. It only supports PEM
|
||||||
|
// containers with no passphrase.
|
||||||
|
func ParseKey(key []byte) (*rsa.PrivateKey, error) {
|
||||||
|
block, _ := pem.Decode(key)
|
||||||
|
if block != nil {
|
||||||
|
key = block.Bytes
|
||||||
|
}
|
||||||
|
parsedKey, err := x509.ParsePKCS8PrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsed, ok := parsedKey.(*rsa.PrivateKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("oauth2: private key is invalid")
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
}
|
|
@ -71,15 +71,15 @@ func (c *ClaimSet) encode() (string, error) {
|
||||||
// Marshal private claim set and then append it to b.
|
// Marshal private claim set and then append it to b.
|
||||||
prv, err := json.Marshal(c.PrivateClaims)
|
prv, err := json.Marshal(c.PrivateClaims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("Invalid map of private claims %v", c.PrivateClaims)
|
return "", fmt.Errorf("jws: invalid map of private claims %v", c.PrivateClaims)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concatenate public and private claim JSON objects.
|
// Concatenate public and private claim JSON objects.
|
||||||
if !bytes.HasSuffix(b, []byte{'}'}) {
|
if !bytes.HasSuffix(b, []byte{'}'}) {
|
||||||
return "", fmt.Errorf("Invalid JSON %s", b)
|
return "", fmt.Errorf("jws: invalid JSON %s", b)
|
||||||
}
|
}
|
||||||
if !bytes.HasPrefix(prv, []byte{'{'}) {
|
if !bytes.HasPrefix(prv, []byte{'{'}) {
|
||||||
return "", fmt.Errorf("Invalid JSON %s", prv)
|
return "", fmt.Errorf("jws: invalid JSON %s", prv)
|
||||||
}
|
}
|
||||||
b[len(b)-1] = ',' // Replace closing curly brace with a comma.
|
b[len(b)-1] = ',' // Replace closing curly brace with a comma.
|
||||||
b = append(b, prv[1:]...) // Append private claims.
|
b = append(b, prv[1:]...) // Append private claims.
|
||||||
|
@ -109,7 +109,7 @@ func Decode(payload string) (c *ClaimSet, err error) {
|
||||||
s := strings.Split(payload, ".")
|
s := strings.Split(payload, ".")
|
||||||
if len(s) < 2 {
|
if len(s) < 2 {
|
||||||
// TODO(jbd): Provide more context about the error.
|
// TODO(jbd): Provide more context about the error.
|
||||||
return nil, errors.New("invalid token received")
|
return nil, errors.New("jws: invalid token received")
|
||||||
}
|
}
|
||||||
decoded, err := base64Decode(s[1])
|
decoded, err := base64Decode(s[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
168
jwt.go
168
jwt.go
|
@ -5,11 +5,7 @@
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -17,6 +13,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/oauth2/internal"
|
||||||
"github.com/golang/oauth2/jws"
|
"github.com/golang/oauth2/jws"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,104 +22,69 @@ var (
|
||||||
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
|
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// JWTOptions represents a OAuth2 client's crendentials to retrieve a
|
// JWTClient requires OAuth 2.0 JWT credentials.
|
||||||
// Bearer JWT token.
|
// Required for the 2-legged JWT flow.
|
||||||
type JWTOptions struct {
|
func JWTClient(email string, key []byte) Option {
|
||||||
// Email is the OAuth client identifier used when communicating with
|
return func(o *Options) error {
|
||||||
// the configured OAuth provider.
|
pk, err := internal.ParseKey(key)
|
||||||
Email string `json:"email"`
|
|
||||||
|
|
||||||
// PrivateKey contains the contents of an RSA private key or the
|
|
||||||
// contents of a PEM file that contains a private key. The provided
|
|
||||||
// private key is used to sign JWT payloads.
|
|
||||||
// PEM containers with a passphrase are not supported.
|
|
||||||
// Use the following command to convert a PKCS 12 file into a PEM.
|
|
||||||
//
|
|
||||||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
|
||||||
//
|
|
||||||
PrivateKey []byte `json:"-"`
|
|
||||||
|
|
||||||
// Scopes identify the level of access being requested.
|
|
||||||
Scopes []string `json:"scopes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewJWTConfig creates a new configuration with the specified options
|
|
||||||
// and OAuth2 provider endpoint.
|
|
||||||
func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) {
|
|
||||||
audURL, err := url.Parse(aud)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
parsedKey, err := parseKey(opts.PrivateKey)
|
o.Email = email
|
||||||
|
o.PrivateKey = pk
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider.
|
||||||
|
func JWTEndpoint(aud string) Option {
|
||||||
|
return func(o *Options) error {
|
||||||
|
au, err := url.Parse(aud)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
|
}
|
||||||
|
o.AUD = au
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return &JWTConfig{
|
|
||||||
opts: opts,
|
|
||||||
aud: audURL,
|
|
||||||
key: parsedKey,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTConfig represents an OAuth 2.0 provider and client options to
|
// Subject requires a user to impersonate.
|
||||||
// provide authorized transports with a Bearer JWT token.
|
// Optional.
|
||||||
type JWTConfig struct {
|
func Subject(user string) Option {
|
||||||
// Client is the HTTP client to be used to retrieve
|
return func(o *Options) error {
|
||||||
// tokens from the OAuth 2.0 provider.
|
o.Subject = user
|
||||||
Client *http.Client
|
return nil
|
||||||
|
}
|
||||||
// Transport is the http.RoundTripper to be used
|
|
||||||
// to construct new oauth2.Transport instances from
|
|
||||||
// this configuration.
|
|
||||||
Transport http.RoundTripper
|
|
||||||
|
|
||||||
opts *JWTOptions
|
|
||||||
aud *url.URL
|
|
||||||
key *rsa.PrivateKey
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransport creates a transport that is authorize with the
|
func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
||||||
// parent JWT configuration.
|
return func(t *Token) (*Token, error) {
|
||||||
func (c *JWTConfig) NewTransport() *Transport {
|
if t == nil {
|
||||||
return NewTransport(c.transport(), c, &Token{})
|
t = &Token{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransportWithUser creates a transport that is authorized by
|
|
||||||
// the client and impersonates the specified user.
|
|
||||||
func (c *JWTConfig) NewTransportWithUser(user string) *Transport {
|
|
||||||
return NewTransport(c.transport(), c, &Token{Subject: user})
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchToken retrieves a new access token and updates the existing token
|
|
||||||
// with the newly fetched credentials.
|
|
||||||
func (c *JWTConfig) FetchToken(existing *Token) (*Token, error) {
|
|
||||||
if existing == nil {
|
|
||||||
existing = &Token{}
|
|
||||||
}
|
|
||||||
|
|
||||||
claimSet := &jws.ClaimSet{
|
claimSet := &jws.ClaimSet{
|
||||||
Iss: c.opts.Email,
|
Iss: o.Email,
|
||||||
Scope: strings.Join(c.opts.Scopes, " "),
|
Scope: strings.Join(o.Scopes, " "),
|
||||||
Aud: c.aud.String(),
|
Aud: o.AUD.String(),
|
||||||
}
|
}
|
||||||
|
if o.Subject != "" {
|
||||||
if existing.Subject != "" {
|
claimSet.Sub = o.Subject
|
||||||
claimSet.Sub = existing.Subject
|
|
||||||
// prn is the old name of sub. Keep setting it
|
// prn is the old name of sub. Keep setting it
|
||||||
// to be compatible with legacy OAuth 2.0 providers.
|
// to be compatible with legacy OAuth 2.0 providers.
|
||||||
claimSet.Prn = existing.Subject
|
claimSet.Prn = o.Subject
|
||||||
}
|
}
|
||||||
|
payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey)
|
||||||
payload, err := jws.Encode(defaultHeader, claimSet, c.key)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
v := url.Values{}
|
v := url.Values{}
|
||||||
v.Set("grant_type", defaultGrantType)
|
v.Set("grant_type", defaultGrantType)
|
||||||
v.Set("assertion", payload)
|
v.Set("assertion", payload)
|
||||||
|
c := o.Client
|
||||||
// Make a request with assertion to get a new token.
|
if c == nil {
|
||||||
resp, err := c.client().PostForm(c.aud.String(), v)
|
c = &http.Client{}
|
||||||
|
}
|
||||||
|
resp, err := c.PostForm(o.AUD.String(), v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -134,7 +96,6 @@ func (c *JWTConfig) FetchToken(existing *Token) (*Token, error) {
|
||||||
if c := resp.StatusCode; c < 200 || c > 299 {
|
if c := resp.StatusCode; c < 200 || c > 299 {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
b := make(map[string]interface{})
|
b := make(map[string]interface{})
|
||||||
if err := json.Unmarshal(body, &b); err != nil {
|
if err := json.Unmarshal(body, &b); err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
|
@ -142,15 +103,13 @@ func (c *JWTConfig) FetchToken(existing *Token) (*Token, error) {
|
||||||
token := &Token{}
|
token := &Token{}
|
||||||
token.AccessToken, _ = b["access_token"].(string)
|
token.AccessToken, _ = b["access_token"].(string)
|
||||||
token.TokenType, _ = b["token_type"].(string)
|
token.TokenType, _ = b["token_type"].(string)
|
||||||
token.Subject = existing.Subject
|
|
||||||
token.raw = b
|
token.raw = b
|
||||||
if e, ok := b["expires_in"].(int); ok {
|
if e, ok := b["expires_in"].(int); ok {
|
||||||
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
|
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
|
||||||
}
|
}
|
||||||
if idtoken, ok := b["id_token"].(string); ok {
|
if idtoken, ok := b["id_token"].(string); ok {
|
||||||
// decode returned id token to get expiry
|
// decode returned id token to get expiry
|
||||||
claimSet := &jws.ClaimSet{}
|
claimSet, err := jws.Decode(idtoken)
|
||||||
claimSet, err = jws.Decode(idtoken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -159,41 +118,4 @@ func (c *JWTConfig) FetchToken(existing *Token) (*Token, error) {
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *JWTConfig) transport() http.RoundTripper {
|
|
||||||
if c.Transport != nil {
|
|
||||||
return c.Transport
|
|
||||||
}
|
|
||||||
return http.DefaultTransport
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *JWTConfig) client() *http.Client {
|
|
||||||
if c.Client != nil {
|
|
||||||
return c.Client
|
|
||||||
}
|
|
||||||
return http.DefaultClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseKey converts the binary contents of a private key file
|
|
||||||
// to an *rsa.PrivateKey. It detects whether the private key is in a
|
|
||||||
// PEM container or not. If so, it extracts the the private key
|
|
||||||
// from PEM container before conversion. It only supports PEM
|
|
||||||
// containers with no passphrase.
|
|
||||||
func parseKey(key []byte) (*rsa.PrivateKey, error) {
|
|
||||||
block, _ := pem.Decode(key)
|
|
||||||
if block != nil {
|
|
||||||
key = block.Bytes
|
|
||||||
}
|
|
||||||
parsedKey, err := x509.ParsePKCS8PrivateKey(key)
|
|
||||||
if err != nil {
|
|
||||||
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
parsed, ok := parsedKey.(*rsa.PrivateKey)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("oauth2: private key is invalid")
|
|
||||||
}
|
|
||||||
return parsed, nil
|
|
||||||
}
|
}
|
||||||
|
|
54
jwt_test.go
54
jwt_test.go
|
@ -41,17 +41,25 @@ mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ
|
||||||
func TestJWTFetch_JSONResponse(t *testing.T) {
|
func TestJWTFetch_JSONResponse(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer"}`))
|
w.Write([]byte(`{
|
||||||
|
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
|
||||||
|
"scope": "user",
|
||||||
|
"token_type": "bearer"
|
||||||
|
}`))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf, err := NewJWTConfig(&JWTOptions{
|
f, err := New(
|
||||||
Email: "aaa@xxx.com",
|
JWTClient("aaa@xxx.com", dummyPrivateKey),
|
||||||
PrivateKey: dummyPrivateKey,
|
JWTEndpoint(ts.URL),
|
||||||
}, ts.URL)
|
)
|
||||||
tok, err := conf.FetchToken(nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed retrieving token: %s.", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
tr := f.NewTransport()
|
||||||
|
c := http.Client{Transport: tr}
|
||||||
|
|
||||||
|
c.Get(ts.URL)
|
||||||
|
tok := tr.Token()
|
||||||
if tok.Expired() {
|
if tok.Expired() {
|
||||||
t.Errorf("Token shouldn't be expired.")
|
t.Errorf("Token shouldn't be expired.")
|
||||||
}
|
}
|
||||||
|
@ -73,11 +81,17 @@ func TestJWTFetch_BadResponse(t *testing.T) {
|
||||||
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
|
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf, err := NewJWTConfig(&JWTOptions{
|
f, err := New(
|
||||||
Email: "aaa@xxx.com",
|
JWTClient("aaa@xxx.com", dummyPrivateKey),
|
||||||
PrivateKey: dummyPrivateKey,
|
JWTEndpoint(ts.URL),
|
||||||
}, ts.URL)
|
)
|
||||||
tok, err := conf.FetchToken(nil)
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
tr := f.NewTransport()
|
||||||
|
c := http.Client{Transport: tr}
|
||||||
|
c.Get(ts.URL)
|
||||||
|
tok := tr.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed retrieving token: %s.", err)
|
t.Errorf("Failed retrieving token: %s.", err)
|
||||||
}
|
}
|
||||||
|
@ -99,11 +113,17 @@ func TestJWTFetch_BadResponseType(t *testing.T) {
|
||||||
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
|
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf, err := NewJWTConfig(&JWTOptions{
|
f, err := New(
|
||||||
Email: "aaa@xxx.com",
|
JWTClient("aaa@xxx.com", dummyPrivateKey),
|
||||||
PrivateKey: dummyPrivateKey,
|
JWTEndpoint(ts.URL),
|
||||||
}, ts.URL)
|
)
|
||||||
tok, err := conf.FetchToken(nil)
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
tr := f.NewTransport()
|
||||||
|
c := http.Client{Transport: tr}
|
||||||
|
c.Get(ts.URL)
|
||||||
|
tok := tr.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed retrieving token: %s.", err)
|
t.Errorf("Failed retrieving token: %s.", err)
|
||||||
}
|
}
|
||||||
|
|
310
oauth2.go
310
oauth2.go
|
@ -8,85 +8,118 @@
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"mime"
|
"mime"
|
||||||
|
"time"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenFetcher refreshes or fetches a new access token from the
|
// Option represents a function that applies some state to
|
||||||
// provider. It should return an error if it's not capable of
|
// an Options object.
|
||||||
// retrieving a token.
|
type Option func(*Options) error
|
||||||
type TokenFetcher interface {
|
|
||||||
// FetchToken retrieves a new access token for the provider.
|
// Client requires the OAuth 2.0 client credentials. You need to provide
|
||||||
// If the implementation doesn't know how to retrieve a new token,
|
// the client identifier and optionally the client secret that are
|
||||||
// it returns an error. The existing token may be nil.
|
// assigned to your application by the OAuth 2.0 provider.
|
||||||
FetchToken(existing *Token) (*Token, error)
|
func Client(id, secret string) Option {
|
||||||
|
return func(opts *Options) error {
|
||||||
|
opts.ClientID = id
|
||||||
|
opts.ClientSecret = secret
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options represents options to provide OAuth 2.0 client credentials
|
// RedirectURL requires the URL to which the user will be returned after
|
||||||
// and access level. A sample configuration:
|
|
||||||
type Options struct {
|
|
||||||
// ClientID is the OAuth client identifier used when communicating with
|
|
||||||
// the configured OAuth provider.
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
|
|
||||||
// ClientSecret is the OAuth client secret used when communicating with
|
|
||||||
// the configured OAuth provider.
|
|
||||||
ClientSecret string `json:"client_secret"`
|
|
||||||
|
|
||||||
// RedirectURL is the URL to which the user will be returned after
|
|
||||||
// granting (or denying) access.
|
// granting (or denying) access.
|
||||||
RedirectURL string `json:"redirect_url"`
|
func RedirectURL(url string) Option {
|
||||||
|
return func(opts *Options) error {
|
||||||
// Scopes optionally specifies a list of requested permission scopes.
|
opts.RedirectURL = url
|
||||||
Scopes []string `json:"scopes,omitempty"`
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConfig creates a generic OAuth 2.0 configuration that talks
|
// Scope requires a list of requested permission scopes.
|
||||||
// to an OAuth 2.0 provider specified with authURL and tokenURL.
|
// It is optinal to specify scopes.
|
||||||
func NewConfig(opts *Options, authURL, tokenURL string) (*Config, error) {
|
func Scope(scopes ...string) Option {
|
||||||
aURL, err := url.Parse(authURL)
|
return func(o *Options) error {
|
||||||
|
o.Scopes = scopes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Endpoint requires OAuth 2.0 provider's authorization and token endpoints.
|
||||||
|
func Endpoint(authURL, tokenURL string) Option {
|
||||||
|
return func(o *Options) error {
|
||||||
|
au, err := url.Parse(authURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tu, err := url.Parse(tokenURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
o.TokenFetcherFunc = makeThreeLeggedFetcher(o)
|
||||||
|
o.AuthURL = au
|
||||||
|
o.TokenURL = tu
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient allows you to provide a custom http.Client to be
|
||||||
|
// used to retrieve tokens from the OAuth 2.0 provider.
|
||||||
|
func HTTPClient(c *http.Client) Option {
|
||||||
|
return func(o *Options) error {
|
||||||
|
o.Client = c
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTripper allows you to provide a custom http.RoundTripper
|
||||||
|
// to be used to construct new oauth2.Transport instances.
|
||||||
|
// If none is provided a default RoundTripper will be used.
|
||||||
|
func RoundTripper(tr http.RoundTripper) Option {
|
||||||
|
return func(o *Options) error {
|
||||||
|
o.Transport = tr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Flow struct {
|
||||||
|
opts Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// New initiates a new flow. It determines the type of the OAuth 2.0
|
||||||
|
// (2-legged, 3-legged or custom) by looking at the provided options.
|
||||||
|
// If the flow type cannot determined automatically, an error is returned.
|
||||||
|
func New(options ...Option) (*Flow, error) {
|
||||||
|
f := &Flow{}
|
||||||
|
for _, opt := range options {
|
||||||
|
if err := opt(&f.opts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tURL, err := url.Parse(tokenURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
if opts.ClientID == "" {
|
switch {
|
||||||
return nil, errors.New("oauth2: missing client ID")
|
case f.opts.TokenFetcherFunc != nil:
|
||||||
|
return f, nil
|
||||||
|
case f.opts.AUD != nil:
|
||||||
|
// TODO(jbd): Assert required JWT params.
|
||||||
|
f.opts.TokenFetcherFunc = makeTwoLeggedFetcher(&f.opts)
|
||||||
|
return f, nil
|
||||||
|
case f.opts.AuthURL != nil && f.opts.TokenURL != nil:
|
||||||
|
// TODO(jbd): Assert required OAuth2 params.
|
||||||
|
f.opts.TokenFetcherFunc = makeThreeLeggedFetcher(&f.opts)
|
||||||
|
return f, nil
|
||||||
|
default:
|
||||||
|
return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens")
|
||||||
}
|
}
|
||||||
return &Config{
|
|
||||||
opts: opts,
|
|
||||||
authURL: aURL,
|
|
||||||
tokenURL: tURL,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config represents the configuration of an OAuth 2.0 consumer client.
|
|
||||||
type Config struct {
|
|
||||||
// Client is the HTTP client to be used to retrieve
|
|
||||||
// tokens from the OAuth 2.0 provider.
|
|
||||||
Client *http.Client
|
|
||||||
|
|
||||||
// Transport is the http.RoundTripper to be used
|
|
||||||
// to construct new oauth2.Transport instances from
|
|
||||||
// this configuration.
|
|
||||||
Transport http.RoundTripper
|
|
||||||
|
|
||||||
opts *Options
|
|
||||||
// AuthURL is the URL the user will be directed to
|
|
||||||
// in order to grant access.
|
|
||||||
authURL *url.URL
|
|
||||||
// TokenURL is the URL used to retrieve OAuth tokens.
|
|
||||||
tokenURL *url.URL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
||||||
|
@ -112,13 +145,13 @@ type Config struct {
|
||||||
// granted consent and the code can only be exchanged for an
|
// granted consent and the code can only be exchanged for an
|
||||||
// access token. If set to "force" the user will always be prompted,
|
// access token. If set to "force" the user will always be prompted,
|
||||||
// and the code can be exchanged for a refresh token.
|
// and the code can be exchanged for a refresh token.
|
||||||
func (c *Config) AuthCodeURL(state, accessType, prompt string) (authURL string) {
|
func (f *Flow) AuthCodeURL(state, accessType, prompt string) string {
|
||||||
u := *c.authURL
|
u := f.opts.AuthURL
|
||||||
v := url.Values{
|
v := url.Values{
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"client_id": {c.opts.ClientID},
|
"client_id": {f.opts.ClientID},
|
||||||
"redirect_uri": condVal(c.opts.RedirectURL),
|
"redirect_uri": condVal(f.opts.RedirectURL),
|
||||||
"scope": condVal(strings.Join(c.opts.Scopes, " ")),
|
"scope": condVal(strings.Join(f.opts.Scopes, " ")),
|
||||||
"state": condVal(state),
|
"state": condVal(state),
|
||||||
"access_type": condVal(accessType),
|
"access_type": condVal(accessType),
|
||||||
"approval_prompt": condVal(prompt),
|
"approval_prompt": condVal(prompt),
|
||||||
|
@ -132,65 +165,122 @@ func (c *Config) AuthCodeURL(state, accessType, prompt string) (authURL string)
|
||||||
return u.String()
|
return u.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransport creates a new authorizable transport. It doesn't
|
// exchange exchanges the authorization code with the OAuth 2.0 provider
|
||||||
// initialize the new transport with a token, so after creation,
|
// to retrieve a new access token.
|
||||||
// you need to set a valid token (or an expired token with a valid
|
func (f *Flow) exchange(code string) (*Token, error) {
|
||||||
// refresh token) in order to be able to do authorized requests.
|
return retrieveToken(&f.opts, url.Values{
|
||||||
func (c *Config) NewTransport() *Transport {
|
"grant_type": {"authorization_code"},
|
||||||
return NewTransport(c.transport(), c, nil)
|
"code": {code},
|
||||||
|
"redirect_uri": condVal(f.opts.RedirectURL),
|
||||||
|
"scope": condVal(strings.Join(f.opts.Scopes, " ")),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransportWithCode exchanges the OAuth 2.0 authorization code with
|
// NewTransportFromCode exchanges the code to retrieve a new access token
|
||||||
// the provider to fetch a new access token (and refresh token). Once
|
// and returns an authorized and authenticated Transport.
|
||||||
// it successfully retrieves a new token, creates a new transport
|
func (f *Flow) NewTransportFromCode(code string) (*Transport, error) {
|
||||||
// authorized with it.
|
token, err := f.exchange(code)
|
||||||
func (c *Config) NewTransportWithCode(code string) (*Transport, error) {
|
|
||||||
token, err := c.Exchange(code)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewTransport(c.transport(), c, token), nil
|
return f.NewTransportFromToken(token), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchToken retrieves a new access token and updates the existing token
|
// NewTransportFromToken returns a new Transport that is authorized
|
||||||
// with the newly fetched credentials. If existing token doesn't
|
// and authenticated with the provided token.
|
||||||
// contain a refresh token, it returns an error.
|
func (f *Flow) NewTransportFromToken(t *Token) *Transport {
|
||||||
func (c *Config) FetchToken(existing *Token) (*Token, error) {
|
tr := f.opts.Transport
|
||||||
if existing == nil || existing.RefreshToken == "" {
|
if tr == nil {
|
||||||
|
tr = http.DefaultTransport
|
||||||
|
}
|
||||||
|
return newTransport(tr, f.opts.TokenFetcherFunc, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTransport returns a Transport.
|
||||||
|
func (f *Flow) NewTransport() *Transport {
|
||||||
|
return f.NewTransportFromToken(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
||||||
|
return func(t *Token) (*Token, error) {
|
||||||
|
if t == nil || t.RefreshToken == "" {
|
||||||
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
|
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
|
||||||
}
|
}
|
||||||
return c.retrieveToken(url.Values{
|
return retrieveToken(o, url.Values{
|
||||||
"grant_type": {"refresh_token"},
|
"grant_type": {"refresh_token"},
|
||||||
"refresh_token": {existing.RefreshToken},
|
"refresh_token": {t.RefreshToken},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange exchanges the authorization code with the OAuth 2.0 provider
|
|
||||||
// to retrieve a new access token.
|
|
||||||
func (c *Config) Exchange(code string) (*Token, error) {
|
|
||||||
return c.retrieveToken(url.Values{
|
|
||||||
"grant_type": {"authorization_code"},
|
|
||||||
"code": {code},
|
|
||||||
"redirect_uri": condVal(c.opts.RedirectURL),
|
|
||||||
"scope": condVal(strings.Join(c.opts.Scopes, " ")),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
// Options represents an object to keep the state of the OAuth 2.0 flow.
|
||||||
v.Set("client_id", c.opts.ClientID)
|
type Options struct {
|
||||||
bustedAuth := !providerAuthHeaderWorks(c.tokenURL.String())
|
// ClientID is the OAuth client identifier used when communicating with
|
||||||
if bustedAuth && c.opts.ClientSecret != "" {
|
// the configured OAuth provider.
|
||||||
v.Set("client_secret", c.opts.ClientSecret)
|
ClientID string
|
||||||
|
|
||||||
|
// ClientSecret is the OAuth client secret used when communicating with
|
||||||
|
// the configured OAuth provider.
|
||||||
|
ClientSecret string
|
||||||
|
|
||||||
|
// RedirectURL is the URL to which the user will be returned after
|
||||||
|
// granting (or denying) access.
|
||||||
|
RedirectURL string
|
||||||
|
|
||||||
|
// Email is the OAuth client identifier used when communicating with
|
||||||
|
// the configured OAuth provider.
|
||||||
|
Email string
|
||||||
|
|
||||||
|
// PrivateKey contains the contents of an RSA private key or the
|
||||||
|
// contents of a PEM file that contains a private key. The provided
|
||||||
|
// private key is used to sign JWT payloads.
|
||||||
|
// PEM containers with a passphrase are not supported.
|
||||||
|
// Use the following command to convert a PKCS 12 file into a PEM.
|
||||||
|
//
|
||||||
|
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||||
|
//
|
||||||
|
PrivateKey *rsa.PrivateKey
|
||||||
|
|
||||||
|
// Scopes identify the level of access being requested.
|
||||||
|
Subject string
|
||||||
|
|
||||||
|
// Scopes optionally specifies a list of requested permission scopes.
|
||||||
|
Scopes []string
|
||||||
|
|
||||||
|
// AuthURL represents the authorization endpoint of the OAuth 2.0 provider.
|
||||||
|
AuthURL *url.URL
|
||||||
|
|
||||||
|
// TokenURL represents the token endpoint of the OAuth 2.0 provider.
|
||||||
|
TokenURL *url.URL
|
||||||
|
|
||||||
|
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
|
||||||
|
AUD *url.URL
|
||||||
|
|
||||||
|
TokenFetcherFunc func(t *Token) (*Token, error)
|
||||||
|
|
||||||
|
Transport http.RoundTripper
|
||||||
|
Client *http.Client
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(v.Encode()))
|
|
||||||
|
func retrieveToken(o *Options, v url.Values) (*Token, error) {
|
||||||
|
v.Set("client_id", o.ClientID)
|
||||||
|
bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String())
|
||||||
|
if bustedAuth && o.ClientSecret != "" {
|
||||||
|
v.Set("client_secret", o.ClientSecret)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
if !bustedAuth && c.opts.ClientSecret != "" {
|
if !bustedAuth && o.ClientSecret != "" {
|
||||||
req.SetBasicAuth(c.opts.ClientID, c.opts.ClientSecret)
|
req.SetBasicAuth(o.ClientID, o.ClientSecret)
|
||||||
}
|
}
|
||||||
r, err := c.client().Do(req)
|
c := o.Client
|
||||||
|
if c == nil {
|
||||||
|
c = &http.Client{}
|
||||||
|
}
|
||||||
|
r, err := c.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -199,7 +289,7 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
if c := r.StatusCode; c < 200 || c > 299 {
|
if code := r.StatusCode; code < 200 || code > 299 {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -255,20 +345,6 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) transport() http.RoundTripper {
|
|
||||||
if c.Transport != nil {
|
|
||||||
return c.Transport
|
|
||||||
}
|
|
||||||
return http.DefaultTransport
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) client() *http.Client {
|
|
||||||
if c.Client != nil {
|
|
||||||
return c.Client
|
|
||||||
}
|
|
||||||
return http.DefaultClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func condVal(v string) []string {
|
func condVal(v string) []string {
|
||||||
if v == "" {
|
if v == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|
124
oauth2_test.go
124
oauth2_test.go
|
@ -20,32 +20,30 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
|
||||||
return t.rt(req)
|
return t.rt(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestConf(url string) *Config {
|
func newTestFlow(url string) *Flow {
|
||||||
conf, _ := NewConfig(&Options{
|
f, _ := New(
|
||||||
ClientID: "CLIENT_ID",
|
Client("CLIENT_ID", "CLIENT_SECRET"),
|
||||||
ClientSecret: "CLIENT_SECRET",
|
RedirectURL("REDIRECT_URL"),
|
||||||
RedirectURL: "REDIRECT_URL",
|
Scope("scope1", "scope2"),
|
||||||
Scopes: []string{
|
Endpoint(url+"/auth", url+"/token"),
|
||||||
"scope1",
|
)
|
||||||
"scope2",
|
return f
|
||||||
},
|
|
||||||
}, url+"/auth", url+"/token")
|
|
||||||
return conf
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthCodeURL(t *testing.T) {
|
func TestAuthCodeURL(t *testing.T) {
|
||||||
conf := newTestConf("server")
|
f := newTestFlow("server")
|
||||||
url := conf.AuthCodeURL("foo", "offline", "force")
|
url := f.AuthCodeURL("foo", "offline", "force")
|
||||||
if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" {
|
if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" {
|
||||||
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url)
|
t.Errorf("Auth code URL doesn't match the expected, found: %v", url)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthCodeURL_Optional(t *testing.T) {
|
func TestAuthCodeURL_Optional(t *testing.T) {
|
||||||
conf, _ := NewConfig(&Options{
|
f, _ := New(
|
||||||
ClientID: "CLIENT_ID",
|
Client("CLIENT_ID", ""),
|
||||||
}, "auth-url", "token-url")
|
Endpoint("auth-url", "token-token"),
|
||||||
url := conf.AuthCodeURL("", "", "")
|
)
|
||||||
|
url := f.AuthCodeURL("", "", "")
|
||||||
if url != "auth-url?client_id=CLIENT_ID&response_type=code" {
|
if url != "auth-url?client_id=CLIENT_ID&response_type=code" {
|
||||||
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url)
|
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url)
|
||||||
}
|
}
|
||||||
|
@ -75,11 +73,12 @@ func TestExchangeRequest(t *testing.T) {
|
||||||
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
|
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf := newTestConf(ts.URL)
|
f := newTestFlow(ts.URL)
|
||||||
tok, err := conf.Exchange("exchange-code")
|
tr, err := f.NewTransportFromCode("exchange-code")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed retrieving token: %s.", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
tok := tr.Token()
|
||||||
if tok.Expired() {
|
if tok.Expired() {
|
||||||
t.Errorf("Token shouldn't be expired.")
|
t.Errorf("Token shouldn't be expired.")
|
||||||
}
|
}
|
||||||
|
@ -119,11 +118,12 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
|
||||||
w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer"}`))
|
w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer"}`))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf := newTestConf(ts.URL)
|
f := newTestFlow(ts.URL)
|
||||||
tok, err := conf.Exchange("exchange-code")
|
tr, err := f.NewTransportFromCode("exchange-code")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed retrieving token: %s.", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
tok := tr.Token()
|
||||||
if tok.Expired() {
|
if tok.Expired() {
|
||||||
t.Errorf("Token shouldn't be expired.")
|
t.Errorf("Token shouldn't be expired.")
|
||||||
}
|
}
|
||||||
|
@ -145,11 +145,12 @@ func TestExchangeRequest_BadResponse(t *testing.T) {
|
||||||
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
|
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf := newTestConf(ts.URL)
|
f := newTestFlow(ts.URL)
|
||||||
tok, err := conf.Exchange("exchange-code")
|
tr, err := f.NewTransportFromCode("exchange-code")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed retrieving token: %s.", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
tok := tr.Token()
|
||||||
if tok.AccessToken != "" {
|
if tok.AccessToken != "" {
|
||||||
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
|
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
|
||||||
}
|
}
|
||||||
|
@ -161,21 +162,18 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
|
||||||
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
|
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf := newTestConf(ts.URL)
|
f := newTestFlow(ts.URL)
|
||||||
tok, err := conf.Exchange("exchange-code")
|
tr, err := f.NewTransportFromCode("exchange-code")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed retrieving token: %s.", err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
tok := tr.Token()
|
||||||
if tok.AccessToken != "" {
|
if tok.AccessToken != "" {
|
||||||
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
|
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
||||||
conf, _ := NewConfig(&Options{
|
|
||||||
ClientID: "CLIENT_ID",
|
|
||||||
}, "https://accounts.google.com/auth",
|
|
||||||
"https://accounts.google.com/token")
|
|
||||||
tr := &mockTransport{
|
tr := &mockTransport{
|
||||||
rt: func(r *http.Request) (w *http.Response, err error) {
|
rt: func(r *http.Request) (w *http.Response, err error) {
|
||||||
headerAuth := r.Header.Get("Authorization")
|
headerAuth := r.Header.Get("Authorization")
|
||||||
|
@ -185,12 +183,20 @@ func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
||||||
return nil, errors.New("no response")
|
return nil, errors.New("no response")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conf.Client = &http.Client{Transport: tr}
|
c := &http.Client{Transport: tr}
|
||||||
conf.Exchange("code")
|
f, _ := New(
|
||||||
|
Client("CLIENT_ID", ""),
|
||||||
|
Endpoint("https://accounts.google.com/auth", "https://accounts.google.com/token"),
|
||||||
|
HTTPClient(c),
|
||||||
|
)
|
||||||
|
f.NewTransportFromCode("code")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTokenRefreshRequest(t *testing.T) {
|
func TestTokenRefreshRequest(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.String() == "/somethingelse" {
|
||||||
|
return
|
||||||
|
}
|
||||||
if r.URL.String() != "/token" {
|
if r.URL.String() != "/token" {
|
||||||
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
|
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
|
||||||
}
|
}
|
||||||
|
@ -204,29 +210,35 @@ func TestTokenRefreshRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf := newTestConf(ts.URL)
|
f := newTestFlow(ts.URL)
|
||||||
conf.FetchToken(&Token{RefreshToken: "REFRESH_TOKEN"})
|
tr := f.NewTransportFromToken(&Token{RefreshToken: "REFRESH_TOKEN"})
|
||||||
}
|
c := http.Client{Transport: tr}
|
||||||
|
c.Get(ts.URL + "/somethingelse")
|
||||||
func TestNewTransportWithCode(t *testing.T) {
|
|
||||||
exchanged := false
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.RequestURI() == "/token" {
|
|
||||||
exchanged = true
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
conf := newTestConf(ts.URL)
|
|
||||||
conf.NewTransportWithCode("exchange-code")
|
|
||||||
if !exchanged {
|
|
||||||
t.Errorf("NewTransportWithCode should have exchanged the code, but it didn't.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFetchWithNoRefreshToken(t *testing.T) {
|
func TestFetchWithNoRefreshToken(t *testing.T) {
|
||||||
fetcher := newTestConf("")
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_, err := fetcher.FetchToken(&Token{})
|
if r.URL.String() == "/somethingelse" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.URL.String() != "/token" {
|
||||||
|
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
|
||||||
|
}
|
||||||
|
headerContentType := r.Header.Get("Content-Type")
|
||||||
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
||||||
|
}
|
||||||
|
body, _ := ioutil.ReadAll(r.Body)
|
||||||
|
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
||||||
|
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
f := newTestFlow(ts.URL)
|
||||||
|
tr := f.NewTransportFromToken(&Token{})
|
||||||
|
c := http.Client{Transport: tr}
|
||||||
|
_, err := c.Get(ts.URL + "/somethingelse")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Fetch should return an error if no refresh token is set")
|
t.Errorf("Fetch should return an error if no refresh token is set")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
23
transport.go
23
transport.go
|
@ -31,9 +31,6 @@ type Token struct {
|
||||||
// The remaining lifetime of the access token.
|
// The remaining lifetime of the access token.
|
||||||
Expiry time.Time `json:"expiry,omitempty"`
|
Expiry time.Time `json:"expiry,omitempty"`
|
||||||
|
|
||||||
// Subject is the user to impersonate.
|
|
||||||
Subject string `json:"subject,omitempty"`
|
|
||||||
|
|
||||||
// raw optionally contains extra metadata from the server
|
// raw optionally contains extra metadata from the server
|
||||||
// when updating a token.
|
// when updating a token.
|
||||||
raw interface{}
|
raw interface{}
|
||||||
|
@ -69,8 +66,8 @@ func (t *Token) Expired() bool {
|
||||||
|
|
||||||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
|
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
fetcher TokenFetcher
|
fetcher func(t *Token) (*Token, error)
|
||||||
origTransport http.RoundTripper
|
base http.RoundTripper
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
token *Token
|
token *Token
|
||||||
|
@ -79,8 +76,8 @@ type Transport struct {
|
||||||
// NewTransport creates a new Transport that uses the provided
|
// NewTransport creates a new Transport that uses the provided
|
||||||
// token fetcher as token retrieving strategy. It authenticates
|
// token fetcher as token retrieving strategy. It authenticates
|
||||||
// the requests and delegates origTransport to make the actual requests.
|
// the requests and delegates origTransport to make the actual requests.
|
||||||
func NewTransport(origTransport http.RoundTripper, fetcher TokenFetcher, token *Token) *Transport {
|
func newTransport(base http.RoundTripper, fn func(t *Token) (*Token, error), token *Token) *Transport {
|
||||||
return &Transport{origTransport: origTransport, fetcher: fetcher, token: token}
|
return &Transport{base: base, fetcher: fn, token: token}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip authorizes and authenticates the request with an
|
// RoundTrip authorizes and authenticates the request with an
|
||||||
|
@ -93,7 +90,6 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
|
||||||
// Check if the token is refreshable.
|
// Check if the token is refreshable.
|
||||||
// If token is refreshable, don't return an error,
|
// If token is refreshable, don't return an error,
|
||||||
// rather refresh.
|
// rather refresh.
|
||||||
|
|
||||||
if err := t.RefreshToken(); err != nil {
|
if err := t.RefreshToken(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -108,11 +104,8 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
|
||||||
if typ == "" {
|
if typ == "" {
|
||||||
typ = defaultTokenType
|
typ = defaultTokenType
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Authorization", typ+" "+token.AccessToken)
|
req.Header.Set("Authorization", typ+" "+token.AccessToken)
|
||||||
|
return t.base.RoundTrip(req)
|
||||||
// Make the HTTP request.
|
|
||||||
return t.origTransport.RoundTrip(req)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token returns the token that authorizes and
|
// Token returns the token that authorizes and
|
||||||
|
@ -127,7 +120,6 @@ func (t *Transport) Token() *Token {
|
||||||
func (t *Transport) SetToken(v *Token) {
|
func (t *Transport) SetToken(v *Token) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
t.token = v
|
t.token = v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,14 +129,11 @@ func (t *Transport) SetToken(v *Token) {
|
||||||
func (t *Transport) RefreshToken() error {
|
func (t *Transport) RefreshToken() error {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
token, err := t.fetcher(t.token)
|
||||||
token, err := t.fetcher.FetchToken(t.token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.token = token
|
t.token = token
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,12 +9,18 @@ import (
|
||||||
|
|
||||||
type mockTokenFetcher struct{ token *Token }
|
type mockTokenFetcher struct{ token *Token }
|
||||||
|
|
||||||
|
func (f *mockTokenFetcher) Fn() func(*Token) (*Token, error) {
|
||||||
|
return func(*Token) (*Token, error) {
|
||||||
|
return f.token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *mockTokenFetcher) FetchToken(existing *Token) (*Token, error) {
|
func (f *mockTokenFetcher) FetchToken(existing *Token) (*Token, error) {
|
||||||
return f.token, nil
|
return f.token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInitialTokenRead(t *testing.T) {
|
func TestInitialTokenRead(t *testing.T) {
|
||||||
tr := NewTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"})
|
tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"})
|
||||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Authorization") != "Bearer abc" {
|
if r.Header.Get("Authorization") != "Bearer abc" {
|
||||||
t.Errorf("Transport doesn't set the Authorization header from the initial token")
|
t.Errorf("Transport doesn't set the Authorization header from the initial token")
|
||||||
|
@ -31,7 +37,7 @@ func TestTokenFetch(t *testing.T) {
|
||||||
AccessToken: "abc",
|
AccessToken: "abc",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
tr := NewTransport(http.DefaultTransport, fetcher, nil)
|
tr := newTransport(http.DefaultTransport, fetcher.Fn(), nil)
|
||||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Authorization") != "Bearer abc" {
|
if r.Header.Get("Authorization") != "Bearer abc" {
|
||||||
t.Errorf("Transport doesn't set the Authorization header from the fetched token")
|
t.Errorf("Transport doesn't set the Authorization header from the fetched token")
|
||||||
|
|
Loading…
Reference in New Issue