oauth2: redesign the API

Tests and examples aren't updated yet. The tree will be broken after this,
but nobody should be using this yet anyway.

Change-Id: I0004c738f40919ab46d107c71c011c510fbc748f
Reviewed-on: https://go-review.googlesource.com/1246
Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
Brad Fitzpatrick 2014-12-10 10:17:33 +11:00
parent b3f9a68f05
commit a568078818
4 changed files with 556 additions and 432 deletions

39
client_appengine.go Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine,!appenginevm
// App Engine hooks.
package oauth2
import (
"log"
"net/http"
"sync"
"appengine"
"appengine/urlfetch"
)
var warnOnce sync.Once
func init() {
registerContextClientFunc(contextClientAppEngine)
}
func contextClientAppEngine(ctx Context) (*http.Client, error) {
if actx, ok := ctx.(appengine.Context); ok {
return urlfetch.NewClient(actx), nil
}
// The user did it wrong. We'll log once (and hope they see it
// in dev_appserver), but stil return (nil, nil) in case some
// other contextClientFunc hook finds a way to proceed.
warnOnce.Do(gaeDoingItWrongHelp)
return nil, nil
}
func gaeDoingItWrongHelp() {
log.Printf("WARNING: you attempted to use the oauth2 package without passing a valid appengine.Context or *http.Request as the oauth2.Context. App Engine requires that all service RPCs (including urlfetch) be associated with an *http.Request/appengine.Context.")
}

195
jwt.go
View File

@ -5,15 +5,16 @@
package oauth2 package oauth2
import ( import (
"crypto/rsa"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws" "golang.org/x/oauth2/jws"
) )
@ -22,100 +23,126 @@ var (
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"} defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
) )
// JWTClient requires OAuth 2.0 JWT credentials. // JWTConfig is the configuration for using JWT to fetch tokens,
// Required for the 2-legged JWT flow. // commonly known as "two-legged OAuth".
func JWTClient(email string, key []byte) Option { type JWTConfig struct {
return func(o *Options) error { // Email is the OAuth client identifier used when communicating with
pk, err := internal.ParseKey(key) // the configured OAuth provider.
if err != nil { Email string
return err
} // PrivateKey contains the contents of an RSA private key or the
o.Email = email // contents of a PEM file that contains a private key. The provided
o.PrivateKey = pk // private key is used to sign JWT payloads.
return nil // PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey *rsa.PrivateKey
// Subject is the optional user to impersonate.
Subject string
// Scopes optionally specifies a list of requested permission scopes.
Scopes []string
// TokenURL is the endpoint required to complete the 2-legged JWT flow.
TokenURL string
}
// TokenSource returns a JWT TokenSource using the configuration
// in c and the HTTP client from the provided context.
//
// The returned TokenSource only does JWT requests when necessary but
// otherwise returns the same token repeatedly until it expires.
//
// The provided initialToken may be nil, in which case the first
// call to TokenSource will do a new JWT request.
func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource {
return &newWhenNeededSource{
t: initialToken,
new: jwtSource{ctx, c},
} }
} }
// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider. // Client returns an HTTP client wrapping the context's
func JWTEndpoint(aud string) Option { // HTTP transport and adding Authorization headers with tokens
return func(o *Options) error { // obtained from c.
au, err := url.Parse(aud) //
if err != nil { // The provided initialToken may be nil, in which case the first
return err // call to TokenSource will do a new JWT request.
} //
o.AUD = au // The returned client and its Transport should not be modified.
return nil func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client {
return &http.Client{
Transport: &Transport{
Source: c.TokenSource(ctx, initialToken),
Base: contextTransport(ctx),
},
} }
} }
// Subject requires a user to impersonate. // jwtSource is a source that always does a signed JWT request for a token.
// Optional. // It should typically be wrapped with a newWhenNeededSource.
func Subject(user string) Option { type jwtSource struct {
return func(o *Options) error { ctx Context
o.Subject = user conf *JWTConfig
return nil
}
} }
func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) { func (js jwtSource) Token() (*Token, error) {
return func(t *Token) (*Token, error) { hc, err := contextClient(js.ctx)
if t == nil { if err != nil {
t = &Token{} return nil, err
} }
claimSet := &jws.ClaimSet{ claimSet := &jws.ClaimSet{
Iss: o.Email, Iss: js.conf.Email,
Scope: strings.Join(o.Scopes, " "), Scope: strings.Join(js.conf.Scopes, " "),
Aud: o.AUD.String(), Aud: js.conf.TokenURL,
} }
if o.Subject != "" { if subject := js.conf.Subject; subject != "" {
claimSet.Sub = o.Subject claimSet.Sub = subject
// prn is the old name of sub. Keep setting it // prn is the old name of sub. Keep setting it
// to be compatible with legacy OAuth 2.0 providers. // to be compatible with legacy OAuth 2.0 providers.
claimSet.Prn = o.Subject claimSet.Prn = subject
} }
payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey) payload, err := jws.Encode(defaultHeader, claimSet, js.conf.PrivateKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
v := url.Values{} v := url.Values{}
v.Set("grant_type", defaultGrantType) v.Set("grant_type", defaultGrantType)
v.Set("assertion", payload) v.Set("assertion", payload)
c := o.Client resp, err := hc.PostForm(js.conf.TokenURL, v)
if c == nil { if err != nil {
c = &http.Client{} return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
resp, err := c.PostForm(o.AUD.String(), v) defer resp.Body.Close()
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
}
b := make(map[string]interface{})
if err := json.Unmarshal(body, &b); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
token := &Token{}
token.AccessToken, _ = b["access_token"].(string)
token.TokenType, _ = b["token_type"].(string)
token.raw = b
if e, ok := b["expires_in"].(int); ok {
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
}
if idtoken, ok := b["id_token"].(string); ok {
// decode returned id token to get expiry
claimSet, err := jws.Decode(idtoken)
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
defer resp.Body.Close() token.Expiry = time.Unix(claimSet.Exp, 0)
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
}
b := make(map[string]interface{})
if err := json.Unmarshal(body, &b); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
token := &Token{}
token.AccessToken, _ = b["access_token"].(string)
token.TokenType, _ = b["token_type"].(string)
token.raw = b
if e, ok := b["expires_in"].(int); ok {
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
}
if idtoken, ok := b["id_token"].(string); ok {
// decode returned id token to get expiry
claimSet, err := jws.Decode(idtoken)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
token.Expiry = time.Unix(claimSet.Exp, 0)
return token, nil
}
return token, nil return token, nil
} }
return token, nil
} }

552
oauth2.go
View File

@ -8,115 +8,168 @@
package oauth2 // import "golang.org/x/oauth2" package oauth2 // import "golang.org/x/oauth2"
import ( import (
"crypto/rsa" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"mime" "mime"
"time"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"time"
"golang.org/x/net/context"
) )
// TokenStore implementations read and write OAuth 2.0 tokens from a persistence layer. // Context can be an golang.org/x/net.Context, or an App Engine Context.
type TokenStore interface { // In the future these will be unified.
// ReadToken reads the token from the store. // If you don't care and aren't running on App Engine, you may use nil.
// If the read is successful, it should return the token and a nil error. type Context interface{}
// The returned tokens may be expired tokens.
// If there is no token in the store, it should return a nil token and a nil error. // Config describes a typical 3-legged OAuth2 flow, with both the
// It should return a non-nil error when an unrecoverable failure occurs. // client application information and the server's URLs.
ReadToken() (*Token, error) type Config struct {
// WriteToken writes the token to the cache. // ClientID is the application's ID.
WriteToken(*Token) ClientID string
// ClientSecret is the application's secret.
ClientSecret string
// Endpoint contains the resource server's token endpoint
// URLs. These are supplied by the server and are often
// available via site-specific packages (for example,
// google.Endpoint or github.Endpoint)
Endpoint Endpoint
// RedirectURL is the URL to redirect users going through
// the OAuth flow, after the resource owner's URLs.
RedirectURL string
// Scope specifies optional requested permissions.
Scopes []string
} }
// Option represents a function that applies some state to // A TokenSource is anything that can return a token.
// an Options object. type TokenSource interface {
type Option func(*Options) error // Token returns a token or an error.
Token() (*Token, error)
}
// Client requires the OAuth 2.0 client credentials. You need to provide // Endpoint contains the OAuth 2.0 provider's authorization and token
// the client identifier and optionally the client secret that are // endpoint URLs.
// assigned to your application by the OAuth 2.0 provider. type Endpoint struct {
func Client(id, secret string) Option { AuthURL string
return func(opts *Options) error { TokenURL string
opts.ClientID = id }
opts.ClientSecret = secret
return nil // Token represents the crendentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
//
// Most users of this package should not access fields of Token
// directly. They're exported mostly for use by related packages
// implementing derivate OAuth2 flows.
type Token struct {
// AccessToken is the token that authorizes and authenticates
// the requests.
AccessToken string `json:"access_token"`
// TokenType is the type of token.
// The Type method returns either this or "Bearer", the default.
TokenType string `json:"token_type,omitempty"`
// RefreshToken is a token that's used by the application
// (as opposed to the user) to refresh the access token
// if it expires.
RefreshToken string `json:"refresh_token,omitempty"`
// Expiry is the optional expiration time of the access token.
//
// If zero, TokenSource implementations will reuse the same
// token forever and RefreshToken or equivalent
// mechanisms for that TokenSource will not be used.
Expiry time.Time `json:"expiry,omitempty"`
// raw optionally contains extra metadata from the server
// when updating a token.
raw interface{}
}
// Type returns t.TokenType if non-empty, else "Bearer".
func (t *Token) Type() string {
if t.TokenType != "" {
return t.TokenType
} }
return "Bearer"
} }
// RedirectURL requires the URL to which the user will be returned after // SetAuthHeader sets the Authorization header to r using the access
// granting (or denying) access. // token in t.
func RedirectURL(url string) Option { //
return func(opts *Options) error { // This method is unnecessary when using Transport or an HTTP Client
opts.RedirectURL = url // returned by this package.
return nil func (t *Token) SetAuthHeader(r *http.Request) {
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
}
// Extra returns an extra field returned from the server during token
// retrieval.
func (t *Token) Extra(key string) string {
if vals, ok := t.raw.(url.Values); ok {
return vals.Get(key)
} }
} if raw, ok := t.raw.(map[string]interface{}); ok {
if val, ok := raw[key].(string); ok {
// Scope requires a list of requested permission scopes. return val
// It is optinal to specify scopes.
func Scope(scopes ...string) Option {
return func(o *Options) error {
o.Scopes = scopes
return nil
}
}
// Endpoint requires OAuth 2.0 provider's authorization and token endpoints.
func Endpoint(authURL, tokenURL string) Option {
return func(o *Options) error {
au, err := url.Parse(authURL)
if err != nil {
return err
}
tu, err := url.Parse(tokenURL)
if err != nil {
return err
}
o.AuthURL = au
o.TokenURL = tu
return nil
}
}
// HTTPClient allows you to provide a custom http.Client to be
// used to retrieve tokens from the OAuth 2.0 provider.
func HTTPClient(c *http.Client) Option {
return func(o *Options) error {
o.Client = c
return nil
}
}
// New builds a new options object and determines the type of the OAuth 2.0
// (2-legged, 3-legged or custom) by looking at the provided options.
// If the flow type cannot determined automatically, an error is returned.
func New(option ...Option) (*Options, error) {
opts := &Options{}
for _, fn := range option {
if err := fn(opts); err != nil {
return nil, err
} }
} }
switch { return ""
case opts.TokenFetcherFunc != nil: }
return opts, nil
case opts.AUD != nil: // Expired returns true if there is no access token or the
// TODO(jbd): Assert the required JWT params. // access token is expired.
opts.TokenFetcherFunc = makeTwoLeggedFetcher(opts) func (t *Token) Expired() bool {
return opts, nil if t.AccessToken == "" {
case opts.AuthURL != nil && opts.TokenURL != nil: return true
// TODO(jbd): Assert the required OAuth2 params.
opts.TokenFetcherFunc = makeThreeLeggedFetcher(opts)
return opts, nil
default:
return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens")
} }
if t.Expiry.IsZero() {
return false
}
return t.Expiry.Before(time.Now())
}
var (
// AccessTypeOnline and AccessTypeOffline are options passed
// to the Options.AuthCodeURL method. They modify the
// "access_type" field that gets sent in the URL returned by
// AuthCodeURL.
//
// Online (the default if neither is specified) is the default.
// If your application needs to refresh access tokens when the
// user is not present at the browser, then use offline. This
// will result in your application obtaining a refresh token
// the first time your application exchanges an authorization
// code for a user.
AccessTypeOnline AuthCodeOption = setParam{"access_type", "online"}
AccessTypeOffline AuthCodeOption = setParam{"access_type", "offline"}
// ApprovalForce forces the users to view the consent dialog
// and confirm the permissions request at the URL returned
// from AuthCodeURL, even if they've already done so.
ApprovalForce AuthCodeOption = setParam{"approval_prompt", "force"}
)
type setParam struct{ k, v string }
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
// An AuthCodeOption is passed to Config.AuthCodeURL.
type AuthCodeOption interface {
setValue(url.Values)
} }
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page // AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
@ -127,185 +180,195 @@ func New(option ...Option) (*Options, error) {
// the state query parameter on your redirect callback. // the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
// //
// Access type is an OAuth extension that gets sent as the // Opts may include AccessTypeOnline or AccessTypeOffline, as well
// "access_type" field in the URL from AuthCodeURL. // as ApprovalForce.
// It may be "online" (default) or "offline". func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
// If your application needs to refresh access tokens when the var buf bytes.Buffer
// user is not present at the browser, then use offline. This buf.WriteString(c.Endpoint.AuthURL)
// will result in your application obtaining a refresh token
// the first time your application exchanges an authorization
// code for a user.
//
// Approval prompt indicates whether the user should be
// re-prompted for consent. If set to "auto" (default) the
// user will be prompted only if they haven't previously
// granted consent and the code can only be exchanged for an
// access token. If set to "force" the user will always be prompted,
// and the code can be exchanged for a refresh token.
func (o *Options) AuthCodeURL(state, accessType, prompt string) string {
u := *o.AuthURL
v := url.Values{ v := url.Values{
"response_type": {"code"}, "response_type": {"code"},
"client_id": {o.ClientID}, "client_id": {c.ClientID},
"redirect_uri": condVal(o.RedirectURL), "redirect_uri": condVal(c.RedirectURL),
"scope": condVal(strings.Join(o.Scopes, " ")), "scope": condVal(strings.Join(c.Scopes, " ")),
"state": condVal(state), "state": condVal(state),
"access_type": condVal(accessType),
"approval_prompt": condVal(prompt),
} }
q := v.Encode() for _, opt := range opts {
if u.RawQuery == "" { opt.setValue(v)
u.RawQuery = q }
if strings.Contains(c.Endpoint.AuthURL, "?") {
buf.WriteByte('&')
} else { } else {
u.RawQuery += "&" + q buf.WriteByte('?')
} }
return u.String() buf.WriteString(v.Encode())
return buf.String()
} }
// exchange exchanges the authorization code with the OAuth 2.0 provider // Exchange converts an authorization code into a token.
// to retrieve a new access token. //
func (o *Options) exchange(code string) (*Token, error) { // It is used after a resource provider redirects the user back
return retrieveToken(o, url.Values{ // to the Redirect URI (the URL obtained from AuthCodeURL).
//
// The HTTP client to use is derived from the context. If nil,
// http.DefaultClient is used. See the Context type's documentation.
//
// The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state").
func (c *Config) Exchange(ctx Context, code string) (*Token, error) {
return retrieveToken(ctx, c, url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
"code": {code}, "code": {code},
"redirect_uri": condVal(o.RedirectURL), "redirect_uri": condVal(c.RedirectURL),
"scope": condVal(strings.Join(o.Scopes, " ")), "scope": condVal(strings.Join(c.Scopes, " ")),
}) })
} }
// NewTransportFromTokenStore reads the token from the store and returns // contextClientFunc is a func which tries to return an *http.Client
// a Transport that is authorized and the authenticated // given a Context value. If it returns an error, the search stops
// by the returned token. // with that error. If it returns (nil, nil), the search continues
func (o *Options) NewTransportFromTokenStore(store TokenStore) (*Transport, error) { // down the list of registered funcs.
tok, err := store.ReadToken() type contextClientFunc func(Context) (*http.Client, error)
if err != nil {
return nil, err var contextClientFuncs []contextClientFunc
}
o.TokenStore = store func registerContextClietnFunc(fn contextClientFunc) {
if tok == nil { contextClientFuncs = append(contextClientFuncs, fn)
return nil, nil
}
return o.newTransportFromToken(tok), nil
} }
// NewTransportFromCode exchanges the code to retrieve a new access token func contextClient(ctx Context) (*http.Client, error) {
// and returns an authorized and authenticated Transport. for _, fn := range contextClientFuncs {
func (o *Options) NewTransportFromCode(code string) (*Transport, error) { c, err := fn(ctx)
token, err := o.exchange(code) if err != nil {
if err != nil { return nil, err
return nil, err
}
return o.newTransportFromToken(token), nil
}
// NewTransport returns a Transport.
func (o *Options) NewTransport() *Transport {
return o.newTransportFromToken(nil)
}
// newTransportFromToken returns a new Transport that is authorized
// and authenticated with the provided token.
func (o *Options) newTransportFromToken(t *Token) *Transport {
// TODO(jbd): App Engine options initiate an http.Client that
// depends on the urlfetcher, but it breaks the promise we made
// that the options object should be working finely with nil-values
// for the http.Client.
tr := http.DefaultTransport
if o.Client != nil && o.Client.Transport != nil {
tr = o.Client.Transport
}
return newTransport(tr, o, t)
}
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
return func(t *Token) (*Token, error) {
if t == nil || t.RefreshToken == "" {
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
} }
return retrieveToken(o, url.Values{ if c != nil {
"grant_type": {"refresh_token"}, return c, nil
"refresh_token": {t.RefreshToken}, }
}) }
if xc, ok := ctx.(context.Context); ok {
if hc, ok := xc.Value(HTTPClient).(*http.Client); ok {
return hc, nil
}
}
return http.DefaultClient, nil
}
func contextTransport(ctx Context) http.RoundTripper {
hc, err := contextClient(ctx)
if err != nil {
// This is a rare error case (somebody using nil on App Engine),
// so I'd rather not everybody do an error check on this Client
// method. They can get the error that they're doing it wrong
// later, at client.Get/PostForm time.
return errorTransport{err}
}
return hc.Transport
}
// Client returns an HTTP client using the provided token.
// The token will auto-refresh as necessary. The underlying
// HTTP transport will be obtained using the provided context.
// The returned client and its Transport should not be modified.
func (c *Config) Client(ctx Context, t *Token) *http.Client {
return &http.Client{
Transport: &Transport{
Source: c.TokenSource(ctx, t),
Base: contextTransport(ctx),
},
} }
} }
// Options represents an object to keep the state of the OAuth 2.0 flow. // TokenSource returns a TokenSource that returns t until t expires,
type Options struct { // automatically refreshing it as necessary using the provided context.
// ClientID is the OAuth client identifier used when communicating with // See the the Context documentation.
// the configured OAuth provider. //
ClientID string // Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx Context, t *Token) TokenSource {
// ClientSecret is the OAuth client secret used when communicating with nwn := &newWhenNeededSource{t: t}
// the configured OAuth provider. nwn.new = tokenRefresher{
ClientSecret string ctx: ctx,
conf: c,
// RedirectURL is the URL to which the user will be returned after oldToken: &nwn.t,
// granting (or denying) access. }
RedirectURL string return nwn
// Email is the OAuth client identifier used when communicating with
// the configured OAuth provider.
Email string
// PrivateKey contains the contents of an RSA private key or the
// contents of a PEM file that contains a private key. The provided
// private key is used to sign JWT payloads.
// PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey *rsa.PrivateKey
// Scopes identify the level of access being requested.
Subject string
// Scopes optionally specifies a list of requested permission scopes.
Scopes []string
// AuthURL represents the authorization endpoint of the OAuth 2.0 provider.
AuthURL *url.URL
// TokenURL represents the token endpoint of the OAuth 2.0 provider.
TokenURL *url.URL
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
AUD *url.URL
// TokenStore reads a token from the store and writes it back to the store
// if a token refresh occurs.
// Optional.
TokenStore TokenStore
TokenFetcherFunc func(t *Token) (*Token, error)
Client *http.Client
} }
func retrieveToken(o *Options, v url.Values) (*Token, error) { // tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
v.Set("client_id", o.ClientID) // HTTP requests to renew a token using a RefreshToken.
bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String()) type tokenRefresher struct {
if bustedAuth && o.ClientSecret != "" { ctx Context // used to get HTTP requests
v.Set("client_secret", o.ClientSecret) conf *Config
oldToken **Token // pointer to old *Token w/ RefreshToken
}
func (tf tokenRefresher) Token() (*Token, error) {
t := *tf.oldToken
if t == nil {
return nil, errors.New("oauth2: attempted use of nil Token")
} }
req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode())) if t.RefreshToken == "" {
return nil, errors.New("oauth2: token expired and refresh token is not set")
}
return retrieveToken(tf.ctx, tf.conf, url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {t.RefreshToken},
})
}
// newWhenNeededSource is a TokenSource that holds a single token in memory
// and validates its expiry before each call to retrieve it with
// Token. If it's expired, it will be auto-refreshed using the
// new TokenSource.
//
// The first call to TokenRefresher must be SetToken.
type newWhenNeededSource struct {
new TokenSource // called when t is expired.
mu sync.Mutex // guards t
t *Token
}
// Token returns the current token if it's still valid, else will
// refresh the current token (using r.Context for HTTP client
// information) and return the new one.
func (s *newWhenNeededSource) Token() (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.t != nil && !s.t.Expired() {
return s.t, nil
}
t, err := s.new.Token()
if err != nil {
return nil, err
}
s.t = t
return t, nil
}
func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
hc, err := contextClient(ctx)
if err != nil {
return nil, err
}
v.Set("client_id", c.ClientID)
bustedAuth := !providerAuthHeaderWorks(c.Endpoint.TokenURL)
if bustedAuth && c.ClientSecret != "" {
v.Set("client_secret", c.ClientSecret)
}
req, err := http.NewRequest("POST", c.Endpoint.TokenURL, strings.NewReader(v.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth && o.ClientSecret != "" { if !bustedAuth && c.ClientSecret != "" {
req.SetBasicAuth(o.ClientID, o.ClientSecret) req.SetBasicAuth(c.ClientID, c.ClientSecret)
} }
c := o.Client r, err := hc.Do(req)
if c == nil {
c = &http.Client{}
}
r, err := c.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Body.Close() defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
@ -314,7 +377,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) {
} }
token := &Token{} token := &Token{}
expires := int(0) expires := 0
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
switch content { switch content {
case "application/x-www-form-urlencoded", "text/plain": case "application/x-www-form-urlencoded", "text/plain":
@ -335,7 +398,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) {
} }
expires, _ = strconv.Atoi(e) expires, _ = strconv.Atoi(e)
default: default:
b := make(map[string]interface{}) b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type
if err = json.Unmarshal(body, &b); err != nil { if err = json.Unmarshal(body, &b); err != nil {
return nil, err return nil, err
} }
@ -397,3 +460,12 @@ func providerAuthHeaderWorks(tokenURL string) bool {
// to this package to let users select server bugs. // to this package to let users select server bugs.
return true return true
} }
// HTTPClient is the context key to use with golang.org/x/net/context's
// WithValue function to associate an *http.Client value with a context.
var HTTPClient contextKey
// contextKey is just an empty struct. It exists so HTTPClient can be
// an immutable public variable with a unique type. It's immutable
// because nobody else can create a contextKey, being unexported.
type contextKey struct{}

