forked from Mirrors/oauth2
jwt: add Config.Audience field
Add an Audience field to jwt.Config which, if set, is used instead of TokenURL as the 'aud' claim in the generated JWT. This allows the jwt package to work with authorization servers that require the 'aud' claim and token endpoint URL to be different values.
Fixes #369.
Change-Id: I883aabece7f9b16ec726d5bfa98c1ec91876b651
GitHub-Last-Rev: fd73e4d50c
GitHub-Pull-Request: golang/oauth2#370
Reviewed-on: https://go-review.googlesource.com/c/162937
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
3e8b2be136
commit
4b83411ed2
|
@ -61,6 +61,11 @@ type Config struct {
|
|||
|
||||
// Expires optionally specifies how long the token is valid for.
|
||||
Expires time.Duration
|
||||
|
||||
// Audience optionally specifies the intended audience of the
|
||||
// request. If empty, the value of TokenURL is used as the
|
||||
// intended audience.
|
||||
Audience string
|
||||
}
|
||||
|
||||
// TokenSource returns a JWT TokenSource using the configuration
|
||||
|
@ -105,6 +110,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
|
|||
if t := js.conf.Expires; t > 0 {
|
||||
claimSet.Exp = time.Now().Add(t).Unix()
|
||||
}
|
||||
if aud := js.conf.Audience; aud != "" {
|
||||
claimSet.Aud = aud
|
||||
}
|
||||
h := *defaultHeader
|
||||
h.KeyID = js.conf.PrivateKeyID
|
||||
payload, err := jws.Encode(&h, claimSet, pk)
|
||||
|
|
|
@ -191,6 +191,80 @@ func TestJWTFetch_Assertion(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestJWTFetch_AssertionPayload(t *testing.T) {
|
||||
var assertion string
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
assertion = r.Form.Get("assertion")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{
|
||||
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
|
||||
"scope": "user",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600
|
||||
}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
for _, conf := range []*Config{
|
||||
{
|
||||
Email: "aaa1@xxx.com",
|
||||
PrivateKey: dummyPrivateKey,
|
||||
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||
TokenURL: ts.URL,
|
||||
},
|
||||
{
|
||||
Email: "aaa2@xxx.com",
|
||||
PrivateKey: dummyPrivateKey,
|
||||
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||
TokenURL: ts.URL,
|
||||
Audience: "https://example.com",
|
||||
},
|
||||
} {
|
||||
t.Run(conf.Email, func(t *testing.T) {
|
||||
_, err := conf.TokenSource(context.Background()).Token()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch token: %v", err)
|
||||
}
|
||||
|
||||
parts := strings.Split(assertion, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("assertion = %q; want 3 parts", assertion)
|
||||
}
|
||||
gotjson, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("invalid token payload; err = %v", err)
|
||||
}
|
||||
|
||||
claimSet := jws.ClaimSet{}
|
||||
if err := json.Unmarshal(gotjson, &claimSet); err != nil {
|
||||
t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err)
|
||||
}
|
||||
|
||||
if got, want := claimSet.Iss, conf.Email; got != want {
|
||||
t.Errorf("payload email = %q; want %q", got, want)
|
||||
}
|
||||
if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want {
|
||||
t.Errorf("payload scope = %q; want %q", got, want)
|
||||
}
|
||||
aud := conf.TokenURL
|
||||
if conf.Audience != "" {
|
||||
aud = conf.Audience
|
||||
}
|
||||
if got, want := claimSet.Aud, aud; got != want {
|
||||
t.Errorf("payload audience = %q; want %q", got, want)
|
||||
}
|
||||
if got, want := claimSet.Sub, conf.Subject; got != want {
|
||||
t.Errorf("payload subject = %q; want %q", got, want)
|
||||
}
|
||||
if got, want := claimSet.Prn, conf.Subject; got != want {
|
||||
t.Errorf("payload prn = %q; want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRetrieveError(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
|
|
Loading…
Reference in New Issue