2023-09-06 02:37:54 -04:00
package oauth2
import (
"context"
"encoding/json"
2023-09-07 01:25:12 -04:00
"errors"
2023-09-06 02:37:54 -04:00
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2/internal"
)
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
const (
errAuthorizationPending = "authorization_pending"
errSlowDown = "slow_down"
errAccessDenied = "access_denied"
errExpiredToken = "expired_token"
)
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
type DeviceAuthResponse struct {
// DeviceCode
DeviceCode string ` json:"device_code" `
// UserCode is the code the user should enter at the verification uri
UserCode string ` json:"user_code" `
// VerificationURI is where user should enter the user code
VerificationURI string ` json:"verification_uri" `
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
VerificationURIComplete string ` json:"verification_uri_complete,omitempty" `
// Expiry is when the device code and user code expire
Expiry time . Time ` json:"expires_in,omitempty" `
// Interval is the duration in seconds that Poll should wait between requests
Interval int64 ` json:"interval,omitempty" `
}
func ( d DeviceAuthResponse ) MarshalJSON ( ) ( [ ] byte , error ) {
type Alias DeviceAuthResponse
var expiresIn int64
if ! d . Expiry . IsZero ( ) {
expiresIn = int64 ( time . Until ( d . Expiry ) . Seconds ( ) )
}
return json . Marshal ( & struct {
ExpiresIn int64 ` json:"expires_in,omitempty" `
* Alias
} {
ExpiresIn : expiresIn ,
Alias : ( * Alias ) ( & d ) ,
} )
}
func ( c * DeviceAuthResponse ) UnmarshalJSON ( data [ ] byte ) error {
type Alias DeviceAuthResponse
aux := & struct {
ExpiresIn int64 ` json:"expires_in" `
* Alias
} {
Alias : ( * Alias ) ( c ) ,
}
if err := json . Unmarshal ( data , & aux ) ; err != nil {
return err
}
if aux . ExpiresIn != 0 {
c . Expiry = time . Now ( ) . UTC ( ) . Add ( time . Second * time . Duration ( aux . ExpiresIn ) )
}
return nil
}
// DeviceAuth returns a device auth struct which contains a device code
// and authorization information provided for users to enter on another device.
func ( c * Config ) DeviceAuth ( ctx context . Context , opts ... AuthCodeOption ) ( * DeviceAuthResponse , error ) {
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
v := url . Values {
"client_id" : { c . ClientID } ,
}
if len ( c . Scopes ) > 0 {
v . Set ( "scope" , strings . Join ( c . Scopes , " " ) )
}
for _ , opt := range opts {
opt . setValue ( v )
}
return retrieveDeviceAuth ( ctx , c , v )
}
func retrieveDeviceAuth ( ctx context . Context , c * Config , v url . Values ) ( * DeviceAuthResponse , error ) {
2023-09-07 01:25:12 -04:00
if c . Endpoint . DeviceAuthURL == "" {
return nil , errors . New ( "endpoint missing DeviceAuthURL" )
}
2023-09-06 02:37:54 -04:00
req , err := http . NewRequest ( "POST" , c . Endpoint . DeviceAuthURL , strings . NewReader ( v . Encode ( ) ) )
if err != nil {
return nil , err
}
req . Header . Set ( "Content-Type" , "application/x-www-form-urlencoded" )
req . Header . Set ( "Accept" , "application/json" )
t := time . Now ( )
r , err := internal . ContextClient ( ctx ) . Do ( req )
if err != nil {
return nil , err
}
body , err := io . ReadAll ( io . LimitReader ( r . Body , 1 << 20 ) )
if err != nil {
return nil , fmt . Errorf ( "oauth2: cannot auth device: %v" , err )
}
if code := r . StatusCode ; code < 200 || code > 299 {
return nil , & RetrieveError {
Response : r ,
Body : body ,
}
}
da := & DeviceAuthResponse { }
err = json . Unmarshal ( body , & da )
if err != nil {
return nil , fmt . Errorf ( "unmarshal %s" , err )
}
if ! da . Expiry . IsZero ( ) {
// Make a small adjustment to account for time taken by the request
da . Expiry = da . Expiry . Add ( - time . Since ( t ) )
}
return da , nil
}
// DeviceAccessToken polls the server to exchange a device code for a token.
func ( c * Config ) DeviceAccessToken ( ctx context . Context , da * DeviceAuthResponse , opts ... AuthCodeOption ) ( * Token , error ) {
if ! da . Expiry . IsZero ( ) {
var cancel context . CancelFunc
ctx , cancel = context . WithDeadline ( ctx , da . Expiry )
defer cancel ( )
}
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
v := url . Values {
"client_id" : { c . ClientID } ,
"grant_type" : { "urn:ietf:params:oauth:grant-type:device_code" } ,
"device_code" : { da . DeviceCode } ,
}
if len ( c . Scopes ) > 0 {
v . Set ( "scope" , strings . Join ( c . Scopes , " " ) )
}
for _ , opt := range opts {
opt . setValue ( v )
}
// "If no value is provided, clients MUST use 5 as the default."
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
interval := da . Interval
if interval == 0 {
interval = 5
}
ticker := time . NewTicker ( time . Duration ( interval ) * time . Second )
defer ticker . Stop ( )
for {
select {
case <- ctx . Done ( ) :
return nil , ctx . Err ( )
case <- ticker . C :
tok , err := retrieveToken ( ctx , c , v )
if err == nil {
return tok , nil
}
e , ok := err . ( * RetrieveError )
if ! ok {
return nil , err
}
switch e . ErrorCode {
case errSlowDown :
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
interval += 5
ticker . Reset ( time . Duration ( interval ) * time . Second )
case errAuthorizationPending :
// Do nothing.
case errAccessDenied , errExpiredToken :
fallthrough
default :
return tok , err
}
}
}
}