From fd73e4d50cfe0450fd59ffc6d4c5db7a3f660b60 Mon Sep 17 00:00:00 2001 From: Niels Widger Date: Thu, 14 Feb 2019 15:44:02 -0500 Subject: [PATCH] jwt: add Config.Audience field Add an Audience field to jwt.Config which, if set, is used instead of TokenURL as the 'aud' claim in the generated JWT. This allows the jwt package to work with authorization servers that require the 'aud' claim and token endpoint URL to be different values. --- jwt/jwt.go | 8 ++++++ jwt/jwt_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/jwt/jwt.go b/jwt/jwt.go index 0783a94..99f3e0a 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -61,6 +61,11 @@ type Config struct { // Expires optionally specifies how long the token is valid for. Expires time.Duration + + // Audience optionally specifies the intended audience of the + // request. If empty, the value of TokenURL is used as the + // intended audience. + Audience string } // TokenSource returns a JWT TokenSource using the configuration @@ -105,6 +110,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) { if t := js.conf.Expires; t > 0 { claimSet.Exp = time.Now().Add(t).Unix() } + if aud := js.conf.Audience; aud != "" { + claimSet.Aud = aud + } h := *defaultHeader h.KeyID = js.conf.PrivateKeyID payload, err := jws.Encode(&h, claimSet, pk) diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 1fbb9aa..9dfa3b3 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -191,6 +191,80 @@ func TestJWTFetch_Assertion(t *testing.T) { } } +func TestJWTFetch_AssertionPayload(t *testing.T) { + var assertion string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertion = r.Form.Get("assertion") + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "user", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + for _, conf := range []*Config{ + { + Email: "aaa1@xxx.com", + PrivateKey: dummyPrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + }, + { + Email: "aaa2@xxx.com", + PrivateKey: dummyPrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + Audience: "https://example.com", + }, + } { + t.Run(conf.Email, func(t *testing.T) { + _, err := conf.TokenSource(context.Background()).Token() + if err != nil { + t.Fatalf("Failed to fetch token: %v", err) + } + + parts := strings.Split(assertion, ".") + if len(parts) != 3 { + t.Fatalf("assertion = %q; want 3 parts", assertion) + } + gotjson, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("invalid token payload; err = %v", err) + } + + claimSet := jws.ClaimSet{} + if err := json.Unmarshal(gotjson, &claimSet); err != nil { + t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err) + } + + if got, want := claimSet.Iss, conf.Email; got != want { + t.Errorf("payload email = %q; want %q", got, want) + } + if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want { + t.Errorf("payload scope = %q; want %q", got, want) + } + aud := conf.TokenURL + if conf.Audience != "" { + aud = conf.Audience + } + if got, want := claimSet.Aud, aud; got != want { + t.Errorf("payload audience = %q; want %q", got, want) + } + if got, want := claimSet.Sub, conf.Subject; got != want { + t.Errorf("payload subject = %q; want %q", got, want) + } + if got, want := claimSet.Prn, conf.Subject; got != want { + t.Errorf("payload prn = %q; want %q", got, want) + } + }) + } +} + func TestTokenRetrieveError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-type", "application/json")