forked from Mirrors/oauth2
oauth2: use a JSON struct types instead of empty interface maps
Change-Id: Ifd66ea35c15dbd14acca0c945b533ec755de12e4 Reviewed-on: https://go-review.googlesource.com/1872 Reviewed-by: Burcu Dogan <jbd@google.com>
This commit is contained in:
parent
9abe144dd5
commit
f5b40b26f1
32
jwt.go
32
jwt.go
|
@ -123,25 +123,33 @@ func (js jwtSource) Token() (*Token, error) {
|
||||||
if c := resp.StatusCode; c < 200 || c > 299 {
|
if c := resp.StatusCode; c < 200 || c > 299 {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
|
||||||
}
|
}
|
||||||
b := make(map[string]interface{})
|
// tokenRes is the JSON response body.
|
||||||
if err := json.Unmarshal(body, &b); err != nil {
|
var tokenRes struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &tokenRes); err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
token := &Token{}
|
token := &Token{
|
||||||
token.AccessToken, _ = b["access_token"].(string)
|
AccessToken: tokenRes.AccessToken,
|
||||||
token.TokenType, _ = b["token_type"].(string)
|
TokenType: tokenRes.TokenType,
|
||||||
token.raw = b
|
raw: make(map[string]interface{}),
|
||||||
if e, ok := b["expires_in"].(float64); ok {
|
|
||||||
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
|
|
||||||
}
|
}
|
||||||
if idtoken, ok := b["id_token"].(string); ok {
|
json.Unmarshal(body, &token.raw) // no error checks for optional fields
|
||||||
|
|
||||||
|
if secs := tokenRes.ExpiresIn; secs > 0 {
|
||||||
|
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
||||||
|
}
|
||||||
|
if v := tokenRes.IDToken; v != "" {
|
||||||
// decode returned id token to get expiry
|
// decode returned id token to get expiry
|
||||||
claimSet, err := jws.Decode(idtoken)
|
claimSet, err := jws.Decode(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err)
|
||||||
}
|
}
|
||||||
token.Expiry = time.Unix(claimSet.Exp, 0)
|
token.Expiry = time.Unix(claimSet.Exp, 0)
|
||||||
return token, nil
|
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,10 +117,10 @@ func TestJWTFetch_BadResponseType(t *testing.T) {
|
||||||
TokenURL: ts.URL,
|
TokenURL: ts.URL,
|
||||||
}
|
}
|
||||||
tok, err := conf.TokenSource(NoContext, nil).Token()
|
tok, err := conf.TokenSource(NoContext, nil).Token()
|
||||||
if err != nil {
|
if err == nil {
|
||||||
t.Fatal(err)
|
t.Error("got a token; expected error")
|
||||||
}
|
|
||||||
if tok.AccessToken != "" {
|
if tok.AccessToken != "" {
|
||||||
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
|
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
65
oauth2.go
65
oauth2.go
|
@ -300,8 +300,7 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
token := &Token{}
|
var token *Token
|
||||||
expires := 0
|
|
||||||
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||||
switch content {
|
switch content {
|
||||||
case "application/x-www-form-urlencoded", "text/plain":
|
case "application/x-www-form-urlencoded", "text/plain":
|
||||||
|
@ -309,10 +308,12 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
token.AccessToken = vals.Get("access_token")
|
token = &Token{
|
||||||
token.TokenType = vals.Get("token_type")
|
AccessToken: vals.Get("access_token"),
|
||||||
token.RefreshToken = vals.Get("refresh_token")
|
TokenType: vals.Get("token_type"),
|
||||||
token.raw = vals
|
RefreshToken: vals.Get("refresh_token"),
|
||||||
|
raw: vals,
|
||||||
|
}
|
||||||
e := vals.Get("expires_in")
|
e := vals.Get("expires_in")
|
||||||
if e == "" {
|
if e == "" {
|
||||||
// TODO(jbd): Facebook's OAuth2 implementation is broken and
|
// TODO(jbd): Facebook's OAuth2 implementation is broken and
|
||||||
|
@ -320,38 +321,52 @@ func retrieveToken(ctx Context, c *Config, v url.Values) (*Token, error) {
|
||||||
// when Facebook fixes their implementation.
|
// when Facebook fixes their implementation.
|
||||||
e = vals.Get("expires")
|
e = vals.Get("expires")
|
||||||
}
|
}
|
||||||
expires, _ = strconv.Atoi(e)
|
expires, _ := strconv.Atoi(e)
|
||||||
|
if expires != 0 {
|
||||||
|
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
b := make(map[string]interface{}) // TODO: don't use a map[string]interface{}; make a type
|
var tj tokenJSON
|
||||||
if err = json.Unmarshal(body, &b); err != nil {
|
if err = json.Unmarshal(body, &tj); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
token.AccessToken, _ = b["access_token"].(string)
|
token = &Token{
|
||||||
token.TokenType, _ = b["token_type"].(string)
|
AccessToken: tj.AccessToken,
|
||||||
token.RefreshToken, _ = b["refresh_token"].(string)
|
TokenType: tj.TokenType,
|
||||||
token.raw = b
|
RefreshToken: tj.RefreshToken,
|
||||||
e, ok := b["expires_in"].(float64)
|
Expiry: tj.expiry(),
|
||||||
if !ok {
|
raw: make(map[string]interface{}),
|
||||||
// TODO(jbd): Facebook's OAuth2 implementation is broken and
|
|
||||||
// returns expires_in field in expires. Remove the fallback to expires,
|
|
||||||
// when Facebook fixes their implementation.
|
|
||||||
e, _ = b["expires"].(float64)
|
|
||||||
}
|
}
|
||||||
expires = int(e)
|
json.Unmarshal(body, &token.raw) // no error checks for optional fields
|
||||||
}
|
}
|
||||||
// Don't overwrite `RefreshToken` with an empty value
|
// Don't overwrite `RefreshToken` with an empty value
|
||||||
// if this was a token refreshing request.
|
// if this was a token refreshing request.
|
||||||
if token.RefreshToken == "" {
|
if token.RefreshToken == "" {
|
||||||
token.RefreshToken = v.Get("refresh_token")
|
token.RefreshToken = v.Get("refresh_token")
|
||||||
}
|
}
|
||||||
if expires == 0 {
|
|
||||||
token.Expiry = time.Time{}
|
|
||||||
} else {
|
|
||||||
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
|
|
||||||
}
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tokenJSON is the struct representing the HTTP response from OAuth2
|
||||||
|
// providers returning a token in JSON form.
|
||||||
|
type tokenJSON struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int32 `json:"expires_in"`
|
||||||
|
Expires int32 `json:"expires"` // broken Facebook spelling of expires_in
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *tokenJSON) expiry() (t time.Time) {
|
||||||
|
if v := e.ExpiresIn; v != 0 {
|
||||||
|
return time.Now().Add(time.Duration(v) * time.Second)
|
||||||
|
}
|
||||||
|
if v := e.Expires; v != 0 {
|
||||||
|
return time.Now().Add(time.Duration(v) * time.Second)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func condVal(v string) []string {
|
func condVal(v string) []string {
|
||||||
if v == "" {
|
if v == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -181,12 +181,9 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
conf := newConf(ts.URL)
|
conf := newConf(ts.URL)
|
||||||
tok, err := conf.Exchange(NoContext, "exchange-code")
|
_, err := conf.Exchange(NoContext, "exchange-code")
|
||||||
if err != nil {
|
if err == nil {
|
||||||
t.Error(err)
|
t.Error("expected error from invalid access_token type")
|
||||||
}
|
|
||||||
if tok.AccessToken != "" {
|
|
||||||
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue