forked from Mirrors/oauth2
oauth2: close request body if errors occur before base RoundTripper is invoked
Fixes golang/oauth#269 Change-Id: I25eb3273a0868a999a2e98961ae5e4040e44ad7a Reviewed-on: https://go-review.googlesource.com/114956 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
bee4e0a411
commit
30b72dfc06
13
transport.go
13
transport.go
|
@ -34,6 +34,15 @@ type Transport struct {
|
||||||
// access token. If no token exists or token is expired,
|
// access token. If no token exists or token is expired,
|
||||||
// tries to refresh/fetch a new token.
|
// tries to refresh/fetch a new token.
|
||||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
reqBodyClosed := false
|
||||||
|
if req.Body != nil {
|
||||||
|
defer func() {
|
||||||
|
if !reqBodyClosed {
|
||||||
|
req.Body.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
if t.Source == nil {
|
if t.Source == nil {
|
||||||
return nil, errors.New("oauth2: Transport's Source is nil")
|
return nil, errors.New("oauth2: Transport's Source is nil")
|
||||||
}
|
}
|
||||||
|
@ -46,6 +55,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
token.SetAuthHeader(req2)
|
token.SetAuthHeader(req2)
|
||||||
t.setModReq(req, req2)
|
t.setModReq(req, req2)
|
||||||
res, err := t.base().RoundTrip(req2)
|
res, err := t.base().RoundTrip(req2)
|
||||||
|
|
||||||
|
// req.Body is assumed to have been closed by the base RoundTripper.
|
||||||
|
reqBodyClosed = true
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.setModReq(req, nil)
|
t.setModReq(req, nil)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -27,6 +29,64 @@ func TestTransportNilTokenSource(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type readCloseCounter struct {
|
||||||
|
CloseCount int
|
||||||
|
ReadErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *readCloseCounter) Read(b []byte) (int, error) {
|
||||||
|
return 0, r.ReadErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *readCloseCounter) Close() error {
|
||||||
|
r.CloseCount++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportCloseRequestBody(t *testing.T) {
|
||||||
|
tr := &Transport{}
|
||||||
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
defer server.Close()
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
body := &readCloseCounter{
|
||||||
|
ReadErr: errors.New("readCloseCounter.Read not implemented"),
|
||||||
|
}
|
||||||
|
resp, err := client.Post(server.URL, "application/json", body)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("got no errors, want an error with nil token source")
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Errorf("Response = %v; want nil", resp)
|
||||||
|
}
|
||||||
|
if expected := 1; body.CloseCount != expected {
|
||||||
|
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportCloseRequestBodySuccess(t *testing.T) {
|
||||||
|
tr := &Transport{
|
||||||
|
Source: StaticTokenSource(&Token{
|
||||||
|
AccessToken: "abc",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
|
defer server.Close()
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
body := &readCloseCounter{
|
||||||
|
ReadErr: io.EOF,
|
||||||
|
}
|
||||||
|
resp, err := client.Post(server.URL, "application/json", body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("got error %v; expected none", err)
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
t.Errorf("Response is nil; expected non-nil")
|
||||||
|
}
|
||||||
|
if expected := 1; body.CloseCount != expected {
|
||||||
|
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTransportTokenSource(t *testing.T) {
|
func TestTransportTokenSource(t *testing.T) {
|
||||||
ts := &tokenSource{
|
ts := &tokenSource{
|
||||||
token: &Token{
|
token: &Token{
|
||||||
|
|
Loading…
Reference in New Issue