diff --git a/internal/transport.go b/internal/transport.go index 33d3669..f1f173e 100644 --- a/internal/transport.go +++ b/internal/transport.go @@ -33,6 +33,11 @@ func RegisterContextClientFunc(fn ContextClientFunc) { } func ContextClient(ctx context.Context) (*http.Client, error) { + if ctx != nil { + if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { + return hc, nil + } + } for _, fn := range contextClientFuncs { c, err := fn(ctx) if err != nil { @@ -42,9 +47,6 @@ func ContextClient(ctx context.Context) (*http.Client, error) { return c, nil } } - if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { - return hc, nil - } return http.DefaultClient, nil } diff --git a/internal/transport_test.go b/internal/transport_test.go new file mode 100644 index 0000000..313f637 --- /dev/null +++ b/internal/transport_test.go @@ -0,0 +1,38 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package internal + +import ( + "net/http" + "testing" + + "golang.org/x/net/context" +) + +func TestContextClient(t *testing.T) { + rc := &http.Client{} + RegisterContextClientFunc(func(context.Context) (*http.Client, error) { + return rc, nil + }) + + c := &http.Client{} + ctx := context.WithValue(nil, HTTPClient, c) + + hc, err := ContextClient(ctx) + if err != nil { + t.Fatalf("want valid client; got err = %v", err) + } + if hc != c { + t.Fatalf("want context client = %p; got = %p", c, hc) + } + + hc, err = ContextClient(context.TODO()) + if err != nil { + t.Fatalf("want valid client; got err = %v", err) + } + if hc != rc { + t.Fatalf("want registered client = %p; got = %p", c, hc) + } +}