diff --git a/jwt/jwt.go b/jwt/jwt.go index 0783a94..99f3e0a 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -61,6 +61,11 @@ type Config struct { // Expires optionally specifies how long the token is valid for. Expires time.Duration + + // Audience optionally specifies the intended audience of the + // request. If empty, the value of TokenURL is used as the + // intended audience. + Audience string } // TokenSource returns a JWT TokenSource using the configuration @@ -105,6 +110,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) { if t := js.conf.Expires; t > 0 { claimSet.Exp = time.Now().Add(t).Unix() } + if aud := js.conf.Audience; aud != "" { + claimSet.Aud = aud + } h := *defaultHeader h.KeyID = js.conf.PrivateKeyID payload, err := jws.Encode(&h, claimSet, pk) diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 1fbb9aa..9dfa3b3 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -191,6 +191,80 @@ func TestJWTFetch_Assertion(t *testing.T) { } } +func TestJWTFetch_AssertionPayload(t *testing.T) { + var assertion string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertion = r.Form.Get("assertion") + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "user", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + for _, conf := range []*Config{ + { + Email: "aaa1@xxx.com", + PrivateKey: dummyPrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + }, + { + Email: "aaa2@xxx.com", + PrivateKey: dummyPrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + Audience: "https://example.com", + }, + } { + t.Run(conf.Email, func(t *testing.T) { + _, err := conf.TokenSource(context.Background()).Token() + if err != nil { + t.Fatalf("Failed to fetch token: %v", err) + } + + parts := strings.Split(assertion, ".") + if len(parts) != 3 { + t.Fatalf("assertion = %q; want 3 parts", assertion) + } + gotjson, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("invalid token payload; err = %v", err) + } + + claimSet := jws.ClaimSet{} + if err := json.Unmarshal(gotjson, &claimSet); err != nil { + t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err) + } + + if got, want := claimSet.Iss, conf.Email; got != want { + t.Errorf("payload email = %q; want %q", got, want) + } + if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want { + t.Errorf("payload scope = %q; want %q", got, want) + } + aud := conf.TokenURL + if conf.Audience != "" { + aud = conf.Audience + } + if got, want := claimSet.Aud, aud; got != want { + t.Errorf("payload audience = %q; want %q", got, want) + } + if got, want := claimSet.Sub, conf.Subject; got != want { + t.Errorf("payload subject = %q; want %q", got, want) + } + if got, want := claimSet.Prn, conf.Subject; got != want { + t.Errorf("payload prn = %q; want %q", got, want) + } + }) + } +} + func TestTokenRetrieveError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-type", "application/json")