oauth2: Allow use of arbitrary RSA private keys to sign JWT token retrieving requests.

This commit is contained in:
Burcu Dogan 2014-08-11 17:54:04 -07:00
parent 4c579cbd0d
commit eb7270d354
2 changed files with 47 additions and 40 deletions

View File

@ -12,10 +12,8 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256" "crypto/sha256"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -123,7 +121,7 @@ func Decode(payload string) (c *ClaimSet, err error) {
} }
// Encode encodes a signed JWS with provided header and claim set. // Encode encodes a signed JWS with provided header and claim set.
func Encode(header *Header, c *ClaimSet, signature []byte) (payload string, err error) { func Encode(header *Header, c *ClaimSet, signature *rsa.PrivateKey) (payload string, err error) {
var encodedHeader, encodedClaimSet string var encodedHeader, encodedClaimSet string
encodedHeader, err = header.encode() encodedHeader, err = header.encode()
if err != nil { if err != nil {
@ -135,14 +133,9 @@ func Encode(header *Header, c *ClaimSet, signature []byte) (payload string, err
} }
ss := fmt.Sprintf("%s.%s", encodedHeader, encodedClaimSet) ss := fmt.Sprintf("%s.%s", encodedHeader, encodedClaimSet)
parsed, err := parsePrivateKey(signature)
if err != nil {
return
}
h := sha256.New() h := sha256.New()
h.Write([]byte(ss)) h.Write([]byte(ss))
b, err := rsa.SignPKCS1v15(rand.Reader, parsed, crypto.SHA256, h.Sum(nil)) b, err := rsa.SignPKCS1v15(rand.Reader, signature, crypto.SHA256, h.Sum(nil))
if err != nil { if err != nil {
return return
} }
@ -168,26 +161,3 @@ func base64Decode(s string) ([]byte, error) {
} }
return base64.URLEncoding.DecodeString(s) return base64.URLEncoding.DecodeString(s)
} }
// parsePrivateKey parses the key to extract the private key.
// It returns an error if private key is not provided or the
// provided key is invalid.
func parsePrivateKey(key []byte) (*rsa.PrivateKey, error) {
invalidPrivateKeyErr := errors.New("Private key is invalid.")
block, _ := pem.Decode(key)
if block == nil {
return nil, invalidPrivateKeyErr
}
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
}
parsed, ok := parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, invalidPrivateKeyErr
}
return parsed, nil
}

49
jwt.go
View File

@ -5,7 +5,10 @@
package oauth2 package oauth2
import ( import (
"crypto/rsa"
"crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem"
"errors" "errors"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -28,7 +31,13 @@ type JWTOptions struct {
// the configured OAuth provider. // the configured OAuth provider.
Email string `json:"email"` Email string `json:"email"`
// The path to the pem file. If you have a p12 file instead, you // Private key to sign JWS payloads.
PrivateKey *rsa.PrivateKey `json:"-"`
// The path to a pem container that includes your private key.
// If PrivateKey is set, this field is ignored.
//
// If you have a p12 file instead, you
// can use `openssl` to export the private key into a pem file. // can use `openssl` to export the private key into a pem file.
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes // $ openssl pkcs12 -in key.p12 -out key.pem -nodes
// Pem file should contain your private key. // Pem file should contain your private key.
@ -38,8 +47,6 @@ type JWTOptions struct {
Scopes []string `json:"scopes"` Scopes []string `json:"scopes"`
} }
// TODO(jbd): Add p12 support.
// NewJWTConfig creates a new configuration with the specified options // NewJWTConfig creates a new configuration with the specified options
// and OAuth2 provider endpoint. // and OAuth2 provider endpoint.
func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) { func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) {
@ -47,11 +54,18 @@ func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if opts.PrivateKey != nil {
return &JWTConfig{opts: opts, aud: audURL, key: opts.PrivateKey}, nil
}
contents, err := ioutil.ReadFile(opts.PemFilename) contents, err := ioutil.ReadFile(opts.PemFilename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &JWTConfig{opts: opts, aud: audURL, signature: contents}, nil parsedKey, err := parsePemKey(contents)
if err != nil {
return nil, err
}
return &JWTConfig{opts: opts, aud: audURL, key: parsedKey}, nil
} }
// JWTConfig represents an OAuth 2.0 provider and client options to // JWTConfig represents an OAuth 2.0 provider and client options to
@ -59,7 +73,7 @@ func NewJWTConfig(opts *JWTOptions, aud string) (*JWTConfig, error) {
type JWTConfig struct { type JWTConfig struct {
opts *JWTOptions opts *JWTOptions
aud *url.URL aud *url.URL
signature []byte key *rsa.PrivateKey
} }
// NewTransport creates a transport that is authorize with the // NewTransport creates a transport that is authorize with the
@ -94,7 +108,7 @@ func (c *JWTConfig) FetchToken(existing *Token) (token *Token, err error) {
claimSet.Prn = existing.Subject claimSet.Prn = existing.Subject
} }
payload, err := jws.Encode(defaultHeader, claimSet, c.signature) payload, err := jws.Encode(defaultHeader, claimSet, c.key)
if err != nil { if err != nil {
return return
} }
@ -140,3 +154,26 @@ func (c *JWTConfig) FetchToken(existing *Token) (token *Token, err error) {
token.Expiry = time.Now().Add(time.Duration(b.ExpiresIn) * time.Second) token.Expiry = time.Now().Add(time.Duration(b.ExpiresIn) * time.Second)
return return
} }
// parsePemKey parses the pem file to extract the private key.
// It returns an error if private key is not provided or the
// provided key is invalid.
func parsePemKey(key []byte) (*rsa.PrivateKey, error) {
invalidPrivateKeyErr := errors.New("oauth2: private key is invalid")
block, _ := pem.Decode(key)
if block == nil {
return nil, invalidPrivateKeyErr
}
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
}
parsed, ok := parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, invalidPrivateKeyErr
}
return parsed, nil
}