forked from Mirrors/oauth2
Merge branch 'byoid-metrics' of https://github.com/aeitzman/oauth2 into byoid-metrics
This commit is contained in:
commit
523ee10d3c
|
@ -0,0 +1,188 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2/internal"
|
||||
)
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
|
||||
const (
|
||||
errAuthorizationPending = "authorization_pending"
|
||||
errSlowDown = "slow_down"
|
||||
errAccessDenied = "access_denied"
|
||||
errExpiredToken = "expired_token"
|
||||
)
|
||||
|
||||
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
|
||||
type DeviceAuthResponse struct {
|
||||
// DeviceCode
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code the user should enter at the verification uri
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is where user should enter the user code
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
|
||||
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
|
||||
// Expiry is when the device code and user code expire
|
||||
Expiry time.Time `json:"expires_in,omitempty"`
|
||||
// Interval is the duration in seconds that Poll should wait between requests
|
||||
Interval int64 `json:"interval,omitempty"`
|
||||
}
|
||||
|
||||
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
|
||||
type Alias DeviceAuthResponse
|
||||
var expiresIn int64
|
||||
if !d.Expiry.IsZero() {
|
||||
expiresIn = int64(time.Until(d.Expiry).Seconds())
|
||||
}
|
||||
return json.Marshal(&struct {
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
*Alias
|
||||
}{
|
||||
ExpiresIn: expiresIn,
|
||||
Alias: (*Alias)(&d),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
|
||||
type Alias DeviceAuthResponse
|
||||
aux := &struct {
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(c),
|
||||
}
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
if aux.ExpiresIn != 0 {
|
||||
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceAuth returns a device auth struct which contains a device code
|
||||
// and authorization information provided for users to enter on another device.
|
||||
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
|
||||
v := url.Values{
|
||||
"client_id": {c.ClientID},
|
||||
}
|
||||
if len(c.Scopes) > 0 {
|
||||
v.Set("scope", strings.Join(c.Scopes, " "))
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt.setValue(v)
|
||||
}
|
||||
return retrieveDeviceAuth(ctx, c, v)
|
||||
}
|
||||
|
||||
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
|
||||
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
t := time.Now()
|
||||
r, err := internal.ContextClient(ctx).Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
|
||||
}
|
||||
if code := r.StatusCode; code < 200 || code > 299 {
|
||||
return nil, &RetrieveError{
|
||||
Response: r,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
|
||||
da := &DeviceAuthResponse{}
|
||||
err = json.Unmarshal(body, &da)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal %s", err)
|
||||
}
|
||||
|
||||
if !da.Expiry.IsZero() {
|
||||
// Make a small adjustment to account for time taken by the request
|
||||
da.Expiry = da.Expiry.Add(-time.Since(t))
|
||||
}
|
||||
|
||||
return da, nil
|
||||
}
|
||||
|
||||
// DeviceAccessToken polls the server to exchange a device code for a token.
|
||||
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
|
||||
if !da.Expiry.IsZero() {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
|
||||
v := url.Values{
|
||||
"client_id": {c.ClientID},
|
||||
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
|
||||
"device_code": {da.DeviceCode},
|
||||
}
|
||||
if len(c.Scopes) > 0 {
|
||||
v.Set("scope", strings.Join(c.Scopes, " "))
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt.setValue(v)
|
||||
}
|
||||
|
||||
// "If no value is provided, clients MUST use 5 as the default."
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
|
||||
interval := da.Interval
|
||||
if interval == 0 {
|
||||
interval = 5
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
tok, err := retrieveToken(ctx, c, v)
|
||||
if err == nil {
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
e, ok := err.(*RetrieveError)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
switch e.ErrorCode {
|
||||
case errSlowDown:
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
|
||||
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
|
||||
interval += 5
|
||||
ticker.Reset(time.Duration(interval) * time.Second)
|
||||
case errAuthorizationPending:
|
||||
// Do nothing.
|
||||
case errAccessDenied, errExpiredToken:
|
||||
fallthrough
|
||||
default:
|
||||
return tok, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
)
|
||||
|
||||
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response DeviceAuthResponse
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
response: DeviceAuthResponse{},
|
||||
want: `{"device_code":"","user_code":"","verification_uri":""}`,
|
||||
},
|
||||
{
|
||||
name: "soon",
|
||||
response: DeviceAuthResponse{
|
||||
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
|
||||
},
|
||||
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
begin := time.Now()
|
||||
gotBytes, err := json.Marshal(tc.response)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
|
||||
t.Skip("test ran too slowly to compare `expires_in`")
|
||||
}
|
||||
got := string(gotBytes)
|
||||
if got != tc.want {
|
||||
t.Errorf("want=%s, got=%s", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
want DeviceAuthResponse
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
data: `{}`,
|
||||
want: DeviceAuthResponse{},
|
||||
},
|
||||
{
|
||||
name: "soon",
|
||||
data: `{"expires_in":100}`,
|
||||
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
begin := time.Now()
|
||||
got := DeviceAuthResponse{}
|
||||
err := json.Unmarshal([]byte(tc.data), &got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
|
||||
t.Errorf("want=%#v, got=%#v", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleConfig_DeviceAuth() {
|
||||
var config Config
|
||||
ctx := context.Background()
|
||||
response, err := config.DeviceAuth(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
|
||||
token, err := config.DeviceAccessToken(ctx, response)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(token)
|
||||
}
|
|
@ -57,6 +57,7 @@ var Fitbit = oauth2.Endpoint{
|
|||
var GitHub = oauth2.Endpoint{
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
DeviceAuthURL: "https://github.com/login/device/code",
|
||||
}
|
||||
|
||||
// GitLab is the endpoint for GitLab.
|
||||
|
@ -69,6 +70,7 @@ var GitLab = oauth2.Endpoint{
|
|||
var Google = oauth2.Endpoint{
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
|
||||
}
|
||||
|
||||
// Heroku is the endpoint for Heroku.
|
||||
|
|
|
@ -26,9 +26,13 @@ func ExampleConfig() {
|
|||
},
|
||||
}
|
||||
|
||||
// use PKCE to protect against CSRF attacks
|
||||
// https://www.ietf.org/archive/id/draft-ietf-oauth-security-topics-22.html#name-countermeasures-6
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
|
||||
// Redirect user to consent page to ask for permission
|
||||
// for the scopes specified above.
|
||||
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline)
|
||||
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
|
||||
fmt.Printf("Visit the URL for the auth dialog: %v", url)
|
||||
|
||||
// Use the authorization code that is pushed to the redirect
|
||||
|
@ -39,7 +43,7 @@ func ExampleConfig() {
|
|||
if _, err := fmt.Scan(&code); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
tok, err := conf.Exchange(ctx, code)
|
||||
tok, err := conf.Exchange(ctx, code, oauth2.VerifierOption(verifier))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -6,11 +6,8 @@
|
|||
package github // import "golang.org/x/oauth2/github"
|
||||
|
||||
import (
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/endpoints"
|
||||
)
|
||||
|
||||
// Endpoint is Github's OAuth 2.0 endpoint.
|
||||
var Endpoint = oauth2.Endpoint{
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
}
|
||||
var Endpoint = endpoints.GitHub
|
||||
|
|
2
go.mod
2
go.mod
|
@ -11,6 +11,6 @@ require (
|
|||
require (
|
||||
cloud.google.com/go/compute v1.20.1 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
golang.org/x/net v0.14.0 // indirect
|
||||
golang.org/x/net v0.15.0 // indirect
|
||||
google.golang.org/protobuf v1.31.0 // indirect
|
||||
)
|
||||
|
|
4
go.sum
4
go.sum
|
@ -11,8 +11,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
|||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
var Endpoint = oauth2.Endpoint{
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
}
|
||||
|
||||
|
|
|
@ -274,49 +274,6 @@ type awsRequest struct {
|
|||
Headers []awsRequestHeader `json:"headers"`
|
||||
}
|
||||
|
||||
func (cs awsCredentialSource) validateMetadataServers() error {
|
||||
if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
|
||||
return err
|
||||
}
|
||||
return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
|
||||
}
|
||||
|
||||
var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
|
||||
|
||||
func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
|
||||
if metadataUrl == "" {
|
||||
// Zero value means use default, which is valid.
|
||||
return true
|
||||
}
|
||||
|
||||
u, err := url.Parse(metadataUrl)
|
||||
if err != nil {
|
||||
// Unparseable URL means invalid
|
||||
return false
|
||||
}
|
||||
|
||||
for _, validHostname := range validHostnames {
|
||||
if u.Hostname() == validHostname {
|
||||
// If it's one of the valid hostnames, everything is good
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// hostname not found in our allowlist, so not valid
|
||||
return false
|
||||
}
|
||||
|
||||
func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
|
||||
if !cs.isValidMetadataServer(metadataUrl) {
|
||||
return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
|
||||
if cs.client == nil {
|
||||
cs.client = oauth2.NewClient(cs.ctx, nil)
|
||||
|
|
|
@ -585,25 +585,18 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
|
|||
func TestAWSCredential_BasicRequest(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -631,25 +624,18 @@ func TestAWSCredential_BasicRequest(t *testing.T) {
|
|||
func TestAWSCredential_IMDSv2(t *testing.T) {
|
||||
server := createDefaultAwsTestServerWithImdsv2(t)
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -677,10 +663,6 @@ func TestAWSCredential_IMDSv2(t *testing.T) {
|
|||
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
delete(server.Credentials, "Token")
|
||||
|
||||
tfc := testFileConfig
|
||||
|
@ -688,15 +670,12 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
|||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -724,21 +703,15 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
|||
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{
|
||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||
|
@ -746,7 +719,6 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
|
|||
"AWS_REGION": "us-west-1",
|
||||
})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -774,21 +746,15 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
|
|||
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{
|
||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||
|
@ -796,7 +762,6 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
|||
"AWS_REGION": "us-west-1",
|
||||
})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -823,21 +788,15 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
|||
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{
|
||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||
|
@ -846,7 +805,6 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
|||
"AWS_DEFAULT_REGION": "us-east-1",
|
||||
})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -873,25 +831,18 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
|||
func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
tfc.CredentialSource.EnvironmentID = "aws3"
|
||||
|
||||
oldGetenv := getenv
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
_, err = tfc.parse(context.Background())
|
||||
_, err := tfc.parse(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("parse() should have failed")
|
||||
}
|
||||
|
@ -903,23 +854,16 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
|
|||
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
tfc.CredentialSource.RegionURL = ""
|
||||
|
||||
oldGetenv := getenv
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -939,23 +883,17 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
|
|||
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
server.WriteRegion = notFound
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -975,10 +913,7 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
|
|||
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("{}"))
|
||||
}
|
||||
|
@ -987,13 +922,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
|
|||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -1013,10 +945,7 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
|
|||
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
|
||||
}
|
||||
|
@ -1025,13 +954,10 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
|
|||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -1051,23 +977,16 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
|
|||
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
tfc.CredentialSource.URL = ""
|
||||
|
||||
oldGetenv := getenv
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -1087,23 +1006,16 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
|
|||
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
server.WriteRolename = notFound
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -1123,23 +1035,16 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
|
|||
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
server.WriteSecurityCredentials = notFound
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{})
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -1159,10 +1064,6 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
|
|||
func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
|
||||
server := createDefaultAwsTestServer()
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("Metadata server should not have been called.")
|
||||
|
@ -1174,11 +1075,9 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
|
|||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{
|
||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||
|
@ -1186,7 +1085,6 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
|
|||
"AWS_REGION": "us-west-1",
|
||||
})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -1214,28 +1112,21 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
|
|||
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
|
||||
server := createDefaultAwsTestServerWithImdsv2(t)
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{
|
||||
"AWS_ACCESS_KEY_ID": accessKeyID,
|
||||
"AWS_SECRET_ACCESS_KEY": secretAccessKey,
|
||||
})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -1263,28 +1154,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
|
|||
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
|
||||
server := createDefaultAwsTestServerWithImdsv2(t)
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{
|
||||
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
|
||||
"AWS_REGION": "us-west-1",
|
||||
})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
@ -1312,28 +1196,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
|
|||
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
|
||||
server := createDefaultAwsTestServerWithImdsv2(t)
|
||||
ts := httptest.NewServer(server)
|
||||
tsURL, err := neturl.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse httptest servername")
|
||||
}
|
||||
|
||||
tfc := testFileConfig
|
||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||
|
||||
oldGetenv := getenv
|
||||
oldNow := now
|
||||
oldValidHostnames := validHostnames
|
||||
defer func() {
|
||||
getenv = oldGetenv
|
||||
now = oldNow
|
||||
validHostnames = oldValidHostnames
|
||||
}()
|
||||
getenv = setEnvironment(map[string]string{
|
||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||
"AWS_REGION": "us-west-1",
|
||||
})
|
||||
now = setTime(defaultTime)
|
||||
validHostnames = []string{tsURL.Hostname()}
|
||||
|
||||
base, err := tfc.parse(context.Background())
|
||||
if err != nil {
|
||||
|
|
|
@ -185,10 +185,6 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
|
|||
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
|
||||
}
|
||||
|
||||
if err := awsCredSource.validateMetadataServers(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return awsCredSource, nil
|
||||
}
|
||||
} else if c.CredentialSource.File != "" {
|
||||
|
|
25
oauth2.go
25
oauth2.go
|
@ -76,6 +76,7 @@ type TokenSource interface {
|
|||
// endpoint URLs.
|
||||
type Endpoint struct {
|
||||
AuthURL string
|
||||
DeviceAuthURL string
|
||||
TokenURL string
|
||||
|
||||
// AuthStyle optionally specifies how the endpoint wants the
|
||||
|
@ -143,15 +144,19 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
|
|||
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
||||
// that asks for permissions for the required scopes explicitly.
|
||||
//
|
||||
// State is a token to protect the user from CSRF attacks. You must
|
||||
// always provide a non-empty string and validate that it matches the
|
||||
// state query parameter on your redirect callback.
|
||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||
// State is an opaque value used by the client to maintain state between the
|
||||
// request and callback. The authorization server includes this value when
|
||||
// redirecting the user agent back to the client.
|
||||
//
|
||||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
||||
// as ApprovalForce.
|
||||
// It can also be used to pass the PKCE challenge.
|
||||
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
|
||||
//
|
||||
// To protect against CSRF attacks, opts should include a PKCE challenge
|
||||
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
|
||||
// generate a random state parameter and verify it after exchange.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6749#section-10.12 (predating
|
||||
// PKCE), https://www.oauth.com/oauth2-servers/pkce/ and
|
||||
// https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-09.html#name-cross-site-request-forgery (describing both approaches)
|
||||
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(c.Endpoint.AuthURL)
|
||||
|
@ -166,7 +171,6 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
|||
v.Set("scope", strings.Join(c.Scopes, " "))
|
||||
}
|
||||
if state != "" {
|
||||
// TODO(light): Docs say never to omit state; don't allow empty.
|
||||
v.Set("state", state)
|
||||
}
|
||||
for _, opt := range opts {
|
||||
|
@ -211,10 +215,11 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
|
|||
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
|
||||
//
|
||||
// The code will be in the *http.Request.FormValue("code"). Before
|
||||
// calling Exchange, be sure to validate FormValue("state").
|
||||
// calling Exchange, be sure to validate FormValue("state") if you are
|
||||
// using it to protect against CSRF attacks.
|
||||
//
|
||||
// Opts may include the PKCE verifier code if previously used in AuthCodeURL.
|
||||
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
|
||||
// If using PKCE to protect against CSRF attacks, opts should include a
|
||||
// VerifierOption.
|
||||
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
|
||||
v := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
codeChallengeKey = "code_challenge"
|
||||
codeChallengeMethodKey = "code_challenge_method"
|
||||
codeVerifierKey = "code_verifier"
|
||||
)
|
||||
|
||||
// GenerateVerifier generates a PKCE code verifier with 32 octets of randomness.
|
||||
// This follows recommendations in RFC 7636.
|
||||
//
|
||||
// A fresh verifier should be generated for each authorization.
|
||||
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
|
||||
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
|
||||
// (or Config.DeviceAccessToken).
|
||||
func GenerateVerifier() string {
|
||||
// "RECOMMENDED that the output of a suitable random number generator be
|
||||
// used to create a 32-octet sequence. The octet sequence is then
|
||||
// base64url-encoded to produce a 43-octet URL-safe string to use as the
|
||||
// code verifier."
|
||||
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
|
||||
data := make([]byte, 32)
|
||||
if _, err := rand.Read(data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
|
||||
// passed to Config.Exchange or Config.DeviceAccessToken only.
|
||||
func VerifierOption(verifier string) AuthCodeOption {
|
||||
return setParam{k: codeVerifierKey, v: verifier}
|
||||
}
|
||||
|
||||
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
|
||||
//
|
||||
// Prefer to use S256ChallengeOption where possible.
|
||||
func S256ChallengeFromVerifier(verifier string) string {
|
||||
sha := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sha[:])
|
||||
}
|
||||
|
||||
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
|
||||
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
|
||||
// only.
|
||||
func S256ChallengeOption(verifier string) AuthCodeOption {
|
||||
return challengeOption{
|
||||
challenge_method: "S256",
|
||||
challenge: S256ChallengeFromVerifier(verifier),
|
||||
}
|
||||
}
|
||||
|
||||
type challengeOption struct{ challenge_method, challenge string }
|
||||
|
||||
func (p challengeOption) setValue(m url.Values) {
|
||||
m.Set(codeChallengeMethodKey, p.challenge_method)
|
||||
m.Set(codeChallengeKey, p.challenge)
|
||||
}
|
Loading…
Reference in New Issue