diff --git a/oauth2.go b/oauth2.go index f757875..a98a0a6 100644 --- a/oauth2.go +++ b/oauth2.go @@ -129,22 +129,16 @@ type Config struct { // that asks for permissions for the required scopes explicitly. func (c *Config) AuthCodeURL(state string) (authURL string) { u := *c.authURL - vals := url.Values{ - "response_type": {"code"}, - "client_id": {c.opts.ClientID}, - "scope": {strings.Join(c.opts.Scopes, " ")}, - "state": {state}, + v := url.Values{ + "response_type": {"code"}, + "client_id": {c.opts.ClientID}, + "redirect_uri": condVal(c.opts.RedirectURL), + "scope": condVal(strings.Join(c.opts.Scopes, " ")), + "state": condVal(state), + "access_type": condVal(c.opts.AccessType), + "approval_prompt": condVal(c.opts.ApprovalPrompt), } - if c.opts.AccessType != "" { - vals.Set("access_type", c.opts.AccessType) - } - if c.opts.ApprovalPrompt != "" { - vals.Set("approval_prompt", c.opts.ApprovalPrompt) - } - if c.opts.RedirectURL != "" { - vals.Set("redirect_uri", c.opts.RedirectURL) - } - q := vals.Encode() + q := v.Encode() if u.RawQuery == "" { u.RawQuery = q } else { @@ -178,11 +172,10 @@ func (c *Config) NewTransportWithCode(code string) (*Transport, error) { // contain a refresh token, it returns an error. func (c *Config) FetchToken(existing *Token) (*Token, error) { if existing == nil || existing.RefreshToken == "" { - return nil, errors.New("cannot fetch access token without refresh token.") + return nil, errors.New("auth2: cannot fetch access token without refresh token") } return c.retrieveToken(url.Values{ "grant_type": {"refresh_token"}, - "client_secret": {c.opts.ClientSecret}, "refresh_token": {existing.RefreshToken}, }) } @@ -190,36 +183,36 @@ func (c *Config) FetchToken(existing *Token) (*Token, error) { // Exchange exchanges the authorization code with the OAuth 2.0 provider // to retrieve a new access token. func (c *Config) Exchange(code string) (*Token, error) { - vals := url.Values{ - "grant_type": {"authorization_code"}, - "client_secret": {c.opts.ClientSecret}, - "code": {code}, - } - if len(c.opts.Scopes) != 0 { - vals.Set("scope", strings.Join(c.opts.Scopes, " ")) - } - if c.opts.RedirectURL != "" { - vals.Set("redirect_uri", c.opts.RedirectURL) - } - return c.retrieveToken(vals) + return c.retrieveToken(url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": condVal(c.opts.RedirectURL), + "scope": condVal(strings.Join(c.opts.Scopes, " ")), + }) } func (c *Config) retrieveToken(v url.Values) (*Token, error) { v.Set("client_id", c.opts.ClientID) - // Note that we're not setting v's client_secret to t.ClientSecret, due - // to https://code.google.com/p/goauth2/issues/detail?id=31 - // Reddit only accepts client_secret in Authorization header. - // Dropbox accepts either, but not both. - // The spec requires servers to always support the Authorization header, - // so that's all we use. - r, err := c.client().PostForm(c.tokenURL.String(), v) + bustedAuth := !providerAuthHeaderWorks(c.tokenURL.String()) + if bustedAuth && c.opts.ClientSecret != "" { + v.Set("client_secret", c.opts.ClientSecret) + } + req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(v.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if !bustedAuth && c.opts.ClientSecret != "" { + req.SetBasicAuth(c.opts.ClientID, c.opts.ClientSecret) + } + r, err := c.client().Do(req) if err != nil { return nil, err } defer r.Body.Close() if r.StatusCode != 200 { // TODO(jbd): Add status code or error message - return nil, errors.New("error during updating token") + return nil, errors.New("oauth2: can't retrieve a new token") } resp := &tokenRespBody{} @@ -281,3 +274,33 @@ func (c *Config) client() *http.Client { } return http.DefaultClient } + +func condVal(v string) []string { + if v == "" { + return nil + } + return []string{v} +} + +// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL +// implements the OAuth2 spec correctly +// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. +// In summary: +// - Reddit only accepts client secret in the Authorization header +// - Dropbox accepts either it in URL param or Auth header, but not both. +// - Google only accepts URL param (not spec compliant?), not Auth header +func providerAuthHeaderWorks(tokenURL string) bool { + if strings.HasPrefix(tokenURL, "https://accounts.google.com/") || + strings.HasPrefix(tokenURL, "https://github.com/") || + strings.HasPrefix(tokenURL, "https://api.instagram.com/") || + strings.HasPrefix(tokenURL, "https://www.douban.com/") { + // Some sites fail to implement the OAuth2 spec fully. + return false + } + + // Assume the provider implements the spec properly + // otherwise. We can add more exceptions as they're + // discovered. We will _not_ be adding configurable hooks + // to this package to let users select server bugs. + return true +} diff --git a/oauth2_test.go b/oauth2_test.go index bac8656..7882dc9 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -8,6 +8,7 @@ import ( "errors" "io/ioutil" "net/http" + "net/http/httptest" "testing" ) @@ -19,7 +20,7 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e return t.rt(req) } -func newTestConf() *Config { +func newTestConf(url string) *Config { conf, _ := NewConfig(&Options{ ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", @@ -30,83 +31,105 @@ func newTestConf() *Config { }, AccessType: "offline", ApprovalPrompt: "force", - }, "auth-url", "token-url") + }, url+"/auth", url+"/token") return conf } func TestAuthCodeURL(t *testing.T) { - conf := newTestConf() + conf := newTestConf("server") url := conf.AuthCodeURL("foo") - if url != "auth-url?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" { - t.Fatalf("Generated auth URL is not the expected. Found %v.", url) + if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" { + t.Fatalf("Auth code URL doesn't match the expected, found: %v", url) } } -func TestExchangePayload(t *testing.T) { - oldDefaultTransport := http.DefaultTransport - defer func() { - http.DefaultTransport = oldDefaultTransport - }() - - conf := newTestConf() - http.DefaultTransport = &mockTransport{ - rt: func(req *http.Request) (resp *http.Response, err error) { - headerContentType := req.Header.Get("Content-Type") - if headerContentType != "application/x-www-form-urlencoded" { - t.Fatalf("Content-Type header is expected to be application/x-www-form-urlencoded, %v found.", headerContentType) - } - body, _ := ioutil.ReadAll(req.Body) - if string(body) != "client_id=CLIENT_ID&client_secret=CLIENT_SECRET&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" { - t.Fatalf("Exchange payload is found to be %v", string(body)) - } - return nil, errors.New("no response") - }, +func TestAuthCodeURL_Optional(t *testing.T) { + conf, _ := NewConfig(&Options{ + ClientID: "CLIENT_ID", + }, "auth-url", "token-url") + url := conf.AuthCodeURL("") + if url != "auth-url?client_id=CLIENT_ID&response_type=code" { + t.Fatalf("Auth code URL doesn't match the expected, found: %v", url) } +} + +func TestExchangeRequest(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) + } + headerAuth := r.Header.Get("Authorization") + if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { + t.Errorf("Unexpected authorization header, %v is found.", headerAuth) + } + headerContentType := r.Header.Get("Content-Type") + if headerContentType != "application/x-www-form-urlencoded" { + t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + } + body, _ := ioutil.ReadAll(r.Body) + if string(body) != "client_id=CLIENT_ID&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" { + t.Errorf("Unexpected exchange payload, %v is found.", string(body)) + } + })) + defer ts.Close() + conf := newTestConf(ts.URL) conf.Exchange("exchange-code") } -func TestRefreshPayload(t *testing.T) { - oldDefaultTransport := http.DefaultTransport - defer func() { - http.DefaultTransport = oldDefaultTransport - }() - conf := newTestConf() - http.DefaultTransport = &mockTransport{ - rt: func(req *http.Request) (resp *http.Response, err error) { - headerContentType := req.Header.Get("Content-Type") - if headerContentType != "application/x-www-form-urlencoded" { - t.Fatalf("Content-Type header is expected to be application/x-www-form-urlencoded, %v found.", headerContentType) - } - body, _ := ioutil.ReadAll(req.Body) - if string(body) != "client_id=CLIENT_ID&client_secret=CLIENT_SECRET&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { - t.Fatalf("Exchange payload is found to be %v", string(body)) +func TestExchangeRequest_NonBasicAuth(t *testing.T) { + conf, _ := NewConfig(&Options{ + ClientID: "CLIENT_ID", + }, "https://accounts.google.com/auth", + "https://accounts.google.com/token") + tr := &mockTransport{ + rt: func(r *http.Request) (w *http.Response, err error) { + headerAuth := r.Header.Get("Authorization") + if headerAuth != "" { + t.Errorf("Unexpected authorization header, %v is found.", headerAuth) } return nil, errors.New("no response") }, } + conf.Client = &http.Client{Transport: tr} + conf.Exchange("code") +} + +func TestTokenRefreshRequest(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + } + headerContentType := r.Header.Get("Content-Type") + if headerContentType != "application/x-www-form-urlencoded" { + t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + } + body, _ := ioutil.ReadAll(r.Body) + if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { + t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) + } + })) + defer ts.Close() + conf := newTestConf(ts.URL) conf.FetchToken(&Token{RefreshToken: "REFRESH_TOKEN"}) } -func TestExchangingTransport(t *testing.T) { - oldDefaultTransport := http.DefaultTransport - defer func() { - http.DefaultTransport = oldDefaultTransport - }() - - conf := newTestConf() - http.DefaultTransport = &mockTransport{ - rt: func(req *http.Request) (resp *http.Response, err error) { - if req.URL.RequestURI() != "token-url" { - t.Fatalf("NewTransportWithCode should have exchanged the code, but it didn't.") - } - return nil, errors.New("no response") - }, - } +func TestNewTransportWithCode(t *testing.T) { + exchanged := false + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.RequestURI() == "/token" { + exchanged = true + } + })) + defer ts.Close() + conf := newTestConf(ts.URL) conf.NewTransportWithCode("exchange-code") + if !exchanged { + t.Errorf("NewTransportWithCode should have exchanged the code, but it didn't.") + } } -func TestFetchWithNoRedirect(t *testing.T) { - fetcher := newTestConf() +func TestFetchWithNoRefreshToken(t *testing.T) { + fetcher := newTestConf("") _, err := fetcher.FetchToken(&Token{}) if err == nil { t.Fatalf("Fetch should return an error if no refresh token is set")