forked from Mirrors/oauth2
token: extra numeric values + test TokenType case
+ Added tests for TokenType by checking case. + Added numeric conversion for float and integer like values from token.Extra. Change-Id: I0909a4458ed58e33428afbf40478a668d150dda7 Reviewed-on: https://go-review.googlesource.com/15156 Reviewed-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
parent
3cab960fb9
commit
2fbf3d7329
|
@ -11,6 +11,7 @@ import (
|
|||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
@ -170,6 +171,54 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
|
|||
if scope != "user" {
|
||||
t.Errorf("Unexpected value for scope: %v", scope)
|
||||
}
|
||||
expiresIn := tok.Extra("expires_in")
|
||||
if expiresIn != float64(86400) {
|
||||
t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraValueRetrieval(t *testing.T) {
|
||||
values := url.Values{}
|
||||
|
||||
kvmap := map[string]string{
|
||||
"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
|
||||
"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1",
|
||||
"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
|
||||
"untrimmed": " untrimmed ",
|
||||
}
|
||||
|
||||
for key, value := range kvmap {
|
||||
values.Set(key, value)
|
||||
}
|
||||
|
||||
tok := Token{
|
||||
raw: values,
|
||||
}
|
||||
|
||||
scope := tok.Extra("scope")
|
||||
if scope != "user" {
|
||||
t.Errorf("Unexpected scope %v wanted \"user\"", scope)
|
||||
}
|
||||
serverTime := tok.Extra("server_time")
|
||||
if serverTime != 1443571905.5606415 {
|
||||
t.Errorf("Unexpected non-float64 value for server_time: %v", serverTime)
|
||||
}
|
||||
refererIp := tok.Extra("referer_ip")
|
||||
if refererIp != "10.0.0.1" {
|
||||
t.Errorf("Unexpected non-string value for referer_ip: %v", refererIp)
|
||||
}
|
||||
expires_in := tok.Extra("expires_in")
|
||||
if expires_in != 86400.92 {
|
||||
t.Errorf("Unexpected value for expires_in, wanted 86400 got %v", expires_in)
|
||||
}
|
||||
requestId := tok.Extra("request_id")
|
||||
if requestId != int64(86400) {
|
||||
t.Errorf("Unexpected non-int64 value for request_id: %v", requestId)
|
||||
}
|
||||
untrimmed := tok.Extra("untrimmed")
|
||||
if untrimmed != " untrimmed " {
|
||||
t.Errorf("Unexpected value for untrimmed, got %q expected \" untrimmed \"", untrimmed)
|
||||
}
|
||||
}
|
||||
|
||||
const day = 24 * time.Hour
|
||||
|
|
25
token.go
25
token.go
|
@ -7,6 +7,7 @@ package oauth2
|
|||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -92,14 +93,28 @@ func (t *Token) WithExtra(extra interface{}) *Token {
|
|||
// Extra fields are key-value pairs returned by the server as a
|
||||
// part of the token retrieval response.
|
||||
func (t *Token) Extra(key string) interface{} {
|
||||
if vals, ok := t.raw.(url.Values); ok {
|
||||
// TODO(jbd): Cast numeric values to int64 or float64.
|
||||
return vals.Get(key)
|
||||
}
|
||||
if raw, ok := t.raw.(map[string]interface{}); ok {
|
||||
return raw[key]
|
||||
}
|
||||
return nil
|
||||
|
||||
vals, ok := t.raw.(url.Values)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := vals.Get(key)
|
||||
switch s := strings.TrimSpace(v); strings.Count(s, ".") {
|
||||
case 0: // Contains no "."; try to parse as int
|
||||
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
return i
|
||||
}
|
||||
case 1: // Contains a single "."; try to parse as float
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// expired reports whether the token is expired.
|
||||
|
|
|
@ -41,6 +41,7 @@ func TestTokenExpiry(t *testing.T) {
|
|||
}{
|
||||
{name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, want: false},
|
||||
{name: "10 seconds", tok: &Token{Expiry: now.Add(expiryDelta)}, want: true},
|
||||
{name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, want: true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got, want := tc.tok.expired(), tc.want; got != want {
|
||||
|
@ -48,3 +49,24 @@ func TestTokenExpiry(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenTypeMethod(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
tok *Token
|
||||
want string
|
||||
}{
|
||||
{name: "bearer-mixed_case", tok: &Token{TokenType: "beAREr"}, want: "Bearer"},
|
||||
{name: "default-bearer", tok: &Token{}, want: "Bearer"},
|
||||
{name: "basic", tok: &Token{TokenType: "basic"}, want: "Basic"},
|
||||
{name: "basic-capitalized", tok: &Token{TokenType: "Basic"}, want: "Basic"},
|
||||
{name: "mac", tok: &Token{TokenType: "mac"}, want: "MAC"},
|
||||
{name: "mac-caps", tok: &Token{TokenType: "MAC"}, want: "MAC"},
|
||||
{name: "mac-mixed_case", tok: &Token{TokenType: "mAc"}, want: "MAC"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got, want := tc.tok.Type(), tc.want; got != want {
|
||||
t.Errorf("TokenType(%q) = %v; want %v", tc.name, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue