downscope: refactor some code to remove an extraneous function and instead run that code inside of Token()

This commit is contained in:
Patrick Jones 2021-06-22 13:14:47 -07:00
parent 304d28ba9e
commit 1024258a24
2 changed files with 18 additions and 22 deletions

View File

@ -22,7 +22,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
const ( var (
identityBindingEndpoint = "https://sts.googleapis.com/v1/token" identityBindingEndpoint = "https://sts.googleapis.com/v1/token"
) )
@ -96,19 +96,22 @@ func NewTokenSource(ctx context.Context, conf DownscopingConfig) downscopingToke
return downscopingTokenSource{ctx: ctx, config: conf} return downscopingTokenSource{ctx: ctx, config: conf}
} }
// downscopedTokenWithEndpoint is a helper function used for unit testing // Token() uses a downscopingTokenSource to generate an oauth2 Token.
// purposes, as it allows us to pass in a locally mocked endpoint. // Do note that the returned TokenSource is an oauth2.StaticTokenSource. If you wish
func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig, endpoint string) (*oauth2.Token, error) { // to refresh this token automatically, then initialize a locally defined
if config.RootSource == nil { // TokenSource struct with the Token held by the StaticTokenSource and wrap
// that TokenSource in an oauth2.ReuseTokenSource.
func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
if dts.config.RootSource == nil {
return nil, fmt.Errorf("downscope: rootTokenSource cannot be nil") return nil, fmt.Errorf("downscope: rootTokenSource cannot be nil")
} }
if len(config.Rules) == 0 { if len(dts.config.Rules) == 0 {
return nil, fmt.Errorf("downscope: length of AccessBoundaryRules must be at least 1") return nil, fmt.Errorf("downscope: length of AccessBoundaryRules must be at least 1")
} }
if len(config.Rules) > 10 { if len(dts.config.Rules) > 10 {
return nil, fmt.Errorf("downscope: length of AccessBoundaryRules may not be greater than 10") return nil, fmt.Errorf("downscope: length of AccessBoundaryRules may not be greater than 10")
} }
for _, val := range config.Rules { for _, val := range dts.config.Rules {
if val.AvailableResource == "" { if val.AvailableResource == "" {
return nil, fmt.Errorf("downscope: all rules must have a nonempty AvailableResource: %+v", val) return nil, fmt.Errorf("downscope: all rules must have a nonempty AvailableResource: %+v", val)
} }
@ -121,11 +124,11 @@ func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig,
Boundary accessBoundary `json:"accessBoundary"` Boundary accessBoundary `json:"accessBoundary"`
}{ }{
Boundary: accessBoundary{ Boundary: accessBoundary{
AccessBoundaryRules: config.Rules, AccessBoundaryRules: dts.config.Rules,
}, },
} }
tok, err := config.RootSource.Token() tok, err := dts.config.RootSource.Token()
if err != nil { if err != nil {
return nil, fmt.Errorf("downscope: unable to obtain root token: %v", err) return nil, fmt.Errorf("downscope: unable to obtain root token: %v", err)
} }
@ -142,8 +145,8 @@ func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig,
form.Add("subject_token", tok.AccessToken) form.Add("subject_token", tok.AccessToken)
form.Add("options", string(b)) form.Add("options", string(b))
myClient := oauth2.NewClient(ctx, nil) myClient := oauth2.NewClient(dts.ctx, nil)
resp, err := myClient.PostForm(endpoint, form) resp, err := myClient.PostForm(identityBindingEndpoint, form)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to generate POST Request %v", err) return nil, fmt.Errorf("unable to generate POST Request %v", err)
} }
@ -180,12 +183,3 @@ func downscopedTokenWithEndpoint(ctx context.Context, config DownscopingConfig,
} }
return newToken, nil return newToken, nil
} }
// Token() uses a downscopingTokenSource to generate an oauth2 Token.
// Do note that the returned TokenSource is an oauth2.StaticTokenSource. If you wish
// to refresh this token automatically, then initialize a locally defined
// TokenSource struct with the Token held by the StaticTokenSource and wrap
// that TokenSource in an oauth2.ReuseTokenSource.
func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
return downscopedTokenWithEndpoint(dts.ctx, dts.config, identityBindingEndpoint)
}

View File

@ -46,7 +46,9 @@ func Test_DownscopedTokenSource(t *testing.T) {
} }
myTok := oauth2.Token{AccessToken: "Mellon"} myTok := oauth2.Token{AccessToken: "Mellon"}
tmpSrc := oauth2.StaticTokenSource(&myTok) tmpSrc := oauth2.StaticTokenSource(&myTok)
_, err := downscopedTokenWithEndpoint(context.Background(), DownscopingConfig{tmpSrc, new}, ts.URL) dts := downscopingTokenSource{context.Background(), DownscopingConfig{tmpSrc, new}}
identityBindingEndpoint = ts.URL
_, err := dts.Token()
if err != nil { if err != nil {
t.Fatalf("NewDownscopedTokenSource failed with error: %v", err) t.Fatalf("NewDownscopedTokenSource failed with error: %v", err)
} }