forked from Mirrors/oauth2
clientcredentials: allow override of grant_type
This commit is contained in:
parent
5dab4167f3
commit
0e6f85e31e
|
@ -90,7 +90,9 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
|
||||||
v.Set("scope", strings.Join(c.conf.Scopes, " "))
|
v.Set("scope", strings.Join(c.conf.Scopes, " "))
|
||||||
}
|
}
|
||||||
for k, p := range c.conf.EndpointParams {
|
for k, p := range c.conf.EndpointParams {
|
||||||
if _, ok := v[k]; ok {
|
// Allow grant_type to be overridden to allow interoperability with
|
||||||
|
// non-compliant implementations.
|
||||||
|
if _, ok := v[k]; ok && k != "grant_type" {
|
||||||
return nil, fmt.Errorf("oauth2: cannot overwrite parameter %q", k)
|
return nil, fmt.Errorf("oauth2: cannot overwrite parameter %q", k)
|
||||||
}
|
}
|
||||||
v[k] = p
|
v[k] = p
|
||||||
|
|
|
@ -31,6 +31,43 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
|
||||||
return t.rt(req)
|
return t.rt(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTokenSourceGrantTypeOverride(t *testing.T) {
|
||||||
|
wantGrantType := "password"
|
||||||
|
var gotGrantType string
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ioutil.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err)
|
||||||
|
}
|
||||||
|
if err := r.Body.Close(); err != nil {
|
||||||
|
t.Errorf("r.Body.Close() == %v, want <nil>", err)
|
||||||
|
}
|
||||||
|
values, err := url.ParseQuery(string(body))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("url.ParseQuery(%q) == %v, %v, want _, <nil>", body, values, err)
|
||||||
|
}
|
||||||
|
gotGrantType = values.Get("grant_type")
|
||||||
|
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer"))
|
||||||
|
}))
|
||||||
|
config := &Config{
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
Scopes: []string{"scope"},
|
||||||
|
TokenURL: ts.URL + "/token",
|
||||||
|
EndpointParams: url.Values{
|
||||||
|
"grant_type": {wantGrantType},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
token, err := config.TokenSource(context.Background()).Token()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("config.TokenSource(_).Token() == %v, %v, want !<nil>, <nil>", token, err)
|
||||||
|
}
|
||||||
|
if gotGrantType != wantGrantType {
|
||||||
|
t.Errorf("grant_type == %q, want %q", gotGrantType, wantGrantType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTokenRequest(t *testing.T) {
|
func TestTokenRequest(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.String() != "/token" {
|
if r.URL.String() != "/token" {
|
||||||
|
|
Loading…
Reference in New Issue