From c5d609d09cb29c6cf01d1c6c65305cc566f391c5 Mon Sep 17 00:00:00 2001 From: Adrien de Peretti Date: Wed, 16 Jul 2025 16:44:09 +0200 Subject: [PATCH] fix(orchestration): Prevent workf. cancellation to execute while rescheduling (#12903) **What** Currently, when cancelling async workflows, the step will get rescheduled while the current worker try to continue the execution leading to concurrency failure on compensation. This pr prevent the current worker from executing while an async step gets rescheduled Co-authored-by: Carlos R. L. Rodrigues <37986729+carlos-r-l-rodrigues@users.noreply.github.com> --- .changeset/spotty-files-wonder.md | 6 ++ .../transaction/transaction-orchestrator.ts | 24 +++++- .../src/__tests__/events.spec.ts | 71 ++++++++++++++--- packages/medusa-test-utils/src/events.ts | 59 ++++++++++---- .../wait-workflow-executions.ts | 2 +- .../integration-tests/__fixtures__/index.ts | 3 +- .../__fixtures__/workflow_parallel_async.ts | 76 ++++++++++++++++++ .../integration-tests/__tests__/index.spec.ts | 78 ++++++++++++++----- .../integration-tests/__tests__/index.spec.ts | 49 ++++++++---- 9 files changed, 305 insertions(+), 63 deletions(-) create mode 100644 .changeset/spotty-files-wonder.md create mode 100644 packages/modules/workflow-engine-inmemory/integration-tests/__fixtures__/workflow_parallel_async.ts diff --git a/.changeset/spotty-files-wonder.md b/.changeset/spotty-files-wonder.md new file mode 100644 index 0000000000..99e348fd3f --- /dev/null +++ b/.changeset/spotty-files-wonder.md @@ -0,0 +1,6 @@ +--- +"@medusajs/test-utils": patch +"@medusajs/orchestration": patch +--- + +fix(orchestration): Prevent workf. cancellation to execute while rescheduling diff --git a/packages/core/orchestration/src/transaction/transaction-orchestrator.ts b/packages/core/orchestration/src/transaction/transaction-orchestrator.ts index 26cd8aec32..870fae6ba5 100644 --- a/packages/core/orchestration/src/transaction/transaction-orchestrator.ts +++ b/packages/core/orchestration/src/transaction/transaction-orchestrator.ts @@ -1185,7 +1185,9 @@ export class TransactionOrchestrator extends EventEmitter { ) if (ret.transactionIsCancelling) { - return await this.cancelTransaction(transaction) + await this.cancelTransaction(transaction, { + preventExecuteNext: true, + }) } if (isAsync && !ret.stopExecution) { @@ -1204,6 +1206,10 @@ export class TransactionOrchestrator extends EventEmitter { isPermanent: boolean, response?: unknown ): Promise { + const isAsync = step.isCompensating() + ? step.definition.compensateAsync + : step.definition.async + if (isDefined(response) && step.saveResponse) { transaction.addResponse( step.definition.action!, @@ -1222,7 +1228,14 @@ export class TransactionOrchestrator extends EventEmitter { ) if (ret.transactionIsCancelling) { - return await this.cancelTransaction(transaction) + await this.cancelTransaction(transaction, { + preventExecuteNext: true, + }) + } + + if (isAsync && !ret.stopExecution) { + // Schedule to continue the execution of async steps because they are not awaited on purpose and can be handled by another machine + await transaction.scheduleRetry(step, 0) } } @@ -1287,7 +1300,8 @@ export class TransactionOrchestrator extends EventEmitter { * @param transaction - The transaction to be reverted */ public async cancelTransaction( - transaction: DistributedTransactionType + transaction: DistributedTransactionType, + options?: { preventExecuteNext?: boolean } ): Promise { if (transaction.modelId !== this.id) { throw new MedusaError( @@ -1319,6 +1333,10 @@ export class TransactionOrchestrator extends EventEmitter { await transaction.saveCheckpoint() + if (options?.preventExecuteNext) { + return + } + await this.executeNext(transaction) } diff --git a/packages/medusa-test-utils/src/__tests__/events.spec.ts b/packages/medusa-test-utils/src/__tests__/events.spec.ts index ea8a5a0580..15c788dd4d 100644 --- a/packages/medusa-test-utils/src/__tests__/events.spec.ts +++ b/packages/medusa-test-utils/src/__tests__/events.spec.ts @@ -1,5 +1,8 @@ import { EventEmitter } from "events" import { waitSubscribersExecution } from "../events" +import { setTimeout } from "timers/promises" + +jest.setTimeout(30000) // Mock the IEventBusModuleService class MockEventBus { @@ -31,11 +34,13 @@ describe("waitSubscribersExecution", () => { describe("with no existing listeners", () => { it("should resolve when event is fired before timeout", async () => { const waitPromise = waitSubscribersExecution(TEST_EVENT, eventBus as any) - setTimeout(() => eventBus.emit(TEST_EVENT, "test-data"), 100).unref() + await setTimeout(100) + eventBus.emit(TEST_EVENT, "test-data") jest.advanceTimersByTime(100) - await expect(waitPromise).resolves.toEqual(["test-data"]) + const res = await waitPromise + expect(res).toEqual(["test-data"]) }) it("should reject when timeout is reached before event is fired", async () => { @@ -70,12 +75,29 @@ describe("waitSubscribersExecution", () => { `Timeout of ${customTimeout}ms exceeded while waiting for event "${TEST_EVENT}"` ) }) + + it("should resolve when event is fired multiple times", async () => { + const waitPromise = waitSubscribersExecution( + TEST_EVENT, + eventBus as any, + { triggerCount: 2 } + ) + eventBus.emit(TEST_EVENT, "test-data") + eventBus.emit(TEST_EVENT, "test-data") + + const promisesRes = await waitPromise + const res = promisesRes.pop() + expect(res).toHaveLength(2) + expect(res[0]).toEqual(["test-data"]) + expect(res[1]).toEqual(["test-data"]) + }) }) describe("with existing listeners", () => { it("should resolve when all listeners complete successfully", async () => { - const listener = jest.fn().mockImplementation(() => { - return new Promise((resolve) => setTimeout(resolve, 200).unref()) + const listener = jest.fn().mockImplementation(async () => { + await setTimeout(200) + return "res" }) eventBus.eventEmitter_.on(TEST_EVENT, listener) @@ -132,20 +154,49 @@ describe("waitSubscribersExecution", () => { expect(listener).not.toHaveBeenCalled() }) + + it("should resolve when event is fired multiple times", async () => { + const listener = jest.fn().mockImplementation(async () => { + await setTimeout(200) + return "res" + }) + + eventBus.eventEmitter_.on(TEST_EVENT, listener) + + const waitPromise = waitSubscribersExecution( + TEST_EVENT, + eventBus as any, + { + triggerCount: 2, + } + ) + + eventBus.emit(TEST_EVENT, "test-data") + eventBus.emit(TEST_EVENT, "test-data") + + const promisesRes = await waitPromise + const res = promisesRes.pop() + expect(res).toHaveLength(2) + expect(res[0]).toEqual("res") + expect(res[1]).toEqual("res") + }) }) describe("with multiple listeners", () => { it("should resolve when all listeners complete", async () => { - const listener1 = jest.fn().mockImplementation(() => { - return new Promise((resolve) => setTimeout(resolve, 100).unref()) + const listener1 = jest.fn().mockImplementation(async () => { + await setTimeout(100) + return "res" }) - const listener2 = jest.fn().mockImplementation(() => { - return new Promise((resolve) => setTimeout(resolve, 200).unref()) + const listener2 = jest.fn().mockImplementation(async () => { + await setTimeout(200) + return "res" }) - const listener3 = jest.fn().mockImplementation(() => { - return new Promise((resolve) => setTimeout(resolve, 300).unref()) + const listener3 = jest.fn().mockImplementation(async () => { + await setTimeout(300) + return "res" }) eventBus.eventEmitter_.on(TEST_EVENT, listener1) diff --git a/packages/medusa-test-utils/src/events.ts b/packages/medusa-test-utils/src/events.ts index 1e7be09956..4b9cf9e04b 100644 --- a/packages/medusa-test-utils/src/events.ts +++ b/packages/medusa-test-utils/src/events.ts @@ -5,7 +5,10 @@ type EventBus = { } type WaitSubscribersExecutionOptions = { + /** Timeout in milliseconds for waiting for the event. Defaults to 15000ms. */ timeout?: number + /** Number of times the event should be triggered before resolving. Defaults to 1. */ + triggerCount?: number } // Map to hold pending promises for each event. @@ -41,7 +44,7 @@ const createTimeoutPromise = ( const doWaitSubscribersExecution = ( eventName: string | symbol, eventBus: EventBus, - { timeout = 15000 }: WaitSubscribersExecutionOptions = {} + { timeout = 15000, triggerCount = 1 }: WaitSubscribersExecutionOptions = {} ): Promise => { const eventEmitter = eventBus.eventEmitter_ const subscriberPromises: Promise[] = [] @@ -50,6 +53,8 @@ const doWaitSubscribersExecution = ( eventName ) + let currentCount = 0 + if (!eventEmitter.listeners(eventName).length) { let ok: (value?: any) => void const promise = new Promise((resolve) => { @@ -57,9 +62,19 @@ const doWaitSubscribersExecution = ( }) subscriberPromises.push(promise) + let res: any[] = [] const newListener = async (...args: any[]) => { - eventEmitter.removeListener(eventName, newListener) - ok(...args) + currentCount++ + res.push(args) + + if (currentCount >= triggerCount) { + eventEmitter.removeListener(eventName, newListener) + if (triggerCount === 1) { + ok(...args) + } else { + ok(res) + } + } } Object.defineProperty(newListener, "__isSubscribersExecutionWrapper", { @@ -83,22 +98,38 @@ const doWaitSubscribersExecution = ( nok = reject }) subscriberPromises.push(promise) + let res: any[] = [] const newListener = async (...args2: any[]) => { - // As soon as the subscriber is executed, we restore the original listener - eventEmitter.removeListener(eventName, newListener) - let listenerToAdd = listener - while (listenerToAdd.originalListener) { - listenerToAdd = listenerToAdd.originalListener - } - eventEmitter.on(eventName, listenerToAdd) - try { - const res = await listener.apply(eventBus, args2) - ok(res) + const listenerRes = listener.apply(eventBus, args2) + if (typeof listenerRes?.then === "function") { + await listenerRes.then((res_) => { + res.push(res_) + currentCount++ + }) + } else { + res.push(listenerRes) + currentCount++ + } + + if (currentCount >= triggerCount) { + const res_ = triggerCount === 1 ? res[0] : res + ok(res_) + } } catch (error) { nok(error) } + + if (currentCount >= triggerCount) { + // As soon as the subscriber is executed the required number of times, we restore the original listener + eventEmitter.removeListener(eventName, newListener) + let listenerToAdd = listener + while (listenerToAdd.originalListener) { + listenerToAdd = listenerToAdd.originalListener + } + eventEmitter.on(eventName, listenerToAdd) + } } Object.defineProperty(newListener, "__isSubscribersExecutionWrapper", { @@ -130,7 +161,7 @@ const doWaitSubscribersExecution = ( * * @param eventName - The name of the event to wait for. * @param eventBus - The event bus instance. - * @param options - Options including timeout. + * @param options - Options including timeout and triggerCount. */ export const waitSubscribersExecution = ( eventName: string | symbol, diff --git a/packages/medusa-test-utils/src/medusa-test-runner-utils/wait-workflow-executions.ts b/packages/medusa-test-utils/src/medusa-test-runner-utils/wait-workflow-executions.ts index 8d425f87af..5d2406567c 100644 --- a/packages/medusa-test-utils/src/medusa-test-runner-utils/wait-workflow-executions.ts +++ b/packages/medusa-test-utils/src/medusa-test-runner-utils/wait-workflow-executions.ts @@ -17,7 +17,7 @@ export async function waitWorkflowExecutions(container: MedusaContainer) { const timeout = setTimeout(() => { throw new Error("Timeout waiting for workflow executions to finish") - }, 10000).unref() + }, 60000).unref() let waitWorkflowsToFinish = true while (waitWorkflowsToFinish) { diff --git a/packages/modules/workflow-engine-inmemory/integration-tests/__fixtures__/index.ts b/packages/modules/workflow-engine-inmemory/integration-tests/__fixtures__/index.ts index 04635d6d64..7f581c011a 100644 --- a/packages/modules/workflow-engine-inmemory/integration-tests/__fixtures__/index.ts +++ b/packages/modules/workflow-engine-inmemory/integration-tests/__fixtures__/index.ts @@ -3,7 +3,8 @@ export * from "./workflow_2" export * from "./workflow_async" export * from "./workflow_conditional_step" export * from "./workflow_idempotent" +export * from "./workflow_not_idempotent_with_retention" +export * from "./workflow_parallel_async" export * from "./workflow_step_timeout" export * from "./workflow_sync" export * from "./workflow_transaction_timeout" -export * from "./workflow_not_idempotent_with_retention" diff --git a/packages/modules/workflow-engine-inmemory/integration-tests/__fixtures__/workflow_parallel_async.ts b/packages/modules/workflow-engine-inmemory/integration-tests/__fixtures__/workflow_parallel_async.ts new file mode 100644 index 0000000000..0658ba939a --- /dev/null +++ b/packages/modules/workflow-engine-inmemory/integration-tests/__fixtures__/workflow_parallel_async.ts @@ -0,0 +1,76 @@ +import { Modules } from "@medusajs/framework/utils" +import { + createStep, + createWorkflow, + parallelize, + StepResponse, +} from "@medusajs/framework/workflows-sdk" + +const step_2 = createStep( + { + name: "step_2", + async: true, + }, + async (_, { container }) => { + const we = container.resolve(Modules.WORKFLOW_ENGINE) + + await we.run("workflow_sub_workflow", { + throwOnError: true, + }) + } +) + +const parallelStep2Invoke = jest.fn(() => { + throw new Error("Error in parallel step") +}) +const step_2_sub = createStep( + { + name: "step_2", + async: true, + }, + parallelStep2Invoke +) + +const subFlow = createWorkflow( + { + name: "workflow_sub_workflow", + retentionTime: 1000, + }, + function (input) { + step_2_sub() + } +) + +const step_1 = createStep( + { + name: "step_1", + async: true, + }, + jest.fn(() => { + return new StepResponse("step_1") + }) +) + +const parallelStep3Invoke = jest.fn(() => { + return new StepResponse({ + done: true, + }) +}) + +const step_3 = createStep( + { + name: "step_3", + async: true, + }, + parallelStep3Invoke +) + +createWorkflow( + { + name: "workflow_parallel_async", + retentionTime: 1000, + }, + function (input) { + parallelize(step_1(), step_2(), step_3()) + } +) diff --git a/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts b/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts index 20e7094fd4..f1f9aa8ef4 100644 --- a/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts +++ b/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts @@ -1,3 +1,4 @@ +import { MedusaContainer } from "@medusajs/framework" import { DistributedTransactionType, TransactionState, @@ -43,7 +44,6 @@ import { workflowEventGroupIdStep2Mock, } from "../__fixtures__/workflow_event_group_id" import { createScheduled } from "../__fixtures__/workflow_scheduled" -import { container, MedusaContainer } from "@medusajs/framework" jest.setTimeout(60000) @@ -143,7 +143,7 @@ moduleIntegrationTestRunner({ }) describe("Cancel transaction", function () { - it("should cancel an ongoing execution with async unfinished yet step", async () => { + it("should cancel an ongoing execution with async unfinished yet step", (done) => { const transactionId = "transaction-to-cancel-id" const step1 = createStep("step1", async () => { return new StepResponse("step1") @@ -168,25 +168,39 @@ moduleIntegrationTestRunner({ return new WorkflowResponse("finished") }) - await workflowOrcModule.run(workflowId, { - input: {}, - transactionId, - }) + workflowOrcModule + .run(workflowId, { + input: {}, + transactionId, + }) + .then(async () => { + await setTimeoutPromise(100) - await setTimeoutPromise(100) + await workflowOrcModule.cancel(workflowId, { + transactionId, + }) - await workflowOrcModule.cancel(workflowId, { - transactionId, - }) + workflowOrcModule.subscribe({ + workflowId, + transactionId, + subscriber: async (event) => { + if (event.eventType === "onFinish") { + const execution = + await workflowOrcModule.listWorkflowExecutions({ + transaction_id: transactionId, + }) - await setTimeoutPromise(1000) + expect(execution.length).toEqual(1) + expect(execution[0].state).toEqual( + TransactionState.REVERTED + ) + done() + } + }, + }) + }) - const execution = await workflowOrcModule.listWorkflowExecutions({ - transaction_id: transactionId, - }) - - expect(execution.length).toEqual(1) - expect(execution[0].state).toEqual(TransactionState.REVERTED) + failTrap(done) }) it("should cancel a complete execution with a sync workflow running as async", async () => { @@ -898,7 +912,6 @@ moduleIntegrationTestRunner({ expect(spy).toHaveBeenCalledTimes(1) - console.log(spy.mock.results) expect(spy).toHaveReturnedWith( expect.objectContaining({ output: { testValue: "test" } }) ) @@ -944,6 +957,35 @@ moduleIntegrationTestRunner({ expect(executionsList).toHaveLength(1) expect(executionsListAfter).toHaveLength(1) }) + + it("should display error when multple async steps are running in parallel", (done) => { + void workflowOrcModule.run("workflow_parallel_async", { + input: {}, + throwOnError: false, + }) + + void workflowOrcModule.subscribe({ + workflowId: "workflow_parallel_async", + subscriber: (event) => { + if (event.eventType === "onFinish") { + done() + expect(event.errors).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + action: "step_2", + handlerType: "invoke", + error: expect.objectContaining({ + message: "Error in parallel step", + }), + }), + ]) + ) + } + }, + }) + + failTrap(done) + }) }) describe("Cleaner job", function () { diff --git a/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts b/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts index 221a244ec9..e74cb52bfb 100644 --- a/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts +++ b/packages/modules/workflow-engine-redis/integration-tests/__tests__/index.spec.ts @@ -151,7 +151,7 @@ moduleIntegrationTestRunner({ describe("Testing basic workflow", function () { describe("Cancel transaction", function () { - it("should cancel an ongoing execution with async unfinished yet step", async () => { + it("should cancel an ongoing execution with async unfinished yet step", (done) => { const transactionId = "transaction-to-cancel-id" const step1 = createStep("step1", async () => { return new StepResponse("step1") @@ -179,25 +179,42 @@ moduleIntegrationTestRunner({ } ) - await workflowOrcModule.run(workflowId, { - input: {}, - transactionId, - }) + workflowOrcModule + .run(workflowId, { + input: {}, + transactionId, + }) + .then(async () => { + await setTimeout(100) - await setTimeout(100) + await workflowOrcModule.cancel(workflowId, { + transactionId, + }) - await workflowOrcModule.cancel(workflowId, { - transactionId, - }) + workflowOrcModule.subscribe({ + workflowId, + transactionId, + subscriber: async (event) => { + if (event.eventType === "onFinish") { + const execution = + await workflowOrcModule.listWorkflowExecutions({ + transaction_id: transactionId, + }) - await setTimeout(1000) + expect(execution.length).toEqual(1) + expect(execution[0].state).toEqual( + TransactionState.REVERTED + ) + done() + } + }, + }) + }) - const execution = await workflowOrcModule.listWorkflowExecutions({ - transaction_id: transactionId, - }) - - expect(execution.length).toEqual(1) - expect(execution[0].state).toEqual(TransactionState.REVERTED) + failTrap( + done, + "should cancel an ongoing execution with async unfinished yet step" + ) }) it("should cancel a complete execution with a sync workflow running as async", async () => {