forked from Mirrors/oauth2
google: adding support for external account authorized user
To support a new type of credential: `ExternalAccountAuthorizedUser`
* Refactor the common dependency STS to a separate package.
* Adding the `externalaccountauthorizeduser` package.
Change-Id: I9b9624f912d216b67a0d31945a50f057f747710b
GitHub-Last-Rev: 6e2aaff345
GitHub-Pull-Request: golang/oauth2#671
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/531095
Reviewed-by: Leo Siracusa <leosiracusa@google.com>
Reviewed-by: Alex Eitzman <eitzman@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
Reviewed-by: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
parent
14b275c918
commit
43b6a7ba19
|
@ -16,6 +16,7 @@ import (
|
||||||
"cloud.google.com/go/compute/metadata"
|
"cloud.google.com/go/compute/metadata"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google/internal/externalaccount"
|
"golang.org/x/oauth2/google/internal/externalaccount"
|
||||||
|
"golang.org/x/oauth2/google/internal/externalaccountauthorizeduser"
|
||||||
"golang.org/x/oauth2/jwt"
|
"golang.org/x/oauth2/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,6 +100,7 @@ const (
|
||||||
serviceAccountKey = "service_account"
|
serviceAccountKey = "service_account"
|
||||||
userCredentialsKey = "authorized_user"
|
userCredentialsKey = "authorized_user"
|
||||||
externalAccountKey = "external_account"
|
externalAccountKey = "external_account"
|
||||||
|
externalAccountAuthorizedUserKey = "external_account_authorized_user"
|
||||||
impersonatedServiceAccount = "impersonated_service_account"
|
impersonatedServiceAccount = "impersonated_service_account"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -132,6 +134,9 @@ type credentialsFile struct {
|
||||||
QuotaProjectID string `json:"quota_project_id"`
|
QuotaProjectID string `json:"quota_project_id"`
|
||||||
WorkforcePoolUserProject string `json:"workforce_pool_user_project"`
|
WorkforcePoolUserProject string `json:"workforce_pool_user_project"`
|
||||||
|
|
||||||
|
// External Account Authorized User fields
|
||||||
|
RevokeURL string `json:"revoke_url"`
|
||||||
|
|
||||||
// Service account impersonation
|
// Service account impersonation
|
||||||
SourceCredentials *credentialsFile `json:"source_credentials"`
|
SourceCredentials *credentialsFile `json:"source_credentials"`
|
||||||
}
|
}
|
||||||
|
@ -200,6 +205,19 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
|
||||||
WorkforcePoolUserProject: f.WorkforcePoolUserProject,
|
WorkforcePoolUserProject: f.WorkforcePoolUserProject,
|
||||||
}
|
}
|
||||||
return cfg.TokenSource(ctx)
|
return cfg.TokenSource(ctx)
|
||||||
|
case externalAccountAuthorizedUserKey:
|
||||||
|
cfg := &externalaccountauthorizeduser.Config{
|
||||||
|
Audience: f.Audience,
|
||||||
|
RefreshToken: f.RefreshToken,
|
||||||
|
TokenURL: f.TokenURLExternal,
|
||||||
|
TokenInfoURL: f.TokenInfoURL,
|
||||||
|
ClientID: f.ClientID,
|
||||||
|
ClientSecret: f.ClientSecret,
|
||||||
|
RevokeURL: f.RevokeURL,
|
||||||
|
QuotaProjectID: f.QuotaProjectID,
|
||||||
|
Scopes: params.Scopes,
|
||||||
|
}
|
||||||
|
return cfg.TokenSource(ctx)
|
||||||
case impersonatedServiceAccount:
|
case impersonatedServiceAccount:
|
||||||
if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil {
|
if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil {
|
||||||
return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials")
|
return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials")
|
||||||
|
|
|
@ -8,13 +8,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google/internal/stsexchange"
|
||||||
)
|
)
|
||||||
|
|
||||||
// now aliases time.Now for testing
|
// now aliases time.Now for testing
|
||||||
|
@ -63,31 +62,10 @@ type Config struct {
|
||||||
WorkforcePoolUserProject string
|
WorkforcePoolUserProject string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each element consists of a list of patterns. validateURLs checks for matches
|
|
||||||
// that include all elements in a given list, in that order.
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
|
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
|
||||||
)
|
)
|
||||||
|
|
||||||
func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
|
|
||||||
parsed, err := url.Parse(input)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !strings.EqualFold(parsed.Scheme, scheme) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
toTest := parsed.Host
|
|
||||||
|
|
||||||
for _, pattern := range patterns {
|
|
||||||
if pattern.MatchString(toTest) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateWorkforceAudience(input string) bool {
|
func validateWorkforceAudience(input string) bool {
|
||||||
return validWorkforceAudiencePattern.MatchString(input)
|
return validWorkforceAudiencePattern.MatchString(input)
|
||||||
}
|
}
|
||||||
|
@ -230,7 +208,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stsRequest := stsTokenExchangeRequest{
|
stsRequest := stsexchange.TokenExchangeRequest{
|
||||||
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
|
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
|
||||||
Audience: conf.Audience,
|
Audience: conf.Audience,
|
||||||
Scope: conf.Scopes,
|
Scope: conf.Scopes,
|
||||||
|
@ -241,7 +219,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
|
||||||
header := make(http.Header)
|
header := make(http.Header)
|
||||||
header.Add("Content-Type", "application/x-www-form-urlencoded")
|
header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
|
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
|
||||||
clientAuth := clientAuthentication{
|
clientAuth := stsexchange.ClientAuthentication{
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
ClientID: conf.ClientID,
|
ClientID: conf.ClientID,
|
||||||
ClientSecret: conf.ClientSecret,
|
ClientSecret: conf.ClientSecret,
|
||||||
|
@ -254,7 +232,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
|
||||||
"userProject": conf.WorkforcePoolUserProject,
|
"userProject": conf.WorkforcePoolUserProject,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
|
stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
// Copyright 2023 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package externalaccountauthorizeduser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google/internal/stsexchange"
|
||||||
|
)
|
||||||
|
|
||||||
|
// now aliases time.Now for testing.
|
||||||
|
var now = func() time.Time {
|
||||||
|
return time.Now().UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenValid = func(token oauth2.Token) bool {
|
||||||
|
return token.Valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workforce pool and
|
||||||
|
// the provider identifier in that pool.
|
||||||
|
Audience string
|
||||||
|
// RefreshToken is the optional OAuth 2.0 refresh token. If specified, credentials can be refreshed.
|
||||||
|
RefreshToken string
|
||||||
|
// TokenURL is the optional STS token exchange endpoint for refresh. Must be specified for refresh, can be left as
|
||||||
|
// None if the token can not be refreshed.
|
||||||
|
TokenURL string
|
||||||
|
// TokenInfoURL is the optional STS endpoint URL for token introspection.
|
||||||
|
TokenInfoURL string
|
||||||
|
// ClientID is only required in conjunction with ClientSecret, as described above.
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret is currently only required if token_info endpoint also needs to be called with the generated GCP
|
||||||
|
// access token. When provided, STS will be called with additional basic authentication using client_id as username
|
||||||
|
// and client_secret as password.
|
||||||
|
ClientSecret string
|
||||||
|
// Token is the OAuth2.0 access token. Can be nil if refresh information is provided.
|
||||||
|
Token string
|
||||||
|
// Expiry is the optional expiration datetime of the OAuth 2.0 access token.
|
||||||
|
Expiry time.Time
|
||||||
|
// RevokeURL is the optional STS endpoint URL for revoking tokens.
|
||||||
|
RevokeURL string
|
||||||
|
// QuotaProjectID is the optional project ID used for quota and billing. This project may be different from the
|
||||||
|
// project used to create the credentials.
|
||||||
|
QuotaProjectID string
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) canRefresh() bool {
|
||||||
|
return c.ClientID != "" && c.ClientSecret != "" && c.RefreshToken != "" && c.TokenURL != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
|
||||||
|
var token oauth2.Token
|
||||||
|
if c.Token != "" && !c.Expiry.IsZero() {
|
||||||
|
token = oauth2.Token{
|
||||||
|
AccessToken: c.Token,
|
||||||
|
Expiry: c.Expiry,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !tokenValid(token) && !c.canRefresh() {
|
||||||
|
return nil, errors.New("oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`).")
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := tokenSource{
|
||||||
|
ctx: ctx,
|
||||||
|
conf: c,
|
||||||
|
}
|
||||||
|
|
||||||
|
return oauth2.ReuseTokenSource(&token, ts), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenSource struct {
|
||||||
|
ctx context.Context
|
||||||
|
conf *Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts tokenSource) Token() (*oauth2.Token, error) {
|
||||||
|
conf := ts.conf
|
||||||
|
if !conf.canRefresh() {
|
||||||
|
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAuth := stsexchange.ClientAuthentication{
|
||||||
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
|
ClientID: conf.ClientID,
|
||||||
|
ClientSecret: conf.ClientSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if stsResponse.ExpiresIn < 0 {
|
||||||
|
return nil, errors.New("oauth2/google: got invalid expiry from security token service")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stsResponse.RefreshToken != "" {
|
||||||
|
conf.RefreshToken = stsResponse.RefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: stsResponse.AccessToken,
|
||||||
|
Expiry: now().Add(time.Duration(stsResponse.ExpiresIn) * time.Second),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
|
@ -0,0 +1,259 @@
|
||||||
|
// Copyright 2023 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package externalaccountauthorizeduser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google/internal/stsexchange"
|
||||||
|
)
|
||||||
|
|
||||||
|
const expiryDelta = 10 * time.Second
|
||||||
|
|
||||||
|
var (
|
||||||
|
expiry = time.Unix(234852, 0)
|
||||||
|
testNow = func() time.Time { return expiry }
|
||||||
|
testValid = func(t oauth2.Token) bool {
|
||||||
|
return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type testRefreshTokenServer struct {
|
||||||
|
URL string
|
||||||
|
Authorization string
|
||||||
|
ContentType string
|
||||||
|
Body string
|
||||||
|
ResponsePayload *stsexchange.Response
|
||||||
|
Response string
|
||||||
|
server *httptest.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
Token: "AAAAAAA",
|
||||||
|
Expiry: now().Add(time.Hour),
|
||||||
|
}
|
||||||
|
ts, err := config.TokenSource(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error getting token source: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := ts.Token()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error retrieving Token: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
||||||
|
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
|
||||||
|
server := &testRefreshTokenServer{
|
||||||
|
URL: "/",
|
||||||
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||||
|
ContentType: "application/x-www-form-urlencoded",
|
||||||
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||||
|
ResponsePayload: &stsexchange.Response{
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
AccessToken: "AAAAAAA",
|
||||||
|
RefreshToken: "CCCCCCC",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := server.run(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error starting server")
|
||||||
|
}
|
||||||
|
defer server.close(t)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
RefreshToken: "BBBBBBBBB",
|
||||||
|
TokenURL: url,
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
}
|
||||||
|
ts, err := config.TokenSource(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error getting token source: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := ts.Token()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error retrieving Token: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
||||||
|
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if config.RefreshToken != "CCCCCCC" {
|
||||||
|
t.Fatalf("Refresh token not updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
|
||||||
|
server := &testRefreshTokenServer{
|
||||||
|
URL: "/",
|
||||||
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||||
|
ContentType: "application/x-www-form-urlencoded",
|
||||||
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||||
|
ResponsePayload: &stsexchange.Response{
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
AccessToken: "AAAAAAA",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := server.run(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error starting server")
|
||||||
|
}
|
||||||
|
defer server.close(t)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
RefreshToken: "BBBBBBBBB",
|
||||||
|
TokenURL: url,
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
}
|
||||||
|
ts, err := config.TokenSource(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error getting token source: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := ts.Token()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error retrieving Token: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
||||||
|
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
|
||||||
|
server := &testRefreshTokenServer{
|
||||||
|
URL: "/",
|
||||||
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||||
|
ContentType: "application/x-www-form-urlencoded",
|
||||||
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
||||||
|
ResponsePayload: &stsexchange.Response{
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
AccessToken: "AAAAAAA",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := server.run(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error starting server")
|
||||||
|
}
|
||||||
|
defer server.close(t)
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty config",
|
||||||
|
config: Config{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing refresh token",
|
||||||
|
config: Config{
|
||||||
|
TokenURL: url,
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing token url",
|
||||||
|
config: Config{
|
||||||
|
RefreshToken: "BBBBBBBBB",
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing client id",
|
||||||
|
config: Config{
|
||||||
|
RefreshToken: "BBBBBBBBB",
|
||||||
|
TokenURL: url,
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing client secrect",
|
||||||
|
config: Config{
|
||||||
|
RefreshToken: "BBBBBBBBB",
|
||||||
|
TokenURL: url,
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
|
||||||
|
_, err := tc.config.TokenSource((context.Background()))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Expected error, but received none")
|
||||||
|
}
|
||||||
|
if got := err.Error(); got != expectErrMsg {
|
||||||
|
t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
|
||||||
|
t.Helper()
|
||||||
|
if trts.server != nil {
|
||||||
|
return "", errors.New("Server is already running")
|
||||||
|
}
|
||||||
|
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if got, want := r.URL.String(), trts.URL; got != want {
|
||||||
|
t.Errorf("URL.String(): got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
headerAuth := r.Header.Get("Authorization")
|
||||||
|
if got, want := headerAuth, trts.Authorization; got != want {
|
||||||
|
t.Errorf("got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
headerContentType := r.Header.Get("Content-Type")
|
||||||
|
if got, want := headerContentType, trts.ContentType; got != want {
|
||||||
|
t.Errorf("got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed reading request body: %s.", err)
|
||||||
|
}
|
||||||
|
if got, want := string(body), trts.Body; got != want {
|
||||||
|
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if trts.ResponsePayload != nil {
|
||||||
|
content, err := json.Marshal(trts.ResponsePayload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to marshall response JSON")
|
||||||
|
}
|
||||||
|
w.Write(content)
|
||||||
|
} else {
|
||||||
|
w.Write([]byte(trts.Response))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
return trts.server.URL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (trts *testRefreshTokenServer) close(t *testing.T) error {
|
||||||
|
t.Helper()
|
||||||
|
if trts.server == nil {
|
||||||
|
return errors.New("No server is running")
|
||||||
|
}
|
||||||
|
trts.server.Close()
|
||||||
|
trts.server = nil
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -2,7 +2,7 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package externalaccount
|
package stsexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -12,8 +12,8 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
|
// ClientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
|
||||||
type clientAuthentication struct {
|
type ClientAuthentication struct {
|
||||||
// AuthStyle can be either basic or request-body
|
// AuthStyle can be either basic or request-body
|
||||||
AuthStyle oauth2.AuthStyle
|
AuthStyle oauth2.AuthStyle
|
||||||
ClientID string
|
ClientID string
|
||||||
|
@ -23,7 +23,7 @@ type clientAuthentication struct {
|
||||||
// InjectAuthentication is used to add authentication to a Secure Token Service exchange
|
// InjectAuthentication is used to add authentication to a Secure Token Service exchange
|
||||||
// request. It modifies either the passed url.Values or http.Header depending on the desired
|
// request. It modifies either the passed url.Values or http.Header depending on the desired
|
||||||
// authentication format.
|
// authentication format.
|
||||||
func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
|
func (c *ClientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
|
||||||
if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
|
if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
|
@ -2,7 +2,7 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package externalaccount
|
package stsexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -38,7 +38,7 @@ func TestClientAuthentication_InjectHeaderAuthentication(t *testing.T) {
|
||||||
"Content-Type": ContentType,
|
"Content-Type": ContentType,
|
||||||
}
|
}
|
||||||
|
|
||||||
headerAuthentication := clientAuthentication{
|
headerAuthentication := ClientAuthentication{
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
|
@ -80,7 +80,7 @@ func TestClientAuthentication_ParamsAuthentication(t *testing.T) {
|
||||||
headerP := http.Header{
|
headerP := http.Header{
|
||||||
"Content-Type": ContentType,
|
"Content-Type": ContentType,
|
||||||
}
|
}
|
||||||
paramsAuthentication := clientAuthentication{
|
paramsAuthentication := ClientAuthentication{
|
||||||
AuthStyle: oauth2.AuthStyleInParams,
|
AuthStyle: oauth2.AuthStyleInParams,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
|
@ -2,7 +2,7 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package externalaccount
|
package stsexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -18,14 +18,17 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// exchangeToken performs an oauth2 token exchange with the provided endpoint.
|
func defaultHeader() http.Header {
|
||||||
|
header := make(http.Header)
|
||||||
|
header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeToken performs an oauth2 token exchange with the provided endpoint.
|
||||||
// The first 4 fields are all mandatory. headers can be used to pass additional
|
// The first 4 fields are all mandatory. headers can be used to pass additional
|
||||||
// headers beyond the bare minimum required by the token exchange. options can
|
// headers beyond the bare minimum required by the token exchange. options can
|
||||||
// be used to pass additional JSON-structured options to the remote server.
|
// be used to pass additional JSON-structured options to the remote server.
|
||||||
func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchangeRequest, authentication clientAuthentication, headers http.Header, options map[string]interface{}) (*stsTokenExchangeResponse, error) {
|
func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
|
||||||
|
|
||||||
client := oauth2.NewClient(ctx, nil)
|
|
||||||
|
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("audience", request.Audience)
|
data.Set("audience", request.Audience)
|
||||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
|
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
|
||||||
|
@ -41,13 +44,28 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
|
||||||
data.Set("options", string(opts))
|
data.Set("options", string(opts))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return makeRequest(ctx, endpoint, data, authentication, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RefreshAccessToken(ctx context.Context, endpoint string, refreshToken string, authentication ClientAuthentication, headers http.Header) (*Response, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("grant_type", "refresh_token")
|
||||||
|
data.Set("refresh_token", refreshToken)
|
||||||
|
|
||||||
|
return makeRequest(ctx, endpoint, data, authentication, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeRequest(ctx context.Context, endpoint string, data url.Values, authentication ClientAuthentication, headers http.Header) (*Response, error) {
|
||||||
|
if headers == nil {
|
||||||
|
headers = defaultHeader()
|
||||||
|
}
|
||||||
|
client := oauth2.NewClient(ctx, nil)
|
||||||
authentication.InjectAuthentication(data, headers)
|
authentication.InjectAuthentication(data, headers)
|
||||||
encodedData := data.Encode()
|
encodedData := data.Encode()
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
|
req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
|
return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
|
||||||
|
|
||||||
}
|
}
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
for key, list := range headers {
|
for key, list := range headers {
|
||||||
|
@ -71,7 +89,7 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
|
||||||
if c := resp.StatusCode; c < 200 || c > 299 {
|
if c := resp.StatusCode; c < 200 || c > 299 {
|
||||||
return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
|
return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
|
||||||
}
|
}
|
||||||
var stsResp stsTokenExchangeResponse
|
var stsResp Response
|
||||||
err = json.Unmarshal(body, &stsResp)
|
err = json.Unmarshal(body, &stsResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
|
return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
|
||||||
|
@ -81,8 +99,8 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
|
||||||
return &stsResp, nil
|
return &stsResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// stsTokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
|
// TokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
|
||||||
type stsTokenExchangeRequest struct {
|
type TokenExchangeRequest struct {
|
||||||
ActingParty struct {
|
ActingParty struct {
|
||||||
ActorToken string
|
ActorToken string
|
||||||
ActorTokenType string
|
ActorTokenType string
|
||||||
|
@ -96,8 +114,8 @@ type stsTokenExchangeRequest struct {
|
||||||
SubjectTokenType string
|
SubjectTokenType string
|
||||||
}
|
}
|
||||||
|
|
||||||
// stsTokenExchangeResponse is used to decode the remote server response during an oauth2 token exchange.
|
// Response is used to decode the remote server response during an oauth2 token exchange.
|
||||||
type stsTokenExchangeResponse struct {
|
type Response struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
IssuedTokenType string `json:"issued_token_type"`
|
IssuedTokenType string `json:"issued_token_type"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
|
@ -2,7 +2,7 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package externalaccount
|
package stsexchange
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -16,13 +16,13 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var auth = clientAuthentication{
|
var auth = ClientAuthentication{
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenRequest = stsTokenExchangeRequest{
|
var exchangeTokenRequest = TokenExchangeRequest{
|
||||||
ActingParty: struct {
|
ActingParty: struct {
|
||||||
ActorToken string
|
ActorToken string
|
||||||
ActorTokenType string
|
ActorTokenType string
|
||||||
|
@ -36,9 +36,9 @@ var tokenRequest = stsTokenExchangeRequest{
|
||||||
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
|
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestbody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
|
var exchangeRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
|
||||||
var responseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
|
var exchangeResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
|
||||||
var expectedToken = stsTokenExchangeResponse{
|
var expectedExchangeToken = Response{
|
||||||
AccessToken: "Sample.Access.Token",
|
AccessToken: "Sample.Access.Token",
|
||||||
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
|
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
|
@ -47,6 +47,18 @@ var expectedToken = stsTokenExchangeResponse{
|
||||||
RefreshToken: "",
|
RefreshToken: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var refreshToken = "ReFrEsHtOkEn"
|
||||||
|
var refreshRequestBody = "grant_type=refresh_token&refresh_token=" + refreshToken
|
||||||
|
var refreshResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform","refresh_token":"REFRESHED_REFRESH"}`
|
||||||
|
var expectedRefreshResponse = Response{
|
||||||
|
AccessToken: "Sample.Access.Token",
|
||||||
|
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
Scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||||
|
RefreshToken: "REFRESHED_REFRESH",
|
||||||
|
}
|
||||||
|
|
||||||
func TestExchangeToken(t *testing.T) {
|
func TestExchangeToken(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
|
@ -65,26 +77,34 @@ func TestExchangeToken(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed reading request body: %v.", err)
|
t.Errorf("Failed reading request body: %v.", err)
|
||||||
}
|
}
|
||||||
if got, want := string(body), requestbody; got != want {
|
if got, want := string(body), exchangeRequestBody; got != want {
|
||||||
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
|
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Write([]byte(responseBody))
|
w.Write([]byte(exchangeResponseBody))
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
headers := http.Header{}
|
headers := http.Header{}
|
||||||
headers.Add("Content-Type", "application/x-www-form-urlencoded")
|
headers.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
resp, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil)
|
resp, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("exchangeToken failed with error: %v", err)
|
t.Fatalf("exchangeToken failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if expectedToken != *resp {
|
if expectedExchangeToken != *resp {
|
||||||
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedToken, *resp)
|
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp, err = ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("exchangeToken failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedExchangeToken != *resp {
|
||||||
|
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExchangeToken_Err(t *testing.T) {
|
func TestExchangeToken_Err(t *testing.T) {
|
||||||
|
@ -96,7 +116,7 @@ func TestExchangeToken_Err(t *testing.T) {
|
||||||
|
|
||||||
headers := http.Header{}
|
headers := http.Header{}
|
||||||
headers.Add("Content-Type", "application/x-www-form-urlencoded")
|
headers.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
_, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil)
|
_, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected handled error; instead got nil.")
|
t.Errorf("Expected handled error; instead got nil.")
|
||||||
}
|
}
|
||||||
|
@ -171,7 +191,7 @@ func TestExchangeToken_Opts(t *testing.T) {
|
||||||
|
|
||||||
// Send a proper reply so that no other errors crop up.
|
// Send a proper reply so that no other errors crop up.
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Write([]byte(responseBody))
|
w.Write([]byte(exchangeResponseBody))
|
||||||
|
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
@ -183,5 +203,69 @@ func TestExchangeToken_Opts(t *testing.T) {
|
||||||
inputOpts := make(map[string]interface{})
|
inputOpts := make(map[string]interface{})
|
||||||
inputOpts["one"] = firstOption
|
inputOpts["one"] = firstOption
|
||||||
inputOpts["two"] = secondOption
|
inputOpts["two"] = secondOption
|
||||||
exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, inputOpts)
|
ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != "POST" {
|
||||||
|
t.Errorf("Unexpected request method, %v is found", r.Method)
|
||||||
|
}
|
||||||
|
if r.URL.String() != "/" {
|
||||||
|
t.Errorf("Unexpected request URL, %v is found", r.URL)
|
||||||
|
}
|
||||||
|
if got, want := r.Header.Get("Authorization"), "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want {
|
||||||
|
t.Errorf("Unexpected authorization header, got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
|
||||||
|
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed reading request body: %v.", err)
|
||||||
|
}
|
||||||
|
if got, want := string(body), refreshRequestBody; got != want {
|
||||||
|
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(refreshResponseBody))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("exchangeToken failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedRefreshResponse != *resp {
|
||||||
|
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("exchangeToken failed with error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedRefreshResponse != *resp {
|
||||||
|
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_Err(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte("what's wrong with this response?"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
_, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected handled error; instead got nil.")
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue