forked from Mirrors/oauth2
Merge branch 'byoid-metrics' of https://github.com/aeitzman/oauth2 into byoid-metrics
This commit is contained in:
commit
523ee10d3c
|
@ -0,0 +1,188 @@
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
|
||||||
|
const (
|
||||||
|
errAuthorizationPending = "authorization_pending"
|
||||||
|
errSlowDown = "slow_down"
|
||||||
|
errAccessDenied = "access_denied"
|
||||||
|
errExpiredToken = "expired_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
|
||||||
|
type DeviceAuthResponse struct {
|
||||||
|
// DeviceCode
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
// UserCode is the code the user should enter at the verification uri
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
// VerificationURI is where user should enter the user code
|
||||||
|
VerificationURI string `json:"verification_uri"`
|
||||||
|
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
|
||||||
|
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
|
||||||
|
// Expiry is when the device code and user code expire
|
||||||
|
Expiry time.Time `json:"expires_in,omitempty"`
|
||||||
|
// Interval is the duration in seconds that Poll should wait between requests
|
||||||
|
Interval int64 `json:"interval,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
type Alias DeviceAuthResponse
|
||||||
|
var expiresIn int64
|
||||||
|
if !d.Expiry.IsZero() {
|
||||||
|
expiresIn = int64(time.Until(d.Expiry).Seconds())
|
||||||
|
}
|
||||||
|
return json.Marshal(&struct {
|
||||||
|
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||||
|
*Alias
|
||||||
|
}{
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
Alias: (*Alias)(&d),
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
type Alias DeviceAuthResponse
|
||||||
|
aux := &struct {
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
*Alias
|
||||||
|
}{
|
||||||
|
Alias: (*Alias)(c),
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &aux); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if aux.ExpiresIn != 0 {
|
||||||
|
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuth returns a device auth struct which contains a device code
|
||||||
|
// and authorization information provided for users to enter on another device.
|
||||||
|
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
|
||||||
|
v := url.Values{
|
||||||
|
"client_id": {c.ClientID},
|
||||||
|
}
|
||||||
|
if len(c.Scopes) > 0 {
|
||||||
|
v.Set("scope", strings.Join(c.Scopes, " "))
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt.setValue(v)
|
||||||
|
}
|
||||||
|
return retrieveDeviceAuth(ctx, c, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
|
||||||
|
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
t := time.Now()
|
||||||
|
r, err := internal.ContextClient(ctx).Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
|
||||||
|
}
|
||||||
|
if code := r.StatusCode; code < 200 || code > 299 {
|
||||||
|
return nil, &RetrieveError{
|
||||||
|
Response: r,
|
||||||
|
Body: body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
da := &DeviceAuthResponse{}
|
||||||
|
err = json.Unmarshal(body, &da)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !da.Expiry.IsZero() {
|
||||||
|
// Make a small adjustment to account for time taken by the request
|
||||||
|
da.Expiry = da.Expiry.Add(-time.Since(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
return da, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAccessToken polls the server to exchange a device code for a token.
|
||||||
|
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
|
||||||
|
if !da.Expiry.IsZero() {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
|
||||||
|
v := url.Values{
|
||||||
|
"client_id": {c.ClientID},
|
||||||
|
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
|
||||||
|
"device_code": {da.DeviceCode},
|
||||||
|
}
|
||||||
|
if len(c.Scopes) > 0 {
|
||||||
|
v.Set("scope", strings.Join(c.Scopes, " "))
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt.setValue(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// "If no value is provided, clients MUST use 5 as the default."
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
|
||||||
|
interval := da.Interval
|
||||||
|
if interval == 0 {
|
||||||
|
interval = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
tok, err := retrieveToken(ctx, c, v)
|
||||||
|
if err == nil {
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
e, ok := err.(*RetrieveError)
|
||||||
|
if !ok {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch e.ErrorCode {
|
||||||
|
case errSlowDown:
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
|
||||||
|
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
|
||||||
|
interval += 5
|
||||||
|
ticker.Reset(time.Duration(interval) * time.Second)
|
||||||
|
case errAuthorizationPending:
|
||||||
|
// Do nothing.
|
||||||
|
case errAccessDenied, errExpiredToken:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
return tok, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,97 @@
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
response DeviceAuthResponse
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
response: DeviceAuthResponse{},
|
||||||
|
want: `{"device_code":"","user_code":"","verification_uri":""}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "soon",
|
||||||
|
response: DeviceAuthResponse{
|
||||||
|
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
|
||||||
|
},
|
||||||
|
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
begin := time.Now()
|
||||||
|
gotBytes, err := json.Marshal(tc.response)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
|
||||||
|
t.Skip("test ran too slowly to compare `expires_in`")
|
||||||
|
}
|
||||||
|
got := string(gotBytes)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("want=%s, got=%s", tc.want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data string
|
||||||
|
want DeviceAuthResponse
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
data: `{}`,
|
||||||
|
want: DeviceAuthResponse{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "soon",
|
||||||
|
data: `{"expires_in":100}`,
|
||||||
|
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
begin := time.Now()
|
||||||
|
got := DeviceAuthResponse{}
|
||||||
|
err := json.Unmarshal([]byte(tc.data), &got)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
|
||||||
|
t.Errorf("want=%#v, got=%#v", tc.want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleConfig_DeviceAuth() {
|
||||||
|
var config Config
|
||||||
|
ctx := context.Background()
|
||||||
|
response, err := config.DeviceAuth(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
|
||||||
|
token, err := config.DeviceAccessToken(ctx, response)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Println(token)
|
||||||
|
}
|
|
@ -57,6 +57,7 @@ var Fitbit = oauth2.Endpoint{
|
||||||
var GitHub = oauth2.Endpoint{
|
var GitHub = oauth2.Endpoint{
|
||||||
AuthURL: "https://github.com/login/oauth/authorize",
|
AuthURL: "https://github.com/login/oauth/authorize",
|
||||||
TokenURL: "https://github.com/login/oauth/access_token",
|
TokenURL: "https://github.com/login/oauth/access_token",
|
||||||
|
DeviceAuthURL: "https://github.com/login/device/code",
|
||||||
}
|
}
|
||||||
|
|
||||||
// GitLab is the endpoint for GitLab.
|
// GitLab is the endpoint for GitLab.
|
||||||
|
@ -69,6 +70,7 @@ var GitLab = oauth2.Endpoint{
|
||||||
var Google = oauth2.Endpoint{
|
var Google = oauth2.Endpoint{
|
||||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
TokenURL: "https://oauth2.googleapis.com/token",
|
TokenURL: "https://oauth2.googleapis.com/token",
|
||||||
|
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heroku is the endpoint for Heroku.
|
// Heroku is the endpoint for Heroku.
|
||||||
|
|
|
@ -26,9 +26,13 @@ func ExampleConfig() {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// use PKCE to protect against CSRF attacks
|
||||||
|
// https://www.ietf.org/archive/id/draft-ietf-oauth-security-topics-22.html#name-countermeasures-6
|
||||||
|
verifier := oauth2.GenerateVerifier()
|
||||||
|
|
||||||
// Redirect user to consent page to ask for permission
|
// Redirect user to consent page to ask for permission
|
||||||
// for the scopes specified above.
|
// for the scopes specified above.
|
||||||
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline)
|
url := conf.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
|
||||||
fmt.Printf("Visit the URL for the auth dialog: %v", url)
|
fmt.Printf("Visit the URL for the auth dialog: %v", url)
|
||||||
|
|
||||||
// Use the authorization code that is pushed to the redirect
|
// Use the authorization code that is pushed to the redirect
|
||||||
|
@ -39,7 +43,7 @@ func ExampleConfig() {
|
||||||
if _, err := fmt.Scan(&code); err != nil {
|
if _, err := fmt.Scan(&code); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
tok, err := conf.Exchange(ctx, code)
|
tok, err := conf.Exchange(ctx, code, oauth2.VerifierOption(verifier))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,11 +6,8 @@
|
||||||
package github // import "golang.org/x/oauth2/github"
|
package github // import "golang.org/x/oauth2/github"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2/endpoints"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Endpoint is Github's OAuth 2.0 endpoint.
|
// Endpoint is Github's OAuth 2.0 endpoint.
|
||||||
var Endpoint = oauth2.Endpoint{
|
var Endpoint = endpoints.GitHub
|
||||||
AuthURL: "https://github.com/login/oauth/authorize",
|
|
||||||
TokenURL: "https://github.com/login/oauth/access_token",
|
|
||||||
}
|
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -11,6 +11,6 @@ require (
|
||||||
require (
|
require (
|
||||||
cloud.google.com/go/compute v1.20.1 // indirect
|
cloud.google.com/go/compute v1.20.1 // indirect
|
||||||
github.com/golang/protobuf v1.5.3 // indirect
|
github.com/golang/protobuf v1.5.3 // indirect
|
||||||
golang.org/x/net v0.14.0 // indirect
|
golang.org/x/net v0.15.0 // indirect
|
||||||
google.golang.org/protobuf v1.31.0 // indirect
|
google.golang.org/protobuf v1.31.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -11,8 +11,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||||
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
|
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
|
||||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
var Endpoint = oauth2.Endpoint{
|
var Endpoint = oauth2.Endpoint{
|
||||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
TokenURL: "https://oauth2.googleapis.com/token",
|
TokenURL: "https://oauth2.googleapis.com/token",
|
||||||
|
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
|
||||||
AuthStyle: oauth2.AuthStyleInParams,
|
AuthStyle: oauth2.AuthStyleInParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -274,49 +274,6 @@ type awsRequest struct {
|
||||||
Headers []awsRequestHeader `json:"headers"`
|
Headers []awsRequestHeader `json:"headers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs awsCredentialSource) validateMetadataServers() error {
|
|
||||||
if err := cs.validateMetadataServer(cs.RegionURL, "region_url"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := cs.validateMetadataServer(cs.CredVerificationURL, "url"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return cs.validateMetadataServer(cs.IMDSv2SessionTokenURL, "imdsv2_session_token_url")
|
|
||||||
}
|
|
||||||
|
|
||||||
var validHostnames []string = []string{"169.254.169.254", "fd00:ec2::254"}
|
|
||||||
|
|
||||||
func (cs awsCredentialSource) isValidMetadataServer(metadataUrl string) bool {
|
|
||||||
if metadataUrl == "" {
|
|
||||||
// Zero value means use default, which is valid.
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := url.Parse(metadataUrl)
|
|
||||||
if err != nil {
|
|
||||||
// Unparseable URL means invalid
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, validHostname := range validHostnames {
|
|
||||||
if u.Hostname() == validHostname {
|
|
||||||
// If it's one of the valid hostnames, everything is good
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// hostname not found in our allowlist, so not valid
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs awsCredentialSource) validateMetadataServer(metadataUrl, urlName string) error {
|
|
||||||
if !cs.isValidMetadataServer(metadataUrl) {
|
|
||||||
return fmt.Errorf("oauth2/google: invalid hostname %s for %s", metadataUrl, urlName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
|
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
|
||||||
if cs.client == nil {
|
if cs.client == nil {
|
||||||
cs.client = oauth2.NewClient(cs.ctx, nil)
|
cs.client = oauth2.NewClient(cs.ctx, nil)
|
||||||
|
|
|
@ -585,25 +585,18 @@ func getExpectedSubjectToken(url, region, accessKeyID, secretAccessKey, security
|
||||||
func TestAWSCredential_BasicRequest(t *testing.T) {
|
func TestAWSCredential_BasicRequest(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -631,25 +624,18 @@ func TestAWSCredential_BasicRequest(t *testing.T) {
|
||||||
func TestAWSCredential_IMDSv2(t *testing.T) {
|
func TestAWSCredential_IMDSv2(t *testing.T) {
|
||||||
server := createDefaultAwsTestServerWithImdsv2(t)
|
server := createDefaultAwsTestServerWithImdsv2(t)
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -677,10 +663,6 @@ func TestAWSCredential_IMDSv2(t *testing.T) {
|
||||||
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
delete(server.Credentials, "Token")
|
delete(server.Credentials, "Token")
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
|
@ -688,15 +670,12 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -724,21 +703,15 @@ func TestAWSCredential_BasicRequestWithoutSecurityToken(t *testing.T) {
|
||||||
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
|
func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{
|
getenv = setEnvironment(map[string]string{
|
||||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||||
|
@ -746,7 +719,6 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
|
||||||
"AWS_REGION": "us-west-1",
|
"AWS_REGION": "us-west-1",
|
||||||
})
|
})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -774,21 +746,15 @@ func TestAWSCredential_BasicRequestWithEnv(t *testing.T) {
|
||||||
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{
|
getenv = setEnvironment(map[string]string{
|
||||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||||
|
@ -796,7 +762,6 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
||||||
"AWS_REGION": "us-west-1",
|
"AWS_REGION": "us-west-1",
|
||||||
})
|
})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -823,21 +788,15 @@ func TestAWSCredential_BasicRequestWithDefaultEnv(t *testing.T) {
|
||||||
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{
|
getenv = setEnvironment(map[string]string{
|
||||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||||
|
@ -846,7 +805,6 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
||||||
"AWS_DEFAULT_REGION": "us-east-1",
|
"AWS_DEFAULT_REGION": "us-east-1",
|
||||||
})
|
})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -873,25 +831,18 @@ func TestAWSCredential_BasicRequestWithTwoRegions(t *testing.T) {
|
||||||
func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
|
func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
tfc.CredentialSource.EnvironmentID = "aws3"
|
tfc.CredentialSource.EnvironmentID = "aws3"
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
_, err = tfc.parse(context.Background())
|
_, err := tfc.parse(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("parse() should have failed")
|
t.Fatalf("parse() should have failed")
|
||||||
}
|
}
|
||||||
|
@ -903,23 +854,16 @@ func TestAWSCredential_RequestWithBadVersion(t *testing.T) {
|
||||||
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
|
func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
tfc.CredentialSource.RegionURL = ""
|
tfc.CredentialSource.RegionURL = ""
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -939,23 +883,17 @@ func TestAWSCredential_RequestWithNoRegionURL(t *testing.T) {
|
||||||
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
|
func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
server.WriteRegion = notFound
|
server.WriteRegion = notFound
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -975,10 +913,7 @@ func TestAWSCredential_RequestWithBadRegionURL(t *testing.T) {
|
||||||
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
|
func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("{}"))
|
w.Write([]byte("{}"))
|
||||||
}
|
}
|
||||||
|
@ -987,13 +922,10 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1013,10 +945,7 @@ func TestAWSCredential_RequestWithMissingCredential(t *testing.T) {
|
||||||
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
|
func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
server.WriteSecurityCredentials = func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
|
w.Write([]byte(`{"AccessKeyId":"FOOBARBAS"}`))
|
||||||
}
|
}
|
||||||
|
@ -1025,13 +954,10 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1051,23 +977,16 @@ func TestAWSCredential_RequestWithIncompleteCredential(t *testing.T) {
|
||||||
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
|
func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
tfc.CredentialSource.URL = ""
|
tfc.CredentialSource.URL = ""
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1087,23 +1006,16 @@ func TestAWSCredential_RequestWithNoCredentialURL(t *testing.T) {
|
||||||
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
|
func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
server.WriteRolename = notFound
|
server.WriteRolename = notFound
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1123,23 +1035,16 @@ func TestAWSCredential_RequestWithBadCredentialURL(t *testing.T) {
|
||||||
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
|
func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
server.WriteSecurityCredentials = notFound
|
server.WriteSecurityCredentials = notFound
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{})
|
getenv = setEnvironment(map[string]string{})
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1159,10 +1064,6 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) {
|
||||||
func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
|
func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) {
|
||||||
server := createDefaultAwsTestServer()
|
server := createDefaultAwsTestServer()
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Error("Metadata server should not have been called.")
|
t.Error("Metadata server should not have been called.")
|
||||||
|
@ -1174,11 +1075,9 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{
|
getenv = setEnvironment(map[string]string{
|
||||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||||
|
@ -1186,7 +1085,6 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
|
||||||
"AWS_REGION": "us-west-1",
|
"AWS_REGION": "us-west-1",
|
||||||
})
|
})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1214,28 +1112,21 @@ func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing
|
||||||
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
|
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
|
||||||
server := createDefaultAwsTestServerWithImdsv2(t)
|
server := createDefaultAwsTestServerWithImdsv2(t)
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{
|
getenv = setEnvironment(map[string]string{
|
||||||
"AWS_ACCESS_KEY_ID": accessKeyID,
|
"AWS_ACCESS_KEY_ID": accessKeyID,
|
||||||
"AWS_SECRET_ACCESS_KEY": secretAccessKey,
|
"AWS_SECRET_ACCESS_KEY": secretAccessKey,
|
||||||
})
|
})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1263,28 +1154,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) {
|
||||||
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
|
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
|
||||||
server := createDefaultAwsTestServerWithImdsv2(t)
|
server := createDefaultAwsTestServerWithImdsv2(t)
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{
|
getenv = setEnvironment(map[string]string{
|
||||||
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
|
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
|
||||||
"AWS_REGION": "us-west-1",
|
"AWS_REGION": "us-west-1",
|
||||||
})
|
})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1312,28 +1196,21 @@ func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) {
|
||||||
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
|
func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) {
|
||||||
server := createDefaultAwsTestServerWithImdsv2(t)
|
server := createDefaultAwsTestServerWithImdsv2(t)
|
||||||
ts := httptest.NewServer(server)
|
ts := httptest.NewServer(server)
|
||||||
tsURL, err := neturl.Parse(ts.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't parse httptest servername")
|
|
||||||
}
|
|
||||||
|
|
||||||
tfc := testFileConfig
|
tfc := testFileConfig
|
||||||
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
tfc.CredentialSource = server.getCredentialSource(ts.URL)
|
||||||
|
|
||||||
oldGetenv := getenv
|
oldGetenv := getenv
|
||||||
oldNow := now
|
oldNow := now
|
||||||
oldValidHostnames := validHostnames
|
|
||||||
defer func() {
|
defer func() {
|
||||||
getenv = oldGetenv
|
getenv = oldGetenv
|
||||||
now = oldNow
|
now = oldNow
|
||||||
validHostnames = oldValidHostnames
|
|
||||||
}()
|
}()
|
||||||
getenv = setEnvironment(map[string]string{
|
getenv = setEnvironment(map[string]string{
|
||||||
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
"AWS_ACCESS_KEY_ID": "AKIDEXAMPLE",
|
||||||
"AWS_REGION": "us-west-1",
|
"AWS_REGION": "us-west-1",
|
||||||
})
|
})
|
||||||
now = setTime(defaultTime)
|
now = setTime(defaultTime)
|
||||||
validHostnames = []string{tsURL.Hostname()}
|
|
||||||
|
|
||||||
base, err := tfc.parse(context.Background())
|
base, err := tfc.parse(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -185,10 +185,6 @@ func (c *Config) parse(ctx context.Context) (baseCredentialSource, error) {
|
||||||
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
|
awsCredSource.IMDSv2SessionTokenURL = c.CredentialSource.IMDSv2SessionTokenURL
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := awsCredSource.validateMetadataServers(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return awsCredSource, nil
|
return awsCredSource, nil
|
||||||
}
|
}
|
||||||
} else if c.CredentialSource.File != "" {
|
} else if c.CredentialSource.File != "" {
|
||||||
|
|
25
oauth2.go
25
oauth2.go
|
@ -76,6 +76,7 @@ type TokenSource interface {
|
||||||
// endpoint URLs.
|
// endpoint URLs.
|
||||||
type Endpoint struct {
|
type Endpoint struct {
|
||||||
AuthURL string
|
AuthURL string
|
||||||
|
DeviceAuthURL string
|
||||||
TokenURL string
|
TokenURL string
|
||||||
|
|
||||||
// AuthStyle optionally specifies how the endpoint wants the
|
// AuthStyle optionally specifies how the endpoint wants the
|
||||||
|
@ -143,15 +144,19 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
|
||||||
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
||||||
// that asks for permissions for the required scopes explicitly.
|
// that asks for permissions for the required scopes explicitly.
|
||||||
//
|
//
|
||||||
// State is a token to protect the user from CSRF attacks. You must
|
// State is an opaque value used by the client to maintain state between the
|
||||||
// always provide a non-empty string and validate that it matches the
|
// request and callback. The authorization server includes this value when
|
||||||
// state query parameter on your redirect callback.
|
// redirecting the user agent back to the client.
|
||||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
|
||||||
//
|
//
|
||||||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
||||||
// as ApprovalForce.
|
// as ApprovalForce.
|
||||||
// It can also be used to pass the PKCE challenge.
|
//
|
||||||
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
|
// To protect against CSRF attacks, opts should include a PKCE challenge
|
||||||
|
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
|
||||||
|
// generate a random state parameter and verify it after exchange.
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc6749#section-10.12 (predating
|
||||||
|
// PKCE), https://www.oauth.com/oauth2-servers/pkce/ and
|
||||||
|
// https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-09.html#name-cross-site-request-forgery (describing both approaches)
|
||||||
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
buf.WriteString(c.Endpoint.AuthURL)
|
buf.WriteString(c.Endpoint.AuthURL)
|
||||||
|
@ -166,7 +171,6 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
||||||
v.Set("scope", strings.Join(c.Scopes, " "))
|
v.Set("scope", strings.Join(c.Scopes, " "))
|
||||||
}
|
}
|
||||||
if state != "" {
|
if state != "" {
|
||||||
// TODO(light): Docs say never to omit state; don't allow empty.
|
|
||||||
v.Set("state", state)
|
v.Set("state", state)
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
|
@ -211,10 +215,11 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
|
||||||
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
|
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
|
||||||
//
|
//
|
||||||
// The code will be in the *http.Request.FormValue("code"). Before
|
// The code will be in the *http.Request.FormValue("code"). Before
|
||||||
// calling Exchange, be sure to validate FormValue("state").
|
// calling Exchange, be sure to validate FormValue("state") if you are
|
||||||
|
// using it to protect against CSRF attacks.
|
||||||
//
|
//
|
||||||
// Opts may include the PKCE verifier code if previously used in AuthCodeURL.
|
// If using PKCE to protect against CSRF attacks, opts should include a
|
||||||
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
|
// VerifierOption.
|
||||||
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
|
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
|
||||||
v := url.Values{
|
v := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
// Copyright 2023 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codeChallengeKey = "code_challenge"
|
||||||
|
codeChallengeMethodKey = "code_challenge_method"
|
||||||
|
codeVerifierKey = "code_verifier"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateVerifier generates a PKCE code verifier with 32 octets of randomness.
|
||||||
|
// This follows recommendations in RFC 7636.
|
||||||
|
//
|
||||||
|
// A fresh verifier should be generated for each authorization.
|
||||||
|
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
|
||||||
|
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
|
||||||
|
// (or Config.DeviceAccessToken).
|
||||||
|
func GenerateVerifier() string {
|
||||||
|
// "RECOMMENDED that the output of a suitable random number generator be
|
||||||
|
// used to create a 32-octet sequence. The octet sequence is then
|
||||||
|
// base64url-encoded to produce a 43-octet URL-safe string to use as the
|
||||||
|
// code verifier."
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
|
||||||
|
data := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(data); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
|
||||||
|
// passed to Config.Exchange or Config.DeviceAccessToken only.
|
||||||
|
func VerifierOption(verifier string) AuthCodeOption {
|
||||||
|
return setParam{k: codeVerifierKey, v: verifier}
|
||||||
|
}
|
||||||
|
|
||||||
|
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
|
||||||
|
//
|
||||||
|
// Prefer to use S256ChallengeOption where possible.
|
||||||
|
func S256ChallengeFromVerifier(verifier string) string {
|
||||||
|
sha := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(sha[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
|
||||||
|
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
|
||||||
|
// only.
|
||||||
|
func S256ChallengeOption(verifier string) AuthCodeOption {
|
||||||
|
return challengeOption{
|
||||||
|
challenge_method: "S256",
|
||||||
|
challenge: S256ChallengeFromVerifier(verifier),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type challengeOption struct{ challenge_method, challenge string }
|
||||||
|
|
||||||
|
func (p challengeOption) setValue(m url.Values) {
|
||||||
|
m.Set(codeChallengeMethodKey, p.challenge_method)
|
||||||
|
m.Set(codeChallengeKey, p.challenge)
|
||||||
|
}
|
Loading…
Reference in New Issue