forked from Mirrors/oauth2
TokenFetcher should not mutate existing tokens.
This commit is contained in:
parent
b2a1756f98
commit
b314823c0b
49
oauth2.go
49
oauth2.go
|
@ -33,7 +33,7 @@ type tokenRespBody struct {
|
||||||
type TokenFetcher interface {
|
type TokenFetcher interface {
|
||||||
// FetchToken retrieves a new access token for the provider.
|
// FetchToken retrieves a new access token for the provider.
|
||||||
// If the implementation doesn't know how to retrieve a new token,
|
// If the implementation doesn't know how to retrieve a new token,
|
||||||
// it returns an error.
|
// it returns an error. Existing token could be nil.
|
||||||
FetchToken(existing *Token) (*Token, error)
|
FetchToken(existing *Token) (*Token, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -173,12 +173,11 @@ func (c *Config) FetchToken(existing *Token) (*Token, error) {
|
||||||
if existing == nil || existing.RefreshToken == "" {
|
if existing == nil || existing.RefreshToken == "" {
|
||||||
return nil, errors.New("cannot fetch access token without refresh token.")
|
return nil, errors.New("cannot fetch access token without refresh token.")
|
||||||
}
|
}
|
||||||
err := c.updateToken(existing, url.Values{
|
return c.retrieveToken(url.Values{
|
||||||
"grant_type": {"refresh_token"},
|
"grant_type": {"refresh_token"},
|
||||||
"client_secret": {c.opts.ClientSecret},
|
"client_secret": {c.opts.ClientSecret},
|
||||||
"refresh_token": {existing.RefreshToken},
|
"refresh_token": {existing.RefreshToken},
|
||||||
})
|
})
|
||||||
return existing, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks if all required configuration fields have non-zero values.
|
// Checks if all required configuration fields have non-zero values.
|
||||||
|
@ -195,7 +194,6 @@ func (c *Config) validate() error {
|
||||||
// Exchange exchanges the exchange code with the OAuth 2.0 provider
|
// Exchange exchanges the exchange code with the OAuth 2.0 provider
|
||||||
// to retrieve a new access token.
|
// to retrieve a new access token.
|
||||||
func (c *Config) Exchange(exchangeCode string) (*Token, error) {
|
func (c *Config) Exchange(exchangeCode string) (*Token, error) {
|
||||||
token := &Token{}
|
|
||||||
vals := url.Values{
|
vals := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_secret": {c.opts.ClientSecret},
|
"client_secret": {c.opts.ClientSecret},
|
||||||
|
@ -207,15 +205,10 @@ func (c *Config) Exchange(exchangeCode string) (*Token, error) {
|
||||||
if c.opts.RedirectURL != "" {
|
if c.opts.RedirectURL != "" {
|
||||||
vals.Set("redirect_uri", c.opts.RedirectURL)
|
vals.Set("redirect_uri", c.opts.RedirectURL)
|
||||||
}
|
}
|
||||||
err := c.updateToken(token, vals)
|
return c.retrieveToken(vals)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateToken mutates both tok and v.
|
func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
||||||
func (c *Config) updateToken(tok *Token, v url.Values) error {
|
|
||||||
v.Set("client_id", c.opts.ClientID)
|
v.Set("client_id", c.opts.ClientID)
|
||||||
// Note that we're not setting v's client_secret to t.ClientSecret, due
|
// Note that we're not setting v's client_secret to t.ClientSecret, due
|
||||||
// to https://code.google.com/p/goauth2/issues/detail?id=31
|
// to https://code.google.com/p/goauth2/issues/detail?id=31
|
||||||
|
@ -225,12 +218,12 @@ func (c *Config) updateToken(tok *Token, v url.Values) error {
|
||||||
// so that's all we use.
|
// so that's all we use.
|
||||||
r, err := http.DefaultClient.PostForm(c.tokenURL.String(), v)
|
r, err := http.DefaultClient.PostForm(c.tokenURL.String(), v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Body.Close()
|
defer r.Body.Close()
|
||||||
if r.StatusCode != 200 {
|
if r.StatusCode != 200 {
|
||||||
// TODO(jbd): Add status code or error message
|
// TODO(jbd): Add status code or error message
|
||||||
return errors.New("Error during updating token.")
|
return nil, errors.New("error during updating token")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := &tokenRespBody{}
|
resp := &tokenRespBody{}
|
||||||
|
@ -239,11 +232,11 @@ func (c *Config) updateToken(tok *Token, v url.Values) error {
|
||||||
case "application/x-www-form-urlencoded", "text/plain":
|
case "application/x-www-form-urlencoded", "text/plain":
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
vals, err := url.ParseQuery(string(body))
|
vals, err := url.ParseQuery(string(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp.AccessToken = vals.Get("access_token")
|
resp.AccessToken = vals.Get("access_token")
|
||||||
resp.TokenType = vals.Get("token_type")
|
resp.TokenType = vals.Get("token_type")
|
||||||
|
@ -252,25 +245,29 @@ func (c *Config) updateToken(tok *Token, v url.Values) error {
|
||||||
resp.IdToken = vals.Get("id_token")
|
resp.IdToken = vals.Get("id_token")
|
||||||
default:
|
default:
|
||||||
if err = json.NewDecoder(r.Body).Decode(&resp); err != nil {
|
if err = json.NewDecoder(r.Body).Decode(&resp); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tok.AccessToken = resp.AccessToken
|
token := &Token{
|
||||||
tok.TokenType = resp.TokenType
|
AccessToken: resp.AccessToken,
|
||||||
|
TokenType: resp.TokenType,
|
||||||
|
RefreshToken: resp.RefreshToken,
|
||||||
|
}
|
||||||
// Don't overwrite `RefreshToken` with an empty value
|
// Don't overwrite `RefreshToken` with an empty value
|
||||||
if resp.RefreshToken != "" {
|
// if this was a token refreshing request.
|
||||||
tok.RefreshToken = resp.RefreshToken
|
if resp.RefreshToken == "" {
|
||||||
|
token.RefreshToken = v.Get("refresh_token")
|
||||||
}
|
}
|
||||||
if resp.ExpiresIn == 0 {
|
if resp.ExpiresIn == 0 {
|
||||||
tok.Expiry = time.Time{}
|
token.Expiry = time.Time{}
|
||||||
} else {
|
} else {
|
||||||
tok.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second)
|
token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second)
|
||||||
}
|
}
|
||||||
if resp.IdToken != "" {
|
if resp.IdToken != "" {
|
||||||
if tok.Extra == nil {
|
if token.Extra == nil {
|
||||||
tok.Extra = make(map[string]string)
|
token.Extra = make(map[string]string)
|
||||||
}
|
}
|
||||||
tok.Extra["id_token"] = resp.IdToken
|
token.Extra["id_token"] = resp.IdToken
|
||||||
}
|
}
|
||||||
return nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue