diff --git a/service/controllers/controllers.go b/service/controllers/controllers.go index 4b3ec60..d590e72 100644 --- a/service/controllers/controllers.go +++ b/service/controllers/controllers.go @@ -35,6 +35,5 @@ func BuildRouter(r *gin.Engine) { pcoClientMap = make(map[primitive.ObjectID]*pco.PcoApiClient) pco := r.Group("/pco") - pco.Use(ValidatePcoWebhook) - pco.POST("/:userid", ConsumePcoWebhook) + pco.POST("/:userid", ValidatePcoWebhook, ConsumePcoWebhook) } diff --git a/service/controllers/pco_auth_middleware.go b/service/controllers/pco_auth_middleware.go index 0185c54..6eafe8e 100644 --- a/service/controllers/pco_auth_middleware.go +++ b/service/controllers/pco_auth_middleware.go @@ -1,22 +1,21 @@ package controllers import ( + "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "io" - "git.preston-baxter.com/Preston_PLB/capstone/frontend-service/config" - "git.preston-baxter.com/Preston_PLB/capstone/frontend-service/db/models" + "git.preston-baxter.com/Preston_PLB/capstone/webhook-service/vendors/pco/webhooks" "github.com/gin-gonic/gin" + "github.com/google/jsonapi" ) const PCO_VALIDATE_HEADER = "X-PCO-Webhooks-Authenticity" func ValidatePcoWebhook(c *gin.Context) { - conf := config.Config() - //get remote version from header remoteDigestStr := c.GetHeader(PCO_VALIDATE_HEADER) if remoteDigestStr == "" { @@ -26,18 +25,28 @@ func ValidatePcoWebhook(c *gin.Context) { } pcoSig := make([]byte, len(remoteDigestStr)/2) _, err := hex.Decode(pcoSig, []byte(remoteDigestStr)) + if err != nil { + log.WithError(err).Error("Failed to decode byte digest") + _ = c.AbortWithError(501, err) + return + } //clone request to harmlessly inspect the body bodyReader := c.Request.Clone(context.Background()).Body body, err := io.ReadAll(bodyReader) if err != nil { log.WithError(err).Error("Failed to read body while validating PCO webhook") - c.AbortWithError(501, err) + _ = c.AbortWithError(501, err) return } //Get secret - key := conf.Vendors[models.PCO_VENDOR_NAME].WebhookSecret + key, err := getAuthSecret(c, body) + if err != nil { + log.WithError(err).Error("Failed to find auth secret for event. It may not be setup") + _ = c.AbortWithError(501, err) + return + } //Get HMAC hmacSig := hmac.New(sha256.New, []byte(key)) @@ -48,3 +57,20 @@ func ValidatePcoWebhook(c *gin.Context) { c.AbortWithStatus(401) } } + +func getAuthSecret(c *gin.Context, body []byte) (string, error) { + userObjectId := userIdFromContext(c) + + event := &webhooks.EventDelivery{} + err := jsonapi.UnmarshalPayload(bytes.NewBuffer(body), event) + if err != nil { + return "", err + } + + webhook, err := mongo.FindPcoSubscriptionForUser(*userObjectId, event.Name) + if err != nil { + return "", err + } + + return webhook.Details.AuthenticitySecret, nil +} diff --git a/service/controllers/pco_webhook.go b/service/controllers/pco_webhook.go index 4444ed6..c7a7027 100644 --- a/service/controllers/pco_webhook.go +++ b/service/controllers/pco_webhook.go @@ -28,34 +28,47 @@ var ( type actionFunc func(*gin.Context, *webhooks.EventDelivery) error +func userIdFromContext(c *gin.Context) (*primitive.ObjectID) { + if id, ok := c.Get("user_bson_id"); !ok { + userId := c.Param("userid") + + if userId == "" { + log.Warn("Webhook did not contain user id. Rejecting") + c.AbortWithStatus(404) + return nil + } + + userObjectId, err := primitive.ObjectIDFromHex(userId) + if err != nil { + log.WithError(err).Warn("User Id was malformed") + c.AbortWithStatus(400) + return nil + } + c.Set("user_bson_id", userObjectId) + return &userObjectId + } else { + if objId, ok := id.(primitive.ObjectID); ok { + return &objId + } else { + return nil + } + } +} + func ConsumePcoWebhook(c *gin.Context) { - userId := c.Param("userid") - - if userId == "" { - log.Warn("Webhook did not contain user id. Rejecting") - c.AbortWithStatus(404) - return - } - - //get actions for user - userObjectId, err := primitive.ObjectIDFromHex(userId) - if err != nil { - log.WithError(err).Warn("User Id was malformed") - c.AbortWithStatus(400) - return - } - c.Set("user_bson_id", userObjectId) + userObjectId := userIdFromContext(c) //read body and handle io in parallel because IO shenanigains wg := new(sync.WaitGroup) wg.Add(2) + //get actions for user var actionMappings []models.ActionMapping var webhookBody *webhooks.EventDelivery errs := make([]error, 2) go func(wg *sync.WaitGroup) { - actionMappings, errs[0] = mongo.FindActionMappingsByUser(userObjectId) + actionMappings, errs[0] = mongo.FindActionMappingsByUser(*userObjectId) wg.Done() }(wg) @@ -81,7 +94,7 @@ func ConsumePcoWebhook(c *gin.Context) { actionKey := fmt.Sprintf("%s:%s", mapping.Action.VendorName, mapping.Action.Type) //if function exists run the function if action, ok := actionFuncMap[actionKey]; ok { - err = action(c, webhookBody) + err := action(c, webhookBody) //handle error if err != nil { log.WithError(err).Errorf("Failed to execute action: %s. From event source: %s:%s", actionKey, mapping.SourceEvent.VendorName, mapping.SourceEvent.Key)