diff --git a/oauth2_test.go b/oauth2_test.go index 2f7d731..a99d7a3 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -11,6 +11,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "reflect" "strconv" "testing" @@ -170,6 +171,54 @@ func TestExchangeRequest_JSONResponse(t *testing.T) { if scope != "user" { t.Errorf("Unexpected value for scope: %v", scope) } + expiresIn := tok.Extra("expires_in") + if expiresIn != float64(86400) { + t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn) + } +} + +func TestExtraValueRetrieval(t *testing.T) { + values := url.Values{} + + kvmap := map[string]string{ + "scope": "user", "token_type": "bearer", "expires_in": "86400.92", + "server_time": "1443571905.5606415", "referer_ip": "10.0.0.1", + "etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400", + "untrimmed": " untrimmed ", + } + + for key, value := range kvmap { + values.Set(key, value) + } + + tok := Token{ + raw: values, + } + + scope := tok.Extra("scope") + if scope != "user" { + t.Errorf("Unexpected scope %v wanted \"user\"", scope) + } + serverTime := tok.Extra("server_time") + if serverTime != 1443571905.5606415 { + t.Errorf("Unexpected non-float64 value for server_time: %v", serverTime) + } + refererIp := tok.Extra("referer_ip") + if refererIp != "10.0.0.1" { + t.Errorf("Unexpected non-string value for referer_ip: %v", refererIp) + } + expires_in := tok.Extra("expires_in") + if expires_in != 86400.92 { + t.Errorf("Unexpected value for expires_in, wanted 86400 got %v", expires_in) + } + requestId := tok.Extra("request_id") + if requestId != int64(86400) { + t.Errorf("Unexpected non-int64 value for request_id: %v", requestId) + } + untrimmed := tok.Extra("untrimmed") + if untrimmed != " untrimmed " { + t.Errorf("Unexpected value for untrimmed, got %q expected \" untrimmed \"", untrimmed) + } } const day = 24 * time.Hour diff --git a/token.go b/token.go index ebbdddb..4e596f0 100644 --- a/token.go +++ b/token.go @@ -7,6 +7,7 @@ package oauth2 import ( "net/http" "net/url" + "strconv" "strings" "time" @@ -92,14 +93,28 @@ func (t *Token) WithExtra(extra interface{}) *Token { // Extra fields are key-value pairs returned by the server as a // part of the token retrieval response. func (t *Token) Extra(key string) interface{} { - if vals, ok := t.raw.(url.Values); ok { - // TODO(jbd): Cast numeric values to int64 or float64. - return vals.Get(key) - } if raw, ok := t.raw.(map[string]interface{}); ok { return raw[key] } - return nil + + vals, ok := t.raw.(url.Values) + if !ok { + return nil + } + + v := vals.Get(key) + switch s := strings.TrimSpace(v); strings.Count(s, ".") { + case 0: // Contains no "."; try to parse as int + if i, err := strconv.ParseInt(s, 10, 64); err == nil { + return i + } + case 1: // Contains a single "."; try to parse as float + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f + } + } + + return v } // expired reports whether the token is expired. diff --git a/token_test.go b/token_test.go index 739eeb2..8344329 100644 --- a/token_test.go +++ b/token_test.go @@ -41,6 +41,7 @@ func TestTokenExpiry(t *testing.T) { }{ {name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, want: false}, {name: "10 seconds", tok: &Token{Expiry: now.Add(expiryDelta)}, want: true}, + {name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, want: true}, } for _, tc := range cases { if got, want := tc.tok.expired(), tc.want; got != want { @@ -48,3 +49,24 @@ func TestTokenExpiry(t *testing.T) { } } } + +func TestTokenTypeMethod(t *testing.T) { + cases := []struct { + name string + tok *Token + want string + }{ + {name: "bearer-mixed_case", tok: &Token{TokenType: "beAREr"}, want: "Bearer"}, + {name: "default-bearer", tok: &Token{}, want: "Bearer"}, + {name: "basic", tok: &Token{TokenType: "basic"}, want: "Basic"}, + {name: "basic-capitalized", tok: &Token{TokenType: "Basic"}, want: "Basic"}, + {name: "mac", tok: &Token{TokenType: "mac"}, want: "MAC"}, + {name: "mac-caps", tok: &Token{TokenType: "MAC"}, want: "MAC"}, + {name: "mac-mixed_case", tok: &Token{TokenType: "mAc"}, want: "MAC"}, + } + for _, tc := range cases { + if got, want := tc.tok.Type(), tc.want; got != want { + t.Errorf("TokenType(%q) = %v; want %v", tc.name, got, want) + } + } +}