2014-05-17 11:26:57 -04:00
|
|
|
// Copyright 2014 The oauth2 Authors. All rights reserved.
|
|
|
|
// Use of this source code is governed by a BSD-style
|
|
|
|
// license that can be found in the LICENSE file.
|
2014-05-13 14:06:46 -04:00
|
|
|
|
2014-05-05 17:54:23 -04:00
|
|
|
package oauth2
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/json"
|
2014-09-01 00:39:59 -04:00
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
2014-05-05 17:54:23 -04:00
|
|
|
"net/http"
|
|
|
|
"net/url"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
|
2014-11-26 14:44:45 -05:00
|
|
|
"golang.org/x/oauth2/internal"
|
|
|
|
"golang.org/x/oauth2/jws"
|
2014-05-05 17:54:23 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
|
|
|
|
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
|
|
|
|
)
|
|
|
|
|
2014-11-06 19:36:41 -05:00
|
|
|
// JWTClient requires OAuth 2.0 JWT credentials.
|
|
|
|
// Required for the 2-legged JWT flow.
|
|
|
|
func JWTClient(email string, key []byte) Option {
|
|
|
|
return func(o *Options) error {
|
|
|
|
pk, err := internal.ParseKey(key)
|
2014-05-05 17:54:23 -04:00
|
|
|
if err != nil {
|
2014-11-06 19:36:41 -05:00
|
|
|
return err
|
2014-05-05 17:54:23 -04:00
|
|
|
}
|
2014-11-06 19:36:41 -05:00
|
|
|
o.Email = email
|
|
|
|
o.PrivateKey = pk
|
|
|
|
return nil
|
2014-05-05 17:54:23 -04:00
|
|
|
}
|
|
|
|
}
|
2014-08-11 20:54:04 -04:00
|
|
|
|
2014-11-06 19:36:41 -05:00
|
|
|
// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider.
|
|
|
|
func JWTEndpoint(aud string) Option {
|
|
|
|
return func(o *Options) error {
|
|
|
|
au, err := url.Parse(aud)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
o.AUD = au
|
|
|
|
return nil
|
2014-09-02 17:06:51 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-11-06 19:36:41 -05:00
|
|
|
// Subject requires a user to impersonate.
|
|
|
|
// Optional.
|
|
|
|
func Subject(user string) Option {
|
|
|
|
return func(o *Options) error {
|
|
|
|
o.Subject = user
|
|
|
|
return nil
|
2014-09-02 17:06:51 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-11-06 19:36:41 -05:00
|
|
|
func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
|
|
|
return func(t *Token) (*Token, error) {
|
|
|
|
if t == nil {
|
|
|
|
t = &Token{}
|
|
|
|
}
|
|
|
|
claimSet := &jws.ClaimSet{
|
|
|
|
Iss: o.Email,
|
|
|
|
Scope: strings.Join(o.Scopes, " "),
|
|
|
|
Aud: o.AUD.String(),
|
|
|
|
}
|
|
|
|
if o.Subject != "" {
|
|
|
|
claimSet.Sub = o.Subject
|
|
|
|
// prn is the old name of sub. Keep setting it
|
|
|
|
// to be compatible with legacy OAuth 2.0 providers.
|
|
|
|
claimSet.Prn = o.Subject
|
|
|
|
}
|
|
|
|
payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey)
|
2014-08-11 20:54:04 -04:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2014-11-06 19:36:41 -05:00
|
|
|
v := url.Values{}
|
|
|
|
v.Set("grant_type", defaultGrantType)
|
|
|
|
v.Set("assertion", payload)
|
|
|
|
c := o.Client
|
|
|
|
if c == nil {
|
|
|
|
c = &http.Client{}
|
|
|
|
}
|
|
|
|
resp, err := c.PostForm(o.AUD.String(), v)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
|
|
|
}
|
|
|
|
if c := resp.StatusCode; c < 200 || c > 299 {
|
|
|
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
|
|
|
|
}
|
|
|
|
b := make(map[string]interface{})
|
|
|
|
if err := json.Unmarshal(body, &b); err != nil {
|
|
|
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
|
|
|
}
|
|
|
|
token := &Token{}
|
|
|
|
token.AccessToken, _ = b["access_token"].(string)
|
|
|
|
token.TokenType, _ = b["token_type"].(string)
|
|
|
|
token.raw = b
|
|
|
|
if e, ok := b["expires_in"].(int); ok {
|
|
|
|
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
|
|
|
|
}
|
|
|
|
if idtoken, ok := b["id_token"].(string); ok {
|
|
|
|
// decode returned id token to get expiry
|
|
|
|
claimSet, err := jws.Decode(idtoken)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
|
|
|
}
|
|
|
|
token.Expiry = time.Unix(claimSet.Exp, 0)
|
|
|
|
return token, nil
|
|
|
|
}
|
|
|
|
return token, nil
|
2014-08-11 20:54:04 -04:00
|
|
|
}
|
|
|
|
}
|