forked from Mirrors/oauth2
clientcredentials: allow override of grant_type
This commit is contained in:
parent
5dab4167f3
commit
0e6f85e31e
|
@ -90,7 +90,9 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
|
|||
v.Set("scope", strings.Join(c.conf.Scopes, " "))
|
||||
}
|
||||
for k, p := range c.conf.EndpointParams {
|
||||
if _, ok := v[k]; ok {
|
||||
// Allow grant_type to be overridden to allow interoperability with
|
||||
// non-compliant implementations.
|
||||
if _, ok := v[k]; ok && k != "grant_type" {
|
||||
return nil, fmt.Errorf("oauth2: cannot overwrite parameter %q", k)
|
||||
}
|
||||
v[k] = p
|
||||
|
|
|
@ -31,6 +31,43 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
|
|||
return t.rt(req)
|
||||
}
|
||||
|
||||
func TestTokenSourceGrantTypeOverride(t *testing.T) {
|
||||
wantGrantType := "password"
|
||||
var gotGrantType string
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("ioutil.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err)
|
||||
}
|
||||
if err := r.Body.Close(); err != nil {
|
||||
t.Errorf("r.Body.Close() == %v, want <nil>", err)
|
||||
}
|
||||
values, err := url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
t.Errorf("url.ParseQuery(%q) == %v, %v, want _, <nil>", body, values, err)
|
||||
}
|
||||
gotGrantType = values.Get("grant_type")
|
||||
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer"))
|
||||
}))
|
||||
config := &Config{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
Scopes: []string{"scope"},
|
||||
TokenURL: ts.URL + "/token",
|
||||
EndpointParams: url.Values{
|
||||
"grant_type": {wantGrantType},
|
||||
},
|
||||
}
|
||||
token, err := config.TokenSource(context.Background()).Token()
|
||||
if err != nil {
|
||||
t.Errorf("config.TokenSource(_).Token() == %v, %v, want !<nil>, <nil>", token, err)
|
||||
}
|
||||
if gotGrantType != wantGrantType {
|
||||
t.Errorf("grant_type == %q, want %q", gotGrantType, wantGrantType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRequest(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.String() != "/token" {
|
||||
|
|
Loading…
Reference in New Issue