forked from Mirrors/oauth2
oauth2: redesign the API
Tests and examples aren't updated yet. The tree will be broken after this, but nobody should be using this yet anyway. Change-Id: I0004c738f40919ab46d107c71c011c510fbc748f Reviewed-on: https://go-review.googlesource.com/1246 Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
parent
b3f9a68f05
commit
a568078818
|
@ -0,0 +1,39 @@
|
|||
// Copyright 2014 The oauth2 Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build appengine,!appenginevm
|
||||
|
||||
// App Engine hooks.
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"appengine"
|
||||
"appengine/urlfetch"
|
||||
)
|
||||
|
||||
var warnOnce sync.Once
|
||||
|
||||
func init() {
|
||||
registerContextClientFunc(contextClientAppEngine)
|
||||
}
|
||||
|
||||
func contextClientAppEngine(ctx Context) (*http.Client, error) {
|
||||
if actx, ok := ctx.(appengine.Context); ok {
|
||||
return urlfetch.NewClient(actx), nil
|
||||
}
|
||||
// The user did it wrong. We'll log once (and hope they see it
|
||||
// in dev_appserver), but stil return (nil, nil) in case some
|
||||
// other contextClientFunc hook finds a way to proceed.
|
||||
warnOnce.Do(gaeDoingItWrongHelp)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func gaeDoingItWrongHelp() {
|
||||
log.Printf("WARNING: you attempted to use the oauth2 package without passing a valid appengine.Context or *http.Request as the oauth2.Context. App Engine requires that all service RPCs (including urlfetch) be associated with an *http.Request/appengine.Context.")
|
||||
}
|
195
jwt.go
195
jwt.go
|
@ -5,15 +5,16 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2/internal"
|
||||
"golang.org/x/oauth2/jws"
|
||||
)
|
||||
|
||||
|
@ -22,100 +23,126 @@ var (
|
|||
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
|
||||
)
|
||||
|
||||
// JWTClient requires OAuth 2.0 JWT credentials.
|
||||
// Required for the 2-legged JWT flow.
|
||||
func JWTClient(email string, key []byte) Option {
|
||||
return func(o *Options) error {
|
||||
pk, err := internal.ParseKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.Email = email
|
||||
o.PrivateKey = pk
|
||||
return nil
|
||||
// JWTConfig is the configuration for using JWT to fetch tokens,
|
||||
// commonly known as "two-legged OAuth".
|
||||
type JWTConfig struct {
|
||||
// Email is the OAuth client identifier used when communicating with
|
||||
// the configured OAuth provider.
|
||||
Email string
|
||||
|
||||
// PrivateKey contains the contents of an RSA private key or the
|
||||
// contents of a PEM file that contains a private key. The provided
|
||||
// private key is used to sign JWT payloads.
|
||||
// PEM containers with a passphrase are not supported.
|
||||
// Use the following command to convert a PKCS 12 file into a PEM.
|
||||
//
|
||||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||
//
|
||||
PrivateKey *rsa.PrivateKey
|
||||
|
||||
// Subject is the optional user to impersonate.
|
||||
Subject string
|
||||
|
||||
// Scopes optionally specifies a list of requested permission scopes.
|
||||
Scopes []string
|
||||
|
||||
// TokenURL is the endpoint required to complete the 2-legged JWT flow.
|
||||
TokenURL string
|
||||
}
|
||||
|
||||
// TokenSource returns a JWT TokenSource using the configuration
|
||||
// in c and the HTTP client from the provided context.
|
||||
//
|
||||
// The returned TokenSource only does JWT requests when necessary but
|
||||
// otherwise returns the same token repeatedly until it expires.
|
||||
//
|
||||
// The provided initialToken may be nil, in which case the first
|
||||
// call to TokenSource will do a new JWT request.
|
||||
func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource {
|
||||
return &newWhenNeededSource{
|
||||
t: initialToken,
|
||||
new: jwtSource{ctx, c},
|
||||
}
|
||||
}
|
||||
|
||||
// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider.
|
||||
func JWTEndpoint(aud string) Option {
|
||||
return func(o *Options) error {
|
||||
au, err := url.Parse(aud)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.AUD = au
|
||||
return nil
|
||||
// Client returns an HTTP client wrapping the context's
|
||||
// HTTP transport and adding Authorization headers with tokens
|
||||
// obtained from c.
|
||||
//
|
||||
// The provided initialToken may be nil, in which case the first
|
||||
// call to TokenSource will do a new JWT request.
|
||||
//
|
||||
// The returned client and its Transport should not be modified.
|
||||
func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &Transport{
|
||||
Source: c.TokenSource(ctx, initialToken),
|
||||
Base: contextTransport(ctx),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Subject requires a user to impersonate.
|
||||
// Optional.
|
||||
func Subject(user string) Option {
|
||||
return func(o *Options) error {
|
||||
o.Subject = user
|
||||
return nil
|
||||
}
|
||||
// jwtSource is a source that always does a signed JWT request for a token.
|
||||
// It should typically be wrapped with a newWhenNeededSource.
|
||||
type jwtSource struct {
|
||||
ctx Context
|
||||
conf *JWTConfig
|
||||
}
|
||||
|
||||
func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
||||
return func(t *Token) (*Token, error) {
|
||||
if t == nil {
|
||||
t = &Token{}
|
||||
}
|
||||
claimSet := &jws.ClaimSet{
|
||||
Iss: o.Email,
|
||||
Scope: strings.Join(o.Scopes, " "),
|
||||
Aud: o.AUD.String(),
|
||||
}
|
||||
if o.Subject != "" {
|
||||
claimSet.Sub = o.Subject
|
||||
// prn is the old name of sub. Keep setting it
|
||||
// to be compatible with legacy OAuth 2.0 providers.
|
||||
claimSet.Prn = o.Subject
|
||||
}
|
||||
payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v := url.Values{}
|
||||
v.Set("grant_type", defaultGrantType)
|
||||
v.Set("assertion", payload)
|
||||
c := o.Client
|
||||
if c == nil {
|
||||
c = &http.Client{}
|
||||
}
|
||||
resp, err := c.PostForm(o.AUD.String(), v)
|
||||
func (js jwtSource) Token() (*Token, error) {
|
||||
hc, err := contextClient(js.ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimSet := &jws.ClaimSet{
|
||||
Iss: js.conf.Email,
|
||||
Scope: strings.Join(js.conf.Scopes, " "),
|
||||
Aud: js.conf.TokenURL,
|
||||
}
|
||||
if subject := js.conf.Subject; subject != "" {
|
||||
claimSet.Sub = subject
|
||||
// prn is the old name of sub. Keep setting it
|
||||
// to be compatible with legacy OAuth 2.0 providers.
|
||||
claimSet.Prn = subject
|
||||
}
|
||||
payload, err := jws.Encode(defaultHeader, claimSet, js.conf.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v := url.Values{}
|
||||
v.Set("grant_type", defaultGrantType)
|
||||
v.Set("assertion", payload)
|
||||
resp, err := hc.PostForm(js.conf.TokenURL, v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
if c := resp.StatusCode; c < 200 || c > 299 {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
|
||||
}
|
||||
b := make(map[string]interface{})
|
||||
if err := json.Unmarshal(body, &b); err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
token := &Token{}
|
||||
token.AccessToken, _ = b["access_token"].(string)
|
||||
token.TokenType, _ = b["token_type"].(string)
|
||||
token.raw = b
|
||||
if e, ok := b["expires_in"].(int); ok {
|
||||
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
|
||||
}
|
||||
if idtoken, ok := b["id_token"].(string); ok {
|
||||
// decode returned id token to get expiry
|
||||
claimSet, err := jws.Decode(idtoken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
if c := resp.StatusCode; c < 200 || c > 299 {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
|
||||
}
|
||||
b := make(map[string]interface{})
|
||||
if err := json.Unmarshal(body, &b); err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
token := &Token{}
|
||||
token.AccessToken, _ = b["access_token"].(string)
|
||||
token.TokenType, _ = b["token_type"].(string)
|
||||
token.raw = b
|
||||
if e, ok := b["expires_in"].(int); ok {
|
||||
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
|
||||
}
|
||||
if idtoken, ok := b["id_token"].(string); ok {
|
||||
// decode returned id token to get expiry
|
||||
claimSet, err := jws.Decode(idtoken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
token.Expiry = time.Unix(claimSet.Exp, 0)
|
||||
return token, nil
|
||||
}
|
||||
token.Expiry = time.Unix(claimSet.Exp, 0)
|
||||
return token, nil
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
|
552
oauth2.go
552
oauth2.go
|
@ -8,115 +8,168 @@
|
|||
package oauth2 // import "golang.org/x/oauth2"
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"mime"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// TokenStore implementations read and write OAuth 2.0 tokens from a persistence layer.
|
||||
type TokenStore interface {
|
||||
// ReadToken reads the token from the store.
|
||||
// If the read is successful, it should return the token and a nil error.
|
||||
// The returned tokens may be expired tokens.
|
||||
// If there is no token in the store, it should return a nil token and a nil error.
|
||||
// It should return a non-nil error when an unrecoverable failure occurs.
|
||||
ReadToken() (*Token, error)
|
||||
// WriteToken writes the token to the cache.
|
||||
WriteToken(*Token)
|
||||
// Context can be an golang.org/x/net.Context, or an App Engine Context.
|
||||
// In the future these will be unified.
|
||||
// If you don't care and aren't running on App Engine, you may use nil.
|
||||
type Context interface{}
|
||||
|
||||
// Config describes a typical 3-legged OAuth2 flow, with both the
|
||||
// client application information and the server's URLs.
|
||||
type Config struct {
|
||||
// ClientID is the application's ID.
|
||||
ClientID string
|
||||
|
||||
// ClientSecret is the application's secret.
|
||||
ClientSecret string
|
||||
|
||||
// Endpoint contains the resource server's token endpoint
|
||||
// URLs. These are supplied by the server and are often
|
||||
// available via site-specific packages (for example,
|
||||
// google.Endpoint or github.Endpoint)
|
||||
Endpoint Endpoint
|
||||
|
||||
// RedirectURL is the URL to redirect users going through
|
||||
// the OAuth flow, after the resource owner's URLs.
|
||||
RedirectURL string
|
||||
|
||||
// Scope specifies optional requested permissions.
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
// Option represents a function that applies some state to
|
||||
// an Options object.
|
||||
type Option func(*Options) error
|
||||
// A TokenSource is anything that can return a token.
|
||||
type TokenSource interface {
|
||||
// Token returns a token or an error.
|
||||
Token() (*Token, error)
|
||||
}
|
||||
|
||||
// Client requires the OAuth 2.0 client credentials. You need to provide
|
||||
// the client identifier and optionally the client secret that are
|
||||
// assigned to your application by the OAuth 2.0 provider.
|
||||
func Client(id, secret string) Option {
|
||||
return func(opts *Options) error {
|
||||
opts.ClientID = id
|
||||
opts.ClientSecret = secret
|
||||
return nil
|
||||
// Endpoint contains the OAuth 2.0 provider's authorization and token
|
||||
// endpoint URLs.
|
||||
type Endpoint struct {
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
}
|
||||
|
||||
// Token represents the crendentials used to authorize
|
||||
// the requests to access protected resources on the OAuth 2.0
|
||||
// provider's backend.
|
||||
//
|
||||
// Most users of this package should not access fields of Token
|
||||
// directly. They're exported mostly for use by related packages
|
||||
// implementing derivate OAuth2 flows.
|
||||
type Token struct {
|
||||
// AccessToken is the token that authorizes and authenticates
|
||||
// the requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
// TokenType is the type of token.
|
||||
// The Type method returns either this or "Bearer", the default.
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
|
||||
// RefreshToken is a token that's used by the application
|
||||
// (as opposed to the user) to refresh the access token
|
||||
// if it expires.
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
|
||||
// Expiry is the optional expiration time of the access token.
|
||||
//
|
||||
// If zero, TokenSource implementations will reuse the same
|
||||
// token forever and RefreshToken or equivalent
|
||||
// mechanisms for that TokenSource will not be used.
|
||||
Expiry time.Time `json:"expiry,omitempty"`
|
||||
|
||||
// raw optionally contains extra metadata from the server
|
||||
// when updating a token.
|
||||
raw interface{}
|
||||
}
|
||||
|
||||
// Type returns t.TokenType if non-empty, else "Bearer".
|
||||
func (t *Token) Type() string {
|
||||
if t.TokenType != "" {
|
||||
return t.TokenType
|
||||
}
|
||||
return "Bearer"
|
||||
}
|
||||
|
||||
// RedirectURL requires the URL to which the user will be returned after
|
||||
// granting (or denying) access.
|
||||
func RedirectURL(url string) Option {
|
||||
return func(opts *Options) error {
|
||||
opts.RedirectURL = url
|
||||
return nil
|
||||
// SetAuthHeader sets the Authorization header to r using the access
|
||||
// token in t.
|
||||
//
|
||||
// This method is unnecessary when using Transport or an HTTP Client
|
||||
// returned by this package.
|
||||
func (t *Token) SetAuthHeader(r *http.Request) {
|
||||
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
|
||||
}
|
||||
|
||||
// Extra returns an extra field returned from the server during token
|
||||
// retrieval.
|
||||
func (t *Token) Extra(key string) string {
|
||||
if vals, ok := t.raw.(url.Values); ok {
|
||||
return vals.Get(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Scope requires a list of requested permission scopes.
|
||||
// It is optinal to specify scopes.
|
||||
func Scope(scopes ...string) Option {
|
||||
return func(o *Options) error {
|
||||
o.Scopes = scopes
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Endpoint requires OAuth 2.0 provider's authorization and token endpoints.
|
||||
func Endpoint(authURL, tokenURL string) Option {
|
||||
return func(o *Options) error {
|
||||
au, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tu, err := url.Parse(tokenURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.AuthURL = au
|
||||
o.TokenURL = tu
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPClient allows you to provide a custom http.Client to be
|
||||
// used to retrieve tokens from the OAuth 2.0 provider.
|
||||
func HTTPClient(c *http.Client) Option {
|
||||
return func(o *Options) error {
|
||||
o.Client = c
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// New builds a new options object and determines the type of the OAuth 2.0
|
||||
// (2-legged, 3-legged or custom) by looking at the provided options.
|
||||
// If the flow type cannot determined automatically, an error is returned.
|
||||
func New(option ...Option) (*Options, error) {
|
||||
opts := &Options{}
|
||||
for _, fn := range option {
|
||||
if err := fn(opts); err != nil {
|
||||
return nil, err
|
||||
if raw, ok := t.raw.(map[string]interface{}); ok {
|
||||
if val, ok := raw[key].(string); ok {
|
||||
return val
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case opts.TokenFetcherFunc != nil:
|
||||
return opts, nil
|
||||
case opts.AUD != nil:
|
||||
// TODO(jbd): Assert the required JWT params.
|
||||
opts.TokenFetcherFunc = makeTwoLeggedFetcher(opts)
|
||||
return opts, nil
|
||||
case opts.AuthURL != nil && opts.TokenURL != nil:
|
||||
// TODO(jbd): Assert the required OAuth2 params.
|
||||
opts.TokenFetcherFunc = makeThreeLeggedFetcher(opts)
|
||||
return opts, nil
|
||||
default:
|
||||
return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens")
|
||||
return ""
|
||||
}
|
||||
|
||||
// Expired returns true if there is no access token or the
|
||||
// access token is expired.
|
||||
func (t *Token) Expired() bool {
|
||||
if t.AccessToken == "" {
|
||||
return true
|
||||
}
|
||||
if t.Expiry.IsZero() {
|
||||
return false
|
||||
}
|
||||
return t.Expiry.Before(time.Now())
|
||||
}
|
||||
|
||||
var (
|
||||
// AccessTypeOnline and AccessTypeOffline are options passed
|
||||
// to the Options.AuthCodeURL method. They modify the
|
||||
// "access_type" field that gets sent in the URL returned by
|
||||
// AuthCodeURL.
|
||||
//
|
||||
// Online (the default if neither is specified) is the default.
|
||||
// If your application needs to refresh access tokens when the
|
||||
// user is not present at the browser, then use offline. This
|
||||
// will result in your application obtaining a refresh token
|
||||
// the first time your application exchanges an authorization
|
||||
// code for a user.
|
||||
AccessTypeOnline AuthCodeOption = setParam{"access_type", "online"}
|
||||
AccessTypeOffline AuthCodeOption = setParam{"access_type", "offline"}
|
||||
|
||||
// ApprovalForce forces the users to view the consent dialog
|
||||
// and confirm the permissions request at the URL returned
|
||||
// from AuthCodeURL, even if they've already done so.
|
||||
ApprovalForce AuthCodeOption = setParam{"approval_prompt", "force"}
|
||||
)
|
||||
|
||||
type setParam struct{ k, v string }
|
||||
|
||||
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
|
||||
|
||||
// An AuthCodeOption is passed to Config.AuthCodeURL.
|
||||
type AuthCodeOption interface {
|
||||
setValue(url.Values)
|
||||
}
|
||||
|
||||
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
||||
|
@ -127,185 +180,195 @@ func New(option ...Option) (*Options, error) {
|
|||
// the state query parameter on your redirect callback.
|
||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||
//
|
||||
// Access type is an OAuth extension that gets sent as the
|
||||
// "access_type" field in the URL from AuthCodeURL.
|
||||
// It may be "online" (default) or "offline".
|
||||
// If your application needs to refresh access tokens when the
|
||||
// user is not present at the browser, then use offline. This
|
||||
// will result in your application obtaining a refresh token
|
||||
// the first time your application exchanges an authorization
|
||||
// code for a user.
|
||||
//
|
||||
// Approval prompt indicates whether the user should be
|
||||
// re-prompted for consent. If set to "auto" (default) the
|
||||
// user will be prompted only if they haven't previously
|
||||
// granted consent and the code can only be exchanged for an
|
||||
// access token. If set to "force" the user will always be prompted,
|
||||
// and the code can be exchanged for a refresh token.
|
||||
func (o *Options) AuthCodeURL(state, accessType, prompt string) string {
|
||||
u := *o.AuthURL
|
||||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
||||
// as ApprovalForce.
|
||||
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(c.Endpoint.AuthURL)
|
||||
v := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {o.ClientID},
|
||||
"redirect_uri": condVal(o.RedirectURL),
|
||||
"scope": condVal(strings.Join(o.Scopes, " ")),
|
||||
"state": condVal(state),
|
||||
"access_type": condVal(accessType),
|
||||
"approval_prompt": condVal(prompt),
|
||||
"response_type": {"code"},
|
||||
"client_id": {c.ClientID},
|
||||
"redirect_uri": condVal(c.RedirectURL),
|
||||
"scope": condVal(strings.Join(c.Scopes, " ")),
|
||||
"state": condVal(state),
|
||||
}
|
||||
q := v.Encode()
|
||||
if u.RawQuery == "" {
|
||||
u.RawQuery = q
|
||||
for _, opt := range opts {
|
||||
opt.setValue(v)
|
||||
}
|
||||
if strings.Contains(c.Endpoint.AuthURL, "?") {
|
||||
buf.WriteByte('&')
|
||||
} else {
|
||||
u.RawQuery += "&" + q
|
||||
buf.WriteByte('?')
|
||||
}
|
||||
return u.String()
|
||||
buf.WriteString(v.Encode())
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// exchange exchanges the authorization code with the OAuth 2.0 provider
|
||||
// to retrieve a new access token.
|
||||
func (o *Options) exchange(code string) (*Token, error) {
|
||||
return retrieveToken(o, url.Values{
|
||||
// Exchange converts an authorization code into a token.
|
||||
//
|
||||
// It is used after a resource provider redirects the user back
|
||||
// to the Redirect URI (the URL obtained from AuthCodeURL).
|
||||
//
|
||||
// The HTTP client to use is derived from the context. If nil,
|
||||
// http.DefaultClient is used. See the Context type's documentation.
|
||||
//
|
||||
// The code will be in the *http.Request.FormValue("code"). Before
|
||||
// calling Exchange, be sure to validate FormValue("state").
|
||||
func (c *Config) Exchange(ctx Context, code string) (*Token, error) {
|
||||
return retrieveToken(ctx, c, url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
"redirect_uri": condVal(o.RedirectURL),
|
||||
"scope": condVal(strings.Join(o.Scopes, " ")),
|
||||
"redirect_uri": condVal(c.RedirectURL),
|
||||
"scope": condVal(strings.Join(c.Scopes, " ")),
|
||||
})
|
||||
}
|
||||
|
||||
// NewTransportFromTokenStore reads the token from the store and returns
|
||||
// a Transport that is authorized and the authenticated
|
||||
// by the returned token.
|
||||
func (o *Options) NewTransportFromTokenStore(store TokenStore) (*Transport, error) {
|
||||
tok, err := store.ReadToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.TokenStore = store
|
||||
if tok == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return o.newTransportFromToken(tok), nil
|
||||
// contextClientFunc is a func which tries to return an *http.Client
|
||||
// given a Context value. If it returns an error, the search stops
|
||||
// with that error. If it returns (nil, nil), the search continues
|
||||
// down the list of registered funcs.
|
||||
type contextClientFunc func(Context) (*http.Client, error)
|
||||
|
||||
var contextClientFuncs []contextClientFunc
|
||||
|
||||
func registerContextClietnFunc(fn contextClientFunc) {
|
||||
contextClientFuncs = append(contextClientFuncs, fn)
|
||||
}
|
||||
|
||||
// NewTransportFromCode exchanges the code to retrieve a new access token
|
||||
// and returns an authorized and authenticated Transport.
|
||||
func (o *Options) NewTransportFromCode(code string) (*Transport, error) {
|
||||
token, err := o.exchange(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o.newTransportFromToken(token), nil
|
||||
}
|
||||
|
||||
// NewTransport returns a Transport.
|
||||
func (o *Options) NewTransport() *Transport {
|
||||
return o.newTransportFromToken(nil)
|
||||
}
|
||||
|
||||
// newTransportFromToken returns a new Transport that is authorized
|
||||
// and authenticated with the provided token.
|
||||
func (o *Options) newTransportFromToken(t *Token) *Transport {
|
||||
// TODO(jbd): App Engine options initiate an http.Client that
|
||||
// depends on the urlfetcher, but it breaks the promise we made
|
||||
// that the options object should be working finely with nil-values
|
||||
// for the http.Client.
|
||||
tr := http.DefaultTransport
|
||||
if o.Client != nil && o.Client.Transport != nil {
|
||||
tr = o.Client.Transport
|
||||
}
|
||||
return newTransport(tr, o, t)
|
||||
}
|
||||
|
||||
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
||||
return func(t *Token) (*Token, error) {
|
||||
if t == nil || t.RefreshToken == "" {
|
||||
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
|
||||
func contextClient(ctx Context) (*http.Client, error) {
|
||||
for _, fn := range contextClientFuncs {
|
||||
c, err := fn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return retrieveToken(o, url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {t.RefreshToken},
|
||||
})
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
if xc, ok := ctx.(context.Context); ok {
|
||||
if hc, ok := xc.Value(HTTPClient).(*http.Client); ok {
|
||||
return hc, nil
|
||||
}
|
||||
}
|
||||
return http.DefaultClient, nil
|
||||
}
|
||||
|
||||
func contextTransport(ctx Context) http.RoundTripper {
|
||||
hc, err := contextClient(ctx)
|
||||
if err != nil {
|
||||
// This is a rare error case (somebody using nil on App Engine),
|
||||
// so I'd rather not everybody do an error check on this Client
|
||||
// method. They can get the error that they're doing it wrong
|
||||
// later, at client.Get/PostForm time.
|
||||
return errorTransport{err}
|
||||
}
|
||||
return hc.Transport
|
||||
}
|
||||
|
||||
// Client returns an HTTP client using the provided token.
|
||||
// The token will auto-refresh as necessary. The underlying
|
||||
// HTTP transport will be obtained using the provided context.
|
||||
// The returned client and its Transport should not be modified.
|
||||
func (c *Config) Client(ctx Context, t *Token) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &Transport{
|
||||
Source: c.TokenSource(ctx, t),
|
||||
Base: contextTransport(ctx),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Options represents an object to keep the state of the OAuth 2.0 flow.
|
||||
type Options struct {
|
||||
// ClientID is the OAuth client identifier used when communicating with
|
||||
// the configured OAuth provider.
|
||||
ClientID string
|
||||
|
||||
// ClientSecret is the OAuth client secret used when communicating with
|
||||
// the configured OAuth provider.
|
||||
ClientSecret string
|
||||
|
||||
// RedirectURL is the URL to which the user will be returned after
|
||||
// granting (or denying) access.
|
||||
RedirectURL string
|
||||
|
||||
// Email is the OAuth client identifier used when communicating with
|
||||
// the configured OAuth provider.
|
||||
Email string
|
||||
|
||||
// PrivateKey contains the contents of an RSA private key or the
|
||||
// contents of a PEM file that contains a private key. The provided
|
||||
// private key is used to sign JWT payloads.
|
||||
// PEM containers with a passphrase are not supported.
|
||||
// Use the following command to convert a PKCS 12 file into a PEM.
|
||||
//
|
||||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||
//
|
||||
PrivateKey *rsa.PrivateKey
|
||||
|
||||
// Scopes identify the level of access being requested.
|
||||
Subject string
|
||||
|
||||
// Scopes optionally specifies a list of requested permission scopes.
|
||||
Scopes []string
|
||||
|
||||
// AuthURL represents the authorization endpoint of the OAuth 2.0 provider.
|
||||
AuthURL *url.URL
|
||||
|
||||
// TokenURL represents the token endpoint of the OAuth 2.0 provider.
|
||||
TokenURL *url.URL
|
||||
|
||||
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
|
||||
AUD *url.URL
|
||||
|
||||
// TokenStore reads a token from the store and writes it back to the store
|
||||
// if a token refresh occurs.
|
||||
// Optional.
|
||||
TokenStore TokenStore
|
||||
|
||||
TokenFetcherFunc func(t *Token) (*Token, error)
|
||||
|
||||
Client *http.Client
|
||||
// TokenSource returns a TokenSource that returns t until t expires,
|
||||
// automatically refreshing it as necessary using the provided context.
|
||||
// See the the Context documentation.
|
||||
//
|
||||
// Most users will use Config.Client instead.
|
||||
func (c *Config) TokenSource(ctx Context, t *Token) TokenSource {
|
||||
nwn := &newWhenNeededSource{t: t}
|
||||
nwn.new = tokenRefresher{
|
||||
ctx: ctx,
|
||||
conf: c,
|
||||
oldToken: &nwn.t,
|
||||
}
|
||||
return nwn
|
||||
}
|
||||
|
||||
func retrieveToken(o *Options, v url.Values) (*Token, error) {
|
||||
v.Set("client_id", o.ClientID)
|
||||
bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String())
|
||||
if bustedAuth && o.ClientSecret != "" {
|
||||
v.Set("client_secret", o.ClientSecret)
|
||||
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
|
||||
// HTTP requests to renew a token using a RefreshToken.
|
||||
type tokenRefresher struct {
|
||||
ctx Context // used to get HTTP requests
|
||||
conf *Config
|
||||
oldToken **Token // pointer to old *Token w/ RefreshToken
|
||||
}
|
||||
|
||||
func (tf tokenRefresher) Token() (*Token, error) {
|
||||
t := *tf.oldToken
|
||||
if t == nil {
|
||||
return nil, errors.New("oauth2: attempted use of nil Token")
|
||||
}
|
||||
req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode()))
|
||||
if t.RefreshToken == "" {
|
||||
return nil, errors.New("oauth2: token expired and refresh token is not set")
|
||||
}
|
||||
return retrieveToken(tf.ctx, tf.conf, url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {t.RefreshToken},
|
||||
})
|
||||
}
|
||||
|
||||
// newWhenNeededSource is a TokenSource that holds a single token in memory
|
||||
// and validates its expiry before each call to retrieve it with
|
||||
// Token. If it's expired, it will be auto-refreshed using the
|
||||
// new TokenSource.
|
||||
//
|
||||
// The first call to TokenRefresher must be SetToken.
|
||||
type newWhenNeededSource struct {
|
||||
new TokenSource // called when t is expired.
|
||||
|
||||
mu sync.Mutex // guards t
|
||||
t *Token
|
||||
}
|
||||
|
||||
// Token returns the current token if it's still valid, else will
|
||||
// refresh the current token (using r.Context for HTTP client
|
||||
// information) and return the new one.
|
||||
func (s *newWhenNeededSource) Token() (*Token, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.t != nil && !s.t.Expired() {
|
||||
return s.t, nil
|
||||
}
|
||||
t, err := s.new.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.t = t
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
|
||||
hc, err := contextClient(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.Set("client_id", c.ClientID)
|
||||
bustedAuth := !providerAuthHeaderWorks(c.Endpoint.TokenURL)
|
||||
if bustedAuth && c.ClientSecret != "" {
|
||||
v.Set("client_secret", c.ClientSecret)
|
||||
}
|
||||
req, err := http.NewRequest("POST", c.Endpoint.TokenURL, strings.NewReader(v.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if !bustedAuth && o.ClientSecret != "" {
|
||||
req.SetBasicAuth(o.ClientID, o.ClientSecret)
|
||||
if !bustedAuth && c.ClientSecret != "" {
|
||||
req.SetBasicAuth(c.ClientID, c.ClientSecret)
|
||||
}
|
||||
c := o.Client
|
||||
if c == nil {
|
||||
c = &http.Client{}
|
||||
}
|
||||
r, err := c.Do(req)
|
||||
r, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Body.Close()
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
|
@ -314,7 +377,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) {
|
|||
}
|
||||
|
||||
token := &Token{}
|
||||
expires := int(0)
|
||||
expires := 0
|
||||
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
switch content {
|
||||
case "application/x-www-form-urlencoded", "text/plain":
|
||||
|
@ -335,7 +398,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) {
|
|||
}
|
||||
expires, _ = strconv.Atoi(e)
|
||||
default:
|
||||
b := make(map[string]interface{})
|
||||
b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type
|
||||
if err = json.Unmarshal(body, &b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -397,3 +460,12 @@ func providerAuthHeaderWorks(tokenURL string) bool {
|
|||
// to this package to let users select server bugs.
|
||||
return true
|
||||
}
|
||||
|
||||
// HTTPClient is the context key to use with golang.org/x/net/context's
|
||||
// WithValue function to associate an *http.Client value with a context.
|
||||
var HTTPClient contextKey
|
||||
|
||||
// contextKey is just an empty struct. It exists so HTTPClient can be
|
||||
// an immutable public variable with a unique type. It's immutable
|
||||
// because nobody else can create a contextKey, being unexported.
|
||||
type contextKey struct{}
|
||||
|
|
202
transport.go
202
transport.go
|
@ -5,136 +5,90 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTokenType = "Bearer"
|
||||
)
|
||||
|
||||
// Token represents the crendentials used to authorize
|
||||
// the requests to access protected resources on the OAuth 2.0
|
||||
// provider's backend.
|
||||
type Token struct {
|
||||
// AccessToken is the token that authorizes and authenticates the requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
// TokenType identifies the type of token returned.
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
|
||||
// RefreshToken is a token that may be used to obtain a new access token.
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
|
||||
// Expiry is the expiration datetime of the access token.
|
||||
Expiry time.Time `json:"expiry,omitempty"`
|
||||
|
||||
// raw optionally contains extra metadata from the server
|
||||
// when updating a token.
|
||||
raw interface{}
|
||||
}
|
||||
|
||||
// Extra returns an extra field returned from the server during token retrieval.
|
||||
// E.g.
|
||||
// idToken := token.Extra("id_token")
|
||||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
|
||||
// wrapping a base RoundTripper and adding an Authorization header
|
||||
// with a token from the supplied Sources.
|
||||
//
|
||||
func (t *Token) Extra(key string) string {
|
||||
if vals, ok := t.raw.(url.Values); ok {
|
||||
return vals.Get(key)
|
||||
}
|
||||
if raw, ok := t.raw.(map[string]interface{}); ok {
|
||||
if val, ok := raw[key].(string); ok {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Expired returns true if there is no access token or the
|
||||
// access token is expired.
|
||||
func (t *Token) Expired() bool {
|
||||
if t.AccessToken == "" {
|
||||
return true
|
||||
}
|
||||
if t.Expiry.IsZero() {
|
||||
return false
|
||||
}
|
||||
return t.Expiry.Before(time.Now())
|
||||
}
|
||||
|
||||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
|
||||
// Transport is a low-level mechanism. Most code will use the
|
||||
// higher-level Config.Client method instead.
|
||||
type Transport struct {
|
||||
opts *Options
|
||||
base http.RoundTripper
|
||||
// Source supplies the token to add to outgoing requests'
|
||||
// Authorization headers.
|
||||
Source TokenSource
|
||||
|
||||
mu sync.RWMutex
|
||||
token *Token
|
||||
}
|
||||
// Base is the base RoundTripper used to make HTTP requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Base http.RoundTripper
|
||||
|
||||
// NewTransport creates a new Transport that uses the provided
|
||||
// token fetcher as token retrieving strategy. It authenticates
|
||||
// the requests and delegates origTransport to make the actual requests.
|
||||
func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
|
||||
return &Transport{
|
||||
base: base,
|
||||
opts: opts,
|
||||
token: token,
|
||||
}
|
||||
mu sync.Mutex // guards modReq
|
||||
modReq map[*http.Request]*http.Request // original -> modified
|
||||
}
|
||||
|
||||
// RoundTrip authorizes and authenticates the request with an
|
||||
// access token. If no token exists or token is expired,
|
||||
// tries to refresh/fetch a new token.
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
token := t.token
|
||||
|
||||
if token == nil || token.Expired() {
|
||||
// Check if the token is refreshable.
|
||||
// If token is refreshable, don't return an error,
|
||||
// rather refresh.
|
||||
if err := t.refreshToken(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token = t.token
|
||||
if t.opts.TokenStore != nil {
|
||||
t.opts.TokenStore.WriteToken(token)
|
||||
}
|
||||
if t.Source == nil {
|
||||
return nil, errors.New("oauth2: Transport's Source is nil")
|
||||
}
|
||||
token, err := t.Source.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// To set the Authorization header, we must make a copy of the Request
|
||||
// so that we don't modify the Request we were given.
|
||||
// This is required by the specification of http.RoundTripper.
|
||||
req = cloneRequest(req)
|
||||
typ := token.TokenType
|
||||
if typ == "" {
|
||||
typ = defaultTokenType
|
||||
req2 := cloneRequest(req) // per RoundTripper contract
|
||||
token.SetAuthHeader(req2)
|
||||
t.setModReq(req, req2)
|
||||
res, err := t.base().RoundTrip(req2)
|
||||
if err != nil {
|
||||
t.setModReq(req, nil)
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", typ+" "+token.AccessToken)
|
||||
return t.base.RoundTrip(req)
|
||||
res.Body = &onEOFReader{
|
||||
rc: res.Body,
|
||||
fn: func() { t.setModReq(req, nil) },
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Token returns the token that authorizes and
|
||||
// authenticates the transport.
|
||||
func (t *Transport) Token() *Token {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.token
|
||||
// CancelRequest cancels an in-flight request by closing its connection.
|
||||
func (t *Transport) CancelRequest(req *http.Request) {
|
||||
type canceler interface {
|
||||
CancelRequest(*http.Request)
|
||||
}
|
||||
if cr, ok := t.Base.(canceler); ok {
|
||||
t.mu.Lock()
|
||||
modReq := t.modReq[req]
|
||||
delete(t.modReq, req)
|
||||
t.mu.Unlock()
|
||||
cr.CancelRequest(modReq)
|
||||
}
|
||||
}
|
||||
|
||||
// refreshToken retrieves a new token, if a refreshing/fetching
|
||||
// method is known and required credentials are presented
|
||||
// (such as a refresh token).
|
||||
func (t *Transport) refreshToken() error {
|
||||
func (t *Transport) base() http.RoundTripper {
|
||||
if t.Base != nil {
|
||||
return t.Base
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
func (t *Transport) setModReq(orig, mod *http.Request) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
token, err := t.opts.TokenFetcherFunc(t.token)
|
||||
if err != nil {
|
||||
return err
|
||||
if t.modReq == nil {
|
||||
t.modReq = make(map[*http.Request]*http.Request)
|
||||
}
|
||||
if mod == nil {
|
||||
delete(t.modReq, orig)
|
||||
} else {
|
||||
t.modReq[orig] = mod
|
||||
}
|
||||
t.token = token
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRequest returns a clone of the provided *http.Request.
|
||||
|
@ -144,9 +98,41 @@ func cloneRequest(r *http.Request) *http.Request {
|
|||
r2 := new(http.Request)
|
||||
*r2 = *r
|
||||
// deep copy of the Header
|
||||
r2.Header = make(http.Header)
|
||||
r2.Header = make(http.Header, len(r.Header))
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = s
|
||||
r2.Header[k] = append([]string(nil), s...)
|
||||
}
|
||||
return r2
|
||||
}
|
||||
|
||||
type onEOFReader struct {
|
||||
rc io.ReadCloser
|
||||
fn func()
|
||||
}
|
||||
|
||||
func (r *onEOFReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.rc.Read(p)
|
||||
if err == io.EOF {
|
||||
r.runFunc()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *onEOFReader) Close() error {
|
||||
err := r.rc.Close()
|
||||
r.runFunc()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *onEOFReader) runFunc() {
|
||||
if fn := r.fn; fn != nil {
|
||||
fn()
|
||||
r.fn = nil
|
||||
}
|
||||
}
|
||||
|
||||
type errorTransport struct{ err error }
|
||||
|
||||
func (t errorTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return nil, t.err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue