diff --git a/.changeset/beige-ligers-yawn.md b/.changeset/beige-ligers-yawn.md new file mode 100644 index 0000000000..221425b09f --- /dev/null +++ b/.changeset/beige-ligers-yawn.md @@ -0,0 +1,6 @@ +--- +"@medusajs/orchestration": patch +"@medusajs/workflows-sdk": patch +--- + +fix(): Workflow cancellation + gracefully handle non serializable state diff --git a/packages/core/orchestration/src/__tests__/transaction/transaction-orchestrator.ts b/packages/core/orchestration/src/__tests__/transaction/transaction-orchestrator.ts index f88a45d70a..0cfa7c1e23 100644 --- a/packages/core/orchestration/src/__tests__/transaction/transaction-orchestrator.ts +++ b/packages/core/orchestration/src/__tests__/transaction/transaction-orchestrator.ts @@ -1,6 +1,7 @@ import { TransactionStepState, TransactionStepStatus } from "@medusajs/utils" import { setTimeout } from "timers/promises" import { + DistributedTransaction, DistributedTransactionType, TransactionHandlerType, TransactionOrchestrator, @@ -10,6 +11,7 @@ import { TransactionStepTimeoutError, TransactionTimeoutError, } from "../../transaction" +import { BaseInMemoryDistributedTransactionStorage } from "../../transaction/datastore/base-in-memory-storage" describe("Transaction Orchestrator", () => { afterEach(() => { @@ -151,6 +153,104 @@ describe("Transaction Orchestrator", () => { expect(actionOrder).toEqual(["one", "two", "three", "four", "five", "six"]) }) + it("Should gracefully handle non serializable error when an async step fails", async () => { + class BaseInMemoryDistributedTransactionStorage_ extends BaseInMemoryDistributedTransactionStorage { + scheduleRetry() { + return Promise.resolve() + } + } + DistributedTransaction.setStorage( + new BaseInMemoryDistributedTransactionStorage_() + ) + + const actionOrder: string[] = [] + async function handler( + actionId: string, + functionHandlerType: TransactionHandlerType, + payload: TransactionPayload + ) { + if (functionHandlerType === TransactionHandlerType.INVOKE) { + actionOrder.push(actionId) + } + + if ( + functionHandlerType === TransactionHandlerType.INVOKE && + actionId === "three" + ) { + const error = new Error("Step 3 failed") + + const obj: any = {} + obj.self = obj + ;(error as any).metadata = obj + throw error + } + } + + const flow: TransactionStepsDefinition = { + next: [ + { + action: "one", + }, + { + action: "two", + next: { + action: "four", + next: { + action: "six", + }, + }, + }, + { + action: "three", + async: true, + maxRetries: 0, + next: { + action: "five", + }, + }, + ], + } + + const strategy = new TransactionOrchestrator({ + id: "transaction-name", + definition: flow, + }) + + const transaction = await strategy.beginTransaction( + "transaction_id_123", + handler + ) + + await strategy.resume(transaction) + + expect(transaction.getErrors()).toHaveLength(2) + expect(transaction.getErrors()).toEqual([ + { + action: "three", + error: { + message: "Step 3 failed", + name: "Error", + stack: expect.any(String), + }, + handlerType: "invoke", + }, + { + action: "three", + error: expect.objectContaining({ + message: expect.stringContaining( + "Converting circular structure to JSON" + ), + stack: expect.any(String), + }), + handlerType: "invoke", + }, + ]) + + DistributedTransaction.setStorage( + new BaseInMemoryDistributedTransactionStorage() + ) + }) + it("Should not execute next steps when a step fails", async () => { const actionOrder: string[] = [] async function handler( diff --git a/packages/core/orchestration/src/transaction/distributed-transaction.ts b/packages/core/orchestration/src/transaction/distributed-transaction.ts index 0713ec94a7..8c84add8b8 100644 --- a/packages/core/orchestration/src/transaction/distributed-transaction.ts +++ b/packages/core/orchestration/src/transaction/distributed-transaction.ts @@ -9,6 +9,7 @@ import { TransactionHandlerType, TransactionState, } from "./types" +import { NonSerializableCheckPointError } from "./errors" /** * @typedef TransactionMetadata @@ -204,19 +205,14 @@ class DistributedTransaction extends EventEmitter { return } - const data = new TransactionCheckpoint( - this.getFlow(), - this.getContext(), - this.getErrors() - ) - const key = TransactionOrchestrator.getKeyName( DistributedTransaction.keyPrefix, this.modelId, this.transactionId ) - const rawData = JSON.parse(JSON.stringify(data)) + const rawData = this.#serializeCheckpointData() + await DistributedTransaction.keyValueStore.save(key, rawData, ttl, options) return rawData @@ -320,6 +316,76 @@ class DistributedTransaction extends EventEmitter { public hasTemporaryData(key: string) { return this.#temporaryStorage.has(key) } + + /** + * Try to serialize the checkpoint data + * If it fails, it means that the context or the errors are not serializable + * and we should handle it + * + * @internal + * @returns + */ + #serializeCheckpointData() { + const data = new TransactionCheckpoint( + this.getFlow(), + this.getContext(), + this.getErrors() + ) + + const isSerializable = (obj) => { + try { + JSON.parse(JSON.stringify(obj)) + return true + } catch { + return false + } + } + + let rawData + try { + rawData = JSON.parse(JSON.stringify(data)) + } catch (e) { + if (!isSerializable(this.context)) { + // This is a safe guard, we should never reach this point + // If we do, it means that the context is not serializable + // and we should throw an error + throw new NonSerializableCheckPointError( + "Unable to serialize context object. Please make sure the workflow input and steps response are serializable." + ) + } + + if (!isSerializable(this.errors)) { + const nonSerializableErrors: TransactionStepError[] = [] + for (const error of this.errors) { + if (!isSerializable(error.error)) { + error.error = { + name: error.error.name, + message: error.error.message, + stack: error.error.stack, + } + nonSerializableErrors.push({ + ...error, + error: e, + }) + } + } + + if (nonSerializableErrors.length) { + this.errors.push(...nonSerializableErrors) + } + } + + const data = new TransactionCheckpoint( + this.getFlow(), + this.getContext(), + this.getErrors() + ) + + rawData = JSON.parse(JSON.stringify(data)) + } + + return rawData + } } DistributedTransaction.setStorage( diff --git a/packages/core/orchestration/src/transaction/errors.ts b/packages/core/orchestration/src/transaction/errors.ts index 9fe72059bf..8a3ff7bb7f 100644 --- a/packages/core/orchestration/src/transaction/errors.ts +++ b/packages/core/orchestration/src/transaction/errors.ts @@ -68,3 +68,19 @@ export class TransactionTimeoutError extends BaseStepErrror { super("TransactionTimeoutError", message, stepResponse) } } + +export class NonSerializableCheckPointError extends Error { + static isNonSerializableCheckPointError( + error: Error + ): error is NonSerializableCheckPointError { + return ( + error instanceof NonSerializableCheckPointError || + error?.name === "NonSerializableCheckPointError" + ) + } + + constructor(message?: string) { + super(message) + this.name = "NonSerializableCheckPointError" + } +} diff --git a/packages/core/workflows-sdk/src/utils/composer/__tests__/index.spec.ts b/packages/core/workflows-sdk/src/utils/composer/__tests__/index.spec.ts index e44765239b..4720f818d6 100644 --- a/packages/core/workflows-sdk/src/utils/composer/__tests__/index.spec.ts +++ b/packages/core/workflows-sdk/src/utils/composer/__tests__/index.spec.ts @@ -1,3 +1,4 @@ +import { TransactionState } from "@medusajs/utils" import { createStep } from "../create-step" import { createWorkflow } from "../create-workflow" import { StepResponse } from "../helpers" @@ -42,6 +43,44 @@ describe("Workflow composer", () => { expect(result).toEqual({ result: "hi from outside" }) }) + it("should cancel transaction on failed sub workflow call", async function () { + const step1 = createStep("step1", async (_, context) => { + return new StepResponse("step1") + }) + + const step2 = createStep("step2", async (input: string, context) => { + return new StepResponse({ result: input }) + }) + const step3 = createStep("step3", async (input: string, context) => { + throw new Error("I have failed") + }) + + const subWorkflow = createWorkflow( + getNewWorkflowId(), + function (input: WorkflowData) { + step1() + return new WorkflowResponse(step2(input)) + } + ) + + const workflow = createWorkflow(getNewWorkflowId(), function () { + const subWorkflowRes = subWorkflow.runAsStep({ + input: "hi from outside", + }) + return new WorkflowResponse(step3(subWorkflowRes.result)) + }) + + const { errors, transaction } = await workflow.run({ + input: {}, + throwOnError: false, + }) + + expect(errors).toHaveLength(1) + expect(errors[0].error.message).toEqual("I have failed") + + expect(transaction.getState()).toEqual(TransactionState.REVERTED) + }) + it("should skip step if condition is false", async function () { const step1 = createStep("step1", async (_, context) => { return new StepResponse({ result: "step1" }) diff --git a/packages/core/workflows-sdk/src/utils/composer/create-workflow.ts b/packages/core/workflows-sdk/src/utils/composer/create-workflow.ts index ef5a158601..329b79ba78 100644 --- a/packages/core/workflows-sdk/src/utils/composer/create-workflow.ts +++ b/packages/core/workflows-sdk/src/utils/composer/create-workflow.ts @@ -10,7 +10,7 @@ import { OrchestrationUtils, } from "@medusajs/utils" import { ulid } from "ulid" -import { exportWorkflow } from "../../helper" +import { exportWorkflow, WorkflowResult } from "../../helper" import { createStep } from "./create-step" import { proxify } from "./helpers/proxy" import { StepResponse } from "./helpers/step-response" @@ -201,20 +201,29 @@ export function createWorkflow( }, }) - const { result, transaction: flowTransaction } = transaction + const { result } = transaction - if (!context.isAsync || flowTransaction.hasFinished()) { - return new StepResponse(result, transaction) - } - - return + return new StepResponse( + result, + context.isAsync ? stepContext.transactionId : transaction + ) }, - async (transaction, { container }) => { + async (transaction, stepContext) => { if (!transaction) { return } - await workflow(container).cancel(transaction) + const { container, ...sharedContext } = stepContext + + await workflow(container).cancel({ + transaction: (transaction as WorkflowResult).transaction, + transactionId: isString(transaction) ? transaction : undefined, + container, + context: { + ...sharedContext, + parentStepIdempotencyKey: stepContext.idempotencyKey, + }, + }) } )(input) as ReturnType> diff --git a/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/index.ts b/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/index.ts index 244f32c5c2..f3183ed070 100644 --- a/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/index.ts +++ b/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/index.ts @@ -4,3 +4,4 @@ export * from "./workflow_async" export * from "./workflow_step_timeout" export * from "./workflow_transaction_timeout" export * from "./workflow_when" +export * from "./workflow_async_compensate" diff --git a/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_async_compensate.ts b/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_async_compensate.ts new file mode 100644 index 0000000000..06850027af --- /dev/null +++ b/packages/modules/workflow-engine-redis/integration-tests/__fixtures__/workflow_async_compensate.ts @@ -0,0 +1,51 @@ +import { + createStep, + createWorkflow, + parallelize, + StepResponse, + WorkflowResponse, +} from "@medusajs/framework/workflows-sdk" + +const step_1_background = createStep( + { + name: "step_1_background_fail", + async: true, + }, + jest.fn(async (input) => { + return new StepResponse(input) + }) +) + +const nestedWorkflow = createWorkflow( + { + name: "nested_sub_flow_async_fail", + }, + function (input) { + const resp = step_1_background(input) + + return resp + } +) + +const step_2 = createStep( + { + name: "step_2_fail", + }, + jest.fn(async () => { + throw new Error("step_2_fail") + }) +) + +createWorkflow( + { + name: "workflow_async_background_fail", + }, + function (input) { + const ret = nestedWorkflow.runAsStep({ + input, + }) + + step_2() + return new WorkflowResponse(ret) + } +) diff --git a/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts b/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts index 00f60a894a..6d6f2353e1 100644 --- a/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts +++ b/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts @@ -1,5 +1,6 @@ import { DistributedTransactionType, + TransactionStep, TransactionStepTimeoutError, TransactionTimeoutError, WorkflowManager, @@ -473,6 +474,44 @@ moduleIntegrationTestRunner({ failTrap(done) }) + + it("should cancel an async sub workflow when compensating", (done) => { + const workflowId = "workflow_async_background_fail" + + void workflowOrcModule.run(workflowId, { + input: { + callSubFlow: true, + }, + transactionId: "trx_123_compensate_async_sub_workflow", + throwOnError: false, + logOnError: false, + }) + + let onCompensateStepSuccess: { step: TransactionStep } | null = null + + void workflowOrcModule.subscribe({ + workflowId, + subscriber: (event) => { + if (event.eventType === "onCompensateStepSuccess") { + onCompensateStepSuccess = event + } + if (event.eventType === "onFinish") { + expect(onCompensateStepSuccess).toBeDefined() + expect(onCompensateStepSuccess!.step.id).toEqual( + "_root.nested_sub_flow_async_fail-as-step" // The workflow as step + ) + expect(onCompensateStepSuccess!.step.compensate).toEqual({ + state: "reverted", + status: "ok", + }) + + done() + } + }, + }) + + failTrap(done) + }) }) // Note: These tests depend on actual Redis instance and waiting for the scheduled jobs to run, which isn't great.