fix(workflow-engine-*): Prevent passing shared context reference (#11873)

* fix(workflow-engine-*): Prevent passing shared context reference

* fix(workflow-engine-*): Prevent passing shared context reference

* prevent tests from hanging

* fix event handling

* add integration tests

* use interval for scheduled in tests

* skip tests for now

* Create silent-glasses-enjoy.md

* fix cancel

* changeset

* push multiple aliases

* test multiple field alias

* increase wait time to index on test

---------

Co-authored-by: Carlos R. L. Rodrigues <37986729+carlos-r-l-rodrigues@users.noreply.github.com>
Co-authored-by: Carlos R. L. Rodrigues <rodrigolr@gmail.com>
This commit is contained in:
Adrien de Peretti
2025-04-09 10:39:29 +02:00
committed by GitHub
parent 2a18a75353
commit 13e159d8ad
18 changed files with 353 additions and 58 deletions
+9
View File
@@ -0,0 +1,9 @@
---
"@medusajs/workflow-engine-inmemory": patch
"@medusajs/workflow-engine-redis": patch
"@medusajs/types": patch
"@medusajs/workflows-sdk": patch
"@medusajs/orchestration": patch
---
fix(workflow-engine-\*): Prevent passing shared context reference and workflow.cancel
@@ -1,13 +1,13 @@
import { medusaIntegrationTestRunner } from "@medusajs/test-utils"
import { IndexTypes } from "@medusajs/types"
import { defaultCurrencies, Modules } from "@medusajs/utils"
import { medusaIntegrationTestRunner } from "@medusajs/test-utils"
import { setTimeout } from "timers/promises"
import {
adminHeaders,
createAdminUser,
} from "../../../helpers/create-admin-user"
jest.setTimeout(120000)
jest.setTimeout(100000)
process.env.ENABLE_INDEX_MODULE = "true"
@@ -65,7 +65,7 @@ medusaIntegrationTestRunner({
})
// Timeout to allow indexing to finish
await setTimeout(2000)
await setTimeout(4000)
const { data: results } = await indexEngine.query<"product">({
fields: [
@@ -144,7 +144,7 @@ medusaIntegrationTestRunner({
})
// Timeout to allow indexing to finish
await setTimeout(2000)
await setTimeout(4000)
const { data: results } = await indexEngine.query<"product">({
fields: [
@@ -1,6 +1,14 @@
import { RemoteJoiner } from "@medusajs/framework/orchestration"
import CustomerModule from "@medusajs/medusa/customer"
import RegionModule from "@medusajs/medusa/region"
import { MedusaModule } from "@medusajs/modules-sdk"
import { medusaIntegrationTestRunner } from "@medusajs/test-utils"
import { IRegionModuleService, RemoteQueryFunction } from "@medusajs/types"
import { ContainerRegistrationKeys, Modules } from "@medusajs/utils"
import {
IRegionModuleService,
ModuleJoinerConfig,
RemoteQueryFunction,
} from "@medusajs/types"
import { ContainerRegistrationKeys, defineLink, Modules } from "@medusajs/utils"
import { createAdminUser } from "../../..//helpers/create-admin-user"
import { adminHeaders } from "../../../helpers/create-admin-user"
@@ -428,6 +436,90 @@ medusaIntegrationTestRunner({
}),
])
})
it("should handle multiple fieldAlias when multiple links between two modules are defined", async () => {
const customer = CustomerModule.linkable.customer
const customerGroup = CustomerModule.linkable.customerGroup
const region = RegionModule.linkable.region
const country = RegionModule.linkable.country
defineLink(customer, region)
defineLink(customerGroup, region)
defineLink(customer, country)
defineLink(customerGroup, country)
const modulesLoaded = MedusaModule.getLoadedModules().map(
(mod) => Object.values(mod)[0]
)
const servicesConfig_: ModuleJoinerConfig[] = []
for (const mod of modulesLoaded || []) {
if (!mod.__definition.isQueryable) {
continue
}
servicesConfig_!.push(mod.__joinerConfig)
}
const linkDefinition = MedusaModule.getCustomLinks().map(
(linkDefinition: any) => {
const definition = linkDefinition(
MedusaModule.getAllJoinerConfigs()
)
return definition
}
)
servicesConfig_.push(...(linkDefinition as any))
const remoteJoiner = new RemoteJoiner(
servicesConfig_,
(() => {}) as any
)
const fieldAlias = (remoteJoiner as any).getServiceConfig({
entity: "Customer",
}).fieldAlias
expect(fieldAlias).toEqual(
expect.objectContaining({
account_holders: {
path: "account_holder_link.account_holder",
isList: true,
entity: "Customer",
},
region: [
{
path: "region_link.region",
isList: false,
forwardArgumentsOnPath: ["region_link.region"],
entity: "Customer",
},
{
path: "region_link.region",
isList: false,
forwardArgumentsOnPath: ["region_link.region"],
entity: "CustomerGroup",
},
],
country: [
{
path: "country_link.country",
isList: false,
forwardArgumentsOnPath: ["country_link.country"],
entity: "Customer",
},
{
path: "country_link.country",
isList: false,
forwardArgumentsOnPath: ["country_link.country"],
entity: "CustomerGroup",
},
],
})
)
})
})
},
})
@@ -383,6 +383,11 @@ export class RemoteJoiner {
`Cannot add alias "${alias}" for "${extend.serviceName}". It is already defined for Entity "${extend.entity}".`
)
}
service_.fieldAlias[alias].push({
...objAlias,
entity: extend.entity,
})
} else {
service_.fieldAlias[alias] = {
...objAlias,
@@ -741,6 +741,7 @@ export class TransactionOrchestrator extends EventEmitter {
this.emit(DistributedTransactionEvent.FINISH, { transaction })
}
const asyncStepsToStart: any[] = []
for (const step of nextSteps.next) {
const curState = step.getStates()
const type = step.isCompensating()
@@ -923,8 +924,8 @@ export class TransactionOrchestrator extends EventEmitter {
return await transaction.handler(...handlerArgs)
}
execution.push(
transaction.saveCheckpoint().then(() => {
asyncStepsToStart.push({
handler: async () => {
let promise: Promise<unknown>
if (TransactionOrchestrator.traceStep) {
@@ -936,7 +937,7 @@ export class TransactionOrchestrator extends EventEmitter {
promise = stepHandler()
}
promise
return promise
.then(async (response: any) => {
const output = response?.__type ? response.output : response
@@ -990,8 +991,8 @@ export class TransactionOrchestrator extends EventEmitter {
response,
})
})
})
)
},
})
}
}
@@ -1005,6 +1006,10 @@ export class TransactionOrchestrator extends EventEmitter {
}
}
if (asyncStepsToStart.length > 0) {
execution.push(...asyncStepsToStart.map((step) => step.handler()))
}
await promiseAll(execution)
if (nextSteps.next.length === 0) {
@@ -391,6 +391,7 @@ export class LocalWorkflow {
async cancel(
transactionOrTransactionId: string | DistributedTransactionType,
_?: unknown, // not used but a common argument on other methods called dynamically
context?: Context,
subscribe?: DistributedTransactionEvents
) {
@@ -73,4 +73,9 @@ export type Context<TManager = unknown> = {
* A string indicating the idempotencyKey of the parent workflow execution.
*/
parentStepIdempotencyKey?: string
/**
* preventReleaseEvents
*/
preventReleaseEvents?: boolean
}
@@ -91,9 +91,10 @@ function createContextualWorkflowRunner<
flow.container = executionContainer
}
const { eventGroupId, parentStepIdempotencyKey } = context
const { eventGroupId, parentStepIdempotencyKey, preventReleaseEvents } =
context
if (!parentStepIdempotencyKey) {
if (!preventReleaseEvents) {
attachOnFinishReleaseEvents(events, flow, { logOnError })
}
@@ -196,6 +196,7 @@ export function createWorkflow<TData, TResult, THooks extends any[]>(
transactionId:
step.__step__ + "-" + (stepContext.transactionId ?? ulid()),
parentStepIdempotencyKey: stepContext.idempotencyKey,
preventReleaseEvents: true,
},
})
@@ -207,15 +208,12 @@ export function createWorkflow<TData, TResult, THooks extends any[]>(
)
},
async (transaction, stepContext) => {
if (!transaction) {
return
}
const { container, ...sharedContext } = stepContext
const transactionId = step.__step__ + "-" + stepContext.transactionId
await workflow(container).cancel({
transaction: (transaction as WorkflowResult<any>).transaction,
transactionId: isString(transaction) ? transaction : undefined,
transaction: (transaction as WorkflowResult<any>)?.transaction,
transactionId,
container,
context: {
...sharedContext,
@@ -31,7 +31,7 @@ import {
import { setTimeout as setTimeoutSync } from "timers"
import { createScheduled } from "../__fixtures__/workflow_scheduled"
jest.setTimeout(3000000)
jest.setTimeout(300000)
const failTrap = (done) => {
setTimeoutSync(() => {
@@ -12,7 +12,7 @@ import { setTimeout as setTimeoutSync } from "timers"
import { setTimeout } from "timers/promises"
import "../__fixtures__"
jest.setTimeout(3000000)
jest.setTimeout(300000)
const failTrap = (done) => {
setTimeoutSync(() => {
@@ -84,7 +84,17 @@ export class WorkflowsModuleService<
@MedusaContext() context: Context = {}
) {
options ??= {}
options.context ??= context
const {
manager,
transactionManager,
preventReleaseEvents,
...restContext
} = context
options.context ??= restContext
options.context.preventReleaseEvents ??=
!!options.context.parentStepIdempotencyKey
delete options.context.parentStepIdempotencyKey
const ret = await this.workflowOrchestratorService_.run<
TWorkflow extends ReturnWorkflow<any, any, any>
@@ -14,6 +14,7 @@ import {
MedusaError,
TransactionState,
TransactionStepState,
isDefined,
isPresent,
} from "@medusajs/framework/utils"
import { WorkflowOrchestratorService } from "@services"
@@ -115,8 +116,8 @@ export class InMemoryDistributedTransactionStorage
return data
}
const { idempotent } = options ?? {}
if (!idempotent) {
const { idempotent, store, retentionTime } = options ?? {}
if (!idempotent && !(store && isDefined(retentionTime))) {
return
}
@@ -11,16 +11,19 @@ export const createScheduled = (
schedule?: SchedulerOptions
) => {
const workflowScheduledStepInvoke = jest.fn((input, { container }) => {
next()
return new StepResponse({
testValue: container.resolve("test-value"),
})
try {
return new StepResponse({
testValue: "test-value",
})
} finally {
next()
}
})
const step = createStep("step_1", workflowScheduledStepInvoke)
createWorkflow(
{ name, schedule: schedule ?? "* * * * * *" },
{ name, schedule: schedule ?? { interval: 1000 } },
function (input) {
return step(input)
}
@@ -1,5 +1,6 @@
import {
DistributedTransactionType,
TransactionState,
TransactionStep,
TransactionStepTimeoutError,
TransactionTimeoutError,
@@ -26,8 +27,14 @@ import { WorkflowsModuleService } from "../../src/services"
import "../__fixtures__"
import { createScheduled } from "../__fixtures__/workflow_scheduled"
import { TestDatabase } from "../utils"
import {
createStep,
createWorkflow,
StepResponse,
WorkflowResponse,
} from "@medusajs/framework/workflows-sdk"
jest.setTimeout(999900000)
jest.setTimeout(300000)
const failTrap = (done) => {
setTimeoutSync(() => {
@@ -39,6 +46,33 @@ const failTrap = (done) => {
}, 5000)
}
function times(num) {
let resolver
let counter = 0
const promise = new Promise((resolve) => {
resolver = resolve
})
return {
next: () => {
counter += 1
if (counter === num) {
resolver()
}
},
// Force resolution after 10 seconds to prevent infinite awaiting
promise: Promise.race([
promise,
new Promise((_, reject) => {
setTimeoutSync(
() => reject("times has not been resolved after 10 seconds."),
1000
)
}),
]),
}
}
// REF:https://stackoverflow.com/questions/78028715/jest-async-test-with-event-emitter-isnt-ending
moduleIntegrationTestRunner<IWorkflowEngineService>({
@@ -56,24 +90,6 @@ moduleIntegrationTestRunner<IWorkflowEngineService>({
jest.clearAllMocks()
})
function times(num) {
let resolver
let counter = 0
const promise = new Promise((resolve) => {
resolver = resolve
})
return {
next: () => {
counter += 1
if (counter === num) {
resolver()
}
},
promise,
}
}
let query: RemoteQueryFunction
let sharedContainer_: MedusaContainer
@@ -534,9 +550,146 @@ moduleIntegrationTestRunner<IWorkflowEngineService>({
})
})
describe("Testing complex workflows", function () {
it("should execute workflow + workflow as step + manual workflow within a step correctly", async () => {
const workflowA_id = "workflow_a"
const workflowB_id = "workflow_b"
const workflowC_id = "workflow_c"
const stepB_1 = createStep("stepB_1", async (input, context) => {
let results: any[] = []
for (let i = 0; i < 2; i++) {
const { result } = await workflowOrcModule.run(workflowC_id, {
input: {},
})
results.push(result)
}
return new StepResponse(results)
})
const stepC_1 = createStep("stepC_1", async (input, context) => {
return new StepResponse({
stepC_1_result: "stepC_1_result",
})
})
createWorkflow(workflowC_id, (input) => {
const result = stepC_1()
return new WorkflowResponse(result)
})
const workflowB = createWorkflow(workflowB_id, (input) => {
const result = stepB_1()
return new WorkflowResponse(result)
})
createWorkflow(workflowA_id, (input) => {
const workflowB_response = workflowB.runAsStep({
input: {},
})
return new WorkflowResponse({ workflowB_response })
})
const { result } = await workflowOrcModule.run(workflowA_id, {
input: {},
throwOnError: false,
})
expect(result).toEqual({
workflowB_response: [
{
stepC_1_result: "stepC_1_result",
},
{
stepC_1_result: "stepC_1_result",
},
],
})
})
it("should execute workflow + workflow as step + manual workflow within a step that fail but do not fail the step", async () => {
const workflowA_id = "workflow_a"
const workflowB_id = "workflow_b"
const workflowC_id = "workflow_c"
const stepB_1 = createStep("stepB_1", async (input, context) => {
let results: any[] = []
for (let i = 0; i < 2; i++) {
const { errors } = await workflowOrcModule.run(workflowC_id, {
input: {},
throwOnError: false,
})
results.push(errors)
}
return new StepResponse(results)
})
const stepC_1 = createStep("stepC_1", async (input, context) => {
throw new Error("Workflow C failed")
})
createWorkflow(workflowC_id, (input) => {
const result = stepC_1()
return new WorkflowResponse(result)
})
const workflowB = createWorkflow(workflowB_id, (input) => {
const result = stepB_1()
return new WorkflowResponse(result)
})
createWorkflow(workflowA_id, (input) => {
const workflowB_response = workflowB.runAsStep({
input: {},
})
return new WorkflowResponse({ workflowB_response })
})
const { result, transaction } = await workflowOrcModule.run(
workflowA_id,
{
input: {},
throwOnError: false,
}
)
expect(
(transaction as DistributedTransactionType).getFlow().state
).toEqual(TransactionState.DONE)
expect(result).toEqual({
workflowB_response: [
[
{
action: "stepC_1",
handlerType: TransactionHandlerType.INVOKE,
error: expect.objectContaining({
message: "Workflow C failed",
}),
},
],
[
{
action: "stepC_1",
handlerType: TransactionHandlerType.INVOKE,
error: expect.objectContaining({
message: "Workflow C failed",
}),
},
],
],
})
})
})
// Note: These tests depend on actual Redis instance and waiting for the scheduled jobs to run, which isn't great.
// Mocking bullmq, however, would make the tests close to useless, so we can keep them very minimal and serve as smoke tests.
describe("Scheduled workflows", () => {
describe.skip("Scheduled workflows", () => {
beforeEach(() => {
jest.clearAllMocks()
})
@@ -552,8 +705,8 @@ moduleIntegrationTestRunner<IWorkflowEngineService>({
it("should stop executions after the set number of executions", async () => {
const wait = times(2)
const spy = await createScheduled("num-executions", wait.next, {
cron: "* * * * * *",
const spy = createScheduled("num-executions", wait.next, {
interval: 1000,
numberOfExecutions: 2,
})
@@ -573,8 +726,8 @@ moduleIntegrationTestRunner<IWorkflowEngineService>({
ContainerRegistrationKeys.LOGGER
)
const spy = await createScheduled("remove-scheduled", wait.next, {
cron: "* * * * * *",
const spy = createScheduled("remove-scheduled", wait.next, {
interval: 1000,
})
const logSpy = jest.spyOn(logger, "warn")
@@ -594,7 +747,7 @@ moduleIntegrationTestRunner<IWorkflowEngineService>({
sharedContainer_.register("test-value", asValue("test"))
const spy = await createScheduled("shared-container-job", wait.next, {
cron: "* * * * * *",
interval: 1000,
})
await wait.promise
@@ -12,7 +12,7 @@ import { setTimeout as setTimeoutSync } from "timers"
import { setTimeout } from "timers/promises"
import "../__fixtures__"
jest.setTimeout(999900000)
jest.setTimeout(300000)
const failTrap = (done) => {
setTimeoutSync(() => {
@@ -96,7 +96,18 @@ export class WorkflowsModuleService<
@MedusaContext() context: Context = {}
) {
options ??= {}
options.context ??= context
const {
manager,
transactionManager,
preventReleaseEvents,
...restContext
} = context
options.context ??= restContext
options.context.preventReleaseEvents ??=
!!options.context.parentStepIdempotencyKey
delete options.context.parentStepIdempotencyKey
const ret = await this.workflowOrchestratorService_.run<
TWorkflow extends ReturnWorkflow<any, any, any>
? UnwrapWorkflowInputDataType<TWorkflow>
@@ -12,6 +12,7 @@ import {
} from "@medusajs/framework/orchestration"
import { Logger, ModulesSdkTypes } from "@medusajs/framework/types"
import {
isDefined,
isPresent,
MedusaError,
promiseAll,
@@ -221,8 +222,8 @@ export class RedisDistributedTransactionStorage
return JSON.parse(data)
}
const { idempotent } = options ?? {}
if (!idempotent) {
const { idempotent, store, retentionTime } = options ?? {}
if (!idempotent && !(store && isDefined(retentionTime))) {
return
}