From a568078818bb97b96ee7c7829b4575b81d4684f1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 10 Dec 2014 10:17:33 +1100 Subject: [PATCH] oauth2: redesign the API Tests and examples aren't updated yet. The tree will be broken after this, but nobody should be using this yet anyway. Change-Id: I0004c738f40919ab46d107c71c011c510fbc748f Reviewed-on: https://go-review.googlesource.com/1246 Reviewed-by: Burcu Dogan --- client_appengine.go | 39 ++++ jwt.go | 195 +++++++++------- oauth2.go | 552 +++++++++++++++++++++++++------------------- transport.go | 202 ++++++++-------- 4 files changed, 556 insertions(+), 432 deletions(-) create mode 100644 client_appengine.go diff --git a/client_appengine.go b/client_appengine.go new file mode 100644 index 0000000..ff060ef --- /dev/null +++ b/client_appengine.go @@ -0,0 +1,39 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build appengine,!appenginevm + +// App Engine hooks. + +package oauth2 + +import ( + "log" + "net/http" + "sync" + + "appengine" + "appengine/urlfetch" +) + +var warnOnce sync.Once + +func init() { + registerContextClientFunc(contextClientAppEngine) +} + +func contextClientAppEngine(ctx Context) (*http.Client, error) { + if actx, ok := ctx.(appengine.Context); ok { + return urlfetch.NewClient(actx), nil + } + // The user did it wrong. We'll log once (and hope they see it + // in dev_appserver), but stil return (nil, nil) in case some + // other contextClientFunc hook finds a way to proceed. + warnOnce.Do(gaeDoingItWrongHelp) + return nil, nil +} + +func gaeDoingItWrongHelp() { + log.Printf("WARNING: you attempted to use the oauth2 package without passing a valid appengine.Context or *http.Request as the oauth2.Context. App Engine requires that all service RPCs (including urlfetch) be associated with an *http.Request/appengine.Context.") +} diff --git a/jwt.go b/jwt.go index 39cc0b1..d861e93 100644 --- a/jwt.go +++ b/jwt.go @@ -5,15 +5,16 @@ package oauth2 import ( + "crypto/rsa" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" "net/url" "strings" "time" - "golang.org/x/oauth2/internal" "golang.org/x/oauth2/jws" ) @@ -22,100 +23,126 @@ var ( defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"} ) -// JWTClient requires OAuth 2.0 JWT credentials. -// Required for the 2-legged JWT flow. -func JWTClient(email string, key []byte) Option { - return func(o *Options) error { - pk, err := internal.ParseKey(key) - if err != nil { - return err - } - o.Email = email - o.PrivateKey = pk - return nil +// JWTConfig is the configuration for using JWT to fetch tokens, +// commonly known as "two-legged OAuth". +type JWTConfig struct { + // Email is the OAuth client identifier used when communicating with + // the configured OAuth provider. + Email string + + // PrivateKey contains the contents of an RSA private key or the + // contents of a PEM file that contains a private key. The provided + // private key is used to sign JWT payloads. + // PEM containers with a passphrase are not supported. + // Use the following command to convert a PKCS 12 file into a PEM. + // + // $ openssl pkcs12 -in key.p12 -out key.pem -nodes + // + PrivateKey *rsa.PrivateKey + + // Subject is the optional user to impersonate. + Subject string + + // Scopes optionally specifies a list of requested permission scopes. + Scopes []string + + // TokenURL is the endpoint required to complete the 2-legged JWT flow. + TokenURL string +} + +// TokenSource returns a JWT TokenSource using the configuration +// in c and the HTTP client from the provided context. +// +// The returned TokenSource only does JWT requests when necessary but +// otherwise returns the same token repeatedly until it expires. +// +// The provided initialToken may be nil, in which case the first +// call to TokenSource will do a new JWT request. +func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource { + return &newWhenNeededSource{ + t: initialToken, + new: jwtSource{ctx, c}, } } -// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider. -func JWTEndpoint(aud string) Option { - return func(o *Options) error { - au, err := url.Parse(aud) - if err != nil { - return err - } - o.AUD = au - return nil +// Client returns an HTTP client wrapping the context's +// HTTP transport and adding Authorization headers with tokens +// obtained from c. +// +// The provided initialToken may be nil, in which case the first +// call to TokenSource will do a new JWT request. +// +// The returned client and its Transport should not be modified. +func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client { + return &http.Client{ + Transport: &Transport{ + Source: c.TokenSource(ctx, initialToken), + Base: contextTransport(ctx), + }, } } -// Subject requires a user to impersonate. -// Optional. -func Subject(user string) Option { - return func(o *Options) error { - o.Subject = user - return nil - } +// jwtSource is a source that always does a signed JWT request for a token. +// It should typically be wrapped with a newWhenNeededSource. +type jwtSource struct { + ctx Context + conf *JWTConfig } -func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) { - return func(t *Token) (*Token, error) { - if t == nil { - t = &Token{} - } - claimSet := &jws.ClaimSet{ - Iss: o.Email, - Scope: strings.Join(o.Scopes, " "), - Aud: o.AUD.String(), - } - if o.Subject != "" { - claimSet.Sub = o.Subject - // prn is the old name of sub. Keep setting it - // to be compatible with legacy OAuth 2.0 providers. - claimSet.Prn = o.Subject - } - payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey) - if err != nil { - return nil, err - } - v := url.Values{} - v.Set("grant_type", defaultGrantType) - v.Set("assertion", payload) - c := o.Client - if c == nil { - c = &http.Client{} - } - resp, err := c.PostForm(o.AUD.String(), v) +func (js jwtSource) Token() (*Token, error) { + hc, err := contextClient(js.ctx) + if err != nil { + return nil, err + } + claimSet := &jws.ClaimSet{ + Iss: js.conf.Email, + Scope: strings.Join(js.conf.Scopes, " "), + Aud: js.conf.TokenURL, + } + if subject := js.conf.Subject; subject != "" { + claimSet.Sub = subject + // prn is the old name of sub. Keep setting it + // to be compatible with legacy OAuth 2.0 providers. + claimSet.Prn = subject + } + payload, err := jws.Encode(defaultHeader, claimSet, js.conf.PrivateKey) + if err != nil { + return nil, err + } + v := url.Values{} + v.Set("grant_type", defaultGrantType) + v.Set("assertion", payload) + resp, err := hc.PostForm(js.conf.TokenURL, v) + if err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + if c := resp.StatusCode; c < 200 || c > 299 { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) + } + b := make(map[string]interface{}) + if err := json.Unmarshal(body, &b); err != nil { + return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + } + token := &Token{} + token.AccessToken, _ = b["access_token"].(string) + token.TokenType, _ = b["token_type"].(string) + token.raw = b + if e, ok := b["expires_in"].(int); ok { + token.Expiry = time.Now().Add(time.Duration(e) * time.Second) + } + if idtoken, ok := b["id_token"].(string); ok { + // decode returned id token to get expiry + claimSet, err := jws.Decode(idtoken) if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - if c := resp.StatusCode; c < 200 || c > 299 { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) - } - b := make(map[string]interface{}) - if err := json.Unmarshal(body, &b); err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - token := &Token{} - token.AccessToken, _ = b["access_token"].(string) - token.TokenType, _ = b["token_type"].(string) - token.raw = b - if e, ok := b["expires_in"].(int); ok { - token.Expiry = time.Now().Add(time.Duration(e) * time.Second) - } - if idtoken, ok := b["id_token"].(string); ok { - // decode returned id token to get expiry - claimSet, err := jws.Decode(idtoken) - if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - token.Expiry = time.Unix(claimSet.Exp, 0) - return token, nil - } + token.Expiry = time.Unix(claimSet.Exp, 0) return token, nil } + return token, nil } diff --git a/oauth2.go b/oauth2.go index 280ac24..88121f3 100644 --- a/oauth2.go +++ b/oauth2.go @@ -8,115 +8,168 @@ package oauth2 // import "golang.org/x/oauth2" import ( - "crypto/rsa" + "bytes" "encoding/json" "errors" "fmt" + "io" "io/ioutil" "mime" - "time" - "net/http" "net/url" "strconv" "strings" + "sync" + "time" + + "golang.org/x/net/context" ) -// TokenStore implementations read and write OAuth 2.0 tokens from a persistence layer. -type TokenStore interface { - // ReadToken reads the token from the store. - // If the read is successful, it should return the token and a nil error. - // The returned tokens may be expired tokens. - // If there is no token in the store, it should return a nil token and a nil error. - // It should return a non-nil error when an unrecoverable failure occurs. - ReadToken() (*Token, error) - // WriteToken writes the token to the cache. - WriteToken(*Token) +// Context can be an golang.org/x/net.Context, or an App Engine Context. +// In the future these will be unified. +// If you don't care and aren't running on App Engine, you may use nil. +type Context interface{} + +// Config describes a typical 3-legged OAuth2 flow, with both the +// client application information and the server's URLs. +type Config struct { + // ClientID is the application's ID. + ClientID string + + // ClientSecret is the application's secret. + ClientSecret string + + // Endpoint contains the resource server's token endpoint + // URLs. These are supplied by the server and are often + // available via site-specific packages (for example, + // google.Endpoint or github.Endpoint) + Endpoint Endpoint + + // RedirectURL is the URL to redirect users going through + // the OAuth flow, after the resource owner's URLs. + RedirectURL string + + // Scope specifies optional requested permissions. + Scopes []string } -// Option represents a function that applies some state to -// an Options object. -type Option func(*Options) error +// A TokenSource is anything that can return a token. +type TokenSource interface { + // Token returns a token or an error. + Token() (*Token, error) +} -// Client requires the OAuth 2.0 client credentials. You need to provide -// the client identifier and optionally the client secret that are -// assigned to your application by the OAuth 2.0 provider. -func Client(id, secret string) Option { - return func(opts *Options) error { - opts.ClientID = id - opts.ClientSecret = secret - return nil +// Endpoint contains the OAuth 2.0 provider's authorization and token +// endpoint URLs. +type Endpoint struct { + AuthURL string + TokenURL string +} + +// Token represents the crendentials used to authorize +// the requests to access protected resources on the OAuth 2.0 +// provider's backend. +// +// Most users of this package should not access fields of Token +// directly. They're exported mostly for use by related packages +// implementing derivate OAuth2 flows. +type Token struct { + // AccessToken is the token that authorizes and authenticates + // the requests. + AccessToken string `json:"access_token"` + + // TokenType is the type of token. + // The Type method returns either this or "Bearer", the default. + TokenType string `json:"token_type,omitempty"` + + // RefreshToken is a token that's used by the application + // (as opposed to the user) to refresh the access token + // if it expires. + RefreshToken string `json:"refresh_token,omitempty"` + + // Expiry is the optional expiration time of the access token. + // + // If zero, TokenSource implementations will reuse the same + // token forever and RefreshToken or equivalent + // mechanisms for that TokenSource will not be used. + Expiry time.Time `json:"expiry,omitempty"` + + // raw optionally contains extra metadata from the server + // when updating a token. + raw interface{} +} + +// Type returns t.TokenType if non-empty, else "Bearer". +func (t *Token) Type() string { + if t.TokenType != "" { + return t.TokenType } + return "Bearer" } -// RedirectURL requires the URL to which the user will be returned after -// granting (or denying) access. -func RedirectURL(url string) Option { - return func(opts *Options) error { - opts.RedirectURL = url - return nil +// SetAuthHeader sets the Authorization header to r using the access +// token in t. +// +// This method is unnecessary when using Transport or an HTTP Client +// returned by this package. +func (t *Token) SetAuthHeader(r *http.Request) { + r.Header.Set("Authorization", t.Type()+" "+t.AccessToken) +} + +// Extra returns an extra field returned from the server during token +// retrieval. +func (t *Token) Extra(key string) string { + if vals, ok := t.raw.(url.Values); ok { + return vals.Get(key) } -} - -// Scope requires a list of requested permission scopes. -// It is optinal to specify scopes. -func Scope(scopes ...string) Option { - return func(o *Options) error { - o.Scopes = scopes - return nil - } -} - -// Endpoint requires OAuth 2.0 provider's authorization and token endpoints. -func Endpoint(authURL, tokenURL string) Option { - return func(o *Options) error { - au, err := url.Parse(authURL) - if err != nil { - return err - } - tu, err := url.Parse(tokenURL) - if err != nil { - return err - } - o.AuthURL = au - o.TokenURL = tu - return nil - } -} - -// HTTPClient allows you to provide a custom http.Client to be -// used to retrieve tokens from the OAuth 2.0 provider. -func HTTPClient(c *http.Client) Option { - return func(o *Options) error { - o.Client = c - return nil - } -} - -// New builds a new options object and determines the type of the OAuth 2.0 -// (2-legged, 3-legged or custom) by looking at the provided options. -// If the flow type cannot determined automatically, an error is returned. -func New(option ...Option) (*Options, error) { - opts := &Options{} - for _, fn := range option { - if err := fn(opts); err != nil { - return nil, err + if raw, ok := t.raw.(map[string]interface{}); ok { + if val, ok := raw[key].(string); ok { + return val } } - switch { - case opts.TokenFetcherFunc != nil: - return opts, nil - case opts.AUD != nil: - // TODO(jbd): Assert the required JWT params. - opts.TokenFetcherFunc = makeTwoLeggedFetcher(opts) - return opts, nil - case opts.AuthURL != nil && opts.TokenURL != nil: - // TODO(jbd): Assert the required OAuth2 params. - opts.TokenFetcherFunc = makeThreeLeggedFetcher(opts) - return opts, nil - default: - return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens") + return "" +} + +// Expired returns true if there is no access token or the +// access token is expired. +func (t *Token) Expired() bool { + if t.AccessToken == "" { + return true } + if t.Expiry.IsZero() { + return false + } + return t.Expiry.Before(time.Now()) +} + +var ( + // AccessTypeOnline and AccessTypeOffline are options passed + // to the Options.AuthCodeURL method. They modify the + // "access_type" field that gets sent in the URL returned by + // AuthCodeURL. + // + // Online (the default if neither is specified) is the default. + // If your application needs to refresh access tokens when the + // user is not present at the browser, then use offline. This + // will result in your application obtaining a refresh token + // the first time your application exchanges an authorization + // code for a user. + AccessTypeOnline AuthCodeOption = setParam{"access_type", "online"} + AccessTypeOffline AuthCodeOption = setParam{"access_type", "offline"} + + // ApprovalForce forces the users to view the consent dialog + // and confirm the permissions request at the URL returned + // from AuthCodeURL, even if they've already done so. + ApprovalForce AuthCodeOption = setParam{"approval_prompt", "force"} +) + +type setParam struct{ k, v string } + +func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) } + +// An AuthCodeOption is passed to Config.AuthCodeURL. +type AuthCodeOption interface { + setValue(url.Values) } // AuthCodeURL returns a URL to OAuth 2.0 provider's consent page @@ -127,185 +180,195 @@ func New(option ...Option) (*Options, error) { // the state query parameter on your redirect callback. // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. // -// Access type is an OAuth extension that gets sent as the -// "access_type" field in the URL from AuthCodeURL. -// It may be "online" (default) or "offline". -// If your application needs to refresh access tokens when the -// user is not present at the browser, then use offline. This -// will result in your application obtaining a refresh token -// the first time your application exchanges an authorization -// code for a user. -// -// Approval prompt indicates whether the user should be -// re-prompted for consent. If set to "auto" (default) the -// user will be prompted only if they haven't previously -// granted consent and the code can only be exchanged for an -// access token. If set to "force" the user will always be prompted, -// and the code can be exchanged for a refresh token. -func (o *Options) AuthCodeURL(state, accessType, prompt string) string { - u := *o.AuthURL +// Opts may include AccessTypeOnline or AccessTypeOffline, as well +// as ApprovalForce. +func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { + var buf bytes.Buffer + buf.WriteString(c.Endpoint.AuthURL) v := url.Values{ - "response_type": {"code"}, - "client_id": {o.ClientID}, - "redirect_uri": condVal(o.RedirectURL), - "scope": condVal(strings.Join(o.Scopes, " ")), - "state": condVal(state), - "access_type": condVal(accessType), - "approval_prompt": condVal(prompt), + "response_type": {"code"}, + "client_id": {c.ClientID}, + "redirect_uri": condVal(c.RedirectURL), + "scope": condVal(strings.Join(c.Scopes, " ")), + "state": condVal(state), } - q := v.Encode() - if u.RawQuery == "" { - u.RawQuery = q + for _, opt := range opts { + opt.setValue(v) + } + if strings.Contains(c.Endpoint.AuthURL, "?") { + buf.WriteByte('&') } else { - u.RawQuery += "&" + q + buf.WriteByte('?') } - return u.String() + buf.WriteString(v.Encode()) + return buf.String() } -// exchange exchanges the authorization code with the OAuth 2.0 provider -// to retrieve a new access token. -func (o *Options) exchange(code string) (*Token, error) { - return retrieveToken(o, url.Values{ +// Exchange converts an authorization code into a token. +// +// It is used after a resource provider redirects the user back +// to the Redirect URI (the URL obtained from AuthCodeURL). +// +// The HTTP client to use is derived from the context. If nil, +// http.DefaultClient is used. See the Context type's documentation. +// +// The code will be in the *http.Request.FormValue("code"). Before +// calling Exchange, be sure to validate FormValue("state"). +func (c *Config) Exchange(ctx Context, code string) (*Token, error) { + return retrieveToken(ctx, c, url.Values{ "grant_type": {"authorization_code"}, "code": {code}, - "redirect_uri": condVal(o.RedirectURL), - "scope": condVal(strings.Join(o.Scopes, " ")), + "redirect_uri": condVal(c.RedirectURL), + "scope": condVal(strings.Join(c.Scopes, " ")), }) } -// NewTransportFromTokenStore reads the token from the store and returns -// a Transport that is authorized and the authenticated -// by the returned token. -func (o *Options) NewTransportFromTokenStore(store TokenStore) (*Transport, error) { - tok, err := store.ReadToken() - if err != nil { - return nil, err - } - o.TokenStore = store - if tok == nil { - return nil, nil - } - return o.newTransportFromToken(tok), nil +// contextClientFunc is a func which tries to return an *http.Client +// given a Context value. If it returns an error, the search stops +// with that error. If it returns (nil, nil), the search continues +// down the list of registered funcs. +type contextClientFunc func(Context) (*http.Client, error) + +var contextClientFuncs []contextClientFunc + +func registerContextClietnFunc(fn contextClientFunc) { + contextClientFuncs = append(contextClientFuncs, fn) } -// NewTransportFromCode exchanges the code to retrieve a new access token -// and returns an authorized and authenticated Transport. -func (o *Options) NewTransportFromCode(code string) (*Transport, error) { - token, err := o.exchange(code) - if err != nil { - return nil, err - } - return o.newTransportFromToken(token), nil -} - -// NewTransport returns a Transport. -func (o *Options) NewTransport() *Transport { - return o.newTransportFromToken(nil) -} - -// newTransportFromToken returns a new Transport that is authorized -// and authenticated with the provided token. -func (o *Options) newTransportFromToken(t *Token) *Transport { - // TODO(jbd): App Engine options initiate an http.Client that - // depends on the urlfetcher, but it breaks the promise we made - // that the options object should be working finely with nil-values - // for the http.Client. - tr := http.DefaultTransport - if o.Client != nil && o.Client.Transport != nil { - tr = o.Client.Transport - } - return newTransport(tr, o, t) -} - -func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) { - return func(t *Token) (*Token, error) { - if t == nil || t.RefreshToken == "" { - return nil, errors.New("oauth2: cannot fetch access token without refresh token") +func contextClient(ctx Context) (*http.Client, error) { + for _, fn := range contextClientFuncs { + c, err := fn(ctx) + if err != nil { + return nil, err } - return retrieveToken(o, url.Values{ - "grant_type": {"refresh_token"}, - "refresh_token": {t.RefreshToken}, - }) + if c != nil { + return c, nil + } + } + if xc, ok := ctx.(context.Context); ok { + if hc, ok := xc.Value(HTTPClient).(*http.Client); ok { + return hc, nil + } + } + return http.DefaultClient, nil +} + +func contextTransport(ctx Context) http.RoundTripper { + hc, err := contextClient(ctx) + if err != nil { + // This is a rare error case (somebody using nil on App Engine), + // so I'd rather not everybody do an error check on this Client + // method. They can get the error that they're doing it wrong + // later, at client.Get/PostForm time. + return errorTransport{err} + } + return hc.Transport +} + +// Client returns an HTTP client using the provided token. +// The token will auto-refresh as necessary. The underlying +// HTTP transport will be obtained using the provided context. +// The returned client and its Transport should not be modified. +func (c *Config) Client(ctx Context, t *Token) *http.Client { + return &http.Client{ + Transport: &Transport{ + Source: c.TokenSource(ctx, t), + Base: contextTransport(ctx), + }, } } -// Options represents an object to keep the state of the OAuth 2.0 flow. -type Options struct { - // ClientID is the OAuth client identifier used when communicating with - // the configured OAuth provider. - ClientID string - - // ClientSecret is the OAuth client secret used when communicating with - // the configured OAuth provider. - ClientSecret string - - // RedirectURL is the URL to which the user will be returned after - // granting (or denying) access. - RedirectURL string - - // Email is the OAuth client identifier used when communicating with - // the configured OAuth provider. - Email string - - // PrivateKey contains the contents of an RSA private key or the - // contents of a PEM file that contains a private key. The provided - // private key is used to sign JWT payloads. - // PEM containers with a passphrase are not supported. - // Use the following command to convert a PKCS 12 file into a PEM. - // - // $ openssl pkcs12 -in key.p12 -out key.pem -nodes - // - PrivateKey *rsa.PrivateKey - - // Scopes identify the level of access being requested. - Subject string - - // Scopes optionally specifies a list of requested permission scopes. - Scopes []string - - // AuthURL represents the authorization endpoint of the OAuth 2.0 provider. - AuthURL *url.URL - - // TokenURL represents the token endpoint of the OAuth 2.0 provider. - TokenURL *url.URL - - // AUD represents the token endpoint required to complete the 2-legged JWT flow. - AUD *url.URL - - // TokenStore reads a token from the store and writes it back to the store - // if a token refresh occurs. - // Optional. - TokenStore TokenStore - - TokenFetcherFunc func(t *Token) (*Token, error) - - Client *http.Client +// TokenSource returns a TokenSource that returns t until t expires, +// automatically refreshing it as necessary using the provided context. +// See the the Context documentation. +// +// Most users will use Config.Client instead. +func (c *Config) TokenSource(ctx Context, t *Token) TokenSource { + nwn := &newWhenNeededSource{t: t} + nwn.new = tokenRefresher{ + ctx: ctx, + conf: c, + oldToken: &nwn.t, + } + return nwn } -func retrieveToken(o *Options, v url.Values) (*Token, error) { - v.Set("client_id", o.ClientID) - bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String()) - if bustedAuth && o.ClientSecret != "" { - v.Set("client_secret", o.ClientSecret) +// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" +// HTTP requests to renew a token using a RefreshToken. +type tokenRefresher struct { + ctx Context // used to get HTTP requests + conf *Config + oldToken **Token // pointer to old *Token w/ RefreshToken +} + +func (tf tokenRefresher) Token() (*Token, error) { + t := *tf.oldToken + if t == nil { + return nil, errors.New("oauth2: attempted use of nil Token") } - req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode())) + if t.RefreshToken == "" { + return nil, errors.New("oauth2: token expired and refresh token is not set") + } + return retrieveToken(tf.ctx, tf.conf, url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {t.RefreshToken}, + }) +} + +// newWhenNeededSource is a TokenSource that holds a single token in memory +// and validates its expiry before each call to retrieve it with +// Token. If it's expired, it will be auto-refreshed using the +// new TokenSource. +// +// The first call to TokenRefresher must be SetToken. +type newWhenNeededSource struct { + new TokenSource // called when t is expired. + + mu sync.Mutex // guards t + t *Token +} + +// Token returns the current token if it's still valid, else will +// refresh the current token (using r.Context for HTTP client +// information) and return the new one. +func (s *newWhenNeededSource) Token() (*Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.t != nil && !s.t.Expired() { + return s.t, nil + } + t, err := s.new.Token() + if err != nil { + return nil, err + } + s.t = t + return t, nil +} + +func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) { + hc, err := contextClient(ctx) + if err != nil { + return nil, err + } + v.Set("client_id", c.ClientID) + bustedAuth := !providerAuthHeaderWorks(c.Endpoint.TokenURL) + if bustedAuth && c.ClientSecret != "" { + v.Set("client_secret", c.ClientSecret) + } + req, err := http.NewRequest("POST", c.Endpoint.TokenURL, strings.NewReader(v.Encode())) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if !bustedAuth && o.ClientSecret != "" { - req.SetBasicAuth(o.ClientID, o.ClientSecret) + if !bustedAuth && c.ClientSecret != "" { + req.SetBasicAuth(c.ClientID, c.ClientSecret) } - c := o.Client - if c == nil { - c = &http.Client{} - } - r, err := c.Do(req) + r, err := hc.Do(req) if err != nil { return nil, err } defer r.Body.Close() - body, err := ioutil.ReadAll(r.Body) + body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } @@ -314,7 +377,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) { } token := &Token{} - expires := int(0) + expires := 0 content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) switch content { case "application/x-www-form-urlencoded", "text/plain": @@ -335,7 +398,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) { } expires, _ = strconv.Atoi(e) default: - b := make(map[string]interface{}) + b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type if err = json.Unmarshal(body, &b); err != nil { return nil, err } @@ -397,3 +460,12 @@ func providerAuthHeaderWorks(tokenURL string) bool { // to this package to let users select server bugs. return true } + +// HTTPClient is the context key to use with golang.org/x/net/context's +// WithValue function to associate an *http.Client value with a context. +var HTTPClient contextKey + +// contextKey is just an empty struct. It exists so HTTPClient can be +// an immutable public variable with a unique type. It's immutable +// because nobody else can create a contextKey, being unexported. +type contextKey struct{} diff --git a/transport.go b/transport.go index 0192aac..9e7b106 100644 --- a/transport.go +++ b/transport.go @@ -5,136 +5,90 @@ package oauth2 import ( + "errors" + "io" "net/http" - "net/url" "sync" - "time" ) -const ( - defaultTokenType = "Bearer" -) - -// Token represents the crendentials used to authorize -// the requests to access protected resources on the OAuth 2.0 -// provider's backend. -type Token struct { - // AccessToken is the token that authorizes and authenticates the requests. - AccessToken string `json:"access_token"` - - // TokenType identifies the type of token returned. - TokenType string `json:"token_type,omitempty"` - - // RefreshToken is a token that may be used to obtain a new access token. - RefreshToken string `json:"refresh_token,omitempty"` - - // Expiry is the expiration datetime of the access token. - Expiry time.Time `json:"expiry,omitempty"` - - // raw optionally contains extra metadata from the server - // when updating a token. - raw interface{} -} - -// Extra returns an extra field returned from the server during token retrieval. -// E.g. -// idToken := token.Extra("id_token") +// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests, +// wrapping a base RoundTripper and adding an Authorization header +// with a token from the supplied Sources. // -func (t *Token) Extra(key string) string { - if vals, ok := t.raw.(url.Values); ok { - return vals.Get(key) - } - if raw, ok := t.raw.(map[string]interface{}); ok { - if val, ok := raw[key].(string); ok { - return val - } - } - return "" -} - -// Expired returns true if there is no access token or the -// access token is expired. -func (t *Token) Expired() bool { - if t.AccessToken == "" { - return true - } - if t.Expiry.IsZero() { - return false - } - return t.Expiry.Before(time.Now()) -} - -// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests. +// Transport is a low-level mechanism. Most code will use the +// higher-level Config.Client method instead. type Transport struct { - opts *Options - base http.RoundTripper + // Source supplies the token to add to outgoing requests' + // Authorization headers. + Source TokenSource - mu sync.RWMutex - token *Token -} + // Base is the base RoundTripper used to make HTTP requests. + // If nil, http.DefaultTransport is used. + Base http.RoundTripper -// NewTransport creates a new Transport that uses the provided -// token fetcher as token retrieving strategy. It authenticates -// the requests and delegates origTransport to make the actual requests. -func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport { - return &Transport{ - base: base, - opts: opts, - token: token, - } + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified } // RoundTrip authorizes and authenticates the request with an // access token. If no token exists or token is expired, // tries to refresh/fetch a new token. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - token := t.token - - if token == nil || token.Expired() { - // Check if the token is refreshable. - // If token is refreshable, don't return an error, - // rather refresh. - if err := t.refreshToken(); err != nil { - return nil, err - } - token = t.token - if t.opts.TokenStore != nil { - t.opts.TokenStore.WriteToken(token) - } + if t.Source == nil { + return nil, errors.New("oauth2: Transport's Source is nil") + } + token, err := t.Source.Token() + if err != nil { + return nil, err } - // To set the Authorization header, we must make a copy of the Request - // so that we don't modify the Request we were given. - // This is required by the specification of http.RoundTripper. - req = cloneRequest(req) - typ := token.TokenType - if typ == "" { - typ = defaultTokenType + req2 := cloneRequest(req) // per RoundTripper contract + token.SetAuthHeader(req2) + t.setModReq(req, req2) + res, err := t.base().RoundTrip(req2) + if err != nil { + t.setModReq(req, nil) + return nil, err } - req.Header.Set("Authorization", typ+" "+token.AccessToken) - return t.base.RoundTrip(req) + res.Body = &onEOFReader{ + rc: res.Body, + fn: func() { t.setModReq(req, nil) }, + } + return res, nil } -// Token returns the token that authorizes and -// authenticates the transport. -func (t *Transport) Token() *Token { - t.mu.RLock() - defer t.mu.RUnlock() - return t.token +// CancelRequest cancels an in-flight request by closing its connection. +func (t *Transport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := t.Base.(canceler); ok { + t.mu.Lock() + modReq := t.modReq[req] + delete(t.modReq, req) + t.mu.Unlock() + cr.CancelRequest(modReq) + } } -// refreshToken retrieves a new token, if a refreshing/fetching -// method is known and required credentials are presented -// (such as a refresh token). -func (t *Transport) refreshToken() error { +func (t *Transport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func (t *Transport) setModReq(orig, mod *http.Request) { t.mu.Lock() defer t.mu.Unlock() - token, err := t.opts.TokenFetcherFunc(t.token) - if err != nil { - return err + if t.modReq == nil { + t.modReq = make(map[*http.Request]*http.Request) + } + if mod == nil { + delete(t.modReq, orig) + } else { + t.modReq[orig] = mod } - t.token = token - return nil } // cloneRequest returns a clone of the provided *http.Request. @@ -144,9 +98,41 @@ func cloneRequest(r *http.Request) *http.Request { r2 := new(http.Request) *r2 = *r // deep copy of the Header - r2.Header = make(http.Header) + r2.Header = make(http.Header, len(r.Header)) for k, s := range r.Header { - r2.Header[k] = s + r2.Header[k] = append([]string(nil), s...) } return r2 } + +type onEOFReader struct { + rc io.ReadCloser + fn func() +} + +func (r *onEOFReader) Read(p []byte) (n int, err error) { + n, err = r.rc.Read(p) + if err == io.EOF { + r.runFunc() + } + return +} + +func (r *onEOFReader) Close() error { + err := r.rc.Close() + r.runFunc() + return err +} + +func (r *onEOFReader) runFunc() { + if fn := r.fn; fn != nil { + fn() + r.fn = nil + } +} + +type errorTransport struct{ err error } + +func (t errorTransport) RoundTrip(*http.Request) (*http.Response, error) { + return nil, t.err +}