token: extra numeric values + test TokenType case

+ Added tests for TokenType by checking case.
+ Added numeric conversion for float and integer like
  values from token.Extra.

Change-Id: I0909a4458ed58e33428afbf40478a668d150dda7
Reviewed-on: https://go-review.googlesource.com/15156
Reviewed-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
Emmanuel Odeke 2015-09-29 18:24:17 -06:00 committed by Andrew Gerrand
parent 3cab960fb9
commit 2fbf3d7329
3 changed files with 91 additions and 5 deletions

View File

@ -11,6 +11,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"reflect" "reflect"
"strconv" "strconv"
"testing" "testing"
@ -170,6 +171,54 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
if scope != "user" { if scope != "user" {
t.Errorf("Unexpected value for scope: %v", scope) t.Errorf("Unexpected value for scope: %v", scope)
} }
expiresIn := tok.Extra("expires_in")
if expiresIn != float64(86400) {
t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn)
}
}
func TestExtraValueRetrieval(t *testing.T) {
values := url.Values{}
kvmap := map[string]string{
"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1",
"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
"untrimmed": " untrimmed ",
}
for key, value := range kvmap {
values.Set(key, value)
}
tok := Token{
raw: values,
}
scope := tok.Extra("scope")
if scope != "user" {
t.Errorf("Unexpected scope %v wanted \"user\"", scope)
}
serverTime := tok.Extra("server_time")
if serverTime != 1443571905.5606415 {
t.Errorf("Unexpected non-float64 value for server_time: %v", serverTime)
}
refererIp := tok.Extra("referer_ip")
if refererIp != "10.0.0.1" {
t.Errorf("Unexpected non-string value for referer_ip: %v", refererIp)
}
expires_in := tok.Extra("expires_in")
if expires_in != 86400.92 {
t.Errorf("Unexpected value for expires_in, wanted 86400 got %v", expires_in)
}
requestId := tok.Extra("request_id")
if requestId != int64(86400) {
t.Errorf("Unexpected non-int64 value for request_id: %v", requestId)
}
untrimmed := tok.Extra("untrimmed")
if untrimmed != " untrimmed " {
t.Errorf("Unexpected value for untrimmed, got %q expected \" untrimmed \"", untrimmed)
}
} }
const day = 24 * time.Hour const day = 24 * time.Hour

View File

@ -7,6 +7,7 @@ package oauth2
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
@ -92,16 +93,30 @@ func (t *Token) WithExtra(extra interface{}) *Token {
// Extra fields are key-value pairs returned by the server as a // Extra fields are key-value pairs returned by the server as a
// part of the token retrieval response. // part of the token retrieval response.
func (t *Token) Extra(key string) interface{} { func (t *Token) Extra(key string) interface{} {
if vals, ok := t.raw.(url.Values); ok {
// TODO(jbd): Cast numeric values to int64 or float64.
return vals.Get(key)
}
if raw, ok := t.raw.(map[string]interface{}); ok { if raw, ok := t.raw.(map[string]interface{}); ok {
return raw[key] return raw[key]
} }
vals, ok := t.raw.(url.Values)
if !ok {
return nil return nil
} }
v := vals.Get(key)
switch s := strings.TrimSpace(v); strings.Count(s, ".") {
case 0: // Contains no "."; try to parse as int
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i
}
case 1: // Contains a single "."; try to parse as float
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f
}
}
return v
}
// expired reports whether the token is expired. // expired reports whether the token is expired.
// t must be non-nil. // t must be non-nil.
func (t *Token) expired() bool { func (t *Token) expired() bool {

View File

@ -41,6 +41,7 @@ func TestTokenExpiry(t *testing.T) {
}{ }{
{name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, want: false}, {name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, want: false},
{name: "10 seconds", tok: &Token{Expiry: now.Add(expiryDelta)}, want: true}, {name: "10 seconds", tok: &Token{Expiry: now.Add(expiryDelta)}, want: true},
{name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, want: true},
} }
for _, tc := range cases { for _, tc := range cases {
if got, want := tc.tok.expired(), tc.want; got != want { if got, want := tc.tok.expired(), tc.want; got != want {
@ -48,3 +49,24 @@ func TestTokenExpiry(t *testing.T) {
} }
} }
} }
func TestTokenTypeMethod(t *testing.T) {
cases := []struct {
name string
tok *Token
want string
}{
{name: "bearer-mixed_case", tok: &Token{TokenType: "beAREr"}, want: "Bearer"},
{name: "default-bearer", tok: &Token{}, want: "Bearer"},
{name: "basic", tok: &Token{TokenType: "basic"}, want: "Basic"},
{name: "basic-capitalized", tok: &Token{TokenType: "Basic"}, want: "Basic"},
{name: "mac", tok: &Token{TokenType: "mac"}, want: "MAC"},
{name: "mac-caps", tok: &Token{TokenType: "MAC"}, want: "MAC"},
{name: "mac-mixed_case", tok: &Token{TokenType: "mAc"}, want: "MAC"},
}
for _, tc := range cases {
if got, want := tc.tok.Type(), tc.want; got != want {
t.Errorf("TokenType(%q) = %v; want %v", tc.name, got, want)
}
}
}