downscope: refactor some code to remove an extraneous function and instead run that code inside of Token()

This commit is contained in:
Patrick Jones 2021-06-22 13:14:47 -07:00
parent 304d28ba9e
commit 1024258a24
2 changed files with 18 additions and 22 deletions

View File

@ -22,7 +22,7 @@ import (
"golang.org/x/oauth2"
)
const (
var (
identityBindingEndpoint = "https://sts.googleapis.com/v1/token"
)
@ -96,19 +96,22 @@ func NewTokenSource(ctx context.Context, conf DownscopingConfig) downscopingToke
return downscopingTokenSource{ctx: ctx, config: conf}
}
// downscopedTokenWithEndpoint is a helper function used for unit testing
// purposes, as it allows us to pass in a locally mocked endpoint.
func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig, endpoint string) (*oauth2.Token, error) {
if config.RootSource == nil {
// Token() uses a downscopingTokenSource to generate an oauth2 Token.
// Do note that the returned TokenSource is an oauth2.StaticTokenSource. If you wish
// to refresh this token automatically, then initialize a locally defined
// TokenSource struct with the Token held by the StaticTokenSource and wrap
// that TokenSource in an oauth2.ReuseTokenSource.
func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
if dts.config.RootSource == nil {
return nil, fmt.Errorf("downscope: rootTokenSource cannot be nil")
}
if len(config.Rules) == 0 {
if len(dts.config.Rules) == 0 {
return nil, fmt.Errorf("downscope: length of AccessBoundaryRules must be at least 1")
}
if len(config.Rules) > 10 {
if len(dts.config.Rules) > 10 {
return nil, fmt.Errorf("downscope: length of AccessBoundaryRules may not be greater than 10")
}
for _, val := range config.Rules {
for _, val := range dts.config.Rules {
if val.AvailableResource == "" {
return nil, fmt.Errorf("downscope: all rules must have a nonempty AvailableResource: %+v", val)
}
@ -121,11 +124,11 @@ func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig,
Boundary accessBoundary `json:"accessBoundary"`
}{
Boundary: accessBoundary{
AccessBoundaryRules: config.Rules,
AccessBoundaryRules: dts.config.Rules,
},
}
tok, err := config.RootSource.Token()
tok, err := dts.config.RootSource.Token()
if err != nil {
return nil, fmt.Errorf("downscope: unable to obtain root token: %v", err)
}
@ -142,8 +145,8 @@ func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig,
form.Add("subject_token", tok.AccessToken)
form.Add("options", string(b))
myClient := oauth2.NewClient(ctx, nil)
resp, err := myClient.PostForm(endpoint, form)
myClient := oauth2.NewClient(dts.ctx, nil)
resp, err := myClient.PostForm(identityBindingEndpoint, form)
if err != nil {
return nil, fmt.Errorf("unable to generate POST Request %v", err)
}
@ -180,12 +183,3 @@ func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig,
}
return newToken, nil
}
// Token() uses a downscopingTokenSource to generate an oauth2 Token.
// Do note that the returned TokenSource is an oauth2.StaticTokenSource. If you wish
// to refresh this token automatically, then initialize a locally defined
// TokenSource struct with the Token held by the StaticTokenSource and wrap
// that TokenSource in an oauth2.ReuseTokenSource.
func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
return downscopedTokenWithEndpoint(dts.ctx, dts.config, identityBindingEndpoint)
}

View File

@ -46,7 +46,9 @@ func Test_DownscopedTokenSource(t *testing.T) {
}
myTok := oauth2.Token{AccessToken: "Mellon"}
tmpSrc := oauth2.StaticTokenSource(&myTok)
_, err := downscopedTokenWithEndpoint(context.Background(), DownscopingConfig{tmpSrc, new}, ts.URL)
dts := downscopingTokenSource{context.Background(), DownscopingConfig{tmpSrc, new}}
identityBindingEndpoint = ts.URL
_, err := dts.Token()
if err != nil {
t.Fatalf("NewDownscopedTokenSource failed with error: %v", err)
}