diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 3c816bb..0812964 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -90,7 +90,9 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { v.Set("scope", strings.Join(c.conf.Scopes, " ")) } for k, p := range c.conf.EndpointParams { - if _, ok := v[k]; ok { + // Allow grant_type to be overridden to allow interoperability with + // non-compliant implementations. + if _, ok := v[k]; ok && k != "grant_type" { return nil, fmt.Errorf("oauth2: cannot overwrite parameter %q", k) } v[k] = p diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 108520c..35cbd54 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -31,6 +31,43 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e return t.rt(req) } +func TestTokenSourceGrantTypeOverride(t *testing.T) { + wantGrantType := "password" + var gotGrantType string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("ioutil.ReadAll(r.Body) == %v, %v, want _, ", body, err) + } + if err := r.Body.Close(); err != nil { + t.Errorf("r.Body.Close() == %v, want ", err) + } + values, err := url.ParseQuery(string(body)) + if err != nil { + t.Errorf("url.ParseQuery(%q) == %v, %v, want _, ", body, values, err) + } + gotGrantType = values.Get("grant_type") + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer")) + })) + config := &Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + Scopes: []string{"scope"}, + TokenURL: ts.URL + "/token", + EndpointParams: url.Values{ + "grant_type": {wantGrantType}, + }, + } + token, err := config.TokenSource(context.Background()).Token() + if err != nil { + t.Errorf("config.TokenSource(_).Token() == %v, %v, want !, ", token, err) + } + if gotGrantType != wantGrantType { + t.Errorf("grant_type == %q, want %q", gotGrantType, wantGrantType) + } +} + func TestTokenRequest(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() != "/token" {