oauth2: Add support for custom params in Exchange

Allows implementation of PKCE https://www.oauth.com/oauth2-servers/pkce/
for secure code exchange.

Fixes golang/oauth2#286

Signed-off-by: Guillaume J. Charmes <gcharmes@magicleap.com>
This commit is contained in:
Guillaume J. Charmes 2018-05-04 11:32:07 -04:00
parent 2f32c3ac0f
commit 31c5ccbed3
No known key found for this signature in database
GPG Key ID: AA04E63AF5CD9AAA
2 changed files with 55 additions and 1 deletions

View File

@ -123,6 +123,8 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// //
// Opts may include AccessTypeOnline or AccessTypeOffline, as well // Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce. // as ApprovalForce.
// It can also be used to pass the PKCE challange.
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL) buf.WriteString(c.Endpoint.AuthURL)
@ -185,7 +187,10 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// //
// The code will be in the *http.Request.FormValue("code"). Before // The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state"). // calling Exchange, be sure to validate FormValue("state").
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) { //
// Opts may include the PKCE verifier code if previously used in AuthCodeURL.
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
v := url.Values{ v := url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
"code": {code}, "code": {code},
@ -193,6 +198,9 @@ func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) {
if c.RedirectURL != "" { if c.RedirectURL != "" {
v.Set("redirect_uri", c.RedirectURL) v.Set("redirect_uri", c.RedirectURL)
} }
for _, opt := range opts {
opt.setValue(v)
}
return retrieveToken(ctx, c, v) return retrieveToken(ctx, c, v)
} }

View File

@ -135,6 +135,52 @@ func TestExchangeRequest(t *testing.T) {
} }
} }
func TestExchangeRequest_CustomParam(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" {
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
}
headerAuth := r.Header.Get("Authorization")
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
}
headerContentType := r.Header.Get("Content-Type")
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %s.", err)
}
if string(body) != "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
t.Errorf("Unexpected exchange payload, %v is found.", string(body))
}
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
}))
defer ts.Close()
conf := newConf(ts.URL)
param := SetAuthURLParam("foo", "bar")
tok, err := conf.Exchange(context.Background(), "exchange-code", param)
if err != nil {
t.Error(err)
}
if !tok.Valid() {
t.Fatalf("Token invalid. Got: %#v", tok)
}
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
if tok.TokenType != "bearer" {
t.Errorf("Unexpected token type, %#v.", tok.TokenType)
}
scope := tok.Extra("scope")
if scope != "user" {
t.Errorf("Unexpected value for scope: %v", scope)
}
}
func TestExchangeRequest_JSONResponse(t *testing.T) { func TestExchangeRequest_JSONResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" { if r.URL.String() != "/token" {