diff --git a/oauth2.go b/oauth2.go index 6a92d70..396f11b 100644 --- a/oauth2.go +++ b/oauth2.go @@ -229,35 +229,46 @@ func (c *Config) Client(ctx Context, t *Token) *http.Client { // // Most users will use Config.Client instead. func (c *Config) TokenSource(ctx Context, t *Token) TokenSource { - nwn := &reuseTokenSource{t: t} - nwn.new = tokenRefresher{ - ctx: ctx, - conf: c, - oldToken: &nwn.t, + tkr := &tokenRefresher{ + ctx: ctx, + conf: c, + } + if t != nil { + tkr.refreshToken = t.RefreshToken + } + return &reuseTokenSource{ + t: t, + new: tkr, } - return nwn } // tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" // HTTP requests to renew a token using a RefreshToken. type tokenRefresher struct { - ctx Context // used to get HTTP requests - conf *Config - oldToken **Token // pointer to old *Token w/ RefreshToken + ctx Context // used to get HTTP requests + conf *Config + refreshToken string } -func (tf tokenRefresher) Token() (*Token, error) { - t := *tf.oldToken - if t == nil { - return nil, errors.New("oauth2: attempted use of nil Token") - } - if t.RefreshToken == "" { +func (tf *tokenRefresher) Token() (*Token, error) { + if tf.refreshToken == "" { return nil, errors.New("oauth2: token expired and refresh token is not set") } - return retrieveToken(tf.ctx, tf.conf, url.Values{ + + tk, err := retrieveToken(tf.ctx, tf.conf, url.Values{ "grant_type": {"refresh_token"}, - "refresh_token": {t.RefreshToken}, + "refresh_token": {tf.refreshToken}, }) + + if err != nil { + return nil, err + } + if tk.RefreshToken != tf.refreshToken { + // possible race condition avoided because tokenRefresher + // should be protected by reuseTokenSource.mu + tf.refreshToken = tk.RefreshToken + } + return tk, err } // reuseTokenSource is a TokenSource that holds a single token in memory diff --git a/oauth2_test.go b/oauth2_test.go index f7b5ca6..edbe03a 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -305,3 +305,26 @@ func TestFetchWithNoRefreshToken(t *testing.T) { t.Errorf("Fetch should return an error if no refresh token is set") } } + +func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"ACCESS TOKEN", "scope": "user", "token_type": "bearer", "refresh_token": "NEW REFRESH TOKEN"}`)) + return + })) + defer ts.Close() + conf := newConf(ts.URL) + tkr := tokenRefresher{ + conf: conf, + ctx: NoContext, + refreshToken: "OLD REFRESH TOKEN", + } + tk, err := tkr.Token() + if err != nil { + t.Errorf("Unexpected refreshToken error returned: %v", err) + return + } + if tk.RefreshToken != tkr.refreshToken { + t.Errorf("tokenRefresher.refresh_token = %s; want %s", tkr.refreshToken, tk.RefreshToken) + } +}