diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 9dfa3b3..8bfa908 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "strings" "testing" @@ -221,6 +222,16 @@ func TestJWTFetch_AssertionPayload(t *testing.T) { TokenURL: ts.URL, Audience: "https://example.com", }, + { + Email: "aaa2@xxx.com", + PrivateKey: dummyPrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + PrivateClaims: map[string]interface{}{ + "private0": "claim0", + "private1": "claim1", + }, + }, } { t.Run(conf.Email, func(t *testing.T) { _, err := conf.TokenSource(context.Background()).Token() @@ -261,6 +272,18 @@ func TestJWTFetch_AssertionPayload(t *testing.T) { if got, want := claimSet.Prn, conf.Subject; got != want { t.Errorf("payload prn = %q; want %q", got, want) } + if len(conf.PrivateClaims) > 0 { + var got interface{} + if err := json.Unmarshal(gotjson, &got); err != nil { + t.Errorf("failed to parse payload; err = %q", err) + } + m := got.(map[string]interface{}) + for v, k := range conf.PrivateClaims { + if !reflect.DeepEqual(m[v], k) { + t.Errorf("payload private claims key = %q: got %q; want %q", v, m[v], k) + } + } + } }) } }