diff --git a/google/appengine.go b/google/appengine.go index 09f327c..433bcf8 100644 --- a/google/appengine.go +++ b/google/appengine.go @@ -27,16 +27,16 @@ func NewAppEngineConfig(context appengine.Context, scopes []string) *AppEngineCo // NewTransport returns a transport that authorizes // the requests with the application's service account. -func (c *AppEngineConfig) NewTransport() oauth2.Transport { +func (c *AppEngineConfig) NewTransport() *oauth2.Transport { if c.Transport != nil { - return oauth2.NewAuthorizedTransport(c.Transport, c, nil) + return oauth2.NewTransport(c.Transport, c, nil) } transport := &urlfetch.Transport{ Context: c.context, Deadline: 0, AllowInvalidServerCertificate: false, } - return oauth2.NewAuthorizedTransport(transport, c, nil) + return oauth2.NewTransport(transport, c, nil) } // FetchToken fetches a new access token for the provided scopes. diff --git a/google/appenginevm.go b/google/appenginevm.go index d2e0626..848185f 100644 --- a/google/appenginevm.go +++ b/google/appenginevm.go @@ -24,8 +24,8 @@ func NewAppEngineConfig(context appengine.Context, scopes []string) *AppEngineCo // NewTransport returns a transport that authorizes // the requests with the application's service account. -func (c *AppEngineConfig) NewTransport() oauth2.Transport { - return oauth2.NewAuthorizedTransport(http.DefaultTransport, c, nil) +func (c *AppEngineConfig) NewTransport() *oauth2.Transport { + return oauth2.NewTransport(http.DefaultTransport, c, nil) } // FetchToken fetches a new access token for the provided scopes. diff --git a/google/google.go b/google/google.go index f83d246..db95b92 100644 --- a/google/google.go +++ b/google/google.go @@ -60,8 +60,8 @@ func NewComputeEngineConfig(account string) *ComputeEngineConfig { } // NewTransport creates an authorized transport. -func (c *ComputeEngineConfig) NewTransport() oauth2.Transport { - return oauth2.NewAuthorizedTransport(http.DefaultTransport, c, nil) +func (c *ComputeEngineConfig) NewTransport() *oauth2.Transport { + return oauth2.NewTransport(http.DefaultTransport, c, nil) } // FetchToken retrieves a new access token via metadata server. diff --git a/jwt.go b/jwt.go index 82f41c7..f69f7cc 100644 --- a/jwt.go +++ b/jwt.go @@ -78,14 +78,14 @@ type JWTConfig struct { // NewTransport creates a transport that is authorize with the // parent JWT configuration. -func (c *JWTConfig) NewTransport() Transport { - return NewAuthorizedTransport(http.DefaultTransport, c, &Token{}) +func (c *JWTConfig) NewTransport() *Transport { + return NewTransport(http.DefaultTransport, c, &Token{}) } // NewTransportWithUser creates a transport that is authorized by // the client and impersonates the specified user. -func (c *JWTConfig) NewTransportWithUser(user string) Transport { - return NewAuthorizedTransport(http.DefaultTransport, c, &Token{Subject: user}) +func (c *JWTConfig) NewTransportWithUser(user string) *Transport { + return NewTransport(http.DefaultTransport, c, &Token{Subject: user}) } // fetchToken retrieves a new access token and updates the existing token diff --git a/oauth2.go b/oauth2.go index 31cae9a..7e179ed 100644 --- a/oauth2.go +++ b/oauth2.go @@ -33,7 +33,7 @@ type tokenRespBody struct { type TokenFetcher interface { // FetchToken retrieves a new access token for the provider. // If the implementation doesn't know how to retrieve a new token, - // it returns an error. Existing token could be nil. + // it returns an error. The existing token may be nil. FetchToken(existing *Token) (*Token, error) } @@ -150,20 +150,20 @@ func (c *Config) AuthCodeURL(state string) (authURL string) { // t, _ := c.NewTransport() // t.SetToken(validToken) // -func (c *Config) NewTransport() Transport { - return NewAuthorizedTransport(http.DefaultTransport, c, nil) +func (c *Config) NewTransport() *Transport { + return NewTransport(http.DefaultTransport, c, nil) } // NewTransportWithCode exchanges the OAuth 2.0 exchange code with // the provider to fetch a new access token (and refresh token). Once // it succesffully retrieves a new token, creates a new transport // authorized with it. -func (c *Config) NewTransportWithCode(exchangeCode string) (Transport, error) { +func (c *Config) NewTransportWithCode(exchangeCode string) (*Transport, error) { token, err := c.Exchange(exchangeCode) if err != nil { return nil, err } - return NewAuthorizedTransport(http.DefaultTransport, c, token), nil + return NewTransport(http.DefaultTransport, c, token), nil } // FetchToken retrieves a new access token and updates the existing token diff --git a/transport.go b/transport.go index 276d2fe..e1367e7 100644 --- a/transport.go +++ b/transport.go @@ -52,52 +52,26 @@ func (t *Token) Expired() bool { return t.Expiry.Before(time.Now()) } -// Transport represents an authorized transport. -// Provides currently in-use user token and allows to set a token to -// be used. If token expires, it tries to fetch a new token, -// if possible. Token fetching is thread-safe. If two or more -// concurrent requests are being made with the same expired token, -// one of the requests will wait for the other to refresh -// the existing token. -type Transport interface { - // Authenticates the request with the existing token. If token is - // expired, tries to refresh/fetch a new token. - // Makes the request by delegating it to the default transport. - RoundTrip(*http.Request) (*http.Response, error) - - // Returns the token authenticates the transport. - // This operation is thread-safe. - Token() *Token - - // Sets a new token to authenticate the transport. - // This operation is thread-safe. - SetToken(token *Token) - - // Refreshes the token if refresh is possible (such as in the - // presense of a refresh token). Returns an error if refresh is - // not possible. Refresh is thread-safe. - RefreshToken() error -} - -type authorizedTransport struct { +// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests. +type Transport struct { fetcher TokenFetcher - token *Token origTransport http.RoundTripper - // Mutex to protect token during auto refreshments. - mu sync.RWMutex + mu sync.RWMutex + token *Token } -// NewAuthorizedTransport creates a transport that uses the provided -// token fetcher to retrieve new tokens if there is no access token -// provided or it is expired. -func NewAuthorizedTransport(origTransport http.RoundTripper, fetcher TokenFetcher, token *Token) Transport { - return &authorizedTransport{origTransport: origTransport, fetcher: fetcher, token: token} +// NewTransport creates a new Transport that uses the provided +// token fetcher as token retrieving strategy. It authenticates +// the requests and delegates origTransport to make the actual requests. +func NewTransport(origTransport http.RoundTripper, fetcher TokenFetcher, token *Token) *Transport { + return &Transport{origTransport: origTransport, fetcher: fetcher, token: token} } -// RoundTrip authorizes the request with the existing token. -// If token is expired, tries to refresh/fetch a new token. -func (t *authorizedTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { +// RoundTrip authorizes and authenticates the request with an +// access token. If no token exists or token is expired, +// tries to refresh/fetch a new token. +func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { token := t.Token() if token == nil || token.Expired() { @@ -126,14 +100,15 @@ func (t *authorizedTransport) RoundTrip(req *http.Request) (resp *http.Response, return t.origTransport.RoundTrip(req) } -// Token returns the existing token that authorizes the Transport. -func (t *authorizedTransport) Token() *Token { +// Token returns the token that authorizes and +// authenticates the transport. +func (t *Transport) Token() *Token { t.mu.RLock() defer t.mu.RUnlock() if t.token == nil { return nil } - token := &Token{ + return &Token{ AccessToken: t.token.AccessToken, TokenType: t.token.TokenType, RefreshToken: t.token.RefreshToken, @@ -141,21 +116,20 @@ func (t *authorizedTransport) Token() *Token { Extra: t.token.Extra, Subject: t.token.Subject, } - return token } // SetToken sets a token to the transport in a thread-safe way. -func (t *authorizedTransport) SetToken(token *Token) { +func (t *Transport) SetToken(v *Token) { t.mu.Lock() defer t.mu.Unlock() - t.token = token + t.token = v } // RefreshToken retrieves a new token, if a refreshing/fetching // method is known and required credentials are presented // (such as a refresh token). -func (t *authorizedTransport) RefreshToken() error { +func (t *Transport) RefreshToken() error { t.mu.Lock() defer t.mu.Unlock() diff --git a/transport_test.go b/transport_test.go index e5bfe6f..7128d45 100644 --- a/transport_test.go +++ b/transport_test.go @@ -14,7 +14,7 @@ func (f *mockTokenFetcher) FetchToken(existing *Token) (*Token, error) { } func TestInitialTokenRead(t *testing.T) { - tr := NewAuthorizedTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"}) + tr := NewTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"}) server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer abc" { t.Errorf("Transport doesn't set the Authorization header from the initial token") @@ -31,7 +31,7 @@ func TestTokenFetch(t *testing.T) { AccessToken: "abc", }, } - tr := NewAuthorizedTransport(http.DefaultTransport, fetcher, nil) + tr := NewTransport(http.DefaultTransport, fetcher, nil) server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer abc" { t.Errorf("Transport doesn't set the Authorization header from the fetched token")