diff --git a/google/appengine.go b/google/appengine.go index ea18871..31e89f5 100644 --- a/google/appengine.go +++ b/google/appengine.go @@ -7,6 +7,8 @@ package google import ( + "net/http" + "github.com/golang/oauth2" "appengine" @@ -16,10 +18,10 @@ import ( // AppEngineConfig represents a configuration for an // App Engine application's Google service account. type AppEngineConfig struct { - // Transport is the transport to be used + // Transport is the http.RoundTripper to be used // to construct new oauth2.Transport instances from // this configuration. - Transport *urlfetch.Transport + Transport http.RoundTripper context appengine.Context scopes []string @@ -29,11 +31,6 @@ type AppEngineConfig struct { // provided auth scopes. func NewAppEngineConfig(context appengine.Context, scopes []string) *AppEngineConfig { return &AppEngineConfig{ - Transport: &urlfetch.Transport{ - Context: context, - Deadline: 0, - AllowInvalidServerCertificate: false, - }, context: context, scopes: scopes, } @@ -42,7 +39,7 @@ func NewAppEngineConfig(context appengine.Context, scopes []string) *AppEngineCo // NewTransport returns a transport that authorizes // the requests with the application's service account. func (c *AppEngineConfig) NewTransport() *oauth2.Transport { - return oauth2.NewTransport(c.Transport, c, nil) + return oauth2.NewTransport(c.transport(), c, nil) } // FetchToken fetches a new access token for the provided scopes. @@ -56,3 +53,10 @@ func (c *AppEngineConfig) FetchToken(existing *oauth2.Token) (*oauth2.Token, err Expiry: expiry, }, nil } + +func (c *AppEngineConfig) transport() http.RoundTripper { + if c.Transport != nil { + return c.Transport + } + return &urlfetch.Transport{Context: c.context} +} diff --git a/google/appenginevm.go b/google/appenginevm.go index 32c017a..d9d27ce 100644 --- a/google/appenginevm.go +++ b/google/appenginevm.go @@ -16,7 +16,7 @@ import ( // AppEngineConfig represents a configuration for an // App Engine application's Google service account. type AppEngineConfig struct { - // Transport is the round tripper to be used + // Transport is the http.RoundTripper to be used // to construct new oauth2.Transport instances from // this configuration. Transport http.RoundTripper @@ -29,16 +29,15 @@ type AppEngineConfig struct { // provided auth scopes. func NewAppEngineConfig(context appengine.Context, scopes []string) *AppEngineConfig { return &AppEngineConfig{ - Transport: http.DefaultTransport, - context: context, - scopes: scopes, + context: context, + scopes: scopes, } } // NewTransport returns a transport that authorizes // the requests with the application's service account. func (c *AppEngineConfig) NewTransport() *oauth2.Transport { - return oauth2.NewTransport(c.Transport, c, nil) + return oauth2.NewTransport(c.transport(), c, nil) } // FetchToken fetches a new access token for the provided scopes. @@ -52,3 +51,10 @@ func (c *AppEngineConfig) FetchToken(existing *oauth2.Token) (*oauth2.Token, err Expiry: expiry, }, nil } + +func (c *AppEngineConfig) transport() http.RoundTripper { + if c.Transport != nil { + return c.Transport + } + return http.DefaultTransport +} diff --git a/google/google.go b/google/google.go index e170415..b8031de 100644 --- a/google/google.go +++ b/google/google.go @@ -65,16 +65,12 @@ func NewServiceAccountConfig(opts *oauth2.JWTOptions) (*oauth2.JWTConfig, error) // from Google Compute Engine instance's metaserver. If no account is // provided, default is used. func NewComputeEngineConfig(account string) *ComputeEngineConfig { - return &ComputeEngineConfig{ - Client: http.DefaultClient, - Transport: http.DefaultTransport, - account: account, - } + return &ComputeEngineConfig{account: account} } // NewTransport creates an authorized transport. func (c *ComputeEngineConfig) NewTransport() *oauth2.Transport { - return oauth2.NewTransport(c.Transport, c, nil) + return oauth2.NewTransport(c.transport(), c, nil) } // FetchToken retrieves a new access token via metadata server. @@ -89,7 +85,7 @@ func (c *ComputeEngineConfig) FetchToken(existing *oauth2.Token) (token *oauth2. return } req.Header.Add("X-Google-Metadata-Request", "True") - resp, err := c.Client.Do(req) + resp, err := c.client().Do(req) if err != nil { return } @@ -106,3 +102,17 @@ func (c *ComputeEngineConfig) FetchToken(existing *oauth2.Token) (token *oauth2. } return } + +func (c *ComputeEngineConfig) transport() http.RoundTripper { + if c.Transport != nil { + return c.Transport + } + return http.DefaultTransport +} + +func (c *ComputeEngineConfig) client() *http.Client { + if c.Client != nil { + return c.Client + } + return http.DefaultClient +} diff --git a/jwt.go b/jwt.go index f28ff39..78d9553 100644 --- a/jwt.go +++ b/jwt.go @@ -58,11 +58,9 @@ func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) { return nil, err } return &JWTConfig{ - Client: http.DefaultClient, - Transport: http.DefaultTransport, - opts: opts, - aud: audURL, - key: parsedKey, + opts: opts, + aud: audURL, + key: parsedKey, }, nil } @@ -73,7 +71,7 @@ type JWTConfig struct { // tokens from the OAuth 2.0 provider. Client *http.Client - // Transport is the round tripper to be used + // Transport is the http.RoundTripper to be used // to construct new oauth2.Transport instances from // this configuration. Transport http.RoundTripper @@ -86,13 +84,13 @@ type JWTConfig struct { // NewTransport creates a transport that is authorize with the // parent JWT configuration. func (c *JWTConfig) NewTransport() *Transport { - return NewTransport(c.Transport, c, &Token{}) + return NewTransport(c.transport(), c, &Token{}) } // NewTransportWithUser creates a transport that is authorized by // the client and impersonates the specified user. func (c *JWTConfig) NewTransportWithUser(user string) *Transport { - return NewTransport(c.Transport, c, &Token{Subject: user}) + return NewTransport(c.transport(), c, &Token{Subject: user}) } // fetchToken retrieves a new access token and updates the existing token @@ -124,7 +122,7 @@ func (c *JWTConfig) FetchToken(existing *Token) (*Token, error) { v.Set("assertion", payload) // Make a request with assertion to get a new token. - resp, err := c.Client.PostForm(c.aud.String(), v) + resp, err := c.client().PostForm(c.aud.String(), v) if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } @@ -163,6 +161,20 @@ func (c *JWTConfig) FetchToken(existing *Token) (*Token, error) { return token, nil } +func (c *JWTConfig) transport() http.RoundTripper { + if c.Transport != nil { + return c.Transport + } + return http.DefaultTransport +} + +func (c *JWTConfig) client() *http.Client { + if c.Client != nil { + return c.Client + } + return http.DefaultClient +} + // parseKey converts the binary contents of a private key file // to an *rsa.PrivateKey. It detects whether the private key is in a // PEM container or not. If so, it extracts the the private key diff --git a/oauth2.go b/oauth2.go index e41ed59..f616780 100644 --- a/oauth2.go +++ b/oauth2.go @@ -115,7 +115,7 @@ type Config struct { // tokens from the OAuth 2.0 provider. Client *http.Client - // Transport is the round tripper to be used + // Transport is the http.RoundTripper to be used // to construct new oauth2.Transport instances from // this configuration. Transport http.RoundTripper @@ -161,7 +161,7 @@ func (c *Config) AuthCodeURL(state string) (authURL string) { // you need to set a valid token (or an expired token with a valid // refresh token) in order to be able to do authorized requests. func (c *Config) NewTransport() *Transport { - return NewTransport(c.Transport, c, nil) + return NewTransport(c.transport(), c, nil) } // NewTransportWithCode exchanges the OAuth 2.0 authorization code with @@ -173,7 +173,7 @@ func (c *Config) NewTransportWithCode(code string) (*Transport, error) { if err != nil { return nil, err } - return NewTransport(c.Transport, c, token), nil + return NewTransport(c.transport(), c, token), nil } // FetchToken retrieves a new access token and updates the existing token @@ -226,7 +226,7 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) { // Dropbox accepts either, but not both. // The spec requires servers to always support the Authorization header, // so that's all we use. - r, err := c.Client.PostForm(c.tokenURL.String(), v) + r, err := c.client().PostForm(c.tokenURL.String(), v) if err != nil { return nil, err } @@ -281,3 +281,17 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) { } return token, nil } + +func (c *Config) transport() http.RoundTripper { + if c.Transport != nil { + return c.Transport + } + return http.DefaultTransport +} + +func (c *Config) client() *http.Client { + if c.Client != nil { + return c.Client + } + return http.DefaultClient +}