feat: adding support for external account authorized user

This commit is contained in:
Jin Qin 2023-09-25 22:52:50 +00:00
parent 14b275c918
commit b621b331ae
8 changed files with 534 additions and 40 deletions

View File

@ -16,6 +16,7 @@ import (
"cloud.google.com/go/compute/metadata" "cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/externalaccount" "golang.org/x/oauth2/google/internal/externalaccount"
"golang.org/x/oauth2/google/internal/externalaccountauthorizeduser"
"golang.org/x/oauth2/jwt" "golang.org/x/oauth2/jwt"
) )
@ -99,6 +100,7 @@ const (
serviceAccountKey = "service_account" serviceAccountKey = "service_account"
userCredentialsKey = "authorized_user" userCredentialsKey = "authorized_user"
externalAccountKey = "external_account" externalAccountKey = "external_account"
externalAccountAuthorizedUserKey = "external_account_authorized_user"
impersonatedServiceAccount = "impersonated_service_account" impersonatedServiceAccount = "impersonated_service_account"
) )
@ -132,6 +134,9 @@ type credentialsFile struct {
QuotaProjectID string `json:"quota_project_id"` QuotaProjectID string `json:"quota_project_id"`
WorkforcePoolUserProject string `json:"workforce_pool_user_project"` WorkforcePoolUserProject string `json:"workforce_pool_user_project"`
// External Account Authorized User fields
RevokeURL string `json:"revoke_url"`
// Service account impersonation // Service account impersonation
SourceCredentials *credentialsFile `json:"source_credentials"` SourceCredentials *credentialsFile `json:"source_credentials"`
} }
@ -200,6 +205,19 @@ func (f *credentialsFile) tokenSource(ctx context.Context, params CredentialsPar
WorkforcePoolUserProject: f.WorkforcePoolUserProject, WorkforcePoolUserProject: f.WorkforcePoolUserProject,
} }
return cfg.TokenSource(ctx) return cfg.TokenSource(ctx)
case externalAccountAuthorizedUserKey:
cfg := &externalaccountauthorizeduser.Config{
Audience: f.Audience,
RefreshToken: f.RefreshToken,
TokenURL: f.TokenURLExternal,
TokenInfoURL: f.TokenInfoURL,
ClientID: f.ClientID,
ClientSecret: f.ClientSecret,
RevokeURL: f.RevokeURL,
QuotaProjectID: f.QuotaProjectID,
Scopes: params.Scopes,
}
return cfg.TokenSource(ctx)
case impersonatedServiceAccount: case impersonatedServiceAccount:
if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil { if f.ServiceAccountImpersonationURL == "" || f.SourceCredentials == nil {
return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials") return nil, errors.New("missing 'source_credentials' field or 'service_account_impersonation_url' in credentials")

View File

@ -15,6 +15,7 @@ import (
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/sts_exchange"
) )
// now aliases time.Now for testing // now aliases time.Now for testing
@ -230,7 +231,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
stsRequest := stsTokenExchangeRequest{ stsRequest := sts_exchange.StsTokenExchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange", GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Audience: conf.Audience, Audience: conf.Audience,
Scope: conf.Scopes, Scope: conf.Scopes,
@ -241,7 +242,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
header := make(http.Header) header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded") header.Add("Content-Type", "application/x-www-form-urlencoded")
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource)) header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
clientAuth := clientAuthentication{ clientAuth := sts_exchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID, ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret, ClientSecret: conf.ClientSecret,
@ -254,7 +255,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
"userProject": conf.WorkforcePoolUserProject, "userProject": conf.WorkforcePoolUserProject,
} }
} }
stsResp, err := exchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options) stsResp, err := sts_exchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, nil, options)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -0,0 +1,114 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"errors"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/sts_exchange"
)
// now aliases time.Now for testing.
var now = func() time.Time {
return time.Now().UTC()
}
var tokenValid = func(token oauth2.Token) bool {
return token.Valid()
}
type Config struct {
// Audience is the Secure Token Service (STS) audience which contains the resource name for the workforce pool and
// the provider identifier in that pool.
Audience string
// RefreshToken is the optional OAuth 2.0 refresh token. If specified, credentials can be refreshed.
RefreshToken string
// TokenURL is the optional STS token exchange endpoint for refresh. Must be specified for refresh, can be left as
// None if the token can not be refreshed.
TokenURL string
// TokenInfoURL is the optional STS endpoint URL for token introspection.
TokenInfoURL string
// ClientID is only required in conjunction with ClientSecret, as described above.
ClientID string
// ClientSecret is currently only required if token_info endpoint also needs to be called with the generated GCP
// access token. When provided, STS will be called with additional basic authentication using client_id as username
// and client_secret as password.
ClientSecret string
// Token is the OAuth2.0 access token. Can be nil if refresh information is provided.
Token string
// Expiry is the optional expiration datetime of the OAuth 2.0 access token.
Expiry time.Time
// RevokeURL is the optional STS endpoint URL for revoking tokens.
RevokeURL string
// QuotaProjectID is the optional project ID used for quota and billing. This project may be different from the
// project used to create the credentials.
QuotaProjectID string
Scopes []string
}
func (c *Config) canRefresh() bool {
return c.ClientID != "" && c.ClientSecret != "" && c.RefreshToken != "" && c.TokenURL != ""
}
func (c *Config) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
var token oauth2.Token
if c.Token != "" && !c.Expiry.IsZero() {
token = oauth2.Token{
AccessToken: c.Token,
Expiry: c.Expiry,
TokenType: "Bearer",
}
}
if !tokenValid(token) && !c.canRefresh() {
return nil, errors.New("oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`).")
}
ts := tokenSource{
ctx: ctx,
conf: c,
}
return oauth2.ReuseTokenSource(&token, ts), nil
}
type tokenSource struct {
ctx context.Context
conf *Config
}
func (ts tokenSource) Token() (*oauth2.Token, error) {
conf := ts.conf
if !conf.canRefresh() {
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
}
clientAuth := sts_exchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
stsResponse, err := sts_exchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
if err != nil {
return nil, err
}
if stsResponse.ExpiresIn < 0 {
return nil, errors.New("oauth2/google: got invalid expiry from security token service")
}
if stsResponse.RefreshToken != "" {
conf.RefreshToken = stsResponse.RefreshToken
}
token := &oauth2.Token{
AccessToken: stsResponse.AccessToken,
Expiry: now().Add(time.Duration(stsResponse.ExpiresIn) * time.Second),
TokenType: "Bearer",
}
return token, nil
}

View File

@ -0,0 +1,259 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package externalaccountauthorizeduser
import (
"context"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/sts_exchange"
)
const expiryDelta = 10 * time.Second
var (
expiry = time.Unix(234852, 0)
testNow = func() time.Time { return expiry }
testValid = func(t oauth2.Token) bool {
return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
}
)
type testRefreshTokenServer struct {
URL string
Authorization string
ContentType string
Body string
ResponsePayload *sts_exchange.Response
Response string
server *httptest.Server
}
func (trts *testRefreshTokenServer) Run(t *testing.T) (string, error) {
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}
func (trts *testRefreshTokenServer) Close() error {
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}
// Tests
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
config := &Config{
Token: "AAAAAAA",
Expiry: now().Add(time.Hour),
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInResponds(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
RefreshToken: "CCCCCCC",
},
}
url, err := server.Run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
if config.RefreshToken != "CCCCCCC" {
t.Fatalf("Refresh token not updated")
}
}
func TestExernalAccountAuthorizedUser_minimumFieldsRequiredForRefresh(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.Run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
config := &Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
}
ts, err := config.TokenSource(context.Background())
if err != nil {
t.Fatalf("Error getting token source: %v", err)
}
token, err := ts.Token()
if err != nil {
t.Fatalf("Error retrieving Token: %v", err)
}
if got, want := token.AccessToken, "AAAAAAA"; got != want {
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
}
}
func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}
url, err := server.Run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
testCases := []struct {
name string
config Config
}{
{
name: "empty config",
config: Config{},
},
{
name: "missing refresh token",
config: Config{
TokenURL: url,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing token url",
config: Config{
RefreshToken: "BBBBBBBBB",
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client id",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientSecret: "CLIENT_SECRET",
},
},
{
name: "missing client secrect",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
ClientID: "CLIENT_ID",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
_, err := tc.config.TokenSource((context.Background()))
if err == nil {
t.Fatalf("Expected error, but received none")
}
if got := err.Error(); got != expectErrMsg {
t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
}
})
}
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package sts_exchange
import ( import (
"encoding/base64" "encoding/base64"
@ -12,8 +12,8 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// clientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1. // ClientAuthentication represents an OAuth client ID and secret and the mechanism for passing these credentials as stated in rfc6749#2.3.1.
type clientAuthentication struct { type ClientAuthentication struct {
// AuthStyle can be either basic or request-body // AuthStyle can be either basic or request-body
AuthStyle oauth2.AuthStyle AuthStyle oauth2.AuthStyle
ClientID string ClientID string
@ -23,7 +23,7 @@ type clientAuthentication struct {
// InjectAuthentication is used to add authentication to a Secure Token Service exchange // InjectAuthentication is used to add authentication to a Secure Token Service exchange
// request. It modifies either the passed url.Values or http.Header depending on the desired // request. It modifies either the passed url.Values or http.Header depending on the desired
// authentication format. // authentication format.
func (c *clientAuthentication) InjectAuthentication(values url.Values, headers http.Header) { func (c *ClientAuthentication) InjectAuthentication(values url.Values, headers http.Header) {
if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil { if c.ClientID == "" || c.ClientSecret == "" || values == nil || headers == nil {
return return
} }

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package sts_exchange
import ( import (
"net/http" "net/http"
@ -38,7 +38,7 @@ func TestClientAuthentication_InjectHeaderAuthentication(t *testing.T) {
"Content-Type": ContentType, "Content-Type": ContentType,
} }
headerAuthentication := clientAuthentication{ headerAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
@ -80,7 +80,7 @@ func TestClientAuthentication_ParamsAuthentication(t *testing.T) {
headerP := http.Header{ headerP := http.Header{
"Content-Type": ContentType, "Content-Type": ContentType,
} }
paramsAuthentication := clientAuthentication{ paramsAuthentication := ClientAuthentication{
AuthStyle: oauth2.AuthStyleInParams, AuthStyle: oauth2.AuthStyleInParams,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package sts_exchange
import ( import (
"context" "context"
@ -18,14 +18,17 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// exchangeToken performs an oauth2 token exchange with the provided endpoint. func defaultHeader() http.Header {
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
return header
}
// ExchangeToken performs an oauth2 token exchange with the provided endpoint.
// The first 4 fields are all mandatory. headers can be used to pass additional // The first 4 fields are all mandatory. headers can be used to pass additional
// headers beyond the bare minimum required by the token exchange. options can // headers beyond the bare minimum required by the token exchange. options can
// be used to pass additional JSON-structured options to the remote server. // be used to pass additional JSON-structured options to the remote server.
func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchangeRequest, authentication clientAuthentication, headers http.Header, options map[string]interface{}) (*stsTokenExchangeResponse, error) { func ExchangeToken(ctx context.Context, endpoint string, request *StsTokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
client := oauth2.NewClient(ctx, nil)
data := url.Values{} data := url.Values{}
data.Set("audience", request.Audience) data.Set("audience", request.Audience)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
@ -41,13 +44,28 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
data.Set("options", string(opts)) data.Set("options", string(opts))
} }
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func RefreshAccessToken(ctx context.Context, endpoint string, refreshToken string, authentication ClientAuthentication, headers http.Header) (*Response, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
return makeRequest(ctx, endpoint, data, authentication, headers)
}
func makeRequest(ctx context.Context, endpoint string, data url.Values, authentication ClientAuthentication, headers http.Header) (*Response, error) {
if headers == nil {
headers = defaultHeader()
}
client := oauth2.NewClient(ctx, nil)
authentication.InjectAuthentication(data, headers) authentication.InjectAuthentication(data, headers)
encodedData := data.Encode() encodedData := data.Encode()
req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData)) req, err := http.NewRequest("POST", endpoint, strings.NewReader(encodedData))
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err) return nil, fmt.Errorf("oauth2/google: failed to properly build http request: %v", err)
} }
req = req.WithContext(ctx) req = req.WithContext(ctx)
for key, list := range headers { for key, list := range headers {
@ -71,7 +89,7 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
if c := resp.StatusCode; c < 200 || c > 299 { if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body) return nil, fmt.Errorf("oauth2/google: status code %d: %s", c, body)
} }
var stsResp stsTokenExchangeResponse var stsResp Response
err = json.Unmarshal(body, &stsResp) err = json.Unmarshal(body, &stsResp)
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err) return nil, fmt.Errorf("oauth2/google: failed to unmarshal response body from Secure Token Server: %v", err)
@ -81,8 +99,8 @@ func exchangeToken(ctx context.Context, endpoint string, request *stsTokenExchan
return &stsResp, nil return &stsResp, nil
} }
// stsTokenExchangeRequest contains fields necessary to make an oauth2 token exchange. // StsTokenExchangeRequest contains fields necessary to make an oauth2 token exchange.
type stsTokenExchangeRequest struct { type StsTokenExchangeRequest struct {
ActingParty struct { ActingParty struct {
ActorToken string ActorToken string
ActorTokenType string ActorTokenType string
@ -96,8 +114,8 @@ type stsTokenExchangeRequest struct {
SubjectTokenType string SubjectTokenType string
} }
// stsTokenExchangeResponse is used to decode the remote server response during an oauth2 token exchange. // Response is used to decode the remote server response during an oauth2 token exchange.
type stsTokenExchangeResponse struct { type Response struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"` IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package externalaccount package sts_exchange
import ( import (
"context" "context"
@ -16,13 +16,13 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
var auth = clientAuthentication{ var auth = ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
} }
var tokenRequest = stsTokenExchangeRequest{ var exchangeTokenRequest = StsTokenExchangeRequest{
ActingParty: struct { ActingParty: struct {
ActorToken string ActorToken string
ActorTokenType string ActorTokenType string
@ -36,9 +36,9 @@ var tokenRequest = stsTokenExchangeRequest{
SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt", SubjectTokenType: "urn:ietf:params:oauth:token-type:jwt",
} }
var requestbody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt" var exchangeRequestBody = "audience=32555940559.apps.googleusercontent.com&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange&requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdevstorage.full_control&subject_token=Sample.Subject.Token&subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Ajwt"
var responseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}` var exchangeResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform"}`
var expectedToken = stsTokenExchangeResponse{ var expectedExchangeToken = Response{
AccessToken: "Sample.Access.Token", AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer", TokenType: "Bearer",
@ -47,6 +47,18 @@ var expectedToken = stsTokenExchangeResponse{
RefreshToken: "", RefreshToken: "",
} }
var refreshToken = "ReFrEsHtOkEn"
var refreshRequestBody = "grant_type=refresh_token&refresh_token=" + refreshToken
var refreshResponseBody = `{"access_token":"Sample.Access.Token","issued_token_type":"urn:ietf:params:oauth:token-type:access_token","token_type":"Bearer","expires_in":3600,"scope":"https://www.googleapis.com/auth/cloud-platform","refresh_token":"REFRESHED_REFRESH"}`
var expectedRefreshResponse = Response{
AccessToken: "Sample.Access.Token",
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
Scope: "https://www.googleapis.com/auth/cloud-platform",
RefreshToken: "REFRESHED_REFRESH",
}
func TestExchangeToken(t *testing.T) { func TestExchangeToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
@ -65,26 +77,34 @@ func TestExchangeToken(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %v.", err) t.Errorf("Failed reading request body: %v.", err)
} }
if got, want := string(body), requestbody; got != want { if got, want := string(body), exchangeRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want) t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody)) w.Write([]byte(exchangeResponseBody))
})) }))
defer ts.Close() defer ts.Close()
headers := http.Header{} headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded") headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) resp, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err != nil { if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err) t.Fatalf("exchangeToken failed with error: %v", err)
} }
if expectedToken != *resp { if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedToken, *resp) t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
} }
resp, err = ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, nil, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedExchangeToken != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedExchangeToken, *resp)
}
} }
func TestExchangeToken_Err(t *testing.T) { func TestExchangeToken_Err(t *testing.T) {
@ -96,7 +116,7 @@ func TestExchangeToken_Err(t *testing.T) {
headers := http.Header{} headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded") headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, nil) _, err := ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, nil)
if err == nil { if err == nil {
t.Errorf("Expected handled error; instead got nil.") t.Errorf("Expected handled error; instead got nil.")
} }
@ -171,7 +191,7 @@ func TestExchangeToken_Opts(t *testing.T) {
// Send a proper reply so that no other errors crop up. // Send a proper reply so that no other errors crop up.
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(responseBody)) w.Write([]byte(exchangeResponseBody))
})) }))
defer ts.Close() defer ts.Close()
@ -183,5 +203,69 @@ func TestExchangeToken_Opts(t *testing.T) {
inputOpts := make(map[string]interface{}) inputOpts := make(map[string]interface{})
inputOpts["one"] = firstOption inputOpts["one"] = firstOption
inputOpts["two"] = secondOption inputOpts["two"] = secondOption
exchangeToken(context.Background(), ts.URL, &tokenRequest, auth, headers, inputOpts) ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
}
func TestRefreshToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Unexpected request method, %v is found", r.Method)
}
if r.URL.String() != "/" {
t.Errorf("Unexpected request URL, %v is found", r.URL)
}
if got, want := r.Header.Get("Authorization"), "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ="; got != want {
t.Errorf("Unexpected authorization header, got %v, want %v", got, want)
}
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}
if got, want := string(body), refreshRequestBody; got != want {
t.Errorf("Unexpected exchange payload, got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(refreshResponseBody))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
resp, err = RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, nil)
if err != nil {
t.Fatalf("exchangeToken failed with error: %v", err)
}
if expectedRefreshResponse != *resp {
t.Errorf("mismatched messages received by mock server. \nWant: \n%v\n\nGot:\n%v", expectedRefreshResponse, *resp)
}
}
func TestRefreshToken_Err(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("what's wrong with this response?"))
}))
defer ts.Close()
headers := http.Header{}
headers.Add("Content-Type", "application/x-www-form-urlencoded")
_, err := RefreshAccessToken(context.Background(), ts.URL, refreshToken, auth, headers)
if err == nil {
t.Errorf("Expected handled error; instead got nil.")
}
} }