oauth2: close request body if errors occur before base RoundTripper is invoked

Fixes golang/oauth#269

Change-Id: I25eb3273a0868a999a2e98961ae5e4040e44ad7a
Reviewed-on: https://go-review.googlesource.com/114956
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Tim Cooper 2018-05-28 18:26:48 -03:00 committed by Brad Fitzpatrick
parent bee4e0a411
commit 30b72dfc06
2 changed files with 73 additions and 0 deletions

View File

@ -34,6 +34,15 @@ type Transport struct {
// access token. If no token exists or token is expired,
// tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
reqBodyClosed := false
if req.Body != nil {
defer func() {
if !reqBodyClosed {
req.Body.Close()
}
}()
}
if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil")
}
@ -46,6 +55,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
token.SetAuthHeader(req2)
t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2)
// req.Body is assumed to have been closed by the base RoundTripper.
reqBodyClosed = true
if err != nil {
t.setModReq(req, nil)
return nil, err

View File

@ -1,6 +1,8 @@
package oauth2
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
@ -27,6 +29,64 @@ func TestTransportNilTokenSource(t *testing.T) {
}
}
type readCloseCounter struct {
CloseCount int
ReadErr error
}
func (r *readCloseCounter) Read(b []byte) (int, error) {
return 0, r.ReadErr
}
func (r *readCloseCounter) Close() error {
r.CloseCount++
return nil
}
func TestTransportCloseRequestBody(t *testing.T) {
tr := &Transport{}
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
defer server.Close()
client := &http.Client{Transport: tr}
body := &readCloseCounter{
ReadErr: errors.New("readCloseCounter.Read not implemented"),
}
resp, err := client.Post(server.URL, "application/json", body)
if err == nil {
t.Errorf("got no errors, want an error with nil token source")
}
if resp != nil {
t.Errorf("Response = %v; want nil", resp)
}
if expected := 1; body.CloseCount != expected {
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
}
}
func TestTransportCloseRequestBodySuccess(t *testing.T) {
tr := &Transport{
Source: StaticTokenSource(&Token{
AccessToken: "abc",
}),
}
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
defer server.Close()
client := &http.Client{Transport: tr}
body := &readCloseCounter{
ReadErr: io.EOF,
}
resp, err := client.Post(server.URL, "application/json", body)
if err != nil {
t.Errorf("got error %v; expected none", err)
}
if resp == nil {
t.Errorf("Response is nil; expected non-nil")
}
if expected := 1; body.CloseCount != expected {
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
}
}
func TestTransportTokenSource(t *testing.T) {
ts := &tokenSource{
token: &Token{