golang.org/x/oauth2/jwt: Set kid to KeyID of private key

Set the KeyID hint in the token header. This allows remote servers to
identify the key used to sign the message.

Fixes #18307

Change-Id: Ib95398079833aad6b390650b465d7b09b5f53fda
Reviewed-on: https://go-review.googlesource.com/34320
Reviewed-by: Jaana Burcu Dogan <jbd@google.com>
This commit is contained in:
Tristan Colgate 2016-12-14 09:25:55 +00:00 committed by Jaana Burcu Dogan
parent 96382aa079
commit 314dd2c0bf
2 changed files with 60 additions and 1 deletions

View File

@ -105,7 +105,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
if t := js.conf.Expires; t > 0 { if t := js.conf.Expires; t > 0 {
claimSet.Exp = time.Now().Add(t).Unix() claimSet.Exp = time.Now().Add(t).Unix()
} }
payload, err := jws.Encode(defaultHeader, claimSet, pk) h := *defaultHeader
h.KeyID = js.conf.PrivateKeyID
payload, err := jws.Encode(&h, claimSet, pk)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -6,9 +6,14 @@ package jwt
import ( import (
"context" "context"
"encoding/base64"
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"golang.org/x/oauth2/jws"
) )
var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
@ -131,3 +136,55 @@ func TestJWTFetch_BadResponseType(t *testing.T) {
} }
} }
} }
func TestJWTFetch_Assertion(t *testing.T) {
var assertion string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
assertion = r.Form.Get("assertion")
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
"scope": "user",
"token_type": "bearer",
"expires_in": 3600
}`))
}))
defer ts.Close()
conf := &Config{
Email: "aaa@xxx.com",
PrivateKey: dummyPrivateKey,
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
TokenURL: ts.URL,
}
_, err := conf.TokenSource(context.Background()).Token()
if err != nil {
t.Fatalf("Failed to fetch token: %v", err)
}
parts := strings.Split(assertion, ".")
if len(parts) != 3 {
t.Fatalf("assertion = %q; want 3 parts", assertion)
}
gotjson, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
t.Fatalf("invalid token header; err = %v", err)
}
got := jws.Header{}
if err := json.Unmarshal(gotjson, &got); err != nil {
t.Errorf("failed to unmarshal json token header = %q; err = %v", gotjson, err)
}
want := jws.Header{
Algorithm: "RS256",
Typ: "JWT",
KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
}
if got != want {
t.Errorf("access token header = %q; want %q", got, want)
}
}