oauth2: redesign the API

Tests and examples aren't updated yet. The tree will be broken after this,
but nobody should be using this yet anyway.

Change-Id: I0004c738f40919ab46d107c71c011c510fbc748f
Reviewed-on: https://go-review.googlesource.com/1246
Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
Brad Fitzpatrick 2014-12-10 10:17:33 +11:00
parent b3f9a68f05
commit a568078818
4 changed files with 556 additions and 432 deletions

39
client_appengine.go Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine,!appenginevm
// App Engine hooks.
package oauth2
import (
"log"
"net/http"
"sync"
"appengine"
"appengine/urlfetch"
)
var warnOnce sync.Once
func init() {
registerContextClientFunc(contextClientAppEngine)
}
func contextClientAppEngine(ctx Context) (*http.Client, error) {
if actx, ok := ctx.(appengine.Context); ok {
return urlfetch.NewClient(actx), nil
}
// The user did it wrong. We'll log once (and hope they see it
// in dev_appserver), but stil return (nil, nil) in case some
// other contextClientFunc hook finds a way to proceed.
warnOnce.Do(gaeDoingItWrongHelp)
return nil, nil
}
func gaeDoingItWrongHelp() {
log.Printf("WARNING: you attempted to use the oauth2 package without passing a valid appengine.Context or *http.Request as the oauth2.Context. App Engine requires that all service RPCs (including urlfetch) be associated with an *http.Request/appengine.Context.")
}

133
jwt.go
View File

@ -5,15 +5,16 @@
package oauth2
import (
"crypto/rsa"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws"
)
@ -22,74 +23,101 @@ var (
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
)
// JWTClient requires OAuth 2.0 JWT credentials.
// Required for the 2-legged JWT flow.
func JWTClient(email string, key []byte) Option {
return func(o *Options) error {
pk, err := internal.ParseKey(key)
// JWTConfig is the configuration for using JWT to fetch tokens,
// commonly known as "two-legged OAuth".
type JWTConfig struct {
// Email is the OAuth client identifier used when communicating with
// the configured OAuth provider.
Email string
// PrivateKey contains the contents of an RSA private key or the
// contents of a PEM file that contains a private key. The provided
// private key is used to sign JWT payloads.
// PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey *rsa.PrivateKey
// Subject is the optional user to impersonate.
Subject string
// Scopes optionally specifies a list of requested permission scopes.
Scopes []string
// TokenURL is the endpoint required to complete the 2-legged JWT flow.
TokenURL string
}
// TokenSource returns a JWT TokenSource using the configuration
// in c and the HTTP client from the provided context.
//
// The returned TokenSource only does JWT requests when necessary but
// otherwise returns the same token repeatedly until it expires.
//
// The provided initialToken may be nil, in which case the first
// call to TokenSource will do a new JWT request.
func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource {
return &newWhenNeededSource{
t: initialToken,
new: jwtSource{ctx, c},
}
}
// Client returns an HTTP client wrapping the context's
// HTTP transport and adding Authorization headers with tokens
// obtained from c.
//
// The provided initialToken may be nil, in which case the first
// call to TokenSource will do a new JWT request.
//
// The returned client and its Transport should not be modified.
func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client {
return &http.Client{
Transport: &Transport{
Source: c.TokenSource(ctx, initialToken),
Base: contextTransport(ctx),
},
}
}
// jwtSource is a source that always does a signed JWT request for a token.
// It should typically be wrapped with a newWhenNeededSource.
type jwtSource struct {
ctx Context
conf *JWTConfig
}
func (js jwtSource) Token() (*Token, error) {
hc, err := contextClient(js.ctx)
if err != nil {
return err
}
o.Email = email
o.PrivateKey = pk
return nil
}
}
// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider.
func JWTEndpoint(aud string) Option {
return func(o *Options) error {
au, err := url.Parse(aud)
if err != nil {
return err
}
o.AUD = au
return nil
}
}
// Subject requires a user to impersonate.
// Optional.
func Subject(user string) Option {
return func(o *Options) error {
o.Subject = user
return nil
}
}
func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
return func(t *Token) (*Token, error) {
if t == nil {
t = &Token{}
return nil, err
}
claimSet := &jws.ClaimSet{
Iss: o.Email,
Scope: strings.Join(o.Scopes, " "),
Aud: o.AUD.String(),
Iss: js.conf.Email,
Scope: strings.Join(js.conf.Scopes, " "),
Aud: js.conf.TokenURL,
}
if o.Subject != "" {
claimSet.Sub = o.Subject
if subject := js.conf.Subject; subject != "" {
claimSet.Sub = subject
// prn is the old name of sub. Keep setting it
// to be compatible with legacy OAuth 2.0 providers.
claimSet.Prn = o.Subject
claimSet.Prn = subject
}
payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey)
payload, err := jws.Encode(defaultHeader, claimSet, js.conf.PrivateKey)
if err != nil {
return nil, err
}
v := url.Values{}
v.Set("grant_type", defaultGrantType)
v.Set("assertion", payload)
c := o.Client
if c == nil {
c = &http.Client{}
}
resp, err := c.PostForm(o.AUD.String(), v)
resp, err := hc.PostForm(js.conf.TokenURL, v)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
@ -117,5 +145,4 @@ func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
return token, nil
}
return token, nil
}
}

504
oauth2.go
View File

@ -8,115 +8,168 @@
package oauth2 // import "golang.org/x/oauth2"
import (
"crypto/rsa"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"mime"
"time"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/net/context"
)
// TokenStore implementations read and write OAuth 2.0 tokens from a persistence layer.
type TokenStore interface {
// ReadToken reads the token from the store.
// If the read is successful, it should return the token and a nil error.
// The returned tokens may be expired tokens.
// If there is no token in the store, it should return a nil token and a nil error.
// It should return a non-nil error when an unrecoverable failure occurs.
ReadToken() (*Token, error)
// WriteToken writes the token to the cache.
WriteToken(*Token)
// Context can be an golang.org/x/net.Context, or an App Engine Context.
// In the future these will be unified.
// If you don't care and aren't running on App Engine, you may use nil.
type Context interface{}
// Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's URLs.
type Config struct {
// ClientID is the application's ID.
ClientID string
// ClientSecret is the application's secret.
ClientSecret string
// Endpoint contains the resource server's token endpoint
// URLs. These are supplied by the server and are often
// available via site-specific packages (for example,
// google.Endpoint or github.Endpoint)
Endpoint Endpoint
// RedirectURL is the URL to redirect users going through
// the OAuth flow, after the resource owner's URLs.
RedirectURL string
// Scope specifies optional requested permissions.
Scopes []string
}
// Option represents a function that applies some state to
// an Options object.
type Option func(*Options) error
// Client requires the OAuth 2.0 client credentials. You need to provide
// the client identifier and optionally the client secret that are
// assigned to your application by the OAuth 2.0 provider.
func Client(id, secret string) Option {
return func(opts *Options) error {
opts.ClientID = id
opts.ClientSecret = secret
return nil
}
// A TokenSource is anything that can return a token.
type TokenSource interface {
// Token returns a token or an error.
Token() (*Token, error)
}
// RedirectURL requires the URL to which the user will be returned after
// granting (or denying) access.
func RedirectURL(url string) Option {
return func(opts *Options) error {
opts.RedirectURL = url
return nil
}
// Endpoint contains the OAuth 2.0 provider's authorization and token
// endpoint URLs.
type Endpoint struct {
AuthURL string
TokenURL string
}
// Scope requires a list of requested permission scopes.
// It is optinal to specify scopes.
func Scope(scopes ...string) Option {
return func(o *Options) error {
o.Scopes = scopes
return nil
}
// Token represents the crendentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
//
// Most users of this package should not access fields of Token
// directly. They're exported mostly for use by related packages
// implementing derivate OAuth2 flows.
type Token struct {
// AccessToken is the token that authorizes and authenticates
// the requests.
AccessToken string `json:"access_token"`
// TokenType is the type of token.
// The Type method returns either this or "Bearer", the default.
TokenType string `json:"token_type,omitempty"`
// RefreshToken is a token that's used by the application
// (as opposed to the user) to refresh the access token
// if it expires.
RefreshToken string `json:"refresh_token,omitempty"`
// Expiry is the optional expiration time of the access token.
//
// If zero, TokenSource implementations will reuse the same
// token forever and RefreshToken or equivalent
// mechanisms for that TokenSource will not be used.
Expiry time.Time `json:"expiry,omitempty"`
// raw optionally contains extra metadata from the server
// when updating a token.
raw interface{}
}
// Endpoint requires OAuth 2.0 provider's authorization and token endpoints.
func Endpoint(authURL, tokenURL string) Option {
return func(o *Options) error {
au, err := url.Parse(authURL)
if err != nil {
return err
}
tu, err := url.Parse(tokenURL)
if err != nil {
return err
}
o.AuthURL = au
o.TokenURL = tu
return nil
// Type returns t.TokenType if non-empty, else "Bearer".
func (t *Token) Type() string {
if t.TokenType != "" {
return t.TokenType
}
return "Bearer"
}
// HTTPClient allows you to provide a custom http.Client to be
// used to retrieve tokens from the OAuth 2.0 provider.
func HTTPClient(c *http.Client) Option {
return func(o *Options) error {
o.Client = c
return nil
}
// SetAuthHeader sets the Authorization header to r using the access
// token in t.
//
// This method is unnecessary when using Transport or an HTTP Client
// returned by this package.
func (t *Token) SetAuthHeader(r *http.Request) {
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
}
// New builds a new options object and determines the type of the OAuth 2.0
// (2-legged, 3-legged or custom) by looking at the provided options.
// If the flow type cannot determined automatically, an error is returned.
func New(option ...Option) (*Options, error) {
opts := &Options{}
for _, fn := range option {
if err := fn(opts); err != nil {
return nil, err
// Extra returns an extra field returned from the server during token
// retrieval.
func (t *Token) Extra(key string) string {
if vals, ok := t.raw.(url.Values); ok {
return vals.Get(key)
}
if raw, ok := t.raw.(map[string]interface{}); ok {
if val, ok := raw[key].(string); ok {
return val
}
}
switch {
case opts.TokenFetcherFunc != nil:
return opts, nil
case opts.AUD != nil:
// TODO(jbd): Assert the required JWT params.
opts.TokenFetcherFunc = makeTwoLeggedFetcher(opts)
return opts, nil
case opts.AuthURL != nil && opts.TokenURL != nil:
// TODO(jbd): Assert the required OAuth2 params.
opts.TokenFetcherFunc = makeThreeLeggedFetcher(opts)
return opts, nil
default:
return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens")
return ""
}
// Expired returns true if there is no access token or the
// access token is expired.
func (t *Token) Expired() bool {
if t.AccessToken == "" {
return true
}
if t.Expiry.IsZero() {
return false
}
return t.Expiry.Before(time.Now())
}
var (
// AccessTypeOnline and AccessTypeOffline are options passed
// to the Options.AuthCodeURL method. They modify the
// "access_type" field that gets sent in the URL returned by
// AuthCodeURL.
//
// Online (the default if neither is specified) is the default.
// If your application needs to refresh access tokens when the
// user is not present at the browser, then use offline. This
// will result in your application obtaining a refresh token
// the first time your application exchanges an authorization
// code for a user.
AccessTypeOnline AuthCodeOption = setParam{"access_type", "online"}
AccessTypeOffline AuthCodeOption = setParam{"access_type", "offline"}
// ApprovalForce forces the users to view the consent dialog
// and confirm the permissions request at the URL returned
// from AuthCodeURL, even if they've already done so.
ApprovalForce AuthCodeOption = setParam{"approval_prompt", "force"}
)
type setParam struct{ k, v string }
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
// An AuthCodeOption is passed to Config.AuthCodeURL.
type AuthCodeOption interface {
setValue(url.Values)
}
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
@ -127,185 +180,195 @@ func New(option ...Option) (*Options, error) {
// the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
//
// Access type is an OAuth extension that gets sent as the
// "access_type" field in the URL from AuthCodeURL.
// It may be "online" (default) or "offline".
// If your application needs to refresh access tokens when the
// user is not present at the browser, then use offline. This
// will result in your application obtaining a refresh token
// the first time your application exchanges an authorization
// code for a user.
//
// Approval prompt indicates whether the user should be
// re-prompted for consent. If set to "auto" (default) the
// user will be prompted only if they haven't previously
// granted consent and the code can only be exchanged for an
// access token. If set to "force" the user will always be prompted,
// and the code can be exchanged for a refresh token.
func (o *Options) AuthCodeURL(state, accessType, prompt string) string {
u := *o.AuthURL
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce.
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL)
v := url.Values{
"response_type": {"code"},
"client_id": {o.ClientID},
"redirect_uri": condVal(o.RedirectURL),
"scope": condVal(strings.Join(o.Scopes, " ")),
"client_id": {c.ClientID},
"redirect_uri": condVal(c.RedirectURL),
"scope": condVal(strings.Join(c.Scopes, " ")),
"state": condVal(state),
"access_type": condVal(accessType),
"approval_prompt": condVal(prompt),
}
q := v.Encode()
if u.RawQuery == "" {
u.RawQuery = q
for _, opt := range opts {
opt.setValue(v)
}
if strings.Contains(c.Endpoint.AuthURL, "?") {
buf.WriteByte('&')
} else {
u.RawQuery += "&" + q
buf.WriteByte('?')
}
return u.String()
buf.WriteString(v.Encode())
return buf.String()
}
// exchange exchanges the authorization code with the OAuth 2.0 provider
// to retrieve a new access token.
func (o *Options) exchange(code string) (*Token, error) {
return retrieveToken(o, url.Values{
// Exchange converts an authorization code into a token.
//
// It is used after a resource provider redirects the user back
// to the Redirect URI (the URL obtained from AuthCodeURL).
//
// The HTTP client to use is derived from the context. If nil,
// http.DefaultClient is used. See the Context type's documentation.
//
// The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state").
func (c *Config) Exchange(ctx Context, code string) (*Token, error) {
return retrieveToken(ctx, c, url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": condVal(o.RedirectURL),
"scope": condVal(strings.Join(o.Scopes, " ")),
"redirect_uri": condVal(c.RedirectURL),
"scope": condVal(strings.Join(c.Scopes, " ")),
})
}
// NewTransportFromTokenStore reads the token from the store and returns
// a Transport that is authorized and the authenticated
// by the returned token.
func (o *Options) NewTransportFromTokenStore(store TokenStore) (*Transport, error) {
tok, err := store.ReadToken()
// contextClientFunc is a func which tries to return an *http.Client
// given a Context value. If it returns an error, the search stops
// with that error. If it returns (nil, nil), the search continues
// down the list of registered funcs.
type contextClientFunc func(Context) (*http.Client, error)
var contextClientFuncs []contextClientFunc
func registerContextClietnFunc(fn contextClientFunc) {
contextClientFuncs = append(contextClientFuncs, fn)
}
func contextClient(ctx Context) (*http.Client, error) {
for _, fn := range contextClientFuncs {
c, err := fn(ctx)
if err != nil {
return nil, err
}
o.TokenStore = store
if tok == nil {
return nil, nil
if c != nil {
return c, nil
}
return o.newTransportFromToken(tok), nil
}
if xc, ok := ctx.(context.Context); ok {
if hc, ok := xc.Value(HTTPClient).(*http.Client); ok {
return hc, nil
}
}
return http.DefaultClient, nil
}
// NewTransportFromCode exchanges the code to retrieve a new access token
// and returns an authorized and authenticated Transport.
func (o *Options) NewTransportFromCode(code string) (*Transport, error) {
token, err := o.exchange(code)
func contextTransport(ctx Context) http.RoundTripper {
hc, err := contextClient(ctx)
if err != nil {
return nil, err
// This is a rare error case (somebody using nil on App Engine),
// so I'd rather not everybody do an error check on this Client
// method. They can get the error that they're doing it wrong
// later, at client.Get/PostForm time.
return errorTransport{err}
}
return o.newTransportFromToken(token), nil
return hc.Transport
}
// NewTransport returns a Transport.
func (o *Options) NewTransport() *Transport {
return o.newTransportFromToken(nil)
// Client returns an HTTP client using the provided token.
// The token will auto-refresh as necessary. The underlying
// HTTP transport will be obtained using the provided context.
// The returned client and its Transport should not be modified.
func (c *Config) Client(ctx Context, t *Token) *http.Client {
return &http.Client{
Transport: &Transport{
Source: c.TokenSource(ctx, t),
Base: contextTransport(ctx),
},
}
}
// newTransportFromToken returns a new Transport that is authorized
// and authenticated with the provided token.
func (o *Options) newTransportFromToken(t *Token) *Transport {
// TODO(jbd): App Engine options initiate an http.Client that
// depends on the urlfetcher, but it breaks the promise we made
// that the options object should be working finely with nil-values
// for the http.Client.
tr := http.DefaultTransport
if o.Client != nil && o.Client.Transport != nil {
tr = o.Client.Transport
// TokenSource returns a TokenSource that returns t until t expires,
// automatically refreshing it as necessary using the provided context.
// See the the Context documentation.
//
// Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx Context, t *Token) TokenSource {
nwn := &newWhenNeededSource{t: t}
nwn.new = tokenRefresher{
ctx: ctx,
conf: c,
oldToken: &nwn.t,
}
return newTransport(tr, o, t)
return nwn
}
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
return func(t *Token) (*Token, error) {
if t == nil || t.RefreshToken == "" {
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
// HTTP requests to renew a token using a RefreshToken.
type tokenRefresher struct {
ctx Context // used to get HTTP requests
conf *Config
oldToken **Token // pointer to old *Token w/ RefreshToken
}
func (tf tokenRefresher) Token() (*Token, error) {
t := *tf.oldToken
if t == nil {
return nil, errors.New("oauth2: attempted use of nil Token")
}
return retrieveToken(o, url.Values{
if t.RefreshToken == "" {
return nil, errors.New("oauth2: token expired and refresh token is not set")
}
return retrieveToken(tf.ctx, tf.conf, url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {t.RefreshToken},
})
}
}
// Options represents an object to keep the state of the OAuth 2.0 flow.
type Options struct {
// ClientID is the OAuth client identifier used when communicating with
// the configured OAuth provider.
ClientID string
// newWhenNeededSource is a TokenSource that holds a single token in memory
// and validates its expiry before each call to retrieve it with
// Token. If it's expired, it will be auto-refreshed using the
// new TokenSource.
//
// The first call to TokenRefresher must be SetToken.
type newWhenNeededSource struct {
new TokenSource // called when t is expired.
// ClientSecret is the OAuth client secret used when communicating with
// the configured OAuth provider.
ClientSecret string
// RedirectURL is the URL to which the user will be returned after
// granting (or denying) access.
RedirectURL string
// Email is the OAuth client identifier used when communicating with
// the configured OAuth provider.
Email string
// PrivateKey contains the contents of an RSA private key or the
// contents of a PEM file that contains a private key. The provided
// private key is used to sign JWT payloads.
// PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey *rsa.PrivateKey
// Scopes identify the level of access being requested.
Subject string
// Scopes optionally specifies a list of requested permission scopes.
Scopes []string
// AuthURL represents the authorization endpoint of the OAuth 2.0 provider.
AuthURL *url.URL
// TokenURL represents the token endpoint of the OAuth 2.0 provider.
TokenURL *url.URL
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
AUD *url.URL
// TokenStore reads a token from the store and writes it back to the store
// if a token refresh occurs.
// Optional.
TokenStore TokenStore
TokenFetcherFunc func(t *Token) (*Token, error)
Client *http.Client
mu sync.Mutex // guards t
t *Token
}
func retrieveToken(o *Options, v url.Values) (*Token, error) {
v.Set("client_id", o.ClientID)
bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String())
if bustedAuth && o.ClientSecret != "" {
v.Set("client_secret", o.ClientSecret)
// Token returns the current token if it's still valid, else will
// refresh the current token (using r.Context for HTTP client
// information) and return the new one.
func (s *newWhenNeededSource) Token() (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.t != nil && !s.t.Expired() {
return s.t, nil
}
req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode()))
t, err := s.new.Token()
if err != nil {
return nil, err
}
s.t = t
return t, nil
}
func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
hc, err := contextClient(ctx)
if err != nil {
return nil, err
}
v.Set("client_id", c.ClientID)
bustedAuth := !providerAuthHeaderWorks(c.Endpoint.TokenURL)
if bustedAuth && c.ClientSecret != "" {
v.Set("client_secret", c.ClientSecret)
}
req, err := http.NewRequest("POST", c.Endpoint.TokenURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth && o.ClientSecret != "" {
req.SetBasicAuth(o.ClientID, o.ClientSecret)
if !bustedAuth && c.ClientSecret != "" {
req.SetBasicAuth(c.ClientID, c.ClientSecret)
}
c := o.Client
if c == nil {
c = &http.Client{}
}
r, err := c.Do(req)
r, err := hc.Do(req)
if err != nil {
return nil, err
}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
@ -314,7 +377,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) {
}
token := &Token{}
expires := int(0)
expires := 0
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
switch content {
case "application/x-www-form-urlencoded", "text/plain":
@ -335,7 +398,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) {
}
expires, _ = strconv.Atoi(e)
default:
b := make(map[string]interface{})
b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type
if err = json.Unmarshal(body, &b); err != nil {
return nil, err
}
@ -397,3 +460,12 @@ func providerAuthHeaderWorks(tokenURL string) bool {
// to this package to let users select server bugs.
return true
}
// HTTPClient is the context key to use with golang.org/x/net/context's
// WithValue function to associate an *http.Client value with a context.
var HTTPClient contextKey
// contextKey is just an empty struct. It exists so HTTPClient can be
// an immutable public variable with a unique type. It's immutable
// because nobody else can create a contextKey, being unexported.
type contextKey struct{}

View File

@ -5,136 +5,90 @@
package oauth2
import (
"errors"
"io"
"net/http"
"net/url"
"sync"
"time"
)
const (
defaultTokenType = "Bearer"
)
// Token represents the crendentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
type Token struct {
// AccessToken is the token that authorizes and authenticates the requests.
AccessToken string `json:"access_token"`
// TokenType identifies the type of token returned.
TokenType string `json:"token_type,omitempty"`
// RefreshToken is a token that may be used to obtain a new access token.
RefreshToken string `json:"refresh_token,omitempty"`
// Expiry is the expiration datetime of the access token.
Expiry time.Time `json:"expiry,omitempty"`
// raw optionally contains extra metadata from the server
// when updating a token.
raw interface{}
}
// Extra returns an extra field returned from the server during token retrieval.
// E.g.
// idToken := token.Extra("id_token")
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
// wrapping a base RoundTripper and adding an Authorization header
// with a token from the supplied Sources.
//
func (t *Token) Extra(key string) string {
if vals, ok := t.raw.(url.Values); ok {
return vals.Get(key)
}
if raw, ok := t.raw.(map[string]interface{}); ok {
if val, ok := raw[key].(string); ok {
return val
}
}
return ""
}
// Expired returns true if there is no access token or the
// access token is expired.
func (t *Token) Expired() bool {
if t.AccessToken == "" {
return true
}
if t.Expiry.IsZero() {
return false
}
return t.Expiry.Before(time.Now())
}
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
// Transport is a low-level mechanism. Most code will use the
// higher-level Config.Client method instead.
type Transport struct {
opts *Options
base http.RoundTripper
// Source supplies the token to add to outgoing requests'
// Authorization headers.
Source TokenSource
mu sync.RWMutex
token *Token
}
// Base is the base RoundTripper used to make HTTP requests.
// If nil, http.DefaultTransport is used.
Base http.RoundTripper
// NewTransport creates a new Transport that uses the provided
// token fetcher as token retrieving strategy. It authenticates
// the requests and delegates origTransport to make the actual requests.
func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
return &Transport{
base: base,
opts: opts,
token: token,
}
mu sync.Mutex // guards modReq
modReq map[*http.Request]*http.Request // original -> modified
}
// RoundTrip authorizes and authenticates the request with an
// access token. If no token exists or token is expired,
// tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
token := t.token
if token == nil || token.Expired() {
// Check if the token is refreshable.
// If token is refreshable, don't return an error,
// rather refresh.
if err := t.refreshToken(); err != nil {
if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil")
}
token, err := t.Source.Token()
if err != nil {
return nil, err
}
token = t.token
if t.opts.TokenStore != nil {
t.opts.TokenStore.WriteToken(token)
}
}
// To set the Authorization header, we must make a copy of the Request
// so that we don't modify the Request we were given.
// This is required by the specification of http.RoundTripper.
req = cloneRequest(req)
typ := token.TokenType
if typ == "" {
typ = defaultTokenType
req2 := cloneRequest(req) // per RoundTripper contract
token.SetAuthHeader(req2)
t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2)
if err != nil {
t.setModReq(req, nil)
return nil, err
}
req.Header.Set("Authorization", typ+" "+token.AccessToken)
return t.base.RoundTrip(req)
res.Body = &onEOFReader{
rc: res.Body,
fn: func() { t.setModReq(req, nil) },
}
return res, nil
}
// Token returns the token that authorizes and
// authenticates the transport.
func (t *Transport) Token() *Token {
t.mu.RLock()
defer t.mu.RUnlock()
return t.token
// CancelRequest cancels an in-flight request by closing its connection.
func (t *Transport) CancelRequest(req *http.Request) {
type canceler interface {
CancelRequest(*http.Request)
}
if cr, ok := t.Base.(canceler); ok {
t.mu.Lock()
modReq := t.modReq[req]
delete(t.modReq, req)
t.mu.Unlock()
cr.CancelRequest(modReq)
}
}
// refreshToken retrieves a new token, if a refreshing/fetching
// method is known and required credentials are presented
// (such as a refresh token).
func (t *Transport) refreshToken() error {
func (t *Transport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func (t *Transport) setModReq(orig, mod *http.Request) {
t.mu.Lock()
defer t.mu.Unlock()
token, err := t.opts.TokenFetcherFunc(t.token)
if err != nil {
return err
if t.modReq == nil {
t.modReq = make(map[*http.Request]*http.Request)
}
if mod == nil {
delete(t.modReq, orig)
} else {
t.modReq[orig] = mod
}
t.token = token
return nil
}
// cloneRequest returns a clone of the provided *http.Request.
@ -144,9 +98,41 @@ func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header)
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = s
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
type onEOFReader struct {
rc io.ReadCloser
fn func()
}
func (r *onEOFReader) Read(p []byte) (n int, err error) {
n, err = r.rc.Read(p)
if err == io.EOF {
r.runFunc()
}
return
}
func (r *onEOFReader) Close() error {
err := r.rc.Close()
r.runFunc()
return err
}
func (r *onEOFReader) runFunc() {
if fn := r.fn; fn != nil {
fn()
r.fn = nil
}
}
type errorTransport struct{ err error }
func (t errorTransport) RoundTrip(*http.Request) (*http.Response, error) {
return nil, t.err
}