forked from Mirrors/oauth2
jwt: add Config.Audience field
Add an Audience field to jwt.Config which, if set, is used instead of TokenURL as the 'aud' claim in the generated JWT. This allows the jwt package to work with authorization servers that require the 'aud' claim and token endpoint URL to be different values.
This commit is contained in:
parent
3e8b2be136
commit
fd73e4d50c
|
@ -61,6 +61,11 @@ type Config struct {
|
||||||
|
|
||||||
// Expires optionally specifies how long the token is valid for.
|
// Expires optionally specifies how long the token is valid for.
|
||||||
Expires time.Duration
|
Expires time.Duration
|
||||||
|
|
||||||
|
// Audience optionally specifies the intended audience of the
|
||||||
|
// request. If empty, the value of TokenURL is used as the
|
||||||
|
// intended audience.
|
||||||
|
Audience string
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenSource returns a JWT TokenSource using the configuration
|
// TokenSource returns a JWT TokenSource using the configuration
|
||||||
|
@ -105,6 +110,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
|
||||||
if t := js.conf.Expires; t > 0 {
|
if t := js.conf.Expires; t > 0 {
|
||||||
claimSet.Exp = time.Now().Add(t).Unix()
|
claimSet.Exp = time.Now().Add(t).Unix()
|
||||||
}
|
}
|
||||||
|
if aud := js.conf.Audience; aud != "" {
|
||||||
|
claimSet.Aud = aud
|
||||||
|
}
|
||||||
h := *defaultHeader
|
h := *defaultHeader
|
||||||
h.KeyID = js.conf.PrivateKeyID
|
h.KeyID = js.conf.PrivateKeyID
|
||||||
payload, err := jws.Encode(&h, claimSet, pk)
|
payload, err := jws.Encode(&h, claimSet, pk)
|
||||||
|
|
|
@ -191,6 +191,80 @@ func TestJWTFetch_Assertion(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJWTFetch_AssertionPayload(t *testing.T) {
|
||||||
|
var assertion string
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.ParseForm()
|
||||||
|
assertion = r.Form.Get("assertion")
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{
|
||||||
|
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
|
||||||
|
"scope": "user",
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": 3600
|
||||||
|
}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
for _, conf := range []*Config{
|
||||||
|
{
|
||||||
|
Email: "aaa1@xxx.com",
|
||||||
|
PrivateKey: dummyPrivateKey,
|
||||||
|
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||||
|
TokenURL: ts.URL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Email: "aaa2@xxx.com",
|
||||||
|
PrivateKey: dummyPrivateKey,
|
||||||
|
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||||
|
TokenURL: ts.URL,
|
||||||
|
Audience: "https://example.com",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(conf.Email, func(t *testing.T) {
|
||||||
|
_, err := conf.TokenSource(context.Background()).Token()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to fetch token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(assertion, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("assertion = %q; want 3 parts", assertion)
|
||||||
|
}
|
||||||
|
gotjson, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("invalid token payload; err = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claimSet := jws.ClaimSet{}
|
||||||
|
if err := json.Unmarshal(gotjson, &claimSet); err != nil {
|
||||||
|
t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := claimSet.Iss, conf.Email; got != want {
|
||||||
|
t.Errorf("payload email = %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want {
|
||||||
|
t.Errorf("payload scope = %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
aud := conf.TokenURL
|
||||||
|
if conf.Audience != "" {
|
||||||
|
aud = conf.Audience
|
||||||
|
}
|
||||||
|
if got, want := claimSet.Aud, aud; got != want {
|
||||||
|
t.Errorf("payload audience = %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := claimSet.Sub, conf.Subject; got != want {
|
||||||
|
t.Errorf("payload subject = %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := claimSet.Prn, conf.Subject; got != want {
|
||||||
|
t.Errorf("payload prn = %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTokenRetrieveError(t *testing.T) {
|
func TestTokenRetrieveError(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-type", "application/json")
|
w.Header().Set("Content-type", "application/json")
|
||||||
|
|
Loading…
Reference in New Issue