View File

@ -5,136 +5,90 @@
package oauth2 package oauth2
import ( import (
"errors"
"io"
"net/http" "net/http"
"net/url"
"sync" "sync"
"time"
) )
const ( // Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
defaultTokenType = "Bearer" // wrapping a base RoundTripper and adding an Authorization header
) // with a token from the supplied Sources.
// Token represents the crendentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
type Token struct {
// AccessToken is the token that authorizes and authenticates the requests.
AccessToken string `json:"access_token"`
// TokenType identifies the type of token returned.
TokenType string `json:"token_type,omitempty"`
// RefreshToken is a token that may be used to obtain a new access token.
RefreshToken string `json:"refresh_token,omitempty"`
// Expiry is the expiration datetime of the access token.
Expiry time.Time `json:"expiry,omitempty"`
// raw optionally contains extra metadata from the server
// when updating a token.
raw interface{}
}
// Extra returns an extra field returned from the server during token retrieval.
// E.g.
// idToken := token.Extra("id_token")
// //
func (t *Token) Extra(key string) string { // Transport is a low-level mechanism. Most code will use the
if vals, ok := t.raw.(url.Values); ok { // higher-level Config.Client method instead.
return vals.Get(key)
}
if raw, ok := t.raw.(map[string]interface{}); ok {
if val, ok := raw[key].(string); ok {
return val
}
}
return ""
}
// Expired returns true if there is no access token or the
// access token is expired.
func (t *Token) Expired() bool {
if t.AccessToken == "" {
return true
}
if t.Expiry.IsZero() {
return false
}
return t.Expiry.Before(time.Now())
}
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
type Transport struct { type Transport struct {
opts *Options // Source supplies the token to add to outgoing requests'
base http.RoundTripper // Authorization headers.
Source TokenSource
mu sync.RWMutex // Base is the base RoundTripper used to make HTTP requests.
token *Token // If nil, http.DefaultTransport is used.
} Base http.RoundTripper
// NewTransport creates a new Transport that uses the provided mu sync.Mutex // guards modReq
// token fetcher as token retrieving strategy. It authenticates modReq map[*http.Request]*http.Request // original -> modified
// the requests and delegates origTransport to make the actual requests.
func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
return &Transport{
base: base,
opts: opts,
token: token,
}
} }
// RoundTrip authorizes and authenticates the request with an // RoundTrip authorizes and authenticates the request with an
// access token. If no token exists or token is expired, // access token. If no token exists or token is expired,
// tries to refresh/fetch a new token. // tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
token := t.token if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil")
if token == nil || token.Expired() { }
// Check if the token is refreshable. token, err := t.Source.Token()
// If token is refreshable, don't return an error, if err != nil {
// rather refresh. return nil, err
if err := t.refreshToken(); err != nil {
return nil, err
}
token = t.token
if t.opts.TokenStore != nil {
t.opts.TokenStore.WriteToken(token)
}
} }
// To set the Authorization header, we must make a copy of the Request req2 := cloneRequest(req) // per RoundTripper contract
// so that we don't modify the Request we were given. token.SetAuthHeader(req2)
// This is required by the specification of http.RoundTripper. t.setModReq(req, req2)
req = cloneRequest(req) res, err := t.base().RoundTrip(req2)
typ := token.TokenType if err != nil {
if typ == "" { t.setModReq(req, nil)
typ = defaultTokenType return nil, err
} }
req.Header.Set("Authorization", typ+" "+token.AccessToken) res.Body = &onEOFReader{
return t.base.RoundTrip(req) rc: res.Body,
fn: func() { t.setModReq(req, nil) },
}
return res, nil
} }
// Token returns the token that authorizes and // CancelRequest cancels an in-flight request by closing its connection.
// authenticates the transport. func (t *Transport) CancelRequest(req *http.Request) {
func (t *Transport) Token() *Token { type canceler interface {
t.mu.RLock() CancelRequest(*http.Request)
defer t.mu.RUnlock() }
return t.token if cr, ok := t.Base.(canceler); ok {
t.mu.Lock()
modReq := t.modReq[req]
delete(t.modReq, req)
t.mu.Unlock()
cr.CancelRequest(modReq)
}
} }
// refreshToken retrieves a new token, if a refreshing/fetching func (t *Transport) base() http.RoundTripper {
// method is known and required credentials are presented if t.Base != nil {
// (such as a refresh token). return t.Base
func (t *Transport) refreshToken() error { }
return http.DefaultTransport
}
func (t *Transport) setModReq(orig, mod *http.Request) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
token, err := t.opts.TokenFetcherFunc(t.token) if t.modReq == nil {
if err != nil { t.modReq = make(map[*http.Request]*http.Request)
return err }
if mod == nil {
delete(t.modReq, orig)
} else {
t.modReq[orig] = mod
} }
t.token = token
return nil
} }
// cloneRequest returns a clone of the provided *http.Request. // cloneRequest returns a clone of the provided *http.Request.
@ -144,9 +98,41 @@ func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request) r2 := new(http.Request)
*r2 = *r *r2 = *r
// deep copy of the Header // deep copy of the Header
r2.Header = make(http.Header) r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header { for k, s := range r.Header {
r2.Header[k] = s r2.Header[k] = append([]string(nil), s...)
} }
return r2 return r2
} }
type onEOFReader struct {
rc io.ReadCloser
fn func()
}
func (r *onEOFReader) Read(p []byte) (n int, err error) {
n, err = r.rc.Read(p)
if err == io.EOF {
r.runFunc()
}
return
}
func (r *onEOFReader) Close() error {
err := r.rc.Close()
r.runFunc()
return err
}
func (r *onEOFReader) runFunc() {
if fn := r.fn; fn != nil {
fn()
r.fn = nil
}
}
type errorTransport struct{ err error }
func (t errorTransport) RoundTrip(*http.Request) (*http.Response, error) {
return nil, t.err
}