clientcredentials: allow override of grant_type

This commit is contained in:
Tom Payne 2019-01-18 13:23:58 +01:00
parent 5dab4167f3
commit 0e6f85e31e
No known key found for this signature in database
GPG Key ID: 6EBF9E59539462BB
2 changed files with 40 additions and 1 deletions

View File

@ -90,7 +90,9 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
v.Set("scope", strings.Join(c.conf.Scopes, " "))
}
for k, p := range c.conf.EndpointParams {
if _, ok := v[k]; ok {
// Allow grant_type to be overridden to allow interoperability with
// non-compliant implementations.
if _, ok := v[k]; ok && k != "grant_type" {
return nil, fmt.Errorf("oauth2: cannot overwrite parameter %q", k)
}
v[k] = p

View File

@ -31,6 +31,43 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
return t.rt(req)
}
func TestTokenSourceGrantTypeOverride(t *testing.T) {
wantGrantType := "password"
var gotGrantType string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("ioutil.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err)
}
if err := r.Body.Close(); err != nil {
t.Errorf("r.Body.Close() == %v, want <nil>", err)
}
values, err := url.ParseQuery(string(body))
if err != nil {
t.Errorf("url.ParseQuery(%q) == %v, %v, want _, <nil>", body, values, err)
}
gotGrantType = values.Get("grant_type")
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer"))
}))
config := &Config{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
Scopes: []string{"scope"},
TokenURL: ts.URL + "/token",
EndpointParams: url.Values{
"grant_type": {wantGrantType},
},
}
token, err := config.TokenSource(context.Background()).Token()
if err != nil {
t.Errorf("config.TokenSource(_).Token() == %v, %v, want !<nil>, <nil>", token, err)
}
if gotGrantType != wantGrantType {
t.Errorf("grant_type == %q, want %q", gotGrantType, wantGrantType)
}
}
func TestTokenRequest(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" {