diff --git a/jwt.go b/jwt.go index 78d9553..28161ca 100644 --- a/jwt.go +++ b/jwt.go @@ -135,29 +135,28 @@ func (c *JWTConfig) FetchToken(existing *Token) (*Token, error) { return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body) } - b := &tokenRespBody{} - if err := json.Unmarshal(body, b); err != nil { + b := make(map[string]interface{}) + if err := json.Unmarshal(body, &b); err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } - - token := &Token{ - AccessToken: b.AccessToken, - TokenType: b.TokenType, - Subject: existing.Subject, + token := &Token{} + token.AccessToken, _ = b["access_token"].(string) + token.TokenType, _ = b["token_type"].(string) + token.Subject = existing.Subject + token.raw = b + if e, ok := b["expires_in"].(int); ok { + token.Expiry = time.Now().Add(time.Duration(e) * time.Second) } - - if b.IdToken != "" { + if idtoken, ok := b["id_token"].(string); ok { // decode returned id token to get expiry claimSet := &jws.ClaimSet{} - claimSet, err = jws.Decode(b.IdToken) + claimSet, err = jws.Decode(idtoken) if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } token.Expiry = time.Unix(claimSet.Exp, 0) return token, nil } - - token.Expiry = time.Now().Add(time.Duration(b.ExpiresIn) * time.Second) return token, nil } diff --git a/jwt_test.go b/jwt_test.go new file mode 100644 index 0000000..7d4711f --- /dev/null +++ b/jwt_test.go @@ -0,0 +1,113 @@ +// Copyright 2014 The oauth2 Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package oauth2 + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE +DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY +fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK +1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr +k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9 +/E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt +3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn +2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3 +nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK +6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf +5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e +DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1 +M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g +z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y +1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK +J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U +f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx +QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA +cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr +Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw +5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg +KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84 +OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd +mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ +5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg== +-----END RSA PRIVATE KEY-----`) + +func TestJWTFetch_JSONResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + conf, err := NewJWTConfig(&JWTOptions{ + Email: "aaa@xxx.com", + PrivateKey: dummyPrivateKey, + }, ts.URL) + tok, err := conf.FetchToken(nil) + if err != nil { + t.Errorf("Failed retrieving token: %s.", err) + } + if tok.Expired() { + t.Errorf("Token shouldn't be expired.") + } + if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { + t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + } + if tok.TokenType != "bearer" { + t.Errorf("Unexpected token type, %#v.", tok.TokenType) + } + scope := tok.Extra("scope") + if scope != "user" { + t.Errorf("Unexpected value for scope: %v", scope) + } +} + +func TestJWTFetch_BadResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + conf, err := NewJWTConfig(&JWTOptions{ + Email: "aaa@xxx.com", + PrivateKey: dummyPrivateKey, + }, ts.URL) + tok, err := conf.FetchToken(nil) + if err != nil { + t.Errorf("Failed retrieving token: %s.", err) + } + if tok.AccessToken != "" { + t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + } + if tok.TokenType != "bearer" { + t.Errorf("Unexpected token type, %#v.", tok.TokenType) + } + scope := tok.Extra("scope") + if scope != "user" { + t.Errorf("Unexpected value for scope: %v", scope) + } +} + +func TestJWTFetch_BadResponseType(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + conf, err := NewJWTConfig(&JWTOptions{ + Email: "aaa@xxx.com", + PrivateKey: dummyPrivateKey, + }, ts.URL) + tok, err := conf.FetchToken(nil) + if err != nil { + t.Errorf("Failed retrieving token: %s.", err) + } + if tok.AccessToken != "" { + t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + } +} diff --git a/oauth2.go b/oauth2.go index 7f0571e..ac3072f 100644 --- a/oauth2.go +++ b/oauth2.go @@ -20,15 +20,6 @@ import ( "time" ) -type tokenRespBody struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` // in seconds - Expires int64 `json:"expires"` // in seconds. Facebook returns expires_in as expires. - IdToken string `json:"id_token"` -} - // TokenFetcher refreshes or fetches a new access token from the // provider. It should return an error if it's not capable of // retrieving a token. @@ -212,7 +203,8 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) { return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) } - resp := &tokenRespBody{} + token := &Token{} + expires := int(0) content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) switch content { case "application/x-www-form-urlencoded", "text/plain": @@ -220,44 +212,45 @@ func (c *Config) retrieveToken(v url.Values) (*Token, error) { if err != nil { return nil, err } - resp.AccessToken = vals.Get("access_token") - resp.TokenType = vals.Get("token_type") - resp.RefreshToken = vals.Get("refresh_token") - resp.ExpiresIn, _ = strconv.ParseInt(vals.Get("expires_in"), 10, 64) - resp.Expires, _ = strconv.ParseInt(vals.Get("expires"), 10, 64) - resp.IdToken = vals.Get("id_token") + token.AccessToken = vals.Get("access_token") + token.TokenType = vals.Get("token_type") + token.RefreshToken = vals.Get("refresh_token") + token.raw = vals + e := vals.Get("expires_in") + if e == "" { + // TODO(jbd): Facebook's OAuth2 implementation is broken and + // returns expires_in field in expires. Remove the fallback to expires, + // when Facebook fixes their implementation. + e = vals.Get("expires") + } + expires, _ = strconv.Atoi(e) default: - if err = json.Unmarshal(body, &resp); err != nil { + b := make(map[string]interface{}) + if err = json.Unmarshal(body, &b); err != nil { return nil, err } - } - token := &Token{ - AccessToken: resp.AccessToken, - TokenType: resp.TokenType, - RefreshToken: resp.RefreshToken, + token.AccessToken, _ = b["access_token"].(string) + token.TokenType, _ = b["token_type"].(string) + token.RefreshToken, _ = b["refresh_token"].(string) + token.raw = b + e, ok := b["expires_in"].(int) + if !ok { + // TODO(jbd): Facebook's OAuth2 implementation is broken and + // returns expires_in field in expires. Remove the fallback to expires, + // when Facebook fixes their implementation. + e, _ = b["expires"].(int) + } + expires = e } // Don't overwrite `RefreshToken` with an empty value // if this was a token refreshing request. - if resp.RefreshToken == "" { + if token.RefreshToken == "" { token.RefreshToken = v.Get("refresh_token") } - if resp.ExpiresIn == 0 || resp.Expires == 0 { + if expires == 0 { token.Expiry = time.Time{} - } - if resp.ExpiresIn > 0 { - token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) - } - if resp.Expires > 0 { - // TODO(jbd): Facebook's OAuth2 implementation is broken and - // returns expires_in field in expires. Remove the fallback to expires, - // when Facebook fixes their implementation. - token.Expiry = time.Now().Add(time.Duration(resp.Expires) * time.Second) - } - if resp.IdToken != "" { - if token.Extra == nil { - token.Extra = make(map[string]string) - } - token.Extra["id_token"] = resp.IdToken + } else { + token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) } return token, nil } diff --git a/oauth2_test.go b/oauth2_test.go index c231ff9..d633f86 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -84,14 +84,18 @@ func TestExchangeRequest(t *testing.T) { t.Errorf("Token shouldn't be expired.") } if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { - t.Errorf("Wrong access token, %#v.", tok.AccessToken) + t.Errorf("Unexpected access token, %#v.", tok.AccessToken) } if tok.TokenType != "bearer" { - t.Errorf("Wrong token type, %#v.", tok.TokenType) + t.Errorf("Unexpected token type, %#v.", tok.TokenType) + } + scope := tok.Extra("scope") + if scope != "user" { + t.Errorf("Unexpected value for scope: %v", scope) } } -func TestExchangeRequest_JsonResponse(t *testing.T) { +func TestExchangeRequest_JSONResponse(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() != "/token" { t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) @@ -124,10 +128,46 @@ func TestExchangeRequest_JsonResponse(t *testing.T) { t.Errorf("Token shouldn't be expired.") } if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { - t.Errorf("Wrong access token, %#v.", tok.AccessToken) + t.Errorf("Unexpected access token, %#v.", tok.AccessToken) } if tok.TokenType != "bearer" { - t.Errorf("Wrong token type, %#v.", tok.TokenType) + t.Errorf("Unexpected token type, %#v.", tok.TokenType) + } + scope := tok.Extra("scope") + if scope != "user" { + t.Errorf("Unexpected value for scope: %v", scope) + } +} + +func TestExchangeRequest_BadResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + conf := newTestConf(ts.URL) + tok, err := conf.Exchange("exchange-code") + if err != nil { + t.Errorf("Failed retrieving token: %s.", err) + } + if tok.AccessToken != "" { + t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + } +} + +func TestExchangeRequest_BadResponseType(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + conf := newTestConf(ts.URL) + tok, err := conf.Exchange("exchange-code") + if err != nil { + t.Errorf("Failed retrieving token: %s.", err) + } + if tok.AccessToken != "" { + t.Errorf("Unexpected access token, %#v.", tok.AccessToken) } } diff --git a/transport.go b/transport.go index 706947b..4483c9c 100644 --- a/transport.go +++ b/transport.go @@ -6,6 +6,7 @@ package oauth2 import ( "net/http" + "net/url" "sync" "time" ) @@ -30,14 +31,28 @@ type Token struct { // The remaining lifetime of the access token. Expiry time.Time `json:"expiry,omitempty"` - // Extra optionally contains extra metadata from the server - // when updating a token. The only current key that may be - // populated is "id_token". It may be nil and will be - // initialized as needed. - Extra map[string]string `json:"extra,omitempty"` - // Subject is the user to impersonate. Subject string `json:"subject,omitempty"` + + // raw optionally contains extra metadata from the server + // when updating a token. + raw interface{} +} + +// Extra returns an extra field returned from the server during token retrieval. +// E.g. +// idToken := token.Extra("id_token") +// +func (t *Token) Extra(key string) string { + if vals, ok := t.raw.(url.Values); ok { + return vals.Get(key) + } + if raw, ok := t.raw.(map[string]interface{}); ok { + if val, ok := raw[key].(string); ok { + return val + } + } + return "" } // Expired returns true if there is no access token or the @@ -105,17 +120,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error func (t *Transport) Token() *Token { t.mu.RLock() defer t.mu.RUnlock() - if t.token == nil { - return nil - } - return &Token{ - AccessToken: t.token.AccessToken, - TokenType: t.token.TokenType, - RefreshToken: t.token.RefreshToken, - Expiry: t.token.Expiry, - Extra: t.token.Extra, - Subject: t.token.Subject, - } + return t.token } // SetToken sets a token to the transport in a thread-safe way.