forked from Mirrors/oauth2
Don't assume optional fields are required and use Basic Auth if available.
See https://github.com/golang/oauth2/issues/33
This commit is contained in:
parent
32b8a902e6
commit
bb8496880f
89
oauth2.go
89
oauth2.go
|
@ -129,22 +129,16 @@ type Config struct {
|
||||||
// that asks for permissions for the required scopes explicitly.
|
// that asks for permissions for the required scopes explicitly.
|
||||||
func (c *Config) AuthCodeURL(state string) (authURL string) {
|
func (c *Config) AuthCodeURL(state string) (authURL string) {
|
||||||
u := *c.authURL
|
u := *c.authURL
|
||||||
vals := url.Values{
|
v := url.Values{
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"client_id": {c.opts.ClientID},
|
"client_id": {c.opts.ClientID},
|
||||||
"scope": {strings.Join(c.opts.Scopes, " ")},
|
"redirect_uri": condVal(c.opts.RedirectURL),
|
||||||
"state": {state},
|
"scope": condVal(strings.Join(c.opts.Scopes, " ")),
|
||||||
|
"state": condVal(state),
|
||||||
|
"access_type": condVal(c.opts.AccessType),
|
||||||
|
"approval_prompt": condVal(c.opts.ApprovalPrompt),
|
||||||
}
|
}
|
||||||
if c.opts.AccessType != "" {
|
q := v.Encode()
|
||||||
vals.Set("access_type", c.opts.AccessType)
|
|
||||||
}
|
|
||||||
if c.opts.ApprovalPrompt != "" {
|
|
||||||
vals.Set("approval_prompt", c.opts.ApprovalPrompt)
|
|
||||||
}
|
|
||||||
if c.opts.RedirectURL != "" {
|
|
||||||
vals.Set("redirect_uri", c.opts.RedirectURL)
|
|
||||||
}
|
|
||||||
q := vals.Encode()
|
|
||||||
if u.RawQuery == "" {
|
if u.RawQuery == "" {
|
||||||
u.RawQuery = q
|
u.RawQuery = q
|
||||||
} else {
|
} else {
|
||||||
|
@ -178,11 +172,10 @@ func (c *Config) NewTransportWithCode(code string) (*Transport, error) {
|
||||||
// contain a refresh token, it returns an error.
|
// contain a refresh token, it returns an error.
|
||||||
func (c *Config) FetchToken(existing *Token) (*Token, error) {
|
func (c *Config) FetchToken(existing *Token) (*Token, error) {
|
||||||
if existing == nil || existing.RefreshToken == "" {
|
if existing == nil || existing.RefreshToken == "" {
|
||||||
return nil, errors.New("cannot fetch access token without refresh token.")
|
return nil, errors.New("auth2: cannot fetch access token without refresh token")
|
||||||
}
|
}
|
||||||
return c.retrieveToken(url.Values{
|
return c.retrieveToken(url.Values{
|
||||||
"grant_type": {"refresh_token"},
|
"grant_type": {"refresh_token"},
|
||||||
"client_secret": {c.opts.ClientSecret},
|
|
||||||
"refresh_token": {existing.RefreshToken},
|
"refresh_token": {existing.RefreshToken},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -190,36 +183,36 @@ func (c *Config) FetchToken(existing *Token) (*Token, error) {
|
||||||
// Exchange exchanges the authorization code with the OAuth 2.0 provider
|
// Exchange exchanges the authorization code with the OAuth 2.0 provider
|
||||||
// to retrieve a new access token.
|
// to retrieve a new access token.
|
||||||
func (c *Config) Exchange(code string) (*Token, error) {
|
func (c *Config) Exchange(code string) (*Token, error) {
|
||||||
vals := url.Values{
|
return c.retrieveToken(url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_secret": {c.opts.ClientSecret},
|
|
||||||
"code": {code},
|
"code": {code},
|
||||||
}
|
"redirect_uri": condVal(c.opts.RedirectURL),
|
||||||
if len(c.opts.Scopes) != 0 {
|
"scope": condVal(strings.Join(c.opts.Scopes, " ")),
|
||||||
vals.Set("scope", strings.Join(c.opts.Scopes, " "))
|
})
|
||||||
}
|
|
||||||
if c.opts.RedirectURL != "" {
|
|
||||||
vals.Set("redirect_uri", c.opts.RedirectURL)
|
|
||||||
}
|
|
||||||
return c.retrieveToken(vals)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
func (c *Config) retrieveToken(v url.Values) (*Token, error) {
|
||||||
v.Set("client_id", c.opts.ClientID)
|
v.Set("client_id", c.opts.ClientID)
|
||||||
// Note that we're not setting v's client_secret to t.ClientSecret, due
|
bustedAuth := !providerAuthHeaderWorks(c.tokenURL.String())
|
||||||
// to https://code.google.com/p/goauth2/issues/detail?id=31
|
if bustedAuth && c.opts.ClientSecret != "" {
|
||||||
// Reddit only accepts client_secret in Authorization header.
|
v.Set("client_secret", c.opts.ClientSecret)
|
||||||
// Dropbox accepts either, but not both.
|
}
|
||||||
// The spec requires servers to always support the Authorization header,
|
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(v.Encode()))
|
||||||
// so that's all we use.
|
if err != nil {
|
||||||
r, err := c.client().PostForm(c.tokenURL.String(), v)
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
if !bustedAuth && c.opts.ClientSecret != "" {
|
||||||
|
req.SetBasicAuth(c.opts.ClientID, c.opts.ClientSecret)
|
||||||
|
}
|
||||||
|
r, err := c.client().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Body.Close()
|
defer r.Body.Close()
|
||||||
if r.StatusCode != 200 {
|
if r.StatusCode != 200 {
|
||||||
// TODO(jbd): Add status code or error message
|
// TODO(jbd): Add status code or error message
|
||||||
return nil, errors.New("error during updating token")
|
return nil, errors.New("oauth2: can't retrieve a new token")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := &tokenRespBody{}
|
resp := &tokenRespBody{}
|
||||||
|
@ -281,3 +274,33 @@ func (c *Config) client() *http.Client {
|
||||||
}
|
}
|
||||||
return http.DefaultClient
|
return http.DefaultClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func condVal(v string) []string {
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{v}
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
|
||||||
|
// implements the OAuth2 spec correctly
|
||||||
|
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
||||||
|
// In summary:
|
||||||
|
// - Reddit only accepts client secret in the Authorization header
|
||||||
|
// - Dropbox accepts either it in URL param or Auth header, but not both.
|
||||||
|
// - Google only accepts URL param (not spec compliant?), not Auth header
|
||||||
|
func providerAuthHeaderWorks(tokenURL string) bool {
|
||||||
|
if strings.HasPrefix(tokenURL, "https://accounts.google.com/") ||
|
||||||
|
strings.HasPrefix(tokenURL, "https://github.com/") ||
|
||||||
|
strings.HasPrefix(tokenURL, "https://api.instagram.com/") ||
|
||||||
|
strings.HasPrefix(tokenURL, "https://www.douban.com/") {
|
||||||
|
// Some sites fail to implement the OAuth2 spec fully.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assume the provider implements the spec properly
|
||||||
|
// otherwise. We can add more exceptions as they're
|
||||||
|
// discovered. We will _not_ be adding configurable hooks
|
||||||
|
// to this package to let users select server bugs.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
127
oauth2_test.go
127
oauth2_test.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,7 +20,7 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
|
||||||
return t.rt(req)
|
return t.rt(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestConf() *Config {
|
func newTestConf(url string) *Config {
|
||||||
conf, _ := NewConfig(&Options{
|
conf, _ := NewConfig(&Options{
|
||||||
ClientID: "CLIENT_ID",
|
ClientID: "CLIENT_ID",
|
||||||
ClientSecret: "CLIENT_SECRET",
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
@ -30,83 +31,105 @@ func newTestConf() *Config {
|
||||||
},
|
},
|
||||||
AccessType: "offline",
|
AccessType: "offline",
|
||||||
ApprovalPrompt: "force",
|
ApprovalPrompt: "force",
|
||||||
}, "auth-url", "token-url")
|
}, url+"/auth", url+"/token")
|
||||||
return conf
|
return conf
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthCodeURL(t *testing.T) {
|
func TestAuthCodeURL(t *testing.T) {
|
||||||
conf := newTestConf()
|
conf := newTestConf("server")
|
||||||
url := conf.AuthCodeURL("foo")
|
url := conf.AuthCodeURL("foo")
|
||||||
if url != "auth-url?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" {
|
if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" {
|
||||||
t.Fatalf("Generated auth URL is not the expected. Found %v.", url)
|
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExchangePayload(t *testing.T) {
|
func TestAuthCodeURL_Optional(t *testing.T) {
|
||||||
oldDefaultTransport := http.DefaultTransport
|
conf, _ := NewConfig(&Options{
|
||||||
defer func() {
|
ClientID: "CLIENT_ID",
|
||||||
http.DefaultTransport = oldDefaultTransport
|
}, "auth-url", "token-url")
|
||||||
}()
|
url := conf.AuthCodeURL("")
|
||||||
|
if url != "auth-url?client_id=CLIENT_ID&response_type=code" {
|
||||||
|
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
conf := newTestConf()
|
func TestExchangeRequest(t *testing.T) {
|
||||||
http.DefaultTransport = &mockTransport{
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
rt: func(req *http.Request) (resp *http.Response, err error) {
|
if r.URL.String() != "/token" {
|
||||||
headerContentType := req.Header.Get("Content-Type")
|
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
|
||||||
|
}
|
||||||
|
headerAuth := r.Header.Get("Authorization")
|
||||||
|
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
|
||||||
|
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
|
||||||
|
}
|
||||||
|
headerContentType := r.Header.Get("Content-Type")
|
||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
t.Fatalf("Content-Type header is expected to be application/x-www-form-urlencoded, %v found.", headerContentType)
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
||||||
}
|
}
|
||||||
body, _ := ioutil.ReadAll(req.Body)
|
body, _ := ioutil.ReadAll(r.Body)
|
||||||
if string(body) != "client_id=CLIENT_ID&client_secret=CLIENT_SECRET&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" {
|
if string(body) != "client_id=CLIENT_ID&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" {
|
||||||
t.Fatalf("Exchange payload is found to be %v", string(body))
|
t.Errorf("Unexpected exchange payload, %v is found.", string(body))
|
||||||
}
|
|
||||||
return nil, errors.New("no response")
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
conf := newTestConf(ts.URL)
|
||||||
conf.Exchange("exchange-code")
|
conf.Exchange("exchange-code")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRefreshPayload(t *testing.T) {
|
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
|
||||||
oldDefaultTransport := http.DefaultTransport
|
conf, _ := NewConfig(&Options{
|
||||||
defer func() {
|
ClientID: "CLIENT_ID",
|
||||||
http.DefaultTransport = oldDefaultTransport
|
}, "https://accounts.google.com/auth",
|
||||||
}()
|
"https://accounts.google.com/token")
|
||||||
conf := newTestConf()
|
tr := &mockTransport{
|
||||||
http.DefaultTransport = &mockTransport{
|
rt: func(r *http.Request) (w *http.Response, err error) {
|
||||||
rt: func(req *http.Request) (resp *http.Response, err error) {
|
headerAuth := r.Header.Get("Authorization")
|
||||||
headerContentType := req.Header.Get("Content-Type")
|
if headerAuth != "" {
|
||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
|
||||||
t.Fatalf("Content-Type header is expected to be application/x-www-form-urlencoded, %v found.", headerContentType)
|
|
||||||
}
|
|
||||||
body, _ := ioutil.ReadAll(req.Body)
|
|
||||||
if string(body) != "client_id=CLIENT_ID&client_secret=CLIENT_SECRET&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
|
||||||
t.Fatalf("Exchange payload is found to be %v", string(body))
|
|
||||||
}
|
}
|
||||||
return nil, errors.New("no response")
|
return nil, errors.New("no response")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
conf.Client = &http.Client{Transport: tr}
|
||||||
|
conf.Exchange("code")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRefreshRequest(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.String() != "/token" {
|
||||||
|
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
|
||||||
|
}
|
||||||
|
headerContentType := r.Header.Get("Content-Type")
|
||||||
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
||||||
|
}
|
||||||
|
body, _ := ioutil.ReadAll(r.Body)
|
||||||
|
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
||||||
|
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
conf := newTestConf(ts.URL)
|
||||||
conf.FetchToken(&Token{RefreshToken: "REFRESH_TOKEN"})
|
conf.FetchToken(&Token{RefreshToken: "REFRESH_TOKEN"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExchangingTransport(t *testing.T) {
|
func TestNewTransportWithCode(t *testing.T) {
|
||||||
oldDefaultTransport := http.DefaultTransport
|
exchanged := false
|
||||||
defer func() {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
http.DefaultTransport = oldDefaultTransport
|
if r.URL.RequestURI() == "/token" {
|
||||||
}()
|
exchanged = true
|
||||||
|
|
||||||
conf := newTestConf()
|
|
||||||
http.DefaultTransport = &mockTransport{
|
|
||||||
rt: func(req *http.Request) (resp *http.Response, err error) {
|
|
||||||
if req.URL.RequestURI() != "token-url" {
|
|
||||||
t.Fatalf("NewTransportWithCode should have exchanged the code, but it didn't.")
|
|
||||||
}
|
|
||||||
return nil, errors.New("no response")
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
conf := newTestConf(ts.URL)
|
||||||
conf.NewTransportWithCode("exchange-code")
|
conf.NewTransportWithCode("exchange-code")
|
||||||
|
if !exchanged {
|
||||||
|
t.Errorf("NewTransportWithCode should have exchanged the code, but it didn't.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFetchWithNoRedirect(t *testing.T) {
|
func TestFetchWithNoRefreshToken(t *testing.T) {
|
||||||
fetcher := newTestConf()
|
fetcher := newTestConf("")
|
||||||
_, err := fetcher.FetchToken(&Token{})
|
_, err := fetcher.FetchToken(&Token{})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Fetch should return an error if no refresh token is set")
|
t.Fatalf("Fetch should return an error if no refresh token is set")
|
||||||
|
|
Loading…
Reference in New Issue