diff --git a/example_test.go b/example_test.go index de8e9f8..52a3587 100644 --- a/example_test.go +++ b/example_test.go @@ -1,3 +1,7 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package oauth2_test import ( @@ -13,33 +17,34 @@ import ( // Related to https://codereview.appspot.com/107320046 func TestA(t *testing.T) {} -func Example_config() { - conf, err := oauth2.NewConfig(&oauth2.Options{ - ClientID: "YOUR_CLIENT_ID", - ClientSecret: "YOUR_CLIENT_SECRET", - RedirectURL: "YOUR_REDIRECT_URL", - Scopes: []string{"SCOPE1", "SCOPE2"}, - }, - "https://provider.com/o/oauth2/auth", - "https://provider.com/o/oauth2/token") +func Example_regular() { + f, err := oauth2.New( + oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"), + oauth2.RedirectURL("YOUR_REDIRECT_URL"), + oauth2.Scope("SCOPE1", "SCOPE2"), + oauth2.Endpoint( + "https://provider.com/o/oauth2/auth", + "https://provider.com/o/oauth2/token", + ), + ) if err != nil { log.Fatal(err) } // Redirect user to consent page to ask for permission // for the scopes specified above. - url := conf.AuthCodeURL("state", "online", "auto") + url := f.AuthCodeURL("state", "online", "auto") fmt.Printf("Visit the URL for the auth dialog: %v", url) // Use the authorization code that is pushed to the redirect URL. // NewTransportWithCode will do the handshake to retrieve // an access token and initiate a Transport that is // authorized and authenticated by the retrieved token. - var authorizationCode string - if _, err = fmt.Scan(&authorizationCode); err != nil { + var code string + if _, err = fmt.Scan(&code); err != nil { log.Fatal(err) } - t, err := conf.NewTransportWithCode(authorizationCode) + t, err := f.NewTransportFromCode(code) if err != nil { log.Fatal(err) } @@ -50,9 +55,8 @@ func Example_config() { client.Get("...") } -func Example_jWTConfig() { - conf, err := oauth2.NewJWTConfig(&oauth2.JWTOptions{ - Email: "xxx@developer.gserviceaccount.com", +func Example_jWT() { + f, err := oauth2.New( // The contents of your RSA private key or your PEM file // that contains a private key. // If you have a p12 file instead, you @@ -61,23 +65,23 @@ func Example_jWTConfig() { // $ openssl pkcs12 -in key.p12 -out key.pem -nodes // // It only supports PEM containers with no passphrase. - PrivateKey: []byte("PRIVATE KEY CONTENTS"), - Scopes: []string{"SCOPE1", "SCOPE2"}, - }, - "https://provider.com/o/oauth2/token") + oauth2.JWTClient( + "xxx@developer.gserviceaccount.com", + []byte("-----BEGIN RSA PRIVATE KEY-----...")), + oauth2.Scope("SCOPE1", "SCOPE2"), + oauth2.JWTEndpoint("https://provider.com/o/oauth2/token"), + // If you would like to impersonate a user, you can + // create a transport with a subject. The following GET + // request will be made on the behalf of user@example.com. + // Subject is optional. + oauth2.Subject("user@example.com"), + ) if err != nil { log.Fatal(err) } // Initiate an http.Client, the following GET request will be - // authorized and authenticated on the behalf of - // xxx@developer.gserviceaccount.com. - client := http.Client{Transport: conf.NewTransport()} - client.Get("...") - - // If you would like to impersonate a user, you can - // create a transport with a subject. The following GET - // request will be made on the behalf of user@example.com. - client = http.Client{Transport: conf.NewTransportWithUser("user@example.com")} + // authorized and authenticated on the behalf of user@example.com. + client := http.Client{Transport: f.NewTransport()} client.Get("...") } diff --git a/google/appengine.go b/google/appengine.go index ead2cdc..461fb88 100644 --- a/google/appengine.go +++ b/google/appengine.go @@ -7,7 +7,6 @@ package google import ( - "net/http" "strings" "sync" "time" @@ -19,6 +18,82 @@ import ( "appengine/urlfetch" ) +var ( + // memcacheGob enables mocking of the memcache.Gob calls for unit testing. + memcacheGob memcacher = &aeMemcache{} + + // accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing. + accessTokenFunc = appengine.AccessToken + + // mu protects multiple threads from attempting to fetch a token at the same time. + mu sync.Mutex + + // tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls. + tokens map[string]*oauth2.Token +) + +// safetyMargin is used to avoid clock-skew problems. +// 5 minutes is conservative because tokens are valid for 60 minutes. +const safetyMargin = 5 * time.Minute + +func init() { + tokens = make(map[string]*oauth2.Token) +} + +// AppEngineContext requires an App Engine request context. +func AppEngineContext(ctx appengine.Context) oauth2.Option { + return func(opts *oauth2.Options) error { + opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts) + opts.Transport = &urlfetch.Transport{Context: ctx} + return nil + } +} + +// FetchToken fetches a new access token for the provided scopes. +// Tokens are cached locally and also with Memcache so that the app can scale +// without hitting quota limits by calling appengine.AccessToken too frequently. +func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) { + return func(existing *oauth2.Token) (*oauth2.Token, error) { + mu.Lock() + defer mu.Unlock() + + key := ":" + strings.Join(opts.Scopes, "_") + now := time.Now().Add(safetyMargin) + if t, ok := tokens[key]; ok && !t.Expiry.Before(now) { + return t, nil + } + delete(tokens, key) + + // Attempt to get token from Memcache + tok := new(oauth2.Token) + _, err := memcacheGob.Get(ctx, key, tok) + if err == nil && !tok.Expiry.Before(now) { + tokens[key] = tok // Save token locally + return tok, nil + } + + token, expiry, err := accessTokenFunc(ctx, opts.Scopes...) + if err != nil { + return nil, err + } + t := &oauth2.Token{ + AccessToken: token, + Expiry: expiry, + } + tokens[key] = t + // Also back up token in Memcache + if err = memcacheGob.Set(ctx, &memcache.Item{ + Key: key, + Value: []byte{}, + Object: *t, + Expiration: expiry.Sub(now), + }); err != nil { + ctx.Errorf("unexpected memcache.Set error: %v", err) + } + return t, nil + } +} + // aeMemcache wraps the needed Memcache functionality to make it easy to mock type aeMemcache struct{} @@ -34,99 +109,3 @@ type memcacher interface { Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) Set(c appengine.Context, item *memcache.Item) error } - -// memcacheGob enables mocking of the memcache.Gob calls for unit testing. -var memcacheGob memcacher = &aeMemcache{} - -// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing. -var accessTokenFunc = appengine.AccessToken - -// safetyMargin is used to avoid clock-skew problems. -// 5 minutes is conservative because tokens are valid for 60 minutes. -const safetyMargin = 5 * time.Minute - -// mu protects multiple threads from attempting to fetch a token at the same time. -var mu sync.Mutex - -// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls. -var tokens map[string]*oauth2.Token - -func init() { - tokens = make(map[string]*oauth2.Token) -} - -// AppEngineConfig represents a configuration for an -// App Engine application's Google service account. -type AppEngineConfig struct { - // Transport is the http.RoundTripper to be used - // to construct new oauth2.Transport instances from - // this configuration. - Transport http.RoundTripper - - context appengine.Context - scopes []string -} - -// NewAppEngineConfig creates a new AppEngineConfig for the -// provided auth scopes. -func NewAppEngineConfig(context appengine.Context, scopes ...string) *AppEngineConfig { - return &AppEngineConfig{ - context: context, - scopes: scopes, - } -} - -// NewTransport returns a transport that authorizes -// the requests with the application's service account. -func (c *AppEngineConfig) NewTransport() *oauth2.Transport { - return oauth2.NewTransport(c.transport(), c, nil) -} - -// FetchToken fetches a new access token for the provided scopes. -// Tokens are cached locally and also with Memcache so that the app can scale -// without hitting quota limits by calling appengine.AccessToken too frequently. -func (c *AppEngineConfig) FetchToken(existing *oauth2.Token) (*oauth2.Token, error) { - mu.Lock() - defer mu.Unlock() - key := ":" + strings.Join(c.scopes, "_") - now := time.Now().Add(safetyMargin) - if t, ok := tokens[key]; ok && !t.Expiry.Before(now) { - return t, nil - } - delete(tokens, key) - - // Attempt to get token from Memcache - tok := new(oauth2.Token) - _, err := memcacheGob.Get(c.context, key, tok) - if err == nil && !tok.Expiry.Before(now) { - tokens[key] = tok // Save token locally - return tok, nil - } - - token, expiry, err := accessTokenFunc(c.context, c.scopes...) - if err != nil { - return nil, err - } - t := &oauth2.Token{ - AccessToken: token, - Expiry: expiry, - } - tokens[key] = t - // Also back up token in Memcache - if err = memcacheGob.Set(c.context, &memcache.Item{ - Key: key, - Value: []byte{}, - Object: *t, - Expiration: expiry.Sub(now), - }); err != nil { - c.context.Errorf("unexpected memcache.Set error: %v", err) - } - return t, nil -} - -func (c *AppEngineConfig) transport() http.RoundTripper { - if c.Transport != nil { - return c.Transport - } - return &urlfetch.Transport{Context: c.context} -} diff --git a/google/appengine_test.go b/google/appengine_test.go index 9637493..5cd67f2 100644 --- a/google/appengine_test.go +++ b/google/appengine_test.go @@ -9,6 +9,7 @@ package google import ( "fmt" "log" + "net/http" "sync" "testing" "time" @@ -72,11 +73,16 @@ func TestFetchTokenLocalCacheMiss(t *testing.T) { memcacheGob = m accessTokenCount = 0 delete(tokens, testScopeKey) // clear local cache - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) if err != nil { - t.Errorf("unable to FetchToken: %v", err) + t.Error(err) } + tr := f.NewTransport() + c := http.Client{Transport: tr} + c.Get("server") if w := 1; m.getCount != w { t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) } @@ -102,8 +108,16 @@ func TestFetchTokenLocalCacheHit(t *testing.T) { AccessToken: "mytoken", Expiry: time.Now().Add(1 * time.Hour), } - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) + if err != nil { + t.Error(err) + } + tr := f.NewTransport() + c := http.Client{Transport: tr} + c.Get("server") if err != nil { t.Errorf("unable to FetchToken: %v", err) } @@ -139,11 +153,16 @@ func TestFetchTokenMemcacheHit(t *testing.T) { Expiration: 1 * time.Hour, }) m.setCount = 0 - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) if err != nil { - t.Errorf("unable to FetchToken: %v", err) + t.Error(err) } + c := http.Client{Transport: f.NewTransport()} + c.Get("server") if w := 1; m.getCount != w { t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) } @@ -180,11 +199,15 @@ func TestFetchTokenLocalCacheExpired(t *testing.T) { Expiration: 1 * time.Hour, }) m.setCount = 0 - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) if err != nil { - t.Errorf("unable to FetchToken: %v", err) + t.Error(err) } + c := http.Client{Transport: f.NewTransport()} + c.Get("server") if w := 1; m.getCount != w { t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) } @@ -217,11 +240,15 @@ func TestFetchTokenMemcacheExpired(t *testing.T) { Expiration: -1 * time.Hour, }) m.setCount = 0 - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) if err != nil { - t.Errorf("unable to FetchToken: %v", err) + t.Error(err) } + c := http.Client{Transport: f.NewTransport()} + c.Get("server") if w := 1; m.getCount != w { t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) } diff --git a/google/appenginevm.go b/google/appenginevm.go index ef88850..d71c333 100644 --- a/google/appenginevm.go +++ b/google/appenginevm.go @@ -7,7 +7,6 @@ package google import ( - "net/http" "strings" "sync" "time" @@ -17,6 +16,81 @@ import ( "google.golang.org/appengine/memcache" ) +var ( + // memcacheGob enables mocking of the memcache.Gob calls for unit testing. + memcacheGob memcacher = &aeMemcache{} + + // accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing. + accessTokenFunc = appengine.AccessToken + + // mu protects multiple threads from attempting to fetch a token at the same time. + mu sync.Mutex + + // tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls. + tokens map[string]*oauth2.Token +) + +// safetyMargin is used to avoid clock-skew problems. +// 5 minutes is conservative because tokens are valid for 60 minutes. +const safetyMargin = 5 * time.Minute + +func init() { + tokens = make(map[string]*oauth2.Token) +} + +// AppEngineContext requires an App Engine request context. +func AppEngineContext(ctx appengine.Context) oauth2.Option { + return func(opts *oauth2.Options) error { + opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts) + return nil + } +} + +// FetchToken fetches a new access token for the provided scopes. +// Tokens are cached locally and also with Memcache so that the app can scale +// without hitting quota limits by calling appengine.AccessToken too frequently. +func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) { + return func(existing *oauth2.Token) (*oauth2.Token, error) { + mu.Lock() + defer mu.Unlock() + + key := ":" + strings.Join(opts.Scopes, "_") + now := time.Now().Add(safetyMargin) + if t, ok := tokens[key]; ok && !t.Expiry.Before(now) { + return t, nil + } + delete(tokens, key) + + // Attempt to get token from Memcache + tok := new(oauth2.Token) + _, err := memcacheGob.Get(ctx, key, tok) + if err == nil && !tok.Expiry.Before(now) { + tokens[key] = tok // Save token locally + return tok, nil + } + + token, expiry, err := accessTokenFunc(ctx, opts.Scopes...) + if err != nil { + return nil, err + } + t := &oauth2.Token{ + AccessToken: token, + Expiry: expiry, + } + tokens[key] = t + // Also back up token in Memcache + if err = memcacheGob.Set(ctx, &memcache.Item{ + Key: key, + Value: []byte{}, + Object: *t, + Expiration: expiry.Sub(now), + }); err != nil { + ctx.Errorf("unexpected memcache.Set error: %v", err) + } + return t, nil + } +} + // aeMemcache wraps the needed Memcache functionality to make it easy to mock type aeMemcache struct{} @@ -32,99 +106,3 @@ type memcacher interface { Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) Set(c appengine.Context, item *memcache.Item) error } - -// memcacheGob enables mocking of the memcache.Gob calls for unit testing. -var memcacheGob memcacher = &aeMemcache{} - -// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing. -var accessTokenFunc = appengine.AccessToken - -// safetyMargin is used to avoid clock-skew problems. -// 5 minutes is conservative because tokens are valid for 60 minutes. -const safetyMargin = 5 * time.Minute - -// mu protects multiple threads from attempting to fetch a token at the same time. -var mu sync.Mutex - -// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls. -var tokens map[string]*oauth2.Token - -func init() { - tokens = make(map[string]*oauth2.Token) -} - -// AppEngineConfig represents a configuration for an -// App Engine application's Google service account. -type AppEngineConfig struct { - // Transport is the http.RoundTripper to be used - // to construct new oauth2.Transport instances from - // this configuration. - Transport http.RoundTripper - - context appengine.Context - scopes []string -} - -// NewAppEngineConfig creates a new AppEngineConfig for the -// provided auth scopes. -func NewAppEngineConfig(context appengine.Context, scopes ...string) *AppEngineConfig { - return &AppEngineConfig{ - context: context, - scopes: scopes, - } -} - -// NewTransport returns a transport that authorizes -// the requests with the application's service account. -func (c *AppEngineConfig) NewTransport() *oauth2.Transport { - return oauth2.NewTransport(c.transport(), c, nil) -} - -// FetchToken fetches a new access token for the provided scopes. -// Tokens are cached locally and also with Memcache so that the app can scale -// without hitting quota limits by calling appengine.AccessToken too frequently. -func (c *AppEngineConfig) FetchToken(existing *oauth2.Token) (*oauth2.Token, error) { - mu.Lock() - defer mu.Unlock() - key := ":" + strings.Join(c.scopes, "_") - now := time.Now().Add(safetyMargin) - if t, ok := tokens[key]; ok && !t.Expiry.Before(now) { - return t, nil - } - delete(tokens, key) - - // Attempt to get token from Memcache - tok := new(oauth2.Token) - _, err := memcacheGob.Get(c.context, key, tok) - if err == nil && !tok.Expiry.Before(now) { - tokens[key] = tok // Save token locally - return tok, nil - } - - token, expiry, err := accessTokenFunc(c.context, c.scopes...) - if err != nil { - return nil, err - } - t := &oauth2.Token{ - AccessToken: token, - Expiry: expiry, - } - tokens[key] = t - // Also back up token in Memcache - if err = memcacheGob.Set(c.context, &memcache.Item{ - Key: key, - Value: []byte{}, - Object: *t, - Expiration: expiry.Sub(now), - }); err != nil { - c.context.Errorf("unexpected memcache.Set error: %v", err) - } - return t, nil -} - -func (c *AppEngineConfig) transport() http.RoundTripper { - if c.Transport != nil { - return c.Transport - } - return http.DefaultTransport -} diff --git a/google/appenginevm_test.go b/google/appenginevm_test.go index dc3ecf5..4af7651 100644 --- a/google/appenginevm_test.go +++ b/google/appenginevm_test.go @@ -9,6 +9,7 @@ package google import ( "fmt" "log" + "net/http" "sync" "testing" "time" @@ -71,11 +72,16 @@ func TestFetchTokenLocalCacheMiss(t *testing.T) { memcacheGob = m accessTokenCount = 0 delete(tokens, testScopeKey) // clear local cache - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) if err != nil { - t.Errorf("unable to FetchToken: %v", err) + t.Error(err) } + tr := f.NewTransport() + c := http.Client{Transport: tr} + c.Get("server") if w := 1; m.getCount != w { t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) } @@ -101,8 +107,16 @@ func TestFetchTokenLocalCacheHit(t *testing.T) { AccessToken: "mytoken", Expiry: time.Now().Add(1 * time.Hour), } - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) + if err != nil { + t.Error(err) + } + tr := f.NewTransport() + c := http.Client{Transport: tr} + c.Get("server") if err != nil { t.Errorf("unable to FetchToken: %v", err) } @@ -138,11 +152,16 @@ func TestFetchTokenMemcacheHit(t *testing.T) { Expiration: 1 * time.Hour, }) m.setCount = 0 - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) if err != nil { - t.Errorf("unable to FetchToken: %v", err) + t.Error(err) } + c := http.Client{Transport: f.NewTransport()} + c.Get("server") if w := 1; m.getCount != w { t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) } @@ -179,11 +198,15 @@ func TestFetchTokenLocalCacheExpired(t *testing.T) { Expiration: 1 * time.Hour, }) m.setCount = 0 - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) if err != nil { - t.Errorf("unable to FetchToken: %v", err) + t.Error(err) } + c := http.Client{Transport: f.NewTransport()} + c.Get("server") if w := 1; m.getCount != w { t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) } @@ -216,11 +239,15 @@ func TestFetchTokenMemcacheExpired(t *testing.T) { Expiration: -1 * time.Hour, }) m.setCount = 0 - config := NewAppEngineConfig(nil, testScope) - _, err := config.FetchToken(nil) + f, err := oauth2.New( + AppEngineContext(nil), + oauth2.Scope(testScope), + ) if err != nil { - t.Errorf("unable to FetchToken: %v", err) + t.Error(err) } + c := http.Client{Transport: f.NewTransport()} + c.Get("server") if w := 1; m.getCount != w { t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w) } diff --git a/google/example_test.go b/google/example_test.go index 5471d93..54a5fd0 100644 --- a/google/example_test.go +++ b/google/example_test.go @@ -24,25 +24,25 @@ func TestA(t *testing.T) {} func Example_webServer() { // Your credentials should be obtained from the Google // Developer Console (https://console.developers.google.com). - config, err := google.NewConfig(&oauth2.Options{ - ClientID: "YOUR_CLIENT_ID", - ClientSecret: "YOUR_CLIENT_SECRET", - RedirectURL: "YOUR_REDIRECT_URL", - Scopes: []string{ + f, err := oauth2.New( + oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"), + oauth2.RedirectURL("YOUR_REDIRECT_URL"), + oauth2.Scope( "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/blogger"}, - }) + "https://www.googleapis.com/auth/blogger", + ), + google.Endpoint(), + ) if err != nil { log.Fatal(err) } - // Redirect user to Google's consent page to ask for permission // for the scopes specified above. - url := config.AuthCodeURL("state", "online", "auto") + url := f.AuthCodeURL("state", "online", "auto") fmt.Printf("Visit the URL for the auth dialog: %v", url) // Handle the exchange code to initiate a transport - t, err := config.NewTransportWithCode("exchange-code") + t, err := f.NewTransportFromCode("exchange-code") if err != nil { log.Fatal(err) } @@ -58,9 +58,12 @@ func Example_serviceAccountsJSON() { // To create a service account client, click "Create new Client ID", // select "Service Account", and click "Create Client ID". A JSON // key file will then be downloaded to your computer. - config, err := google.NewServiceAccountJSONConfig( - "/path/to/your-project-key.json", - "https://www.googleapis.com/auth/bigquery", + f, err := oauth2.New( + google.ServiceAccountJSONKey("/path/to/your-project-key.json"), + oauth2.Scope( + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/blogger", + ), ) if err != nil { log.Fatal(err) @@ -68,63 +71,74 @@ func Example_serviceAccountsJSON() { // Initiate an http.Client. The following GET request will be // authorized and authenticated on the behalf of // your service account. - client := http.Client{Transport: config.NewTransport()} - client.Get("...") - - // If you would like to impersonate a user, you can - // create a transport with a subject. The following GET - // request will be made on the behalf of user@example.com. - client = http.Client{Transport: config.NewTransportWithUser("user@example.com")} + client := http.Client{Transport: f.NewTransport()} client.Get("...") } func Example_serviceAccounts() { // Your credentials should be obtained from the Google // Developer Console (https://console.developers.google.com). - config, err := google.NewServiceAccountConfig(&oauth2.JWTOptions{ - Email: "xxx@developer.gserviceaccount.com", + f, err := oauth2.New( // The contents of your RSA private key or your PEM file // that contains a private key. // If you have a p12 file instead, you - // can use `openssl` to export the private key into a PEM file. + // can use `openssl` to export the private key into a pem file. // // $ openssl pkcs12 -in key.p12 -out key.pem -nodes // - // Supports only PEM containers without a passphrase. - PrivateKey: []byte("PRIVATE KEY CONTENTS"), - Scopes: []string{ + // It only supports PEM containers with no passphrase. + oauth2.JWTClient( + "xxx@developer.gserviceaccount.com", + []byte("-----BEGIN RSA PRIVATE KEY-----...")), + oauth2.Scope( "https://www.googleapis.com/auth/bigquery", - }, - }) + "https://www.googleapis.com/auth/blogger", + ), + google.JWTEndpoint(), + // If you would like to impersonate a user, you can + // create a transport with a subject. The following GET + // request will be made on the behalf of user@example.com. + // Subject is optional. + oauth2.Subject("user@example.com"), + ) if err != nil { log.Fatal(err) } // Initiate an http.Client, the following GET request will be - // authorized and authenticated on the behalf of - // xxx@developer.gserviceaccount.com. - client := http.Client{Transport: config.NewTransport()} - client.Get("...") - - // If you would like to impersonate a user, you can - // create a transport with a subject. The following GET - // request will be made on the behalf of user@example.com. - client = http.Client{Transport: config.NewTransportWithUser("user@example.com")} + // authorized and authenticated on the behalf of user@example.com. + client := http.Client{Transport: f.NewTransport()} client.Get("...") } -func Example_appEngine() { - c := appengine.NewContext(nil) - config := google.NewAppEngineConfig(c, "https://www.googleapis.com/auth/bigquery") +func Example_appEngineVMs() { + ctx := appengine.NewContext(nil) + f, err := oauth2.New( + google.AppEngineContext(ctx), + oauth2.Scope( + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/blogger", + ), + ) + if err != nil { + log.Fatal(err) + } // The following client will be authorized by the App Engine // app's service account for the provided scopes. - client := http.Client{Transport: config.NewTransport()} + client := http.Client{Transport: f.NewTransport()} client.Get("...") } func Example_computeEngine() { - // If no other account is specified, "default" is used. - config := google.NewComputeEngineConfig("") - client := http.Client{Transport: config.NewTransport()} + f, err := oauth2.New( + // Query Google Compute Engine's metadata server to retrieve + // an access token for the provided account. + // If no account is specified, "default" is used. + google.ComputeEngineAccount(""), + ) + if err != nil { + log.Fatal(err) + } + client := http.Client{Transport: f.NewTransport()} client.Get("...") } diff --git a/google/google.go b/google/google.go index 2fcf8a6..69c9d22 100644 --- a/google/google.go +++ b/google/google.go @@ -15,18 +15,20 @@ package google import ( "encoding/json" + "fmt" "io/ioutil" "net/http" + "net/url" "time" "github.com/golang/oauth2" + "github.com/golang/oauth2/internal" ) -const ( - // Google endpoints. - uriGoogleAuth = "https://accounts.google.com/o/oauth2/auth" - uriGoogleToken = "https://accounts.google.com/o/oauth2/token" +var ( + uriGoogleAuth, _ = url.Parse("https://accounts.google.com/o/oauth2/auth") + uriGoogleToken, _ = url.Parse("https://accounts.google.com/o/oauth2/token") ) type metaTokenRespBody struct { @@ -35,112 +37,93 @@ type metaTokenRespBody struct { TokenType string `json:"token_type"` } -// ComputeEngineConfig represents a OAuth 2.0 consumer client -// running on Google Compute Engine. -type ComputeEngineConfig struct { - // Client is the HTTP client to be used to retrieve - // tokens from the OAuth 2.0 provider. - Client *http.Client - - // Transport is the round tripper to be used - // to construct new oauth2.Transport instances from - // this configuration. - Transport http.RoundTripper - - account string +// JWTEndpoint adds the endpoints required to complete the 2-legged service account flow. +func JWTEndpoint() oauth2.Option { + return func(opts *oauth2.Options) error { + opts.AUD = uriGoogleToken + return nil + } } -// NewConfig creates a new OAuth2 config that uses Google -// endpoints. -func NewConfig(opts *oauth2.Options) (*oauth2.Config, error) { - return oauth2.NewConfig(opts, uriGoogleAuth, uriGoogleToken) +// Endpoint adds the endpoints required to do the 3-legged Web server flow. +func Endpoint() oauth2.Option { + return func(opts *oauth2.Options) error { + opts.AuthURL = uriGoogleAuth + opts.TokenURL = uriGoogleToken + return nil + } } -// NewServiceAccountConfig creates a new JWT config that can -// fetch Bearer JWT tokens from Google endpoints. -func NewServiceAccountConfig(opts *oauth2.JWTOptions) (*oauth2.JWTConfig, error) { - return oauth2.NewJWTConfig(opts, uriGoogleToken) +// ComputeEngineAccount uses the specified account to retrieve an access +// token from the Google Compute Engine's metadata server. If no user is +// provided, "default" is being used. +func ComputeEngineAccount(account string) oauth2.Option { + return func(opts *oauth2.Options) error { + if account == "" { + account = "default" + } + opts.TokenFetcherFunc = makeComputeFetcher(opts, account) + return nil + } } -// NewServiceAccountJSONConfig creates a new JWT config from a -// JSON key file downloaded from the Google Developers Console. -// See the "Credentials" page under "APIs & Auth" for your project -// at https://console.developers.google.com. -func NewServiceAccountJSONConfig(filename string, scopes ...string) (*oauth2.JWTConfig, error) { - b, err := ioutil.ReadFile(filename) - if err != nil { - return nil, err +// ServiceAccountJSONKey uses the provided Google Developers +// JSON key file to authorize the user. See the "Credentials" page under +// "APIs & Auth" for your project at https://console.developers.google.com +// to download a JSON key file. +func ServiceAccountJSONKey(filename string) oauth2.Option { + return func(opts *oauth2.Options) error { + b, err := ioutil.ReadFile(filename) + if err != nil { + return err + } + var key struct { + Email string `json:"client_email"` + PrivateKey string `json:"private_key"` + } + if err := json.Unmarshal(b, &key); err != nil { + return err + } + pk, err := internal.ParseKey([]byte(key.PrivateKey)) + if err != nil { + return err + } + opts.Email = key.Email + opts.PrivateKey = pk + opts.AUD = uriGoogleToken + return nil } - var key struct { - Email string `json:"client_email"` - PrivateKey string `json:"private_key"` - } - if err := json.Unmarshal(b, &key); err != nil { - return nil, err - } - opts := &oauth2.JWTOptions{ - Email: key.Email, - PrivateKey: []byte(key.PrivateKey), - Scopes: scopes, - } - return NewServiceAccountConfig(opts) } -// NewComputeEngineConfig creates a new config that can fetch tokens -// from Google Compute Engine instance's metaserver. If no account is -// provided, default is used. -func NewComputeEngineConfig(account string) *ComputeEngineConfig { - return &ComputeEngineConfig{account: account} -} - -// NewTransport creates an authorized transport. -func (c *ComputeEngineConfig) NewTransport() *oauth2.Transport { - return oauth2.NewTransport(c.transport(), c, nil) -} - -// FetchToken retrieves a new access token via metadata server. -func (c *ComputeEngineConfig) FetchToken(existing *oauth2.Token) (token *oauth2.Token, err error) { - account := "default" - if c.account != "" { - account = c.account - } - u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + account + "/token" - req, err := http.NewRequest("GET", u, nil) - if err != nil { - return - } - req.Header.Add("X-Google-Metadata-Request", "True") - resp, err := c.client().Do(req) - if err != nil { - return - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode > 299 { - return nil, fmt.Errorf("oauth2: can't retrieve a token from metadata server, status code: %d", resp.StatusCode) - } - var tokenResp metaTokenRespBody - err = json.NewDecoder(resp.Body).Decode(&tokenResp) - if err != nil { - return - } - token = &oauth2.Token{ - AccessToken: tokenResp.AccessToken, - TokenType: tokenResp.TokenType, - Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second), - } - return -} - -func (c *ComputeEngineConfig) transport() http.RoundTripper { - if c.Transport != nil { - return c.Transport - } - return http.DefaultTransport -} - -func (c *ComputeEngineConfig) client() *http.Client { - if c.Client != nil { - return c.Client - } - return http.DefaultClient +func makeComputeFetcher(opts *oauth2.Options, account string) func(*oauth2.Token) (*oauth2.Token, error) { + return func(t *oauth2.Token) (*oauth2.Token, error) { + u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + account + "/token" + req, err := http.NewRequest("GET", u, nil) + if err != nil { + return nil, err + } + req.Header.Add("X-Google-Metadata-Request", "True") + c := &http.Client{} + if opts.Client != nil { + c = opts.Client + } + resp, err := c.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return nil, fmt.Errorf("oauth2: can't retrieve a token from metadata server, status code: %d", resp.StatusCode) + } + var tokenResp metaTokenRespBody + err = json.NewDecoder(resp.Body).Decode(&tokenResp) + if err != nil { + return nil, err + } + return &oauth2.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second), + }, nil + } } diff --git a/internal/oauth2.go b/internal/oauth2.go new file mode 100644 index 0000000..b91b662 --- /dev/null +++ b/internal/oauth2.go @@ -0,0 +1,36 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" +) + +// ParseKey converts the binary contents of a private key file +// to an *rsa.PrivateKey. It detects whether the private key is in a +// PEM container or not. If so, it extracts the the private key +// from PEM container before conversion. It only supports PEM +// containers with no passphrase. +func ParseKey(key []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(key) + if block != nil { + key = block.Bytes + } + parsedKey, err := x509.ParsePKCS8PrivateKey(key) + if err != nil { + parsedKey, err = x509.ParsePKCS1PrivateKey(key) + if err != nil { + return nil, err + } + } + parsed, ok := parsedKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("oauth2: private key is invalid") + } + return parsed, nil +} diff --git a/jws/jws.go b/jws/jws.go index f6565db..b81e4f0 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -71,15 +71,15 @@ func (c *ClaimSet) encode() (string, error) { // Marshal private claim set and then append it to b. prv, err := json.Marshal(c.PrivateClaims) if err != nil { - return "", fmt.Errorf("Invalid map of private claims %v", c.PrivateClaims) + return "", fmt.Errorf("jws: invalid map of private claims %v", c.PrivateClaims) } // Concatenate public and private claim JSON objects. if !bytes.HasSuffix(b, []byte{'}'}) { - return "", fmt.Errorf("Invalid JSON %s", b) + return "", fmt.Errorf("jws: invalid JSON %s", b) } if !bytes.HasPrefix(prv, []byte{'{'}) { - return "", fmt.Errorf("Invalid JSON %s", prv) + return "", fmt.Errorf("jws: invalid JSON %s", prv) } b[len(b)-1] = ',' // Replace closing curly brace with a comma. b = append(b, prv[1:]...) // Append private claims. @@ -109,7 +109,7 @@ func Decode(payload string) (c *ClaimSet, err error) { s := strings.Split(payload, ".") if len(s) < 2 { // TODO(jbd): Provide more context about the error. - return nil, errors.New("invalid token received") + return nil, errors.New("jws: invalid token received") } decoded, err := base64Decode(s[1]) if err != nil { diff --git a/jwt.go b/jwt.go index 28161ca..4c49975 100644 --- a/jwt.go +++ b/jwt.go @@ -5,11 +5,7 @@ package oauth2 import ( - "crypto/rsa" - "crypto/x509" "encoding/json" - "encoding/pem" - "errors" "fmt" "io/ioutil" "net/http" @@ -17,6 +13,7 @@ import ( "strings" "time" + "github.com/golang/oauth2/internal" "github.com/golang/oauth2/jws" ) @@ -25,175 +22,100 @@ var ( defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"} ) -// JWTOptions represents a OAuth2 client's crendentials to retrieve a -// Bearer JWT token. -type JWTOptions struct { - // Email is the OAuth client identifier used when communicating with - // the configured OAuth provider. - Email string `json:"email"` - - // PrivateKey contains the contents of an RSA private key or the - // contents of a PEM file that contains a private key. The provided - // private key is used to sign JWT payloads. - // PEM containers with a passphrase are not supported. - // Use the following command to convert a PKCS 12 file into a PEM. - // - // $ openssl pkcs12 -in key.p12 -out key.pem -nodes - // - PrivateKey []byte `json:"-"` - - // Scopes identify the level of access being requested. - Scopes []string `json:"scopes"` -} - -// NewJWTConfig creates a new configuration with the specified options -// and OAuth2 provider endpoint. -func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) { - audURL, err := url.Parse(aud) - if err != nil { - return nil, err - } - parsedKey, err := parseKey(opts.PrivateKey) - if err != nil { - return nil, err - } - return &JWTConfig{ - opts: opts, - aud: audURL, - key: parsedKey, - }, nil -} - -// JWTConfig represents an OAuth 2.0 provider and client options to -// provide authorized transports with a Bearer JWT token. -type JWTConfig struct { - // Client is the HTTP client to be used to retrieve - // tokens from the OAuth 2.0 provider. - Client *http.Client - - // Transport is the http.RoundTripper to be used - // to construct new oauth2.Transport instances from - // this configuration. - Transport http.RoundTripper - - opts *JWTOptions - aud *url.URL - key *rsa.PrivateKey -} - -// NewTransport creates a transport that is authorize with the -// parent JWT configuration. -func (c *JWTConfig) NewTransport() *Transport { - return NewTransport(c.transport(), c, &Token{}) -} - -// NewTransportWithUser creates a transport that is authorized by -// the client and impersonates the specified user. -func (c *JWTConfig) NewTransportWithUser(user string) *Transport { - return NewTransport(c.transport(), c, &Token{Subject: user}) -} - -// fetchToken retrieves a new access token and updates the existing token -// with the newly fetched credentials. -func (c *JWTConfig) FetchToken(existing *Token) (*Token, error) { - if existing == nil { - existing = &Token{} - } - - claimSet := &jws.ClaimSet{ - Iss: c.opts.Email, - Scope: strings.Join(c.opts.Scopes, " "), - Aud: c.aud.String(), - } - - if existing.Subject != "" { - claimSet.Sub = existing.Subject - // prn is the old name of sub. Keep setting it - // to be compatible with legacy OAuth 2.0 providers. - claimSet.Prn = existing.Subject - } - - payload, err := jws.Encode(defaultHeader, claimSet, c.key) - if err != nil { - return nil, err - } - v := url.Values{} - v.Set("grant_type", defaultGrantType) - v.Set("assertion", payload) - - // Make a request with assertion to get a new token. - resp, err := c.client().PostForm(c.aud.String(), v) - if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - if c := resp.StatusCode; c < 200 || c > 299 { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) - } - - b := make(map[string]interface{}) - if err := json.Unmarshal(body, &b); err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - token := &Token{} - token.AccessToken, _ = b["access_token"].(string) - token.TokenType, _ = b["token_type"].(string) - token.Subject = existing.Subject - token.raw = b - if e, ok := b["expires_in"].(int); ok { - token.Expiry = time.Now().Add(time.Duration(e) * time.Second) - } - if idtoken, ok := b["id_token"].(string); ok { - // decode returned id token to get expiry - claimSet := &jws.ClaimSet{} - claimSet, err = jws.Decode(idtoken) +// JWTClient requires OAuth 2.0 JWT credentials. +// Required for the 2-legged JWT flow. +func JWTClient(email string, key []byte) Option { + return func(o *Options) error { + pk, err := internal.ParseKey(key) if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + return err } - token.Expiry = time.Unix(claimSet.Exp, 0) - return token, nil + o.Email = email + o.PrivateKey = pk + return nil } - return token, nil } -func (c *JWTConfig) transport() http.RoundTripper { - if c.Transport != nil { - return c.Transport +// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider. +func JWTEndpoint(aud string) Option { + return func(o *Options) error { + au, err := url.Parse(aud) + if err != nil { + return err + } + o.AUD = au + return nil } - return http.DefaultTransport } -func (c *JWTConfig) client() *http.Client { - if c.Client != nil { - return c.Client +// Subject requires a user to impersonate. +// Optional. +func Subject(user string) Option { + return func(o *Options) error { + o.Subject = user + return nil } - return http.DefaultClient } -// parseKey converts the binary contents of a private key file -// to an *rsa.PrivateKey. It detects whether the private key is in a -// PEM container or not. If so, it extracts the the private key -// from PEM container before conversion. It only supports PEM -// containers with no passphrase. -func parseKey(key []byte) (*rsa.PrivateKey, error) { - block, _ := pem.Decode(key) - if block != nil { - key = block.Bytes - } - parsedKey, err := x509.ParsePKCS8PrivateKey(key) - if err != nil { - parsedKey, err = x509.ParsePKCS1PrivateKey(key) +func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) { + return func(t *Token) (*Token, error) { + if t == nil { + t = &Token{} + } + claimSet := &jws.ClaimSet{ + Iss: o.Email, + Scope: strings.Join(o.Scopes, " "), + Aud: o.AUD.String(), + } + if o.Subject != "" { + claimSet.Sub = o.Subject + // prn is the old name of sub. Keep setting it + // to be compatible with legacy OAuth 2.0 providers. + claimSet.Prn = o.Subject + } + payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey) if err != nil { return nil, err } + v := url.Values{} + v.Set("grant_type", defaultGrantType) + v.Set("assertion", payload) + c := o.Client + if c == nil { + c = &http.Client{} + } + resp, err := c.PostForm(o.AUD.String(), v) + if err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + if c := resp.StatusCode; c < 200 || c > 299 { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) + } + b := make(map[string]interface{}) + if err := json.Unmarshal(body, &b); err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + token := &Token{} + token.AccessToken, _ = b["access_token"].(string) + token.TokenType, _ = b["token_type"].(string) + token.raw = b + if e, ok := b["expires_in"].(int); ok { + token.Expiry = time.Now().Add(time.Duration(e) * time.Second) + } + if idtoken, ok := b["id_token"].(string); ok { + // decode returned id token to get expiry + claimSet, err := jws.Decode(idtoken) + if err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + token.Expiry = time.Unix(claimSet.Exp, 0) + return token, nil + } + return token, nil } - parsed, ok := parsedKey.(*rsa.PrivateKey) - if !ok { - return nil, errors.New("oauth2: private key is invalid") - } - return parsed, nil } diff --git a/jwt_test.go b/jwt_test.go index 7d4711f..b51c702 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -41,17 +41,25 @@ mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ func TestJWTFetch_JSONResponse(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer"}`)) + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "user", + "token_type": "bearer" + }`)) })) defer ts.Close() - conf, err := NewJWTConfig(&JWTOptions{ - Email: "aaa@xxx.com", - PrivateKey: dummyPrivateKey, - }, ts.URL) - tok, err := conf.FetchToken(nil) + f, err := New( + JWTClient("aaa@xxx.com", dummyPrivateKey), + JWTEndpoint(ts.URL), + ) if err != nil { - t.Errorf("Failed retrieving token: %s.", err) + t.Error(err) } + tr := f.NewTransport() + c := http.Client{Transport: tr} + + c.Get(ts.URL) + tok := tr.Token() if tok.Expired() { t.Errorf("Token shouldn't be expired.") } @@ -73,11 +81,17 @@ func TestJWTFetch_BadResponse(t *testing.T) { w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - conf, err := NewJWTConfig(&JWTOptions{ - Email: "aaa@xxx.com", - PrivateKey: dummyPrivateKey, - }, ts.URL) - tok, err := conf.FetchToken(nil) + f, err := New( + JWTClient("aaa@xxx.com", dummyPrivateKey), + JWTEndpoint(ts.URL), + ) + if err != nil { + t.Error(err) + } + tr := f.NewTransport() + c := http.Client{Transport: tr} + c.Get(ts.URL) + tok := tr.Token() if err != nil { t.Errorf("Failed retrieving token: %s.", err) } @@ -99,11 +113,17 @@ func TestJWTFetch_BadResponseType(t *testing.T) { w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - conf, err := NewJWTConfig(&JWTOptions{ - Email: "aaa@xxx.com", - PrivateKey: dummyPrivateKey, - }, ts.URL) - tok, err := conf.FetchToken(nil) + f, err := New( + JWTClient("aaa@xxx.com", dummyPrivateKey), + JWTEndpoint(ts.URL), + ) + if err != nil { + t.Error(err) + } + tr := f.NewTransport() + c := http.Client{Transport: tr} + c.Get(ts.URL) + tok := tr.Token() if err != nil { t.Errorf("Failed retrieving token: %s.", err) } diff --git a/oauth2.go b/oauth2.go index ac3072f..24ddc3a 100644 --- a/oauth2.go +++ b/oauth2.go @@ -8,85 +8,118 @@ package oauth2 import ( + "crypto/rsa" "encoding/json" "errors" "fmt" "io/ioutil" "mime" + "time" + "net/http" "net/url" "strconv" "strings" - "time" ) -// TokenFetcher refreshes or fetches a new access token from the -// provider. It should return an error if it's not capable of -// retrieving a token. -type TokenFetcher interface { - // FetchToken retrieves a new access token for the provider. - // If the implementation doesn't know how to retrieve a new token, - // it returns an error. The existing token may be nil. - FetchToken(existing *Token) (*Token, error) +// Option represents a function that applies some state to +// an Options object. +type Option func(*Options) error + +// Client requires the OAuth 2.0 client credentials. You need to provide +// the client identifier and optionally the client secret that are +// assigned to your application by the OAuth 2.0 provider. +func Client(id, secret string) Option { + return func(opts *Options) error { + opts.ClientID = id + opts.ClientSecret = secret + return nil + } } -// Options represents options to provide OAuth 2.0 client credentials -// and access level. A sample configuration: -type Options struct { - // ClientID is the OAuth client identifier used when communicating with - // the configured OAuth provider. - ClientID string `json:"client_id"` - - // ClientSecret is the OAuth client secret used when communicating with - // the configured OAuth provider. - ClientSecret string `json:"client_secret"` - - // RedirectURL is the URL to which the user will be returned after - // granting (or denying) access. - RedirectURL string `json:"redirect_url"` - - // Scopes optionally specifies a list of requested permission scopes. - Scopes []string `json:"scopes,omitempty"` +// RedirectURL requires the URL to which the user will be returned after +// granting (or denying) access. +func RedirectURL(url string) Option { + return func(opts *Options) error { + opts.RedirectURL = url + return nil + } } -// NewConfig creates a generic OAuth 2.0 configuration that talks -// to an OAuth 2.0 provider specified with authURL and tokenURL. -func NewConfig(opts *Options, authURL, tokenURL string) (*Config, error) { - aURL, err := url.Parse(authURL) - if err != nil { - return nil, err +// Scope requires a list of requested permission scopes. +// It is optinal to specify scopes. +func Scope(scopes ...string) Option { + return func(o *Options) error { + o.Scopes = scopes + return nil } - tURL, err := url.Parse(tokenURL) - if err != nil { - return nil, err - } - if opts.ClientID == "" { - return nil, errors.New("oauth2: missing client ID") - } - return &Config{ - opts: opts, - authURL: aURL, - tokenURL: tURL, - }, nil } -// Config represents the configuration of an OAuth 2.0 consumer client. -type Config struct { - // Client is the HTTP client to be used to retrieve - // tokens from the OAuth 2.0 provider. - Client *http.Client +// Endpoint requires OAuth 2.0 provider's authorization and token endpoints. +func Endpoint(authURL, tokenURL string) Option { + return func(o *Options) error { + au, err := url.Parse(authURL) + if err != nil { + return err + } + tu, err := url.Parse(tokenURL) + if err != nil { + return err + } + o.TokenFetcherFunc = makeThreeLeggedFetcher(o) + o.AuthURL = au + o.TokenURL = tu + return nil + } +} - // Transport is the http.RoundTripper to be used - // to construct new oauth2.Transport instances from - // this configuration. - Transport http.RoundTripper +// HTTPClient allows you to provide a custom http.Client to be +// used to retrieve tokens from the OAuth 2.0 provider. +func HTTPClient(c *http.Client) Option { + return func(o *Options) error { + o.Client = c + return nil + } +} - opts *Options - // AuthURL is the URL the user will be directed to - // in order to grant access. - authURL *url.URL - // TokenURL is the URL used to retrieve OAuth tokens. - tokenURL *url.URL +// RoundTripper allows you to provide a custom http.RoundTripper +// to be used to construct new oauth2.Transport instances. +// If none is provided a default RoundTripper will be used. +func RoundTripper(tr http.RoundTripper) Option { + return func(o *Options) error { + o.Transport = tr + return nil + } +} + +type Flow struct { + opts Options +} + +// New initiates a new flow. It determines the type of the OAuth 2.0 +// (2-legged, 3-legged or custom) by looking at the provided options. +// If the flow type cannot determined automatically, an error is returned. +func New(options ...Option) (*Flow, error) { + f := &Flow{} + for _, opt := range options { + if err := opt(&f.opts); err != nil { + return nil, err + } + } + switch { + case f.opts.TokenFetcherFunc != nil: + return f, nil + case f.opts.AUD != nil: + // TODO(jbd): Assert required JWT params. + f.opts.TokenFetcherFunc = makeTwoLeggedFetcher(&f.opts) + return f, nil + case f.opts.AuthURL != nil && f.opts.TokenURL != nil: + // TODO(jbd): Assert required OAuth2 params. + f.opts.TokenFetcherFunc = makeThreeLeggedFetcher(&f.opts) + return f, nil + default: + return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens") + } } // AuthCodeURL returns a URL to OAuth 2.0 provider's consent page @@ -112,13 +145,13 @@ type Config struct { // granted consent and the code can only be exchanged for an // access token. If set to "force" the user will always be prompted, // and the code can be exchanged for a refresh token. -func (c *Config) AuthCodeURL(state, accessType, prompt string) (authURL string) { - u := *c.authURL +func (f *Flow) AuthCodeURL(state, accessType, prompt string) string { + u := f.opts.AuthURL v := url.Values{ "response_type": {"code"}, - "client_id": {c.opts.ClientID}, - "redirect_uri": condVal(c.opts.RedirectURL), - "scope": condVal(strings.Join(c.opts.Scopes, " ")), + "client_id": {f.opts.ClientID}, + "redirect_uri": condVal(f.opts.RedirectURL), + "scope": condVal(strings.Join(f.opts.Scopes, " ")), "state": condVal(state), "access_type": condVal(accessType), "approval_prompt": condVal(prompt), @@ -132,65 +165,122 @@ func (c *Config) AuthCodeURL(state, accessType, prompt string) (authURL string) return u.String() } -// NewTransport creates a new authorizable transport. It doesn't -// initialize the new transport with a token, so after creation, -// you need to set a valid token (or an expired token with a valid -// refresh token) in order to be able to do authorized requests. -func (c *Config) NewTransport() *Transport { - return NewTransport(c.transport(), c, nil) +// exchange exchanges the authorization code with the OAuth 2.0 provider +// to retrieve a new access token. +func (f *Flow) exchange(code string) (*Token, error) { + return retrieveToken(&f.opts, url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": condVal(f.opts.RedirectURL), + "scope": condVal(strings.Join(f.opts.Scopes, " ")), + }) } -// NewTransportWithCode exchanges the OAuth 2.0 authorization code with -// the provider to fetch a new access token (and refresh token). Once -// it successfully retrieves a new token, creates a new transport -// authorized with it. -func (c *Config) NewTransportWithCode(code string) (*Transport, error) { - token, err := c.Exchange(code) +// NewTransportFromCode exchanges the code to retrieve a new access token +// and returns an authorized and authenticated Transport. +func (f *Flow) NewTransportFromCode(code string) (*Transport, error) { + token, err := f.exchange(code) if err != nil { return nil, err } - return NewTransport(c.transport(), c, token), nil + return f.NewTransportFromToken(token), nil } -// FetchToken retrieves a new access token and updates the existing token -// with the newly fetched credentials. If existing token doesn't -// contain a refresh token, it returns an error. -func (c *Config) FetchToken(existing *Token) (*Token, error) { - if existing == nil || existing.RefreshToken == "" { - return nil, errors.New("oauth2: cannot fetch access token without refresh token") +// NewTransportFromToken returns a new Transport that is authorized +// and authenticated with the provided token. +func (f *Flow) NewTransportFromToken(t *Token) *Transport { + tr := f.opts.Transport + if tr == nil { + tr = http.DefaultTransport } - return c.retrieveToken(url.Values{ - "grant_type": {"refresh_token"}, - "refresh_token": {existing.RefreshToken}, - }) + return newTransport(tr, f.opts.TokenFetcherFunc, t) } -// Exchange exchanges the authorization code with the OAuth 2.0 provider -// to retrieve a new access token. -func (c *Config) Exchange(code string) (*Token, error) { - return c.retrieveToken(url.Values{ - "grant_type": {"authorization_code"}, - "code": {code}, - "redirect_uri": condVal(c.opts.RedirectURL), - "scope": condVal(strings.Join(c.opts.Scopes, " ")), - }) +// NewTransport returns a Transport. +func (f *Flow) NewTransport() *Transport { + return f.NewTransportFromToken(nil) } -func (c *Config) retrieveToken(v url.Values) (*Token, error) { - v.Set("client_id", c.opts.ClientID) - bustedAuth := !providerAuthHeaderWorks(c.tokenURL.String()) - if bustedAuth && c.opts.ClientSecret != "" { - v.Set("client_secret", c.opts.ClientSecret) +func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) { + return func(t *Token) (*Token, error) { + if t == nil || t.RefreshToken == "" { + return nil, errors.New("oauth2: cannot fetch access token without refresh token") + } + return retrieveToken(o, url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {t.RefreshToken}, + }) } - req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(v.Encode())) +} + +// Options represents an object to keep the state of the OAuth 2.0 flow. +type Options struct { + // ClientID is the OAuth client identifier used when communicating with + // the configured OAuth provider. + ClientID string + + // ClientSecret is the OAuth client secret used when communicating with + // the configured OAuth provider. + ClientSecret string + + // RedirectURL is the URL to which the user will be returned after + // granting (or denying) access. + RedirectURL string + + // Email is the OAuth client identifier used when communicating with + // the configured OAuth provider. + Email string + + // PrivateKey contains the contents of an RSA private key or the + // contents of a PEM file that contains a private key. The provided + // private key is used to sign JWT payloads. + // PEM containers with a passphrase are not supported. + // Use the following command to convert a PKCS 12 file into a PEM. + // + // $ openssl pkcs12 -in key.p12 -out key.pem -nodes + // + PrivateKey *rsa.PrivateKey + + // Scopes identify the level of access being requested. + Subject string + + // Scopes optionally specifies a list of requested permission scopes. + Scopes []string + + // AuthURL represents the authorization endpoint of the OAuth 2.0 provider. + AuthURL *url.URL + + // TokenURL represents the token endpoint of the OAuth 2.0 provider. + TokenURL *url.URL + + // AUD represents the token endpoint required to complete the 2-legged JWT flow. + AUD *url.URL + + TokenFetcherFunc func(t *Token) (*Token, error) + + Transport http.RoundTripper + Client *http.Client +} + +func retrieveToken(o *Options, v url.Values) (*Token, error) { + v.Set("client_id", o.ClientID) + bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String()) + if bustedAuth && o.ClientSecret != "" { + v.Set("client_secret", o.ClientSecret) + } + req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode())) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if !bustedAuth && c.opts.ClientSecret != "" { - req.SetBasicAuth(c.opts.ClientID, c.opts.ClientSecret) + if !bustedAuth && o.ClientSecret != "" { + req.SetBasicAuth(o.ClientID, o.ClientSecret) } - r, err := c.client().Do(req) + c := o.Client + if c == nil { + c = &http.Client{} + } + r, err := c.Do(req) if err != nil { return nil, err } @@ -199,7 +289,7 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) { if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } - if c := r.StatusCode; c < 200 || c > 299 { + if code := r.StatusCode; code < 200 || code > 299 { return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) } @@ -255,20 +345,6 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) { return token, nil } -func (c *Config) transport() http.RoundTripper { - if c.Transport != nil { - return c.Transport - } - return http.DefaultTransport -} - -func (c *Config) client() *http.Client { - if c.Client != nil { - return c.Client - } - return http.DefaultClient -} - func condVal(v string) []string { if v == "" { return nil diff --git a/oauth2_test.go b/oauth2_test.go index d633f86..92419c6 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -20,32 +20,30 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e return t.rt(req) } -func newTestConf(url string) *Config { - conf, _ := NewConfig(&Options{ - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - RedirectURL: "REDIRECT_URL", - Scopes: []string{ - "scope1", - "scope2", - }, - }, url+"/auth", url+"/token") - return conf +func newTestFlow(url string) *Flow { + f, _ := New( + Client("CLIENT_ID", "CLIENT_SECRET"), + RedirectURL("REDIRECT_URL"), + Scope("scope1", "scope2"), + Endpoint(url+"/auth", url+"/token"), + ) + return f } func TestAuthCodeURL(t *testing.T) { - conf := newTestConf("server") - url := conf.AuthCodeURL("foo", "offline", "force") + f := newTestFlow("server") + url := f.AuthCodeURL("foo", "offline", "force") if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" { - t.Fatalf("Auth code URL doesn't match the expected, found: %v", url) + t.Errorf("Auth code URL doesn't match the expected, found: %v", url) } } func TestAuthCodeURL_Optional(t *testing.T) { - conf, _ := NewConfig(&Options{ - ClientID: "CLIENT_ID", - }, "auth-url", "token-url") - url := conf.AuthCodeURL("", "", "") + f, _ := New( + Client("CLIENT_ID", ""), + Endpoint("auth-url", "token-token"), + ) + url := f.AuthCodeURL("", "", "") if url != "auth-url?client_id=CLIENT_ID&response_type=code" { t.Fatalf("Auth code URL doesn't match the expected, found: %v", url) } @@ -75,11 +73,12 @@ func TestExchangeRequest(t *testing.T) { w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) })) defer ts.Close() - conf := newTestConf(ts.URL) - tok, err := conf.Exchange("exchange-code") + f := newTestFlow(ts.URL) + tr, err := f.NewTransportFromCode("exchange-code") if err != nil { - t.Errorf("Failed retrieving token: %s.", err) + t.Error(err) } + tok := tr.Token() if tok.Expired() { t.Errorf("Token shouldn't be expired.") } @@ -119,11 +118,12 @@ func TestExchangeRequest_JSONResponse(t *testing.T) { w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - conf := newTestConf(ts.URL) - tok, err := conf.Exchange("exchange-code") + f := newTestFlow(ts.URL) + tr, err := f.NewTransportFromCode("exchange-code") if err != nil { - t.Errorf("Failed retrieving token: %s.", err) + t.Error(err) } + tok := tr.Token() if tok.Expired() { t.Errorf("Token shouldn't be expired.") } @@ -145,11 +145,12 @@ func TestExchangeRequest_BadResponse(t *testing.T) { w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - conf := newTestConf(ts.URL) - tok, err := conf.Exchange("exchange-code") + f := newTestFlow(ts.URL) + tr, err := f.NewTransportFromCode("exchange-code") if err != nil { - t.Errorf("Failed retrieving token: %s.", err) + t.Error(err) } + tok := tr.Token() if tok.AccessToken != "" { t.Errorf("Unexpected access token, %#v.", tok.AccessToken) } @@ -161,21 +162,18 @@ func TestExchangeRequest_BadResponseType(t *testing.T) { w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - conf := newTestConf(ts.URL) - tok, err := conf.Exchange("exchange-code") + f := newTestFlow(ts.URL) + tr, err := f.NewTransportFromCode("exchange-code") if err != nil { - t.Errorf("Failed retrieving token: %s.", err) + t.Error(err) } + tok := tr.Token() if tok.AccessToken != "" { t.Errorf("Unexpected access token, %#v.", tok.AccessToken) } } func TestExchangeRequest_NonBasicAuth(t *testing.T) { - conf, _ := NewConfig(&Options{ - ClientID: "CLIENT_ID", - }, "https://accounts.google.com/auth", - "https://accounts.google.com/token") tr := &mockTransport{ rt: func(r *http.Request) (w *http.Response, err error) { headerAuth := r.Header.Get("Authorization") @@ -185,12 +183,20 @@ func TestExchangeRequest_NonBasicAuth(t *testing.T) { return nil, errors.New("no response") }, } - conf.Client = &http.Client{Transport: tr} - conf.Exchange("code") + c := &http.Client{Transport: tr} + f, _ := New( + Client("CLIENT_ID", ""), + Endpoint("https://accounts.google.com/auth", "https://accounts.google.com/token"), + HTTPClient(c), + ) + f.NewTransportFromCode("code") } func TestTokenRefreshRequest(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() == "/somethingelse" { + return + } if r.URL.String() != "/token" { t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) } @@ -204,29 +210,35 @@ func TestTokenRefreshRequest(t *testing.T) { } })) defer ts.Close() - conf := newTestConf(ts.URL) - conf.FetchToken(&Token{RefreshToken: "REFRESH_TOKEN"}) -} - -func TestNewTransportWithCode(t *testing.T) { - exchanged := false - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.RequestURI() == "/token" { - exchanged = true - } - })) - defer ts.Close() - conf := newTestConf(ts.URL) - conf.NewTransportWithCode("exchange-code") - if !exchanged { - t.Errorf("NewTransportWithCode should have exchanged the code, but it didn't.") - } + f := newTestFlow(ts.URL) + tr := f.NewTransportFromToken(&Token{RefreshToken: "REFRESH_TOKEN"}) + c := http.Client{Transport: tr} + c.Get(ts.URL + "/somethingelse") } func TestFetchWithNoRefreshToken(t *testing.T) { - fetcher := newTestConf("") - _, err := fetcher.FetchToken(&Token{}) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() == "/somethingelse" { + return + } + if r.URL.String() != "/token" { + t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + } + headerContentType := r.Header.Get("Content-Type") + if headerContentType != "application/x-www-form-urlencoded" { + t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + } + body, _ := ioutil.ReadAll(r.Body) + if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { + t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) + } + })) + defer ts.Close() + f := newTestFlow(ts.URL) + tr := f.NewTransportFromToken(&Token{}) + c := http.Client{Transport: tr} + _, err := c.Get(ts.URL + "/somethingelse") if err == nil { - t.Fatalf("Fetch should return an error if no refresh token is set") + t.Errorf("Fetch should return an error if no refresh token is set") } } diff --git a/transport.go b/transport.go index 4483c9c..e1a35b0 100644 --- a/transport.go +++ b/transport.go @@ -31,9 +31,6 @@ type Token struct { // The remaining lifetime of the access token. Expiry time.Time `json:"expiry,omitempty"` - // Subject is the user to impersonate. - Subject string `json:"subject,omitempty"` - // raw optionally contains extra metadata from the server // when updating a token. raw interface{} @@ -69,8 +66,8 @@ func (t *Token) Expired() bool { // Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests. type Transport struct { - fetcher TokenFetcher - origTransport http.RoundTripper + fetcher func(t *Token) (*Token, error) + base http.RoundTripper mu sync.RWMutex token *Token @@ -79,8 +76,8 @@ type Transport struct { // NewTransport creates a new Transport that uses the provided // token fetcher as token retrieving strategy. It authenticates // the requests and delegates origTransport to make the actual requests. -func NewTransport(origTransport http.RoundTripper, fetcher TokenFetcher, token *Token) *Transport { - return &Transport{origTransport: origTransport, fetcher: fetcher, token: token} +func newTransport(base http.RoundTripper, fn func(t *Token) (*Token, error), token *Token) *Transport { + return &Transport{base: base, fetcher: fn, token: token} } // RoundTrip authorizes and authenticates the request with an @@ -93,7 +90,6 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error // Check if the token is refreshable. // If token is refreshable, don't return an error, // rather refresh. - if err := t.RefreshToken(); err != nil { return nil, err } @@ -108,11 +104,8 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error if typ == "" { typ = defaultTokenType } - req.Header.Set("Authorization", typ+" "+token.AccessToken) - - // Make the HTTP request. - return t.origTransport.RoundTrip(req) + return t.base.RoundTrip(req) } // Token returns the token that authorizes and @@ -127,7 +120,6 @@ func (t *Transport) Token() *Token { func (t *Transport) SetToken(v *Token) { t.mu.Lock() defer t.mu.Unlock() - t.token = v } @@ -137,14 +129,11 @@ func (t *Transport) SetToken(v *Token) { func (t *Transport) RefreshToken() error { t.mu.Lock() defer t.mu.Unlock() - - token, err := t.fetcher.FetchToken(t.token) + token, err := t.fetcher(t.token) if err != nil { return err } - t.token = token - return nil } diff --git a/transport_test.go b/transport_test.go index 7128d45..f7cdbc4 100644 --- a/transport_test.go +++ b/transport_test.go @@ -9,12 +9,18 @@ import ( type mockTokenFetcher struct{ token *Token } +func (f *mockTokenFetcher) Fn() func(*Token) (*Token, error) { + return func(*Token) (*Token, error) { + return f.token, nil + } +} + func (f *mockTokenFetcher) FetchToken(existing *Token) (*Token, error) { return f.token, nil } func TestInitialTokenRead(t *testing.T) { - tr := NewTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"}) + tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"}) server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer abc" { t.Errorf("Transport doesn't set the Authorization header from the initial token") @@ -31,7 +37,7 @@ func TestTokenFetch(t *testing.T) { AccessToken: "abc", }, } - tr := NewTransport(http.DefaultTransport, fetcher, nil) + tr := newTransport(http.DefaultTransport, fetcher.Fn(), nil) server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer abc" { t.Errorf("Transport doesn't set the Authorization header from the fetched token")