diff --git a/authhandler/authhandler.go b/authhandler/authhandler.go index da59975..0b53f15 100644 --- a/authhandler/authhandler.go +++ b/authhandler/authhandler.go @@ -13,20 +13,23 @@ import ( "golang.org/x/oauth2" ) -// AuthorizationHandler is a 3-legged-OAuth helper that -// prompts the user for OAuth consent at the specified Auth URL +// AuthorizationHandler is a 3-legged-OAuth helper that prompts +// the user for OAuth consent at the specified auth code URL // and returns an auth code and state upon approval. -type AuthorizationHandler func(string) (string, string, error) +type AuthorizationHandler func(authCodeURL string) (code string, state string, err error) // TokenSource returns an oauth2.TokenSource that fetches access tokens // using 3-legged-OAuth flow. // -// The provided oauth2.Config should be a full configuration containing AuthURL, -// TokenURL, and scope. An environment-specific AuthorizationHandler is used to -// obtain user consent. +// The provided context.Context is used for oauth2 Exchange operation. // -// Per OAuth protocol, a unique "state" string should be sent and verified -// before exchanging auth code for OAuth token to prevent CSRF attacks. +// The provided oauth2.Config should be a full configuration containing AuthURL, +// TokenURL, and Scope. +// +// An environment-specific AuthorizationHandler is used to obtain user consent. +// +// Per the OAuth protocol, a unique "state" string should be sent and verified +// before exchanging the auth code for OAuth token to prevent CSRF attacks. func TokenSource(ctx context.Context, config *oauth2.Config, authHandler AuthorizationHandler, state string) oauth2.TokenSource { return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state}) } @@ -47,5 +50,5 @@ func (source authHandlerSource) Token() (*oauth2.Token, error) { if state == source.state { return source.config.Exchange(source.ctx, code) } - return nil, errors.New("State mismatch in 3-legged-OAuth flow.") + return nil, errors.New("state mismatch in 3-legged-OAuth flow.") } diff --git a/authhandler/authhandler_test.go b/authhandler/authhandler_test.go new file mode 100644 index 0000000..8ba6ca6 --- /dev/null +++ b/authhandler/authhandler_test.go @@ -0,0 +1,99 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package authhandler + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "golang.org/x/oauth2" +) + +func TestTokenExchange_Success(t *testing.T) { + authhandler := func(authCodeURL string) (string, string, error) { + if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" { + return "testCode", "testState", nil + } + return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL) + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + if r.Form.Get("code") == "testCode" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "pubsub", + "token_type": "bearer", + "expires_in": 3600 + }`)) + } + })) + defer ts.Close() + + conf := &oauth2.Config{ + ClientID: "testClientID", + Scopes: []string{"pubsub"}, + Endpoint: oauth2.Endpoint{ + AuthURL: "testAuthCodeURL", + TokenURL: ts.URL, + }, + } + + tok, err := TokenSource(context.Background(), conf, authhandler, "testState").Token() + if err != nil { + t.Fatal(err) + } + if !tok.Valid() { + t.Errorf("got invalid token: %v", tok) + } + if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want { + t.Errorf("access token = %q; want %q", got, want) + } + if got, want := tok.TokenType, "bearer"; got != want { + t.Errorf("token type = %q; want %q", got, want) + } + if got := tok.Expiry.IsZero(); got { + t.Errorf("token expiry is zero = %v, want false", got) + } + scope := tok.Extra("scope") + if got, want := scope, "pubsub"; got != want { + t.Errorf("scope = %q; want %q", got, want) + } +} + +func TestTokenExchange_StateMismatch(t *testing.T) { + authhandler := func(authCodeURL string) (string, string, error) { + return "testCode", "testStateMismatch", nil + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "pubsub", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + conf := &oauth2.Config{ + ClientID: "testClientID", + Scopes: []string{"pubsub"}, + Endpoint: oauth2.Endpoint{ + AuthURL: "testAuthCodeURL", + TokenURL: ts.URL, + }, + } + + _, err := TokenSource(context.Background(), conf, authhandler, "testState").Token() + if want_err := "state mismatch in 3-legged-OAuth flow."; err == nil || err.Error() != want_err { + t.Errorf("err = %q; want %q", err, want_err) + } +}