Capstone/ui/db/token_source.go

153 lines
3.9 KiB
Go

package db
import (
"context"
"errors"
"time"
"git.preston-baxter.com/Preston_PLB/capstone/frontend-service/config"
"git.preston-baxter.com/Preston_PLB/capstone/frontend-service/db/models"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"golang.org/x/oauth2"
)
type VendorTokenSource struct {
db *DB
vendor *models.VendorAccount
}
func (db *DB) NewVendorTokenSource(vendor *models.VendorAccount) *VendorTokenSource {
return &VendorTokenSource{db: db, vendor: vendor}
}
// Returns refreshed token from vendor account
// Not threadsafe, please wrap in a oauth2.RefreshToken
func (ts *VendorTokenSource) Token() (*oauth2.Token, error) {
conf := config.Config()
//get locking collection
col := ts.db.client.Database(conf.Mongo.LockDb).Collection(conf.Mongo.LockCol)
//Define lock
token_lock := &models.TokenLock{
VendorId: ts.vendor.MongoId(),
TokenId: ts.vendor.OauthCredentials.AccessToken,
Refreshed: false,
}
//Don't forget to create the mongo id
token_lock.MongoId()
token_lock.UpdateObjectInfo()
//try and aquire lock
opts := options.InsertOne()
_, err := col.InsertOne(context.Background(), token_lock, opts)
if err != nil {
//If we didn't get the lock. Wait until whoever did refreshed the token
if mongo.IsDuplicateKeyError(err) {
err = ts.waitForToken(token_lock)
if err != nil {
return nil, err
}
//get new vendorAccount
va, err := ts.db.FindVendorAccountById(ts.vendor.MongoId())
if err != nil {
return nil, err
}
//re-assign vendor account. Let go garbage collector handle the rest
ts.vendor = va
return ts.vendor.Token(), nil
}
//other error return nil
return nil, err
}
//Refresh token we have the lock
token, err := oauth2.RefreshToken(context.Background(), conf.Vendors[ts.vendor.Name].OauthConfig(), ts.vendor.OauthCredentials.RefreshToken)
if err != nil {
return token, err
}
//update vendor
ts.vendor.OauthCredentials.AccessToken = token.AccessToken
//save vendor
err = ts.db.SaveModel(ts.vendor)
if err != nil {
return nil, err
}
//release lock
updateOpts := options.Update()
_, err = col.UpdateByID(context.Background(), token_lock.MongoId(), bson.M{"$set": bson.M{"refreshed": true}}, updateOpts)
if err != nil {
return nil, err
}
return token, nil
}
// Allow us to check for kind of error at the end
var TokenWaitExpired error = errors.New("Waiting for token to refresh took too long")
// Used to extract the token lock that was updated when the change stream alerts
type tokenLockChangeEvent struct {
TokenLock *models.TokenLock `bson:"fullDocument"`
}
func (ts *VendorTokenSource) waitForToken(tl *models.TokenLock) error {
conf := config.Config()
//get locking collection
col := ts.db.client.Database(conf.Mongo.LockDb).Collection(conf.Mongo.LockCol)
//Define timeoutfunction
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFunc()
//Get change stream that looks for our token lock
opts := options.ChangeStream().SetFullDocument(options.UpdateLookup)
changeStream, err := col.Watch(ctx, mongo.Pipeline{
{{Key: "$match", Value: bson.D{{Key: "fullDocument.vendor_id", Value: tl.VendorId}, {Key: "fullDocument.token_id", Value: tl.TokenId}}}},
}, opts)
if err != nil {
return err
}
defer changeStream.Close(context.Background())
for changeStream.Next(ctx) {
var tl_event tokenLockChangeEvent
err := changeStream.Decode(&tl_event)
if err != nil {
return err
}
if tl_event.TokenLock.Refreshed {
return nil
}
}
//We waited to long check if its refreshed and carry on
res := col.FindOne(context.Background(), bson.M{"token_id": tl.TokenId})
if res.Err() != nil {
return errors.Join(TokenWaitExpired, res.Err())
}
err = res.Decode(tl)
if err != nil {
return errors.Join(TokenWaitExpired, res.Err())
}
if tl.Refreshed {
return nil
}
return TokenWaitExpired
}