oauth2/google/internal/externalaccount/aws.go

251 lines
6.5 KiB
Go
Raw Normal View History

2021-01-19 12:26:12 -05:00
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
2021-01-20 13:44:08 -05:00
2021-01-19 12:26:12 -05:00
package externalaccount
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
2021-01-20 13:44:08 -05:00
"errors"
2021-01-19 12:26:12 -05:00
"fmt"
2021-01-20 13:44:08 -05:00
"io"
2021-01-19 12:26:12 -05:00
"io/ioutil"
"net/http"
"path"
"sort"
"strings"
"time"
)
2021-01-20 13:44:08 -05:00
// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
2021-01-19 12:26:12 -05:00
type RequestSigner struct {
RegionName string
AwsSecurityCredentials map[string]string
}
2021-01-20 13:44:08 -05:00
// NewRequestSigner is a method to create a RequestSigner with the appropriately filled out fields.
2021-01-19 12:26:12 -05:00
func NewRequestSigner(regionName string, awsSecurityCredentials map[string]string) *RequestSigner {
return &RequestSigner{
RegionName: regionName,
AwsSecurityCredentials: awsSecurityCredentials,
}
}
2021-01-20 13:44:08 -05:00
const (
2021-01-19 12:26:12 -05:00
// AWS Signature Version 4 signing algorithm identifier.
2021-01-20 13:44:08 -05:00
awsAlgorithm = "AWS4-HMAC-SHA256"
2021-01-19 12:26:12 -05:00
// The termination string for the AWS credential scope value as defined in
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
2021-01-20 13:44:08 -05:00
awsRequestType = "aws4_request"
2021-01-19 12:26:12 -05:00
// The AWS authorization header name for the security session token if available.
2021-01-20 13:44:08 -05:00
awsSecurityTokenHeader = "x-amz-security-token"
2021-01-19 12:26:12 -05:00
// The AWS authorization header name for the auto-generated date.
2021-01-20 13:44:08 -05:00
awsDateHeader = "x-amz-date"
2021-01-19 12:26:12 -05:00
2021-01-20 13:44:08 -05:00
awsTimeFormatLong = "20060102T150405Z"
awsTimeFormatShort = "20060102"
)
2021-01-19 12:26:12 -05:00
2021-01-20 13:44:08 -05:00
func getSha256(input []byte) (string, error) {
2021-01-19 12:26:12 -05:00
hash := sha256.New()
2021-01-20 13:44:08 -05:00
if _, err := hash.Write(input); err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
2021-01-19 12:26:12 -05:00
}
2021-01-20 13:44:08 -05:00
func getHmacSha256(key, input []byte) ([]byte, error) {
2021-01-19 12:26:12 -05:00
hash := hmac.New(sha256.New, key)
2021-01-20 13:44:08 -05:00
if _, err := hash.Write(input); err != nil {
return nil, err
}
return hash.Sum(nil), nil
}
func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
if r.Header != nil {
r2.Header = make(http.Header, len(r.Header))
// Find total number of values.
headerCount := 0
for _, headerValues := range r.Header {
headerCount += len(headerValues)
}
copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
for headerKey, headerValues := range r.Header {
headerCount = copy(copiedHeaders, headerValues)
r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
copiedHeaders = copiedHeaders[headerCount:]
}
}
return r2
2021-01-19 12:26:12 -05:00
}
func canonicalPath(req *http.Request) string {
result := req.URL.EscapedPath()
if result == "" {
return "/"
}
return path.Clean(result)
}
func canonicalQuery(req *http.Request) string {
queryValues := req.URL.Query()
for queryKey := range queryValues {
sort.Strings(queryValues[queryKey])
}
return queryValues.Encode()
}
func canonicalHeaders(req *http.Request) (string, string) {
// Header keys need to be sorted alphabetically.
var headers []string
lowerCaseHeaders := make(http.Header)
for k, v := range req.Header {
k := strings.ToLower(k)
if _, ok := lowerCaseHeaders[k]; ok {
// include additional values
lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
} else {
headers = append(headers, k)
lowerCaseHeaders[k] = v
}
}
sort.Strings(headers)
2021-01-20 13:44:08 -05:00
var fullHeaders strings.Builder
2021-01-19 12:26:12 -05:00
for _, header := range headers {
headerValue := strings.Join(lowerCaseHeaders[header], ",")
2021-01-20 13:44:08 -05:00
fullHeaders.WriteString(header)
fullHeaders.WriteRune(':')
fullHeaders.WriteString(headerValue)
fullHeaders.WriteRune('\n')
2021-01-19 12:26:12 -05:00
}
2021-01-20 13:44:08 -05:00
return strings.Join(headers, ";"), fullHeaders.String()
2021-01-19 12:26:12 -05:00
}
2021-01-20 13:44:08 -05:00
func requestDataHash(req *http.Request) (string, error) {
var requestData []byte
2021-01-19 12:26:12 -05:00
if req.Body != nil {
2021-01-20 13:44:08 -05:00
requestBody, err := req.GetBody()
if err != nil {
return "", err
}
defer requestBody.Close()
requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
if err != nil {
return "", err
}
2021-01-19 12:26:12 -05:00
}
return getSha256(requestData)
}
func requestHost(req *http.Request) string {
if req.Host != "" {
return req.Host
}
return req.URL.Host
}
2021-01-20 13:44:08 -05:00
func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
dataHash, err := requestDataHash(req)
if err != nil {
return "", err
}
2021-01-19 12:26:12 -05:00
return strings.Join([]string{
req.Method,
canonicalPath(req),
canonicalQuery(req),
canonicalHeaderData,
canonicalHeaderColumns,
2021-01-20 13:44:08 -05:00
dataHash,
}, "\n"), nil
2021-01-19 12:26:12 -05:00
}
2021-01-20 13:44:08 -05:00
// SignRequest adds the appropriate headers to an http.Request
// or returns an error if something prevented this.
func (rs *RequestSigner) SignRequest(req *http.Request) error {
signedRequest := cloneRequest(req)
timestamp := now()
2021-01-19 12:26:12 -05:00
signedRequest.Header.Add("host", requestHost(req))
2021-01-20 13:44:08 -05:00
if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok {
signedRequest.Header.Add(awsSecurityTokenHeader, securityToken)
2021-01-19 12:26:12 -05:00
}
if signedRequest.Header.Get("date") == "" {
2021-01-20 13:44:08 -05:00
signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
2021-01-19 12:26:12 -05:00
}
2021-01-20 13:44:08 -05:00
authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
if err != nil {
return err
}
signedRequest.Header.Set("Authorization", authorizationCode)
2021-01-19 12:26:12 -05:00
2021-01-20 13:44:08 -05:00
req.Header = signedRequest.Header
return nil
2021-01-19 12:26:12 -05:00
}
2021-01-20 13:44:08 -05:00
func (rs *RequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"]
if !ok {
return "", errors.New("Missing Secret Access Key")
}
accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"]
if !ok {
return "", errors.New("Missing Access Key Id")
}
2021-01-19 12:26:12 -05:00
canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
dateStamp := timestamp.Format(awsTimeFormatShort)
2021-01-20 13:44:08 -05:00
serviceName := ""
if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
serviceName = splitHost[0]
}
credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType)
2021-01-19 12:26:12 -05:00
2021-01-20 13:44:08 -05:00
requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
if err != nil {
return "", err
}
requestHash, err := getSha256([]byte(requestString))
if err != nil{
return "", err
}
2021-01-19 12:26:12 -05:00
stringToSign := strings.Join([]string{
awsAlgorithm,
timestamp.Format(awsTimeFormatLong),
credentialScope,
2021-01-20 13:44:08 -05:00
requestHash,
2021-01-19 12:26:12 -05:00
}, "\n")
2021-01-20 13:44:08 -05:00
signingKey := []byte("AWS4" + secretAccessKey)
2021-01-19 12:26:12 -05:00
for _, signingInput := range []string{
dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
} {
2021-01-20 13:44:08 -05:00
signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
if err != nil{
return "", err
}
2021-01-19 12:26:12 -05:00
}
2021-01-20 13:44:08 -05:00
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
2021-01-19 12:26:12 -05:00
}