From 4b83411ed2b36bd4e3302e9b1d3c973fb1ba24db Mon Sep 17 00:00:00 2001 From: Niels Widger Date: Sat, 16 Feb 2019 14:30:17 +0000 Subject: [PATCH] jwt: add Config.Audience field Add an Audience field to jwt.Config which, if set, is used instead of TokenURL as the 'aud' claim in the generated JWT. This allows the jwt package to work with authorization servers that require the 'aud' claim and token endpoint URL to be different values. Fixes #369. Change-Id: I883aabece7f9b16ec726d5bfa98c1ec91876b651 GitHub-Last-Rev: fd73e4d50cfe0450fd59ffc6d4c5db7a3f660b60 GitHub-Pull-Request: golang/oauth2#370 Reviewed-on: https://go-review.googlesource.com/c/162937 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- jwt/jwt.go | 8 ++++++ jwt/jwt_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/jwt/jwt.go b/jwt/jwt.go index 0783a94..99f3e0a 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -61,6 +61,11 @@ type Config struct { // Expires optionally specifies how long the token is valid for. Expires time.Duration + + // Audience optionally specifies the intended audience of the + // request. If empty, the value of TokenURL is used as the + // intended audience. + Audience string } // TokenSource returns a JWT TokenSource using the configuration @@ -105,6 +110,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) { if t := js.conf.Expires; t > 0 { claimSet.Exp = time.Now().Add(t).Unix() } + if aud := js.conf.Audience; aud != "" { + claimSet.Aud = aud + } h := *defaultHeader h.KeyID = js.conf.PrivateKeyID payload, err := jws.Encode(&h, claimSet, pk) diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 1fbb9aa..9dfa3b3 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -191,6 +191,80 @@ func TestJWTFetch_Assertion(t *testing.T) { } } +func TestJWTFetch_AssertionPayload(t *testing.T) { + var assertion string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertion = r.Form.Get("assertion") + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "user", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + for _, conf := range []*Config{ + { + Email: "aaa1@xxx.com", + PrivateKey: dummyPrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + }, + { + Email: "aaa2@xxx.com", + PrivateKey: dummyPrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + Audience: "https://example.com", + }, + } { + t.Run(conf.Email, func(t *testing.T) { + _, err := conf.TokenSource(context.Background()).Token() + if err != nil { + t.Fatalf("Failed to fetch token: %v", err) + } + + parts := strings.Split(assertion, ".") + if len(parts) != 3 { + t.Fatalf("assertion = %q; want 3 parts", assertion) + } + gotjson, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("invalid token payload; err = %v", err) + } + + claimSet := jws.ClaimSet{} + if err := json.Unmarshal(gotjson, &claimSet); err != nil { + t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err) + } + + if got, want := claimSet.Iss, conf.Email; got != want { + t.Errorf("payload email = %q; want %q", got, want) + } + if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want { + t.Errorf("payload scope = %q; want %q", got, want) + } + aud := conf.TokenURL + if conf.Audience != "" { + aud = conf.Audience + } + if got, want := claimSet.Aud, aud; got != want { + t.Errorf("payload audience = %q; want %q", got, want) + } + if got, want := claimSet.Sub, conf.Subject; got != want { + t.Errorf("payload subject = %q; want %q", got, want) + } + if got, want := claimSet.Prn, conf.Subject; got != want { + t.Errorf("payload prn = %q; want %q", got, want) + } + }) + } +} + func TestTokenRetrieveError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-type", "application/json")