oauth2: use a JSON struct types instead of empty interface maps

Change-Id: Ifd66ea35c15dbd14acca0c945b533ec755de12e4
Reviewed-on: https://go-review.googlesource.com/1872
Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
Brad Fitzpatrick 2014-12-19 16:20:30 +11:00 committed by Brad Fitzpatrick
parent 9abe144dd5
commit f5b40b26f1
4 changed files with 68 additions and 48 deletions

32
jwt.go
View File

@ -123,25 +123,33 @@ func (js jwtSource) Token() (*Token, error) {
if c := resp.StatusCode; c < 200 || c > 299 { if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
} }
b := make(map[string]interface{}) // tokenRes is the JSON response body.
if err := json.Unmarshal(body, &b); err != nil { var tokenRes struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
}
if err := json.Unmarshal(body, &tokenRes); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
token := &Token{} token := &Token{
token.AccessToken, _ = b["access_token"].(string) AccessToken: tokenRes.AccessToken,
token.TokenType, _ = b["token_type"].(string) TokenType: tokenRes.TokenType,
token.raw = b raw: make(map[string]interface{}),
if e, ok := b["expires_in"].(float64); ok {
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
} }
if idtoken, ok := b["id_token"].(string); ok { json.Unmarshal(body, &token.raw) // no error checks for optional fields
if secs := tokenRes.ExpiresIn; secs > 0 {
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
}
if v := tokenRes.IDToken; v != "" {
// decode returned id token to get expiry // decode returned id token to get expiry
claimSet, err := jws.Decode(idtoken) claimSet, err := jws.Decode(v)
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
} }
token.Expiry = time.Unix(claimSet.Exp, 0) token.Expiry = time.Unix(claimSet.Exp, 0)
return token, nil
} }
return token, nil return token, nil
} }

View File

@ -117,10 +117,10 @@ func TestJWTFetch_BadResponseType(t *testing.T) {
TokenURL: ts.URL, TokenURL: ts.URL,
} }
tok, err := conf.TokenSource(NoContext, nil).Token() tok, err := conf.TokenSource(NoContext, nil).Token()
if err != nil { if err == nil {
t.Fatal(err) t.Error("got a token; expected error")
}
if tok.AccessToken != "" { if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken) t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
} }
}
} }

View File

@ -300,8 +300,7 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
} }
token := &Token{} var token *Token
expires := 0
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
switch content { switch content {
case "application/x-www-form-urlencoded", "text/plain": case "application/x-www-form-urlencoded", "text/plain":
@ -309,10 +308,12 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
token.AccessToken = vals.Get("access_token") token = &Token{
token.TokenType = vals.Get("token_type") AccessToken: vals.Get("access_token"),
token.RefreshToken = vals.Get("refresh_token") TokenType: vals.Get("token_type"),
token.raw = vals RefreshToken: vals.Get("refresh_token"),
raw: vals,
}
e := vals.Get("expires_in") e := vals.Get("expires_in")
if e == "" { if e == "" {
// TODO(jbd): Facebook's OAuth2 implementation is broken and // TODO(jbd): Facebook's OAuth2 implementation is broken and
@ -320,38 +321,52 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
// when Facebook fixes their implementation. // when Facebook fixes their implementation.
e = vals.Get("expires") e = vals.Get("expires")
} }
expires, _ = strconv.Atoi(e) expires, _ := strconv.Atoi(e)
if expires != 0 {
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
}
default: default:
b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type var tj tokenJSON
if err = json.Unmarshal(body, &b); err != nil { if err = json.Unmarshal(body, &tj); err != nil {
return nil, err return nil, err
} }
token.AccessToken, _ = b["access_token"].(string) token = &Token{
token.TokenType, _ = b["token_type"].(string) AccessToken: tj.AccessToken,
token.RefreshToken, _ = b["refresh_token"].(string) TokenType: tj.TokenType,
token.raw = b RefreshToken: tj.RefreshToken,
e, ok := b["expires_in"].(float64) Expiry: tj.expiry(),
if !ok { raw: make(map[string]interface{}),
// TODO(jbd): Facebook's OAuth2 implementation is broken and
// returns expires_in field in expires. Remove the fallback to expires,
// when Facebook fixes their implementation.
e, _ = b["expires"].(float64)
} }
expires = int(e) json.Unmarshal(body, &token.raw) // no error checks for optional fields
} }
// Don't overwrite `RefreshToken` with an empty value // Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request. // if this was a token refreshing request.
if token.RefreshToken == "" { if token.RefreshToken == "" {
token.RefreshToken = v.Get("refresh_token") token.RefreshToken = v.Get("refresh_token")
} }
if expires == 0 {
token.Expiry = time.Time{}
} else {
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
}
return token, nil return token, nil
} }
// tokenJSON is the struct representing the HTTP response from OAuth2
// providers returning a token in JSON form.
type tokenJSON struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int32 `json:"expires_in"`
Expires int32 `json:"expires"` // broken Facebook spelling of expires_in
}
func (e *tokenJSON) expiry() (t time.Time) {
if v := e.ExpiresIn; v != 0 {
return time.Now().Add(time.Duration(v) * time.Second)
}
if v := e.Expires; v != 0 {
return time.Now().Add(time.Duration(v) * time.Second)
}
return
}
func condVal(v string) []string { func condVal(v string) []string {
if v == "" { if v == "" {
return nil return nil

View File

@ -181,12 +181,9 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
conf := newConf(ts.URL) conf := newConf(ts.URL)
tok, err := conf.Exchange(NoContext, "exchange-code") _, err := conf.Exchange(NoContext, "exchange-code")
if err != nil { if err == nil {
t.Error(err) t.Error("expected error from invalid access_token type")
}
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
} }
} }