From f5b40b26f1d35712ba5784a1b4e4adae57230a4d Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 19 Dec 2014 16:20:30 +1100 Subject: [PATCH] oauth2: use a JSON struct types instead of empty interface maps Change-Id: Ifd66ea35c15dbd14acca0c945b533ec755de12e4 Reviewed-on: https://go-review.googlesource.com/1872 Reviewed-by: Burcu Dogan --- jwt.go | 32 +++++++++++++++---------- jwt_test.go | 10 ++++---- oauth2.go | 65 +++++++++++++++++++++++++++++++------------------- oauth2_test.go | 9 +++---- 4 files changed, 68 insertions(+), 48 deletions(-) diff --git a/jwt.go b/jwt.go index 7c02e3d..d7ec0d3 100644 --- a/jwt.go +++ b/jwt.go @@ -123,25 +123,33 @@ func (js jwtSource) Token() (*Token, error) { if c := resp.StatusCode; c < 200 || c > 299 { return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) } - b := make(map[string]interface{}) - if err := json.Unmarshal(body, &b); err != nil { + // tokenRes is the JSON response body. + var tokenRes struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + IDToken string `json:"id_token"` + ExpiresIn int64 `json:"expires_in"` // relative seconds from now + } + if err := json.Unmarshal(body, &tokenRes); err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } - token := &Token{} - token.AccessToken, _ = b["access_token"].(string) - token.TokenType, _ = b["token_type"].(string) - token.raw = b - if e, ok := b["expires_in"].(float64); ok { - token.Expiry = time.Now().Add(time.Duration(e) * time.Second) + token := &Token{ + AccessToken: tokenRes.AccessToken, + TokenType: tokenRes.TokenType, + raw: make(map[string]interface{}), } - if idtoken, ok := b["id_token"].(string); ok { + json.Unmarshal(body, &token.raw) // no error checks for optional fields + + if secs := tokenRes.ExpiresIn; secs > 0 { + token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) + } + if v := tokenRes.IDToken; v != "" { // decode returned id token to get expiry - claimSet, err := jws.Decode(idtoken) + claimSet, err := jws.Decode(v) if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err) } token.Expiry = time.Unix(claimSet.Exp, 0) - return token, nil } return token, nil } diff --git a/jwt_test.go b/jwt_test.go index 3cc5671..8c2e62e 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -117,10 +117,10 @@ func TestJWTFetch_BadResponseType(t *testing.T) { TokenURL: ts.URL, } tok, err := conf.TokenSource(NoContext, nil).Token() - if err != nil { - t.Fatal(err) - } - if tok.AccessToken != "" { - t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + if err == nil { + t.Error("got a token; expected error") + if tok.AccessToken != "" { + t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + } } } diff --git a/oauth2.go b/oauth2.go index cdb836d..3644683 100644 --- a/oauth2.go +++ b/oauth2.go @@ -300,8 +300,7 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) { return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) } - token := &Token{} - expires := 0 + var token *Token content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) switch content { case "application/x-www-form-urlencoded", "text/plain": @@ -309,10 +308,12 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) { if err != nil { return nil, err } - token.AccessToken = vals.Get("access_token") - token.TokenType = vals.Get("token_type") - token.RefreshToken = vals.Get("refresh_token") - token.raw = vals + token = &Token{ + AccessToken: vals.Get("access_token"), + TokenType: vals.Get("token_type"), + RefreshToken: vals.Get("refresh_token"), + raw: vals, + } e := vals.Get("expires_in") if e == "" { // TODO(jbd): Facebook's OAuth2 implementation is broken and @@ -320,38 +321,52 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) { // when Facebook fixes their implementation. e = vals.Get("expires") } - expires, _ = strconv.Atoi(e) + expires, _ := strconv.Atoi(e) + if expires != 0 { + token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) + } default: - b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type - if err = json.Unmarshal(body, &b); err != nil { + var tj tokenJSON + if err = json.Unmarshal(body, &tj); err != nil { return nil, err } - token.AccessToken, _ = b["access_token"].(string) - token.TokenType, _ = b["token_type"].(string) - token.RefreshToken, _ = b["refresh_token"].(string) - token.raw = b - e, ok := b["expires_in"].(float64) - if !ok { - // TODO(jbd): Facebook's OAuth2 implementation is broken and - // returns expires_in field in expires. Remove the fallback to expires, - // when Facebook fixes their implementation. - e, _ = b["expires"].(float64) + token = &Token{ + AccessToken: tj.AccessToken, + TokenType: tj.TokenType, + RefreshToken: tj.RefreshToken, + Expiry: tj.expiry(), + raw: make(map[string]interface{}), } - expires = int(e) + json.Unmarshal(body, &token.raw) // no error checks for optional fields } // Don't overwrite `RefreshToken` with an empty value // if this was a token refreshing request. if token.RefreshToken == "" { token.RefreshToken = v.Get("refresh_token") } - if expires == 0 { - token.Expiry = time.Time{} - } else { - token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) - } return token, nil } +// tokenJSON is the struct representing the HTTP response from OAuth2 +// providers returning a token in JSON form. +type tokenJSON struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int32 `json:"expires_in"` + Expires int32 `json:"expires"` // broken Facebook spelling of expires_in +} + +func (e *tokenJSON) expiry() (t time.Time) { + if v := e.ExpiresIn; v != 0 { + return time.Now().Add(time.Duration(v) * time.Second) + } + if v := e.Expires; v != 0 { + return time.Now().Add(time.Duration(v) * time.Second) + } + return +} + func condVal(v string) []string { if v == "" { return nil diff --git a/oauth2_test.go b/oauth2_test.go index 8159b86..c567c3a 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -181,12 +181,9 @@ func TestExchangeRequest_BadResponseType(t *testing.T) { })) defer ts.Close() conf := newConf(ts.URL) - tok, err := conf.Exchange(NoContext, "exchange-code") - if err != nil { - t.Error(err) - } - if tok.AccessToken != "" { - t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + _, err := conf.Exchange(NoContext, "exchange-code") + if err == nil { + t.Error("expected error from invalid access_token type") } }