forked from Mirrors/oauth2
oauth2: redesign the API
Tests and examples aren't updated yet. The tree will be broken after this, but nobody should be using this yet anyway. Change-Id: I0004c738f40919ab46d107c71c011c510fbc748f Reviewed-on: https://go-review.googlesource.com/1246 Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
parent
b3f9a68f05
commit
a568078818
|
@ -0,0 +1,39 @@
|
||||||
|
// Copyright 2014 The oauth2 Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build appengine,!appenginevm
|
||||||
|
|
||||||
|
// App Engine hooks.
|
||||||
|
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"appengine"
|
||||||
|
"appengine/urlfetch"
|
||||||
|
)
|
||||||
|
|
||||||
|
var warnOnce sync.Once
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
registerContextClientFunc(contextClientAppEngine)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextClientAppEngine(ctx Context) (*http.Client, error) {
|
||||||
|
if actx, ok := ctx.(appengine.Context); ok {
|
||||||
|
return urlfetch.NewClient(actx), nil
|
||||||
|
}
|
||||||
|
// The user did it wrong. We'll log once (and hope they see it
|
||||||
|
// in dev_appserver), but stil return (nil, nil) in case some
|
||||||
|
// other contextClientFunc hook finds a way to proceed.
|
||||||
|
warnOnce.Do(gaeDoingItWrongHelp)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func gaeDoingItWrongHelp() {
|
||||||
|
log.Printf("WARNING: you attempted to use the oauth2 package without passing a valid appengine.Context or *http.Request as the oauth2.Context. App Engine requires that all service RPCs (including urlfetch) be associated with an *http.Request/appengine.Context.")
|
||||||
|
}
|
133
jwt.go
133
jwt.go
|
@ -5,15 +5,16 @@
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2/internal"
|
|
||||||
"golang.org/x/oauth2/jws"
|
"golang.org/x/oauth2/jws"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,74 +23,101 @@ var (
|
||||||
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
|
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// JWTClient requires OAuth 2.0 JWT credentials.
|
// JWTConfig is the configuration for using JWT to fetch tokens,
|
||||||
// Required for the 2-legged JWT flow.
|
// commonly known as "two-legged OAuth".
|
||||||
func JWTClient(email string, key []byte) Option {
|
type JWTConfig struct {
|
||||||
return func(o *Options) error {
|
// Email is the OAuth client identifier used when communicating with
|
||||||
pk, err := internal.ParseKey(key)
|
// the configured OAuth provider.
|
||||||
|
Email string
|
||||||
|
|
||||||
|
// PrivateKey contains the contents of an RSA private key or the
|
||||||
|
// contents of a PEM file that contains a private key. The provided
|
||||||
|
// private key is used to sign JWT payloads.
|
||||||
|
// PEM containers with a passphrase are not supported.
|
||||||
|
// Use the following command to convert a PKCS 12 file into a PEM.
|
||||||
|
//
|
||||||
|
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
||||||
|
//
|
||||||
|
PrivateKey *rsa.PrivateKey
|
||||||
|
|
||||||
|
// Subject is the optional user to impersonate.
|
||||||
|
Subject string
|
||||||
|
|
||||||
|
// Scopes optionally specifies a list of requested permission scopes.
|
||||||
|
Scopes []string
|
||||||
|
|
||||||
|
// TokenURL is the endpoint required to complete the 2-legged JWT flow.
|
||||||
|
TokenURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenSource returns a JWT TokenSource using the configuration
|
||||||
|
// in c and the HTTP client from the provided context.
|
||||||
|
//
|
||||||
|
// The returned TokenSource only does JWT requests when necessary but
|
||||||
|
// otherwise returns the same token repeatedly until it expires.
|
||||||
|
//
|
||||||
|
// The provided initialToken may be nil, in which case the first
|
||||||
|
// call to TokenSource will do a new JWT request.
|
||||||
|
func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource {
|
||||||
|
return &newWhenNeededSource{
|
||||||
|
t: initialToken,
|
||||||
|
new: jwtSource{ctx, c},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client returns an HTTP client wrapping the context's
|
||||||
|
// HTTP transport and adding Authorization headers with tokens
|
||||||
|
// obtained from c.
|
||||||
|
//
|
||||||
|
// The provided initialToken may be nil, in which case the first
|
||||||
|
// call to TokenSource will do a new JWT request.
|
||||||
|
//
|
||||||
|
// The returned client and its Transport should not be modified.
|
||||||
|
func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &Transport{
|
||||||
|
Source: c.TokenSource(ctx, initialToken),
|
||||||
|
Base: contextTransport(ctx),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtSource is a source that always does a signed JWT request for a token.
|
||||||
|
// It should typically be wrapped with a newWhenNeededSource.
|
||||||
|
type jwtSource struct {
|
||||||
|
ctx Context
|
||||||
|
conf *JWTConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (js jwtSource) Token() (*Token, error) {
|
||||||
|
hc, err := contextClient(js.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
|
||||||
o.Email = email
|
|
||||||
o.PrivateKey = pk
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider.
|
|
||||||
func JWTEndpoint(aud string) Option {
|
|
||||||
return func(o *Options) error {
|
|
||||||
au, err := url.Parse(aud)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
o.AUD = au
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subject requires a user to impersonate.
|
|
||||||
// Optional.
|
|
||||||
func Subject(user string) Option {
|
|
||||||
return func(o *Options) error {
|
|
||||||
o.Subject = user
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
|
||||||
return func(t *Token) (*Token, error) {
|
|
||||||
if t == nil {
|
|
||||||
t = &Token{}
|
|
||||||
}
|
}
|
||||||
claimSet := &jws.ClaimSet{
|
claimSet := &jws.ClaimSet{
|
||||||
Iss: o.Email,
|
Iss: js.conf.Email,
|
||||||
Scope: strings.Join(o.Scopes, " "),
|
Scope: strings.Join(js.conf.Scopes, " "),
|
||||||
Aud: o.AUD.String(),
|
Aud: js.conf.TokenURL,
|
||||||
}
|
}
|
||||||
if o.Subject != "" {
|
if subject := js.conf.Subject; subject != "" {
|
||||||
claimSet.Sub = o.Subject
|
claimSet.Sub = subject
|
||||||
// prn is the old name of sub. Keep setting it
|
// prn is the old name of sub. Keep setting it
|
||||||
// to be compatible with legacy OAuth 2.0 providers.
|
// to be compatible with legacy OAuth 2.0 providers.
|
||||||
claimSet.Prn = o.Subject
|
claimSet.Prn = subject
|
||||||
}
|
}
|
||||||
payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey)
|
payload, err := jws.Encode(defaultHeader, claimSet, js.conf.PrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
v := url.Values{}
|
v := url.Values{}
|
||||||
v.Set("grant_type", defaultGrantType)
|
v.Set("grant_type", defaultGrantType)
|
||||||
v.Set("assertion", payload)
|
v.Set("assertion", payload)
|
||||||
c := o.Client
|
resp, err := hc.PostForm(js.conf.TokenURL, v)
|
||||||
if c == nil {
|
|
||||||
c = &http.Client{}
|
|
||||||
}
|
|
||||||
resp, err := c.PostForm(o.AUD.String(), v)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -118,4 +146,3 @@ func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
502
oauth2.go
502
oauth2.go
|
@ -8,115 +8,168 @@
|
||||||
package oauth2 // import "golang.org/x/oauth2"
|
package oauth2 // import "golang.org/x/oauth2"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rsa"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"mime"
|
"mime"
|
||||||
"time"
|
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenStore implementations read and write OAuth 2.0 tokens from a persistence layer.
|
// Context can be an golang.org/x/net.Context, or an App Engine Context.
|
||||||
type TokenStore interface {
|
// In the future these will be unified.
|
||||||
// ReadToken reads the token from the store.
|
// If you don't care and aren't running on App Engine, you may use nil.
|
||||||
// If the read is successful, it should return the token and a nil error.
|
type Context interface{}
|
||||||
// The returned tokens may be expired tokens.
|
|
||||||
// If there is no token in the store, it should return a nil token and a nil error.
|
// Config describes a typical 3-legged OAuth2 flow, with both the
|
||||||
// It should return a non-nil error when an unrecoverable failure occurs.
|
// client application information and the server's URLs.
|
||||||
ReadToken() (*Token, error)
|
type Config struct {
|
||||||
// WriteToken writes the token to the cache.
|
// ClientID is the application's ID.
|
||||||
WriteToken(*Token)
|
ClientID string
|
||||||
|
|
||||||
|
// ClientSecret is the application's secret.
|
||||||
|
ClientSecret string
|
||||||
|
|
||||||
|
// Endpoint contains the resource server's token endpoint
|
||||||
|
// URLs. These are supplied by the server and are often
|
||||||
|
// available via site-specific packages (for example,
|
||||||
|
// google.Endpoint or github.Endpoint)
|
||||||
|
Endpoint Endpoint
|
||||||
|
|
||||||
|
// RedirectURL is the URL to redirect users going through
|
||||||
|
// the OAuth flow, after the resource owner's URLs.
|
||||||
|
RedirectURL string
|
||||||
|
|
||||||
|
// Scope specifies optional requested permissions.
|
||||||
|
Scopes []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option represents a function that applies some state to
|
// A TokenSource is anything that can return a token.
|
||||||
// an Options object.
|
type TokenSource interface {
|
||||||
type Option func(*Options) error
|
// Token returns a token or an error.
|
||||||
|
Token() (*Token, error)
|
||||||
// Client requires the OAuth 2.0 client credentials. You need to provide
|
|
||||||
// the client identifier and optionally the client secret that are
|
|
||||||
// assigned to your application by the OAuth 2.0 provider.
|
|
||||||
func Client(id, secret string) Option {
|
|
||||||
return func(opts *Options) error {
|
|
||||||
opts.ClientID = id
|
|
||||||
opts.ClientSecret = secret
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RedirectURL requires the URL to which the user will be returned after
|
// Endpoint contains the OAuth 2.0 provider's authorization and token
|
||||||
// granting (or denying) access.
|
// endpoint URLs.
|
||||||
func RedirectURL(url string) Option {
|
type Endpoint struct {
|
||||||
return func(opts *Options) error {
|
AuthURL string
|
||||||
opts.RedirectURL = url
|
TokenURL string
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scope requires a list of requested permission scopes.
|
// Token represents the crendentials used to authorize
|
||||||
// It is optinal to specify scopes.
|
// the requests to access protected resources on the OAuth 2.0
|
||||||
func Scope(scopes ...string) Option {
|
// provider's backend.
|
||||||
return func(o *Options) error {
|
//
|
||||||
o.Scopes = scopes
|
// Most users of this package should not access fields of Token
|
||||||
return nil
|
// directly. They're exported mostly for use by related packages
|
||||||
}
|
// implementing derivate OAuth2 flows.
|
||||||
|
type Token struct {
|
||||||
|
// AccessToken is the token that authorizes and authenticates
|
||||||
|
// the requests.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
|
||||||
|
// TokenType is the type of token.
|
||||||
|
// The Type method returns either this or "Bearer", the default.
|
||||||
|
TokenType string `json:"token_type,omitempty"`
|
||||||
|
|
||||||
|
// RefreshToken is a token that's used by the application
|
||||||
|
// (as opposed to the user) to refresh the access token
|
||||||
|
// if it expires.
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
|
||||||
|
// Expiry is the optional expiration time of the access token.
|
||||||
|
//
|
||||||
|
// If zero, TokenSource implementations will reuse the same
|
||||||
|
// token forever and RefreshToken or equivalent
|
||||||
|
// mechanisms for that TokenSource will not be used.
|
||||||
|
Expiry time.Time `json:"expiry,omitempty"`
|
||||||
|
|
||||||
|
// raw optionally contains extra metadata from the server
|
||||||
|
// when updating a token.
|
||||||
|
raw interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Endpoint requires OAuth 2.0 provider's authorization and token endpoints.
|
// Type returns t.TokenType if non-empty, else "Bearer".
|
||||||
func Endpoint(authURL, tokenURL string) Option {
|
func (t *Token) Type() string {
|
||||||
return func(o *Options) error {
|
if t.TokenType != "" {
|
||||||
au, err := url.Parse(authURL)
|
return t.TokenType
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
tu, err := url.Parse(tokenURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
o.AuthURL = au
|
|
||||||
o.TokenURL = tu
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
return "Bearer"
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPClient allows you to provide a custom http.Client to be
|
// SetAuthHeader sets the Authorization header to r using the access
|
||||||
// used to retrieve tokens from the OAuth 2.0 provider.
|
// token in t.
|
||||||
func HTTPClient(c *http.Client) Option {
|
//
|
||||||
return func(o *Options) error {
|
// This method is unnecessary when using Transport or an HTTP Client
|
||||||
o.Client = c
|
// returned by this package.
|
||||||
return nil
|
func (t *Token) SetAuthHeader(r *http.Request) {
|
||||||
}
|
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New builds a new options object and determines the type of the OAuth 2.0
|
// Extra returns an extra field returned from the server during token
|
||||||
// (2-legged, 3-legged or custom) by looking at the provided options.
|
// retrieval.
|
||||||
// If the flow type cannot determined automatically, an error is returned.
|
func (t *Token) Extra(key string) string {
|
||||||
func New(option ...Option) (*Options, error) {
|
if vals, ok := t.raw.(url.Values); ok {
|
||||||
opts := &Options{}
|
return vals.Get(key)
|
||||||
for _, fn := range option {
|
}
|
||||||
if err := fn(opts); err != nil {
|
if raw, ok := t.raw.(map[string]interface{}); ok {
|
||||||
return nil, err
|
if val, ok := raw[key].(string); ok {
|
||||||
|
return val
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch {
|
return ""
|
||||||
case opts.TokenFetcherFunc != nil:
|
|
||||||
return opts, nil
|
|
||||||
case opts.AUD != nil:
|
|
||||||
// TODO(jbd): Assert the required JWT params.
|
|
||||||
opts.TokenFetcherFunc = makeTwoLeggedFetcher(opts)
|
|
||||||
return opts, nil
|
|
||||||
case opts.AuthURL != nil && opts.TokenURL != nil:
|
|
||||||
// TODO(jbd): Assert the required OAuth2 params.
|
|
||||||
opts.TokenFetcherFunc = makeThreeLeggedFetcher(opts)
|
|
||||||
return opts, nil
|
|
||||||
default:
|
|
||||||
return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expired returns true if there is no access token or the
|
||||||
|
// access token is expired.
|
||||||
|
func (t *Token) Expired() bool {
|
||||||
|
if t.AccessToken == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if t.Expiry.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return t.Expiry.Before(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// AccessTypeOnline and AccessTypeOffline are options passed
|
||||||
|
// to the Options.AuthCodeURL method. They modify the
|
||||||
|
// "access_type" field that gets sent in the URL returned by
|
||||||
|
// AuthCodeURL.
|
||||||
|
//
|
||||||
|
// Online (the default if neither is specified) is the default.
|
||||||
|
// If your application needs to refresh access tokens when the
|
||||||
|
// user is not present at the browser, then use offline. This
|
||||||
|
// will result in your application obtaining a refresh token
|
||||||
|
// the first time your application exchanges an authorization
|
||||||
|
// code for a user.
|
||||||
|
AccessTypeOnline AuthCodeOption = setParam{"access_type", "online"}
|
||||||
|
AccessTypeOffline AuthCodeOption = setParam{"access_type", "offline"}
|
||||||
|
|
||||||
|
// ApprovalForce forces the users to view the consent dialog
|
||||||
|
// and confirm the permissions request at the URL returned
|
||||||
|
// from AuthCodeURL, even if they've already done so.
|
||||||
|
ApprovalForce AuthCodeOption = setParam{"approval_prompt", "force"}
|
||||||
|
)
|
||||||
|
|
||||||
|
type setParam struct{ k, v string }
|
||||||
|
|
||||||
|
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
|
||||||
|
|
||||||
|
// An AuthCodeOption is passed to Config.AuthCodeURL.
|
||||||
|
type AuthCodeOption interface {
|
||||||
|
setValue(url.Values)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
||||||
|
@ -127,185 +180,195 @@ func New(option ...Option) (*Options, error) {
|
||||||
// the state query parameter on your redirect callback.
|
// the state query parameter on your redirect callback.
|
||||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||||
//
|
//
|
||||||
// Access type is an OAuth extension that gets sent as the
|
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
||||||
// "access_type" field in the URL from AuthCodeURL.
|
// as ApprovalForce.
|
||||||
// It may be "online" (default) or "offline".
|
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
||||||
// If your application needs to refresh access tokens when the
|
var buf bytes.Buffer
|
||||||
// user is not present at the browser, then use offline. This
|
buf.WriteString(c.Endpoint.AuthURL)
|
||||||
// will result in your application obtaining a refresh token
|
|
||||||
// the first time your application exchanges an authorization
|
|
||||||
// code for a user.
|
|
||||||
//
|
|
||||||
// Approval prompt indicates whether the user should be
|
|
||||||
// re-prompted for consent. If set to "auto" (default) the
|
|
||||||
// user will be prompted only if they haven't previously
|
|
||||||
// granted consent and the code can only be exchanged for an
|
|
||||||
// access token. If set to "force" the user will always be prompted,
|
|
||||||
// and the code can be exchanged for a refresh token.
|
|
||||||
func (o *Options) AuthCodeURL(state, accessType, prompt string) string {
|
|
||||||
u := *o.AuthURL
|
|
||||||
v := url.Values{
|
v := url.Values{
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"client_id": {o.ClientID},
|
"client_id": {c.ClientID},
|
||||||
"redirect_uri": condVal(o.RedirectURL),
|
"redirect_uri": condVal(c.RedirectURL),
|
||||||
"scope": condVal(strings.Join(o.Scopes, " ")),
|
"scope": condVal(strings.Join(c.Scopes, " ")),
|
||||||
"state": condVal(state),
|
"state": condVal(state),
|
||||||
"access_type": condVal(accessType),
|
|
||||||
"approval_prompt": condVal(prompt),
|
|
||||||
}
|
}
|
||||||
q := v.Encode()
|
for _, opt := range opts {
|
||||||
if u.RawQuery == "" {
|
opt.setValue(v)
|
||||||
u.RawQuery = q
|
}
|
||||||
|
if strings.Contains(c.Endpoint.AuthURL, "?") {
|
||||||
|
buf.WriteByte('&')
|
||||||
} else {
|
} else {
|
||||||
u.RawQuery += "&" + q
|
buf.WriteByte('?')
|
||||||
}
|
}
|
||||||
return u.String()
|
buf.WriteString(v.Encode())
|
||||||
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// exchange exchanges the authorization code with the OAuth 2.0 provider
|
// Exchange converts an authorization code into a token.
|
||||||
// to retrieve a new access token.
|
//
|
||||||
func (o *Options) exchange(code string) (*Token, error) {
|
// It is used after a resource provider redirects the user back
|
||||||
return retrieveToken(o, url.Values{
|
// to the Redirect URI (the URL obtained from AuthCodeURL).
|
||||||
|
//
|
||||||
|
// The HTTP client to use is derived from the context. If nil,
|
||||||
|
// http.DefaultClient is used. See the Context type's documentation.
|
||||||
|
//
|
||||||
|
// The code will be in the *http.Request.FormValue("code"). Before
|
||||||
|
// calling Exchange, be sure to validate FormValue("state").
|
||||||
|
func (c *Config) Exchange(ctx Context, code string) (*Token, error) {
|
||||||
|
return retrieveToken(ctx, c, url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"code": {code},
|
"code": {code},
|
||||||
"redirect_uri": condVal(o.RedirectURL),
|
"redirect_uri": condVal(c.RedirectURL),
|
||||||
"scope": condVal(strings.Join(o.Scopes, " ")),
|
"scope": condVal(strings.Join(c.Scopes, " ")),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransportFromTokenStore reads the token from the store and returns
|
// contextClientFunc is a func which tries to return an *http.Client
|
||||||
// a Transport that is authorized and the authenticated
|
// given a Context value. If it returns an error, the search stops
|
||||||
// by the returned token.
|
// with that error. If it returns (nil, nil), the search continues
|
||||||
func (o *Options) NewTransportFromTokenStore(store TokenStore) (*Transport, error) {
|
// down the list of registered funcs.
|
||||||
tok, err := store.ReadToken()
|
type contextClientFunc func(Context) (*http.Client, error)
|
||||||
|
|
||||||
|
var contextClientFuncs []contextClientFunc
|
||||||
|
|
||||||
|
func registerContextClietnFunc(fn contextClientFunc) {
|
||||||
|
contextClientFuncs = append(contextClientFuncs, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextClient(ctx Context) (*http.Client, error) {
|
||||||
|
for _, fn := range contextClientFuncs {
|
||||||
|
c, err := fn(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
o.TokenStore = store
|
if c != nil {
|
||||||
if tok == nil {
|
return c, nil
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
return o.newTransportFromToken(tok), nil
|
}
|
||||||
|
if xc, ok := ctx.(context.Context); ok {
|
||||||
|
if hc, ok := xc.Value(HTTPClient).(*http.Client); ok {
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return http.DefaultClient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransportFromCode exchanges the code to retrieve a new access token
|
func contextTransport(ctx Context) http.RoundTripper {
|
||||||
// and returns an authorized and authenticated Transport.
|
hc, err := contextClient(ctx)
|
||||||
func (o *Options) NewTransportFromCode(code string) (*Transport, error) {
|
|
||||||
token, err := o.exchange(code)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
// This is a rare error case (somebody using nil on App Engine),
|
||||||
|
// so I'd rather not everybody do an error check on this Client
|
||||||
|
// method. They can get the error that they're doing it wrong
|
||||||
|
// later, at client.Get/PostForm time.
|
||||||
|
return errorTransport{err}
|
||||||
}
|
}
|
||||||
return o.newTransportFromToken(token), nil
|
return hc.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransport returns a Transport.
|
// Client returns an HTTP client using the provided token.
|
||||||
func (o *Options) NewTransport() *Transport {
|
// The token will auto-refresh as necessary. The underlying
|
||||||
return o.newTransportFromToken(nil)
|
// HTTP transport will be obtained using the provided context.
|
||||||
|
// The returned client and its Transport should not be modified.
|
||||||
|
func (c *Config) Client(ctx Context, t *Token) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &Transport{
|
||||||
|
Source: c.TokenSource(ctx, t),
|
||||||
|
Base: contextTransport(ctx),
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTransportFromToken returns a new Transport that is authorized
|
// TokenSource returns a TokenSource that returns t until t expires,
|
||||||
// and authenticated with the provided token.
|
// automatically refreshing it as necessary using the provided context.
|
||||||
func (o *Options) newTransportFromToken(t *Token) *Transport {
|
// See the the Context documentation.
|
||||||
// TODO(jbd): App Engine options initiate an http.Client that
|
//
|
||||||
// depends on the urlfetcher, but it breaks the promise we made
|
// Most users will use Config.Client instead.
|
||||||
// that the options object should be working finely with nil-values
|
func (c *Config) TokenSource(ctx Context, t *Token) TokenSource {
|
||||||
// for the http.Client.
|
nwn := &newWhenNeededSource{t: t}
|
||||||
tr := http.DefaultTransport
|
nwn.new = tokenRefresher{
|
||||||
if o.Client != nil && o.Client.Transport != nil {
|
ctx: ctx,
|
||||||
tr = o.Client.Transport
|
conf: c,
|
||||||
|
oldToken: &nwn.t,
|
||||||
}
|
}
|
||||||
return newTransport(tr, o, t)
|
return nwn
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
|
||||||
return func(t *Token) (*Token, error) {
|
// HTTP requests to renew a token using a RefreshToken.
|
||||||
if t == nil || t.RefreshToken == "" {
|
type tokenRefresher struct {
|
||||||
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
|
ctx Context // used to get HTTP requests
|
||||||
|
conf *Config
|
||||||
|
oldToken **Token // pointer to old *Token w/ RefreshToken
|
||||||
}
|
}
|
||||||
return retrieveToken(o, url.Values{
|
|
||||||
|
func (tf tokenRefresher) Token() (*Token, error) {
|
||||||
|
t := *tf.oldToken
|
||||||
|
if t == nil {
|
||||||
|
return nil, errors.New("oauth2: attempted use of nil Token")
|
||||||
|
}
|
||||||
|
if t.RefreshToken == "" {
|
||||||
|
return nil, errors.New("oauth2: token expired and refresh token is not set")
|
||||||
|
}
|
||||||
|
return retrieveToken(tf.ctx, tf.conf, url.Values{
|
||||||
"grant_type": {"refresh_token"},
|
"grant_type": {"refresh_token"},
|
||||||
"refresh_token": {t.RefreshToken},
|
"refresh_token": {t.RefreshToken},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Options represents an object to keep the state of the OAuth 2.0 flow.
|
// newWhenNeededSource is a TokenSource that holds a single token in memory
|
||||||
type Options struct {
|
// and validates its expiry before each call to retrieve it with
|
||||||
// ClientID is the OAuth client identifier used when communicating with
|
// Token. If it's expired, it will be auto-refreshed using the
|
||||||
// the configured OAuth provider.
|
// new TokenSource.
|
||||||
ClientID string
|
|
||||||
|
|
||||||
// ClientSecret is the OAuth client secret used when communicating with
|
|
||||||
// the configured OAuth provider.
|
|
||||||
ClientSecret string
|
|
||||||
|
|
||||||
// RedirectURL is the URL to which the user will be returned after
|
|
||||||
// granting (or denying) access.
|
|
||||||
RedirectURL string
|
|
||||||
|
|
||||||
// Email is the OAuth client identifier used when communicating with
|
|
||||||
// the configured OAuth provider.
|
|
||||||
Email string
|
|
||||||
|
|
||||||
// PrivateKey contains the contents of an RSA private key or the
|
|
||||||
// contents of a PEM file that contains a private key. The provided
|
|
||||||
// private key is used to sign JWT payloads.
|
|
||||||
// PEM containers with a passphrase are not supported.
|
|
||||||
// Use the following command to convert a PKCS 12 file into a PEM.
|
|
||||||
//
|
//
|
||||||
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
|
// The first call to TokenRefresher must be SetToken.
|
||||||
//
|
type newWhenNeededSource struct {
|
||||||
PrivateKey *rsa.PrivateKey
|
new TokenSource // called when t is expired.
|
||||||
|
|
||||||
// Scopes identify the level of access being requested.
|
mu sync.Mutex // guards t
|
||||||
Subject string
|
t *Token
|
||||||
|
|
||||||
// Scopes optionally specifies a list of requested permission scopes.
|
|
||||||
Scopes []string
|
|
||||||
|
|
||||||
// AuthURL represents the authorization endpoint of the OAuth 2.0 provider.
|
|
||||||
AuthURL *url.URL
|
|
||||||
|
|
||||||
// TokenURL represents the token endpoint of the OAuth 2.0 provider.
|
|
||||||
TokenURL *url.URL
|
|
||||||
|
|
||||||
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
|
|
||||||
AUD *url.URL
|
|
||||||
|
|
||||||
// TokenStore reads a token from the store and writes it back to the store
|
|
||||||
// if a token refresh occurs.
|
|
||||||
// Optional.
|
|
||||||
TokenStore TokenStore
|
|
||||||
|
|
||||||
TokenFetcherFunc func(t *Token) (*Token, error)
|
|
||||||
|
|
||||||
Client *http.Client
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func retrieveToken(o *Options, v url.Values) (*Token, error) {
|
// Token returns the current token if it's still valid, else will
|
||||||
v.Set("client_id", o.ClientID)
|
// refresh the current token (using r.Context for HTTP client
|
||||||
bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String())
|
// information) and return the new one.
|
||||||
if bustedAuth && o.ClientSecret != "" {
|
func (s *newWhenNeededSource) Token() (*Token, error) {
|
||||||
v.Set("client_secret", o.ClientSecret)
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.t != nil && !s.t.Expired() {
|
||||||
|
return s.t, nil
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode()))
|
t, err := s.new.Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.t = t
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
|
||||||
|
hc, err := contextClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
v.Set("client_id", c.ClientID)
|
||||||
|
bustedAuth := !providerAuthHeaderWorks(c.Endpoint.TokenURL)
|
||||||
|
if bustedAuth && c.ClientSecret != "" {
|
||||||
|
v.Set("client_secret", c.ClientSecret)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", c.Endpoint.TokenURL, strings.NewReader(v.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
if !bustedAuth && o.ClientSecret != "" {
|
if !bustedAuth && c.ClientSecret != "" {
|
||||||
req.SetBasicAuth(o.ClientID, o.ClientSecret)
|
req.SetBasicAuth(c.ClientID, c.ClientSecret)
|
||||||
}
|
}
|
||||||
c := o.Client
|
r, err := hc.Do(req)
|
||||||
if c == nil {
|
|
||||||
c = &http.Client{}
|
|
||||||
}
|
|
||||||
r, err := c.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Body.Close()
|
defer r.Body.Close()
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -314,7 +377,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
token := &Token{}
|
token := &Token{}
|
||||||
expires := int(0)
|
expires := 0
|
||||||
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||||
switch content {
|
switch content {
|
||||||
case "application/x-www-form-urlencoded", "text/plain":
|
case "application/x-www-form-urlencoded", "text/plain":
|
||||||
|
@ -335,7 +398,7 @@ func retrieveToken(o *Options, v url.Values) (*Token, error) {
|
||||||
}
|
}
|
||||||
expires, _ = strconv.Atoi(e)
|
expires, _ = strconv.Atoi(e)
|
||||||
default:
|
default:
|
||||||
b := make(map[string]interface{})
|
b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type
|
||||||
if err = json.Unmarshal(body, &b); err != nil {
|
if err = json.Unmarshal(body, &b); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -397,3 +460,12 @@ func providerAuthHeaderWorks(tokenURL string) bool {
|
||||||
// to this package to let users select server bugs.
|
// to this package to let users select server bugs.
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HTTPClient is the context key to use with golang.org/x/net/context's
|
||||||
|
// WithValue function to associate an *http.Client value with a context.
|
||||||
|
var HTTPClient contextKey
|
||||||
|
|
||||||
|
// contextKey is just an empty struct. It exists so HTTPClient can be
|
||||||
|
// an immutable public variable with a unique type. It's immutable
|
||||||
|
// because nobody else can create a contextKey, being unexported.
|
||||||
|
type contextKey struct{}
|
||||||
|
|
200
transport.go
200
transport.go
|
@ -5,136 +5,90 @@
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
|
||||||
defaultTokenType = "Bearer"
|
// wrapping a base RoundTripper and adding an Authorization header
|
||||||
)
|
// with a token from the supplied Sources.
|
||||||
|
|
||||||
// Token represents the crendentials used to authorize
|
|
||||||
// the requests to access protected resources on the OAuth 2.0
|
|
||||||
// provider's backend.
|
|
||||||
type Token struct {
|
|
||||||
// AccessToken is the token that authorizes and authenticates the requests.
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
|
|
||||||
// TokenType identifies the type of token returned.
|
|
||||||
TokenType string `json:"token_type,omitempty"`
|
|
||||||
|
|
||||||
// RefreshToken is a token that may be used to obtain a new access token.
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
|
|
||||||
// Expiry is the expiration datetime of the access token.
|
|
||||||
Expiry time.Time `json:"expiry,omitempty"`
|
|
||||||
|
|
||||||
// raw optionally contains extra metadata from the server
|
|
||||||
// when updating a token.
|
|
||||||
raw interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extra returns an extra field returned from the server during token retrieval.
|
|
||||||
// E.g.
|
|
||||||
// idToken := token.Extra("id_token")
|
|
||||||
//
|
//
|
||||||
func (t *Token) Extra(key string) string {
|
// Transport is a low-level mechanism. Most code will use the
|
||||||
if vals, ok := t.raw.(url.Values); ok {
|
// higher-level Config.Client method instead.
|
||||||
return vals.Get(key)
|
|
||||||
}
|
|
||||||
if raw, ok := t.raw.(map[string]interface{}); ok {
|
|
||||||
if val, ok := raw[key].(string); ok {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expired returns true if there is no access token or the
|
|
||||||
// access token is expired.
|
|
||||||
func (t *Token) Expired() bool {
|
|
||||||
if t.AccessToken == "" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if t.Expiry.IsZero() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return t.Expiry.Before(time.Now())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
|
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
opts *Options
|
// Source supplies the token to add to outgoing requests'
|
||||||
base http.RoundTripper
|
// Authorization headers.
|
||||||
|
Source TokenSource
|
||||||
|
|
||||||
mu sync.RWMutex
|
// Base is the base RoundTripper used to make HTTP requests.
|
||||||
token *Token
|
// If nil, http.DefaultTransport is used.
|
||||||
}
|
Base http.RoundTripper
|
||||||
|
|
||||||
// NewTransport creates a new Transport that uses the provided
|
mu sync.Mutex // guards modReq
|
||||||
// token fetcher as token retrieving strategy. It authenticates
|
modReq map[*http.Request]*http.Request // original -> modified
|
||||||
// the requests and delegates origTransport to make the actual requests.
|
|
||||||
func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
|
|
||||||
return &Transport{
|
|
||||||
base: base,
|
|
||||||
opts: opts,
|
|
||||||
token: token,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip authorizes and authenticates the request with an
|
// RoundTrip authorizes and authenticates the request with an
|
||||||
// access token. If no token exists or token is expired,
|
// access token. If no token exists or token is expired,
|
||||||
// tries to refresh/fetch a new token.
|
// tries to refresh/fetch a new token.
|
||||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
token := t.token
|
if t.Source == nil {
|
||||||
|
return nil, errors.New("oauth2: Transport's Source is nil")
|
||||||
if token == nil || token.Expired() {
|
}
|
||||||
// Check if the token is refreshable.
|
token, err := t.Source.Token()
|
||||||
// If token is refreshable, don't return an error,
|
if err != nil {
|
||||||
// rather refresh.
|
|
||||||
if err := t.refreshToken(); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
token = t.token
|
|
||||||
if t.opts.TokenStore != nil {
|
req2 := cloneRequest(req) // per RoundTripper contract
|
||||||
t.opts.TokenStore.WriteToken(token)
|
token.SetAuthHeader(req2)
|
||||||
|
t.setModReq(req, req2)
|
||||||
|
res, err := t.base().RoundTrip(req2)
|
||||||
|
if err != nil {
|
||||||
|
t.setModReq(req, nil)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res.Body = &onEOFReader{
|
||||||
|
rc: res.Body,
|
||||||
|
fn: func() { t.setModReq(req, nil) },
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelRequest cancels an in-flight request by closing its connection.
|
||||||
|
func (t *Transport) CancelRequest(req *http.Request) {
|
||||||
|
type canceler interface {
|
||||||
|
CancelRequest(*http.Request)
|
||||||
|
}
|
||||||
|
if cr, ok := t.Base.(canceler); ok {
|
||||||
|
t.mu.Lock()
|
||||||
|
modReq := t.modReq[req]
|
||||||
|
delete(t.modReq, req)
|
||||||
|
t.mu.Unlock()
|
||||||
|
cr.CancelRequest(modReq)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// To set the Authorization header, we must make a copy of the Request
|
func (t *Transport) base() http.RoundTripper {
|
||||||
// so that we don't modify the Request we were given.
|
if t.Base != nil {
|
||||||
// This is required by the specification of http.RoundTripper.
|
return t.Base
|
||||||
req = cloneRequest(req)
|
|
||||||
typ := token.TokenType
|
|
||||||
if typ == "" {
|
|
||||||
typ = defaultTokenType
|
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", typ+" "+token.AccessToken)
|
return http.DefaultTransport
|
||||||
return t.base.RoundTrip(req)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token returns the token that authorizes and
|
func (t *Transport) setModReq(orig, mod *http.Request) {
|
||||||
// authenticates the transport.
|
|
||||||
func (t *Transport) Token() *Token {
|
|
||||||
t.mu.RLock()
|
|
||||||
defer t.mu.RUnlock()
|
|
||||||
return t.token
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshToken retrieves a new token, if a refreshing/fetching
|
|
||||||
// method is known and required credentials are presented
|
|
||||||
// (such as a refresh token).
|
|
||||||
func (t *Transport) refreshToken() error {
|
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
token, err := t.opts.TokenFetcherFunc(t.token)
|
if t.modReq == nil {
|
||||||
if err != nil {
|
t.modReq = make(map[*http.Request]*http.Request)
|
||||||
return err
|
}
|
||||||
|
if mod == nil {
|
||||||
|
delete(t.modReq, orig)
|
||||||
|
} else {
|
||||||
|
t.modReq[orig] = mod
|
||||||
}
|
}
|
||||||
t.token = token
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cloneRequest returns a clone of the provided *http.Request.
|
// cloneRequest returns a clone of the provided *http.Request.
|
||||||
|
@ -144,9 +98,41 @@ func cloneRequest(r *http.Request) *http.Request {
|
||||||
r2 := new(http.Request)
|
r2 := new(http.Request)
|
||||||
*r2 = *r
|
*r2 = *r
|
||||||
// deep copy of the Header
|
// deep copy of the Header
|
||||||
r2.Header = make(http.Header)
|
r2.Header = make(http.Header, len(r.Header))
|
||||||
for k, s := range r.Header {
|
for k, s := range r.Header {
|
||||||
r2.Header[k] = s
|
r2.Header[k] = append([]string(nil), s...)
|
||||||
}
|
}
|
||||||
return r2
|
return r2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type onEOFReader struct {
|
||||||
|
rc io.ReadCloser
|
||||||
|
fn func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *onEOFReader) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = r.rc.Read(p)
|
||||||
|
if err == io.EOF {
|
||||||
|
r.runFunc()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *onEOFReader) Close() error {
|
||||||
|
err := r.rc.Close()
|
||||||
|
r.runFunc()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *onEOFReader) runFunc() {
|
||||||
|
if fn := r.fn; fn != nil {
|
||||||
|
fn()
|
||||||
|
r.fn = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorTransport struct{ err error }
|
||||||
|
|
||||||
|
func (t errorTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||||
|
return nil, t.err
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue