TokenFetcher should not mutate existing tokens.

This commit is contained in:
Burcu Dogan 2014-08-13 16:47:57 -07:00
parent b2a1756f98
commit b314823c0b
1 changed files with 23 additions and 26 deletions

View File

@ -33,7 +33,7 @@ type tokenRespBody struct {
type TokenFetcher interface { type TokenFetcher interface {
// FetchToken retrieves a new access token for the provider. // FetchToken retrieves a new access token for the provider.
// If the implementation doesn't know how to retrieve a new token, // If the implementation doesn't know how to retrieve a new token,
// it returns an error. // it returns an error. Existing token could be nil.
FetchToken(existing *Token) (*Token, error) FetchToken(existing *Token) (*Token, error)
} }
@ -173,12 +173,11 @@ func (c *Config) FetchToken(existing *Token) (*Token, error) {
if existing == nil || existing.RefreshToken == "" { if existing == nil || existing.RefreshToken == "" {
return nil, errors.New("cannot fetch access token without refresh token.") return nil, errors.New("cannot fetch access token without refresh token.")
} }
err := c.updateToken(existing, url.Values{ return c.retrieveToken(url.Values{
"grant_type": {"refresh_token"}, "grant_type": {"refresh_token"},
"client_secret": {c.opts.ClientSecret}, "client_secret": {c.opts.ClientSecret},
"refresh_token": {existing.RefreshToken}, "refresh_token": {existing.RefreshToken},
}) })
return existing, err
} }
// Checks if all required configuration fields have non-zero values. // Checks if all required configuration fields have non-zero values.
@ -195,7 +194,6 @@ func (c *Config) validate() error {
// Exchange exchanges the exchange code with the OAuth 2.0 provider // Exchange exchanges the exchange code with the OAuth 2.0 provider
// to retrieve a new access token. // to retrieve a new access token.
func (c *Config) Exchange(exchangeCode string) (*Token, error) { func (c *Config) Exchange(exchangeCode string) (*Token, error) {
token := &Token{}
vals := url.Values{ vals := url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
"client_secret": {c.opts.ClientSecret}, "client_secret": {c.opts.ClientSecret},
@ -207,15 +205,10 @@ func (c *Config) Exchange(exchangeCode string) (*Token, error) {
if c.opts.RedirectURL != "" { if c.opts.RedirectURL != "" {
vals.Set("redirect_uri", c.opts.RedirectURL) vals.Set("redirect_uri", c.opts.RedirectURL)
} }
err := c.updateToken(token, vals) return c.retrieveToken(vals)
if err != nil {
return nil, err
}
return token, nil
} }
// updateToken mutates both tok and v. func (c *Config) retrieveToken(v url.Values) (*Token, error) {
func (c *Config) updateToken(tok *Token, v url.Values) error {
v.Set("client_id", c.opts.ClientID) v.Set("client_id", c.opts.ClientID)
// Note that we're not setting v's client_secret to t.ClientSecret, due // Note that we're not setting v's client_secret to t.ClientSecret, due
// to https://code.google.com/p/goauth2/issues/detail?id=31 // to https://code.google.com/p/goauth2/issues/detail?id=31
@ -225,12 +218,12 @@ func (c *Config) updateToken(tok *Token, v url.Values) error {
// so that's all we use. // so that's all we use.
r, err := http.DefaultClient.PostForm(c.tokenURL.String(), v) r, err := http.DefaultClient.PostForm(c.tokenURL.String(), v)
if err != nil { if err != nil {
return err return nil, err
} }
defer r.Body.Close() defer r.Body.Close()
if r.StatusCode != 200 { if r.StatusCode != 200 {
// TODO(jbd): Add status code or error message // TODO(jbd): Add status code or error message
return errors.New("Error during updating token.") return nil, errors.New("error during updating token")
} }
resp := &tokenRespBody{} resp := &tokenRespBody{}
@ -239,11 +232,11 @@ func (c *Config) updateToken(tok *Token, v url.Values) error {
case "application/x-www-form-urlencoded", "text/plain": case "application/x-www-form-urlencoded", "text/plain":
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
return err return nil, err
} }
vals, err := url.ParseQuery(string(body)) vals, err := url.ParseQuery(string(body))
if err != nil { if err != nil {
return err return nil, err
} }
resp.AccessToken = vals.Get("access_token") resp.AccessToken = vals.Get("access_token")
resp.TokenType = vals.Get("token_type") resp.TokenType = vals.Get("token_type")
@ -252,25 +245,29 @@ func (c *Config) updateToken(tok *Token, v url.Values) error {
resp.IdToken = vals.Get("id_token") resp.IdToken = vals.Get("id_token")
default: default:
if err = json.NewDecoder(r.Body).Decode(&resp); err != nil { if err = json.NewDecoder(r.Body).Decode(&resp); err != nil {
return err return nil, err
} }
} }
tok.AccessToken = resp.AccessToken token := &Token{
tok.TokenType = resp.TokenType AccessToken: resp.AccessToken,
TokenType: resp.TokenType,
RefreshToken: resp.RefreshToken,
}
// Don't overwrite `RefreshToken` with an empty value // Don't overwrite `RefreshToken` with an empty value
if resp.RefreshToken != "" { // if this was a token refreshing request.
tok.RefreshToken = resp.RefreshToken if resp.RefreshToken == "" {
token.RefreshToken = v.Get("refresh_token")
} }
if resp.ExpiresIn == 0 { if resp.ExpiresIn == 0 {
tok.Expiry = time.Time{} token.Expiry = time.Time{}
} else { } else {
tok.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second)
} }
if resp.IdToken != "" { if resp.IdToken != "" {
if tok.Extra == nil { if token.Extra == nil {
tok.Extra = make(map[string]string) token.Extra = make(map[string]string)
} }
tok.Extra["id_token"] = resp.IdToken token.Extra["id_token"] = resp.IdToken
} }
return nil return token, nil
} }