forked from Mirrors/oauth2
oauth2: close request body if errors occur before base RoundTripper is invoked
Fixes golang/oauth#269 Change-Id: I25eb3273a0868a999a2e98961ae5e4040e44ad7a Reviewed-on: https://go-review.googlesource.com/114956 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
bee4e0a411
commit
30b72dfc06
13
transport.go
13
transport.go
|
@ -34,6 +34,15 @@ type Transport struct {
|
|||
// access token. If no token exists or token is expired,
|
||||
// tries to refresh/fetch a new token.
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
reqBodyClosed := false
|
||||
if req.Body != nil {
|
||||
defer func() {
|
||||
if !reqBodyClosed {
|
||||
req.Body.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if t.Source == nil {
|
||||
return nil, errors.New("oauth2: Transport's Source is nil")
|
||||
}
|
||||
|
@ -46,6 +55,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
token.SetAuthHeader(req2)
|
||||
t.setModReq(req, req2)
|
||||
res, err := t.base().RoundTrip(req2)
|
||||
|
||||
// req.Body is assumed to have been closed by the base RoundTripper.
|
||||
reqBodyClosed = true
|
||||
|
||||
if err != nil {
|
||||
t.setModReq(req, nil)
|
||||
return nil, err
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -27,6 +29,64 @@ func TestTransportNilTokenSource(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type readCloseCounter struct {
|
||||
CloseCount int
|
||||
ReadErr error
|
||||
}
|
||||
|
||||
func (r *readCloseCounter) Read(b []byte) (int, error) {
|
||||
return 0, r.ReadErr
|
||||
}
|
||||
|
||||
func (r *readCloseCounter) Close() error {
|
||||
r.CloseCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTransportCloseRequestBody(t *testing.T) {
|
||||
tr := &Transport{}
|
||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
|
||||
defer server.Close()
|
||||
client := &http.Client{Transport: tr}
|
||||
body := &readCloseCounter{
|
||||
ReadErr: errors.New("readCloseCounter.Read not implemented"),
|
||||
}
|
||||
resp, err := client.Post(server.URL, "application/json", body)
|
||||
if err == nil {
|
||||
t.Errorf("got no errors, want an error with nil token source")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Errorf("Response = %v; want nil", resp)
|
||||
}
|
||||
if expected := 1; body.CloseCount != expected {
|
||||
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportCloseRequestBodySuccess(t *testing.T) {
|
||||
tr := &Transport{
|
||||
Source: StaticTokenSource(&Token{
|
||||
AccessToken: "abc",
|
||||
}),
|
||||
}
|
||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
|
||||
defer server.Close()
|
||||
client := &http.Client{Transport: tr}
|
||||
body := &readCloseCounter{
|
||||
ReadErr: io.EOF,
|
||||
}
|
||||
resp, err := client.Post(server.URL, "application/json", body)
|
||||
if err != nil {
|
||||
t.Errorf("got error %v; expected none", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Errorf("Response is nil; expected non-nil")
|
||||
}
|
||||
if expected := 1; body.CloseCount != expected {
|
||||
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportTokenSource(t *testing.T) {
|
||||
ts := &tokenSource{
|
||||
token: &Token{
|
||||
|
|
Loading…
Reference in New Issue