diff --git a/google/appengine.go b/google/appengine.go index b0c425d..ebdde07 100644 --- a/google/appengine.go +++ b/google/appengine.go @@ -8,6 +8,7 @@ import ( "github.com/golang/oauth2" "appengine" + "appengine/urlfetch" ) // AppEngineConfig represents a configuration for an @@ -26,7 +27,12 @@ func NewAppEngineConfig(context appengine.Context, scopes []string) *AppEngineCo // NewTransport returns a transport that authorizes // the requests with the application's service account. func (c *AppEngineConfig) NewTransport() oauth2.Transport { - return oauth2.NewAuthorizedTransport(c, nil) + transport := &urlfetch.Transport{ + Context: c.context, + Deadline: 0, + AllowInvalidServerCertificate: false, + } + return oauth2.NewAuthorizedTransport(transport, c, nil) } // FetchToken fetches a new access token for the provided scopes. diff --git a/google/appenginevm.go b/google/appenginevm.go index 1885a47..96c65fe 100644 --- a/google/appenginevm.go +++ b/google/appenginevm.go @@ -3,6 +3,7 @@ package google import ( + "net/http" "strings" "github.com/golang/oauth2" @@ -25,7 +26,7 @@ func NewAppEngineConfig(context appengine.Context, scopes []string) *AppEngineCo // NewTransport returns a transport that authorizes // the requests with the application's service account. func (c *AppEngineConfig) NewTransport() oauth2.Transport { - return oauth2.NewAuthorizedTransport(c, nil) + return oauth2.NewAuthorizedTransport(http.DefaultTransport, c, nil) } // FetchToken fetches a new access token for the provided scopes. diff --git a/google/google.go b/google/google.go index fec36de..292c315 100644 --- a/google/google.go +++ b/google/google.go @@ -61,7 +61,7 @@ func NewComputeEngineConfig(account string) *ComputeEngineConfig { // NewTransport creates an authorized transport. func (c *ComputeEngineConfig) NewTransport() oauth2.Transport { - return oauth2.NewAuthorizedTransport(c, nil) + return oauth2.NewAuthorizedTransport(http.DefaultTransport, c, nil) } // FetchToken retrieves a new access token via metadata server. @@ -76,7 +76,7 @@ func (c *ComputeEngineConfig) FetchToken(existing *oauth2.Token) (token *oauth2. return } req.Header.Add("X-Google-Metadata-Request", "True") - resp, err := (&http.Client{Transport: oauth2.DefaultTransport}).Do(req) + resp, err := (&http.Client{}).Do(req) if err != nil { return } diff --git a/jwt.go b/jwt.go index 5bda21b..d35be52 100644 --- a/jwt.go +++ b/jwt.go @@ -61,13 +61,13 @@ type JWTConfig struct { // NewTransport creates a transport that is authorize with the // parent JWT configuration. func (c *JWTConfig) NewTransport() Transport { - return NewAuthorizedTransport(c, &Token{}) + return NewAuthorizedTransport(http.DefaultTransport, c, &Token{}) } // NewTransportWithUser creates a transport that is authorized by // the client and impersonates the specified user. func (c *JWTConfig) NewTransportWithUser(user string) Transport { - return NewAuthorizedTransport(c, &Token{Subject: user}) + return NewAuthorizedTransport(http.DefaultTransport, c, &Token{Subject: user}) } // fetchToken retrieves a new access token and updates the existing token @@ -100,7 +100,7 @@ func (c *JWTConfig) FetchToken(existing *Token) (token *Token, err error) { v.Set("assertion", payload) // Make a request with assertion to get a new token. - client := http.Client{Transport: DefaultTransport} + client := http.Client{} resp, err := client.PostForm(c.aud, v) if err != nil { return nil, err diff --git a/oauth2.go b/oauth2.go index 4642627..2d3d778 100644 --- a/oauth2.go +++ b/oauth2.go @@ -18,10 +18,6 @@ import ( "time" ) -// The default transport implementation to be used while -// making the authorized requests. -var DefaultTransport = http.DefaultTransport - type tokenRespBody struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` @@ -138,7 +134,7 @@ func (c *Config) AuthCodeURL(state string) (authURL string, err error) { // t.SetToken(validToken) // func (c *Config) NewTransport() Transport { - return NewAuthorizedTransport(c, nil) + return NewAuthorizedTransport(http.DefaultTransport, c, nil) } // NewTransportWithCode exchanges the OAuth 2.0 exchange code with @@ -150,7 +146,7 @@ func (c *Config) NewTransportWithCode(exchangeCode string) (Transport, error) { if err != nil { return nil, err } - return NewAuthorizedTransport(c, token), nil + return NewAuthorizedTransport(http.DefaultTransport, c, token), nil } // FetchToken retrieves a new access token and updates the existing token @@ -210,7 +206,7 @@ func (c *Config) exchange(exchangeCode string) (*Token, error) { func (c *Config) updateToken(tok *Token, v url.Values) error { v.Set("client_id", c.opts.ClientID) v.Set("client_secret", c.opts.ClientSecret) - r, err := (&http.Client{Transport: DefaultTransport}).PostForm(c.tokenURL, v) + r, err := (&http.Client{}).PostForm(c.tokenURL, v) if err != nil { return err } diff --git a/oauth2_test.go b/oauth2_test.go index e734a83..b09a813 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -35,8 +35,6 @@ func newTestConf() *Config { } func TestAuthCodeURL(t *testing.T) { - DefaultTransport = http.DefaultTransport - conf := newTestConf() url, err := conf.AuthCodeURL("foo") if err != nil { @@ -48,8 +46,13 @@ func TestAuthCodeURL(t *testing.T) { } func TestExchangePayload(t *testing.T) { + oldDefaultTransport := http.DefaultTransport + defer func() { + http.DefaultTransport = oldDefaultTransport + }() + conf := newTestConf() - DefaultTransport = &mockTransport{ + http.DefaultTransport = &mockTransport{ rt: func(req *http.Request) (resp *http.Response, err error) { headerContentType := req.Header.Get("Content-Type") if headerContentType != "application/x-www-form-urlencoded" { @@ -66,8 +69,13 @@ func TestExchangePayload(t *testing.T) { } func TestExchangingTransport(t *testing.T) { + oldDefaultTransport := http.DefaultTransport + defer func() { + http.DefaultTransport = oldDefaultTransport + }() + conf := newTestConf() - DefaultTransport = &mockTransport{ + http.DefaultTransport = &mockTransport{ rt: func(req *http.Request) (resp *http.Response, err error) { if req.URL.RequestURI() != "token-url" { t.Fatalf("NewTransportWithCode should have exchanged the code, but it didn't.") @@ -79,8 +87,6 @@ func TestExchangingTransport(t *testing.T) { } func TestFetchWithNoRedirect(t *testing.T) { - DefaultTransport = http.DefaultTransport - fetcher := newTestConf() _, err := fetcher.FetchToken(&Token{}) if err == nil { diff --git a/transport.go b/transport.go index 39726b0..cf72dab 100644 --- a/transport.go +++ b/transport.go @@ -76,8 +76,9 @@ type Transport interface { } type authorizedTransport struct { - fetcher TokenFetcher - token *Token + fetcher TokenFetcher + token *Token + origTransport http.RoundTripper // Mutex to protect token during auto refreshments. mu sync.RWMutex @@ -86,8 +87,8 @@ type authorizedTransport struct { // NewAuthorizedTransport creates a transport that uses the provided // token fetcher to retrieve new tokens if there is no access token // provided or it is expired. -func NewAuthorizedTransport(fetcher TokenFetcher, token *Token) Transport { - return &authorizedTransport{fetcher: fetcher, token: token} +func NewAuthorizedTransport(origTransport http.RoundTripper, fetcher TokenFetcher, token *Token) Transport { + return &authorizedTransport{origTransport: origTransport, fetcher: fetcher, token: token} } // RoundTrip authorizes the request with the existing token. @@ -118,7 +119,7 @@ func (t *authorizedTransport) RoundTrip(req *http.Request) (resp *http.Response, req.Header.Set("Authorization", typ+" "+token.AccessToken) // Make the HTTP request. - return DefaultTransport.RoundTrip(req) + return t.origTransport.RoundTrip(req) } // Token returns the existing token that authorizes the Transport. diff --git a/transport_test.go b/transport_test.go index 847266a..e5bfe6f 100644 --- a/transport_test.go +++ b/transport_test.go @@ -14,7 +14,7 @@ func (f *mockTokenFetcher) FetchToken(existing *Token) (*Token, error) { } func TestInitialTokenRead(t *testing.T) { - tr := NewAuthorizedTransport(nil, &Token{AccessToken: "abc"}) + tr := NewAuthorizedTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"}) server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer abc" { t.Errorf("Transport doesn't set the Authorization header from the initial token") @@ -31,7 +31,7 @@ func TestTokenFetch(t *testing.T) { AccessToken: "abc", }, } - tr := NewAuthorizedTransport(fetcher, nil) + tr := NewAuthorizedTransport(http.DefaultTransport, fetcher, nil) server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer abc" { t.Errorf("Transport doesn't set the Authorization header from the fetched token")