From ab12fee4d10d255ed365981b5993cbb97a5e1787 Mon Sep 17 00:00:00 2001 From: Andy Zhao Date: Wed, 17 Mar 2021 12:29:45 -0700 Subject: [PATCH] authhandler: Make authHandler the last parameter --- authhandler/authhandler.go | 8 ++++---- authhandler/authhandler_test.go | 6 +++--- authhandler/example_test.go | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/authhandler/authhandler.go b/authhandler/authhandler.go index 0b53f15..e397346 100644 --- a/authhandler/authhandler.go +++ b/authhandler/authhandler.go @@ -30,7 +30,7 @@ type AuthorizationHandler func(authCodeURL string) (code string, state string, e // // Per the OAuth protocol, a unique "state" string should be sent and verified // before exchanging the auth code for OAuth token to prevent CSRF attacks. -func TokenSource(ctx context.Context, config *oauth2.Config, authHandler AuthorizationHandler, state string) oauth2.TokenSource { +func TokenSource(ctx context.Context, config *oauth2.Config, state string, authHandler AuthorizationHandler) oauth2.TokenSource { return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state}) } @@ -47,8 +47,8 @@ func (source authHandlerSource) Token() (*oauth2.Token, error) { if err != nil { return nil, err } - if state == source.state { - return source.config.Exchange(source.ctx, code) + if state != source.state { + return nil, errors.New("state mismatch in 3-legged-OAuth flow") } - return nil, errors.New("state mismatch in 3-legged-OAuth flow.") + return source.config.Exchange(source.ctx, code) } diff --git a/authhandler/authhandler_test.go b/authhandler/authhandler_test.go index 8ba6ca6..084198f 100644 --- a/authhandler/authhandler_test.go +++ b/authhandler/authhandler_test.go @@ -45,7 +45,7 @@ func TestTokenExchange_Success(t *testing.T) { }, } - tok, err := TokenSource(context.Background(), conf, authhandler, "testState").Token() + tok, err := TokenSource(context.Background(), conf, "testState", authhandler).Token() if err != nil { t.Fatal(err) } @@ -92,8 +92,8 @@ func TestTokenExchange_StateMismatch(t *testing.T) { }, } - _, err := TokenSource(context.Background(), conf, authhandler, "testState").Token() - if want_err := "state mismatch in 3-legged-OAuth flow."; err == nil || err.Error() != want_err { + _, err := TokenSource(context.Background(), conf, "testState", authhandler).Token() + if want_err := "state mismatch in 3-legged-OAuth flow"; err == nil || err.Error() != want_err { t.Errorf("err = %q; want %q", err, want_err) } } diff --git a/authhandler/example_test.go b/authhandler/example_test.go index 99456ef..ef4cf31 100644 --- a/authhandler/example_test.go +++ b/authhandler/example_test.go @@ -38,7 +38,7 @@ func ExampleCmdAuthorizationHandler() { } state := "unique_state" - token, err := authhandler.TokenSource(ctx, conf, authhandler.CmdAuthorizationHandler(state), state).Token() + token, err := authhandler.TokenSource(ctx, conf, state, authhandler.CmdAuthorizationHandler(state)).Token() if err != nil { fmt.Println(err)