diff --git a/.changeset/breezy-beds-invent.md b/.changeset/breezy-beds-invent.md new file mode 100644 index 0000000000..4e3e8198be --- /dev/null +++ b/.changeset/breezy-beds-invent.md @@ -0,0 +1,8 @@ +--- +"@medusajs/workflow-engine-inmemory": patch +"@medusajs/workflow-engine-redis": patch +"@medusajs/orchestration": patch +"@medusajs/workflows-sdk": patch +--- + +chore(workflow-engine-\*): cleanup and improvements diff --git a/packages/core/orchestration/src/transaction/distributed-transaction.ts b/packages/core/orchestration/src/transaction/distributed-transaction.ts index d5442f6b74..adfdd5bb73 100644 --- a/packages/core/orchestration/src/transaction/distributed-transaction.ts +++ b/packages/core/orchestration/src/transaction/distributed-transaction.ts @@ -55,6 +55,9 @@ const mergeStep = ( } } +const getErrorSignature = (err: TransactionStepError) => + `${err.action}:${err.handlerType}:${err.error?.message}` + /** * @typedef TransactionMetadata * @property model_id - The id of the model_id that created the transaction (modelId). @@ -105,7 +108,52 @@ const stateFlowOrder = [ TransactionState.FAILED, ] +const stateFlowOrderMap = new Map( + stateFlowOrder.map((state, index) => [state, index]) +) + +const finishedStatesSet = new Set([ + TransactionState.DONE, + TransactionState.REVERTED, + TransactionState.FAILED, +]) + export class TransactionCheckpoint { + static readonly #ALLOWED_STATE_TRANSITIONS = { + [TransactionStepState.DORMANT]: [TransactionStepState.NOT_STARTED], + [TransactionStepState.NOT_STARTED]: [ + TransactionStepState.INVOKING, + TransactionStepState.COMPENSATING, + TransactionStepState.FAILED, + TransactionStepState.SKIPPED, + TransactionStepState.SKIPPED_FAILURE, + ], + [TransactionStepState.INVOKING]: [ + TransactionStepState.FAILED, + TransactionStepState.DONE, + TransactionStepState.TIMEOUT, + TransactionStepState.SKIPPED, + ], + [TransactionStepState.COMPENSATING]: [ + TransactionStepState.REVERTED, + TransactionStepState.FAILED, + ], + [TransactionStepState.DONE]: [TransactionStepState.COMPENSATING], + } as const + + static readonly #ALLOWED_STATUS_TRANSITIONS = { + [TransactionStepStatus.WAITING]: [ + TransactionStepStatus.OK, + TransactionStepStatus.TEMPORARY_FAILURE, + TransactionStepStatus.PERMANENT_FAILURE, + ], + [TransactionStepStatus.TEMPORARY_FAILURE]: [ + TransactionStepStatus.IDLE, + TransactionStepStatus.PERMANENT_FAILURE, + ], + [TransactionStepStatus.PERMANENT_FAILURE]: [TransactionStepStatus.IDLE], + } as const + constructor( public flow: TransactionFlow, public context: TransactionContext, @@ -165,17 +213,15 @@ export class TransactionCheckpoint { currentTransactionData.flow[prop] ?? 0 ) } else if (prop === "state") { - const curState = stateFlowOrder.findIndex( - (state) => state === currentTransactionData.flow.state - ) - const storedState = stateFlowOrder.findIndex( - (state) => state === storedData.flow.state - ) + const currentStateIndex = + stateFlowOrderMap.get(currentTransactionData.flow.state) ?? -1 + const storedStateIndex = + stateFlowOrderMap.get(storedData.flow.state) ?? -1 - if (storedState > curState) { + if (storedStateIndex > currentStateIndex) { currentTransactionData.flow.state = storedData.flow.state } else if ( - curState < storedState && + currentStateIndex < storedStateIndex && currentTransactionData.flow.state !== TransactionState.WAITING_TO_COMPENSATE ) { @@ -265,43 +311,6 @@ export class TransactionCheckpoint { status: TransactionStepStatus } ): boolean { - // Define allowed state transitions - const allowedStateTransitions = { - [TransactionStepState.DORMANT]: [TransactionStepState.NOT_STARTED], - [TransactionStepState.NOT_STARTED]: [ - TransactionStepState.INVOKING, - TransactionStepState.COMPENSATING, - TransactionStepState.FAILED, - TransactionStepState.SKIPPED, - TransactionStepState.SKIPPED_FAILURE, - ], - [TransactionStepState.INVOKING]: [ - TransactionStepState.FAILED, - TransactionStepState.DONE, - TransactionStepState.TIMEOUT, - TransactionStepState.SKIPPED, - ], - [TransactionStepState.COMPENSATING]: [ - TransactionStepState.REVERTED, - TransactionStepState.FAILED, - ], - [TransactionStepState.DONE]: [TransactionStepState.COMPENSATING], - } - - // Define allowed status transitions - const allowedStatusTransitions = { - [TransactionStepStatus.WAITING]: [ - TransactionStepStatus.OK, - TransactionStepStatus.TEMPORARY_FAILURE, - TransactionStepStatus.PERMANENT_FAILURE, - ], - [TransactionStepStatus.TEMPORARY_FAILURE]: [ - TransactionStepStatus.IDLE, - TransactionStepStatus.PERMANENT_FAILURE, - ], - [TransactionStepStatus.PERMANENT_FAILURE]: [TransactionStepStatus.IDLE], - } - if ( currentStepState.state === storedStepState.state && currentStepState.status === storedStepState.status @@ -311,7 +320,9 @@ export class TransactionCheckpoint { // Check if state transition from stored to current is allowed const allowedStatesFromCurrent = - allowedStateTransitions[currentStepState.state] || [] + TransactionCheckpoint.#ALLOWED_STATE_TRANSITIONS[ + currentStepState.state + ] || [] const isStateTransitionValid = allowedStatesFromCurrent.includes( storedStepState.state ) @@ -328,7 +339,9 @@ export class TransactionCheckpoint { // Check if status transition from stored to current is allowed const allowedStatusesFromCurrent = - allowedStatusTransitions[currentStepState.status] || [] + TransactionCheckpoint.#ALLOWED_STATUS_TRANSITIONS[ + currentStepState.status + ] || [] return allowedStatusesFromCurrent.includes(storedStepState.status) } @@ -338,15 +351,13 @@ export class TransactionCheckpoint { incomingErrors: TransactionStepError[] ): void { const existingErrorSignatures = new Set( - currentErrors.map( - (err) => `${err.action}:${err.handlerType}:${err.error?.message}` - ) + currentErrors.map(getErrorSignature) ) for (const error of incomingErrors) { - const signature = `${error.action}:${error.handlerType}:${error.error?.message}` - if (!existingErrorSignatures.has(signature)) { + if (!existingErrorSignatures.has(getErrorSignature(error))) { currentErrors.push(error) + existingErrorSignatures.add(getErrorSignature(error)) } } } @@ -451,11 +462,7 @@ class DistributedTransaction extends EventEmitter { } public hasFinished(): boolean { - return [ - TransactionState.DONE, - TransactionState.REVERTED, - TransactionState.FAILED, - ].includes(this.getState()) + return finishedStatesSet.has(this.getState()) } public getState(): TransactionState { @@ -520,7 +527,7 @@ class DistributedTransaction extends EventEmitter { this.transactionId ) - let checkpoint + let checkpoint: TransactionCheckpoint let retries = 0 let backoffMs = 50 @@ -553,7 +560,7 @@ class DistributedTransaction extends EventEmitter { await setTimeoutPromise(backoffMs + jitter) - backoffMs = Math.min(backoffMs * 2, 1000) + backoffMs = Math.min(backoffMs * 2, 500) const lastCheckpoint = await DistributedTransaction.loadTransaction( this.modelId, @@ -697,65 +704,45 @@ class DistributedTransaction extends EventEmitter { * @returns */ #serializeCheckpointData() { - const data = new TransactionCheckpoint( - this.getFlow(), - this.getContext(), - this.getErrors() - ) - - const isSerializable = (obj) => { - try { - JSON.parse(JSON.stringify(obj)) - return true - } catch { - return false - } - } - - let rawData try { - rawData = JSON.parse(JSON.stringify(data)) - } catch (e) { - if (!isSerializable(this.context)) { - // This is a safe guard, we should never reach this point - // If we do, it means that the context is not serializable - // and we should throw an error - throw new NonSerializableCheckPointError( - "Unable to serialize context object. Please make sure the workflow input and steps response are serializable." - ) - } - - if (!isSerializable(this.errors)) { - const nonSerializableErrors: TransactionStepError[] = [] - for (const error of this.errors) { - if (!isSerializable(error.error)) { - error.error = { - name: error.error.name, - message: error.error.message, - stack: error.error.stack, - } - nonSerializableErrors.push({ - ...error, - error: e, - }) - } - } - - if (nonSerializableErrors.length) { - this.errors.push(...nonSerializableErrors) - } - } - - const data = new TransactionCheckpoint( - this.getFlow(), - this.getContext(), - this.getErrors() + JSON.stringify(this.context) + } catch { + throw new NonSerializableCheckPointError( + "Unable to serialize context object. Please make sure the workflow input and steps response are serializable." ) - - rawData = JSON.parse(JSON.stringify(data)) } - return rawData + let errorsToUse = this.getErrors() + try { + JSON.stringify(errorsToUse) + } catch { + // Sanitize non-serializable errors + const sanitizedErrors: TransactionStepError[] = [] + for (const error of this.errors) { + try { + JSON.stringify(error) + sanitizedErrors.push(error) + } catch { + sanitizedErrors.push({ + action: error.action, + handlerType: error.handlerType, + error: { + name: error.error?.name || "Error", + message: error.error?.message || String(error.error), + stack: error.error?.stack, + }, + }) + } + } + errorsToUse = sanitizedErrors + this.errors = sanitizedErrors + } + + return new TransactionCheckpoint( + JSON.parse(JSON.stringify(this.getFlow())), + this.getContext(), + [...errorsToUse] + ) } } diff --git a/packages/core/orchestration/src/transaction/transaction-orchestrator.ts b/packages/core/orchestration/src/transaction/transaction-orchestrator.ts index 31b0381278..7f032e0744 100644 --- a/packages/core/orchestration/src/transaction/transaction-orchestrator.ts +++ b/packages/core/orchestration/src/transaction/transaction-orchestrator.ts @@ -40,6 +40,33 @@ import { TransactionTimeoutError, } from "./errors" +const canMoveForwardStates = new Set([ + TransactionStepState.DONE, + TransactionStepState.FAILED, + TransactionStepState.TIMEOUT, + TransactionStepState.SKIPPED, + TransactionStepState.SKIPPED_FAILURE, +]) + +const canMoveBackwardStates = new Set([ + TransactionStepState.DONE, + TransactionStepState.REVERTED, + TransactionStepState.FAILED, + TransactionStepState.DORMANT, + TransactionStepState.SKIPPED, +]) + +const flagStepsToRevertStates = new Set([ + TransactionStepState.DONE, + TransactionStepState.TIMEOUT, +]) + +const setStepTimeoutSkipStates = new Set([ + TransactionStepState.TIMEOUT, + TransactionStepState.DONE, + TransactionStepState.REVERTED, +]) + /** * @class TransactionOrchestrator is responsible for managing and executing distributed transactions. * It is based on a single transaction definition, which is used to execute all the transaction steps @@ -184,14 +211,6 @@ export class TransactionOrchestrator extends EventEmitter { } private canMoveForward(flow: TransactionFlow, previousStep: TransactionStep) { - const states = [ - TransactionStepState.DONE, - TransactionStepState.FAILED, - TransactionStepState.TIMEOUT, - TransactionStepState.SKIPPED, - TransactionStepState.SKIPPED_FAILURE, - ] - const siblings = TransactionOrchestrator.getPreviousStep( flow, previousStep @@ -199,23 +218,15 @@ export class TransactionOrchestrator extends EventEmitter { return ( !!previousStep.definition.noWait || - siblings.every((sib) => states.includes(sib.invoke.state)) + siblings.every((sib) => canMoveForwardStates.has(sib.invoke.state)) ) } private canMoveBackward(flow: TransactionFlow, step: TransactionStep) { - const states = [ - TransactionStepState.DONE, - TransactionStepState.REVERTED, - TransactionStepState.FAILED, - TransactionStepState.DORMANT, - TransactionStepState.SKIPPED, - ] - const siblings = step.next.map((sib) => flow.steps[sib]) return ( siblings.length === 0 || - siblings.every((sib) => states.includes(sib.compensate.state)) + siblings.every((sib) => canMoveBackwardStates.has(sib.compensate.state)) ) } @@ -521,9 +532,7 @@ export class TransactionOrchestrator extends EventEmitter { } if ( - [TransactionStepState.DONE, TransactionStepState.TIMEOUT].includes( - curState.state - ) || + flagStepsToRevertStates.has(curState.state) || curState.status === TransactionStepStatus.PERMANENT_FAILURE ) { stepDef.beginCompensation() @@ -707,13 +716,7 @@ export class TransactionOrchestrator extends EventEmitter { step: TransactionStep, error: TransactionStepTimeoutError | TransactionTimeoutError ): Promise { - if ( - [ - TransactionStepState.TIMEOUT, - TransactionStepState.DONE, - TransactionStepState.REVERTED, - ].includes(step.getStates().state) - ) { + if (setStepTimeoutSkipStates.has(step.getStates().state)) { return } @@ -773,13 +776,10 @@ export class TransactionOrchestrator extends EventEmitter { error = serializeError(error) } else { try { - if (error?.message) { - error = JSON.parse(JSON.stringify(error)) - } else { - error = { - message: JSON.stringify(error), - } - } + const serialized = JSON.stringify(error) + error = error?.message + ? JSON.parse(serialized) + : { message: serialized } } catch (e) { error = { message: "Unknown non-serializable error", @@ -953,6 +953,18 @@ export class TransactionOrchestrator extends EventEmitter { return shouldContinueExecution }) + let asyncStepCount = 0 + for (const s of nextSteps.next) { + const stepIsAsync = s.isCompensating() + ? s.definition.compensateAsync + : s.definition.async + if (stepIsAsync) asyncStepCount++ + } + const hasMultipleAsyncSteps = asyncStepCount > 1 + const hasAsyncSteps = !!asyncStepCount + + // If there is any async step, we don't need to save the checkpoint here as it will be saved + // later down there await transaction.saveCheckpoint().catch((error) => { if (TransactionOrchestrator.isExpectedError(error)) { continueExecution = false @@ -962,11 +974,15 @@ export class TransactionOrchestrator extends EventEmitter { throw error }) + if (!continueExecution) { + break + } + const execution: Promise[] = [] const executionAsync: (() => Promise)[] = [] let i = 0 - let hasAsyncSteps = false + for (const step of nextSteps.next) { const stepIndex = i++ if (!stepsShouldContinueExecution[stepIndex]) { @@ -988,21 +1004,9 @@ export class TransactionOrchestrator extends EventEmitter { // Compute current transaction state await this.computeCurrentTransactionState(transaction) - if (!continueExecution) { - break - } const promise = this.createStepExecutionPromise(transaction, step) - const hasMultipleAsyncSteps = - nextSteps.next.filter((step) => { - const isAsync = step.isCompensating() - ? step.definition.compensateAsync - : step.definition.async - - return isAsync - }).length > 1 - const hasVersionControl = hasMultipleAsyncSteps || step.hasAwaitingRetry() @@ -1017,7 +1021,6 @@ export class TransactionOrchestrator extends EventEmitter { ) } else { // Execute async step in background as part of the next event loop cycle and continue the execution of the transaction - hasAsyncSteps = true executionAsync.push(() => this.executeAsyncStep(promise, transaction, step, nextSteps) ) @@ -1105,12 +1108,9 @@ export class TransactionOrchestrator extends EventEmitter { private createStepPayload( transaction: DistributedTransactionType, step: TransactionStep, - flow: TransactionFlow + flow: TransactionFlow, + type: TransactionHandlerType ): TransactionPayload { - const type = step.isCompensating() - ? TransactionHandlerType.COMPENSATE - : TransactionHandlerType.INVOKE - return new TransactionPayload( { model_id: flow.modelId, @@ -1136,13 +1136,9 @@ export class TransactionOrchestrator extends EventEmitter { private prepareHandlerArgs( transaction: DistributedTransactionType, step: TransactionStep, - flow: TransactionFlow, - payload: TransactionPayload + payload: TransactionPayload, + type: TransactionHandlerType ): Parameters { - const type = step.isCompensating() - ? TransactionHandlerType.COMPENSATE - : TransactionHandlerType.INVOKE - return [ step.definition.action + "", type, @@ -1164,11 +1160,13 @@ export class TransactionOrchestrator extends EventEmitter { ? TransactionHandlerType.COMPENSATE : TransactionHandlerType.INVOKE + const flow = transaction.getFlow() + const payload = this.createStepPayload(transaction, step, flow, type) const handlerArgs = this.prepareHandlerArgs( transaction, step, - transaction.getFlow(), - this.createStepPayload(transaction, step, transaction.getFlow()) + payload, + type ) const traceData = { diff --git a/packages/core/utils/src/common/parse-stringify-if-necessary.ts b/packages/core/utils/src/common/parse-stringify-if-necessary.ts index 25888292f9..cc92da9093 100644 --- a/packages/core/utils/src/common/parse-stringify-if-necessary.ts +++ b/packages/core/utils/src/common/parse-stringify-if-necessary.ts @@ -1,18 +1,49 @@ import { isDefined } from "./is-defined" /** - * Only apply JSON.parse JSON.stringify when we have objects, arrays, dates, etc.. + * Creates a deep copy of the input, ensuring it's JSON-serializable. + * - Breaks all reference sharing (creates new objects/arrays) + * - Removes non-serializable values (functions, symbols, undefined properties) + * - Normalizes special types (Date -> string) + * - Only stringifies special objects, not entire tree (optimization) * @param result - * @returns + * @returns A deep copy with no shared references, guaranteed to be JSON-serializable */ -export function parseStringifyIfNecessary(result: unknown) { - if (typeof result == null || typeof result !== "object") { +export function parseStringifyIfNecessary(result: unknown): any { + if (result == null || typeof result !== "object") { return result } - const strResult = JSON.stringify(result) - if (isDefined(strResult)) { - return JSON.parse(strResult) + if (Array.isArray(result)) { + return result.map((item) => parseStringifyIfNecessary(item)) } - return result + + const isPlainObject = + result.constructor === Object || result.constructor === undefined + + if (!isPlainObject) { + const strResult = JSON.stringify(result) + if (isDefined(strResult)) { + return JSON.parse(strResult) + } + return undefined + } + + const copy: any = {} + for (const key in result) { + if (result.hasOwnProperty(key)) { + const value = (result as any)[key] + + if (typeof value === "function" || typeof value === "symbol") { + continue + } + + const copiedValue = parseStringifyIfNecessary(value) + + if (copiedValue !== undefined) { + copy[key] = copiedValue + } + } + } + return copy } diff --git a/packages/core/workflows-sdk/package.json b/packages/core/workflows-sdk/package.json index c3a5bd20c4..bb6806e14c 100644 --- a/packages/core/workflows-sdk/package.json +++ b/packages/core/workflows-sdk/package.json @@ -49,7 +49,7 @@ "scripts": { "build": "rimraf dist && tsc --build", "watch": "tsc --build --watch", - "test": "jest --bail --forceExit", + "test": "jest --bail --forceExit -- src/**/__tests__/**/*.spec.ts", "test:run": "node ./dist/utils/_playground.js" } } diff --git a/packages/core/workflows-sdk/src/utils/composer/__tests__/compose.ts b/packages/core/workflows-sdk/src/utils/composer/__tests__/compose.spec.ts similarity index 100% rename from packages/core/workflows-sdk/src/utils/composer/__tests__/compose.ts rename to packages/core/workflows-sdk/src/utils/composer/__tests__/compose.spec.ts diff --git a/packages/core/workflows-sdk/src/utils/composer/helpers/resolve-value.ts b/packages/core/workflows-sdk/src/utils/composer/helpers/resolve-value.ts index a7b2b950fd..f19658c4dd 100644 --- a/packages/core/workflows-sdk/src/utils/composer/helpers/resolve-value.ts +++ b/packages/core/workflows-sdk/src/utils/composer/helpers/resolve-value.ts @@ -1,5 +1,4 @@ import { - deepCopy, isObject, OrchestrationUtils, parseStringifyIfNecessary, @@ -10,10 +9,10 @@ import * as util from "node:util" type InputPrimitive = string | Symbol type InputObject = object & { __type?: string | Symbol; output?: any } -function resolveProperty(property, transactionContext) { +function resolveProperty(property: any, transactionContext: any) { const { invoke: invokeRes } = transactionContext - let res + let res: any if (property.__type === OrchestrationUtils.SymbolInputReference) { res = transactionContext.payload @@ -132,7 +131,7 @@ function unwrapInput({ if (result != null && typeof result === "object") { const unwrapped = unwrapInput({ inputTOUnwrap: result, - parentRef: parentRef[key] || {}, + parentRef: {}, transactionContext, }) if (unwrapped instanceof Promise) { @@ -161,7 +160,7 @@ function unwrapInput({ if (resolved != null && typeof resolved === "object") { const unwrapped = unwrapInput({ inputTOUnwrap: resolved, - parentRef: parentRef[key] || {}, + parentRef: {}, transactionContext, }) if (unwrapped instanceof Promise) { @@ -184,18 +183,17 @@ function unwrapInput({ export function resolveValue( input: InputPrimitive | InputObject | unknown | undefined, - transactionContext + transactionContext: any ): Promise | any { if (input == null || typeof input !== "object") { return input } - const input_ = deepCopy( + const input_ = (input as InputObject)?.__type === - OrchestrationUtils.SymbolWorkflowWorkflowData + OrchestrationUtils.SymbolWorkflowWorkflowData ? (input as InputObject).output : input - ) let result: any diff --git a/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts b/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts index 6e862eab30..9300e2c3c4 100644 --- a/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts +++ b/packages/modules/workflow-engine-inmemory/integration-tests/__tests__/index.spec.ts @@ -1102,7 +1102,7 @@ moduleIntegrationTestRunner({ expect(executionsListAfter).toHaveLength(1) }) - it("should display error when multple async steps are running in parallel", async () => { + it("should display error when multiple async steps are running in parallel", async () => { let errors: Error[] = [] const onFinishPromise = new Promise((resolve) => { void workflowOrcModule.subscribe({ diff --git a/packages/modules/workflow-engine-inmemory/src/utils/workflow-orchestrator-storage.ts b/packages/modules/workflow-engine-inmemory/src/utils/workflow-orchestrator-storage.ts index 57f8b94afe..f9db3bd756 100644 --- a/packages/modules/workflow-engine-inmemory/src/utils/workflow-orchestrator-storage.ts +++ b/packages/modules/workflow-engine-inmemory/src/utils/workflow-orchestrator-storage.ts @@ -31,22 +31,25 @@ import { WorkflowExecution } from "../models/workflow-execution" const THIRTY_MINUTES_IN_MS = 1000 * 60 * 30 -const doneStates = [ +const doneStates = new Set([ TransactionStepState.DONE, TransactionStepState.REVERTED, TransactionStepState.FAILED, TransactionStepState.SKIPPED, TransactionStepState.SKIPPED_FAILURE, TransactionStepState.TIMEOUT, -] +]) -const finishedStates = [ +const finishedStates = new Set([ TransactionState.DONE, TransactionState.FAILED, TransactionState.REVERTED, -] +]) -const failedStates = [TransactionState.FAILED, TransactionState.REVERTED] +const failedStates = new Set([ + TransactionState.FAILED, + TransactionState.REVERTED, +]) function calculateDelayFromExpression(expression: CronExpression): number { const nextTime = expression.next().getTime() @@ -179,7 +182,7 @@ export class InMemoryDistributedTransactionStorage private async saveToDb(data: TransactionCheckpoint, retentionTime?: number) { const isNotStarted = data.flow.state === TransactionState.NOT_STARTED const asyncVersion = data.flow._v - const isFinished = finishedStates.includes(data.flow.state) + const isFinished = finishedStates.has(data.flow.state) const isWaitingToCompensate = data.flow.state === TransactionState.WAITING_TO_COMPENSATE @@ -187,16 +190,16 @@ export class InMemoryDistributedTransactionStorage const stepsArray = Object.values(data.flow.steps) as TransactionStep[] let currentStep!: TransactionStep + let currentStepsIsAsync = false const targetStates = isFlowInvoking - ? [ + ? new Set([ TransactionStepState.INVOKING, TransactionStepState.DONE, TransactionStepState.FAILED, - ] - : [TransactionStepState.COMPENSATING] + ]) + : new Set([TransactionStepState.COMPENSATING]) - // Find the current step from the end for (let i = stepsArray.length - 1; i >= 0; i--) { const step = stepsArray[i] @@ -204,20 +207,29 @@ export class InMemoryDistributedTransactionStorage break } - const isTargetState = targetStates.includes(step.invoke?.state) + const isTargetState = targetStates.has(step.invoke?.state) - if (isTargetState) { + if (isTargetState && !currentStep) { currentStep = step break } } - const currentStepsIsAsync = currentStep - ? stepsArray.some( - (step) => - step?.definition?.async === true && step.depth === currentStep.depth - ) - : false + if (currentStep) { + for (const step of stepsArray) { + if (step.id === "_root") { + continue + } + + if ( + step.depth === currentStep.depth && + step?.definition?.async === true + ) { + currentStepsIsAsync = true + break + } + } + } if ( !(isNotStarted || isFinished || isWaitingToCompensate) && @@ -284,7 +296,7 @@ export class InMemoryDistributedTransactionStorage const execution = trx.execution as TransactionFlow if (!idempotent) { - const isFailedOrReverted = failedStates.includes(execution.state) + const isFailedOrReverted = failedStates.has(execution.state) const isDone = execution.state === TransactionState.DONE @@ -331,21 +343,12 @@ export class InMemoryDistributedTransactionStorage */ const { retentionTime } = options ?? {} - const hasFinished = finishedStates.includes(data.flow.state) - - let cachedCheckpoint: TransactionCheckpoint | undefined - const getCheckpoint = async (options?: TransactionOptions) => { - if (!cachedCheckpoint) { - cachedCheckpoint = await this.get(key, options) - } - return cachedCheckpoint - } + const hasFinished = finishedStates.has(data.flow.state) await this.#preventRaceConditionExecutionIfNecessary({ data, key, options, - getCheckpoint, }) // Only store retention time if it's provided @@ -377,12 +380,12 @@ export class InMemoryDistributedTransactionStorage TransactionCheckpoint.mergeCheckpoints(data, storedData) } - const { flow, errors } = data + const { flow, context, errors } = data this.storage[key] = { - flow, - context: {} as TransactionContext, - errors, + flow: JSON.parse(JSON.stringify(flow)), + context: JSON.parse(JSON.stringify(context)), + errors: [...errors], } as TransactionCheckpoint // Optimize DB operations - only perform when necessary @@ -412,14 +415,10 @@ export class InMemoryDistributedTransactionStorage data, key, options, - getCheckpoint, }: { data: TransactionCheckpoint key: string options?: TransactionOptions - getCheckpoint: ( - options: TransactionOptions - ) => Promise }) { const isInitialCheckpoint = [TransactionState.NOT_STARTED].includes( data.flow.state @@ -441,7 +440,7 @@ export class InMemoryDistributedTransactionStorage } as Parameters[1] data_ = - (await getCheckpoint(getOptions as TransactionOptions)) ?? + (await this.get(key, getOptions as TransactionOptions)) ?? ({ flow: {} } as TransactionCheckpoint) } @@ -457,7 +456,7 @@ export class InMemoryDistributedTransactionStorage ? latestStep.compensate?.state : latestStep.invoke?.state - const shouldSkip = doneStates.includes(latestState) + const shouldSkip = doneStates.has(latestState) if (shouldSkip) { throw new SkipStepAlreadyFinishedError( @@ -750,7 +749,7 @@ export class InMemoryDistributedTransactionStorage }, updated_at: { $lte: raw( - (alias) => + (_alias) => `CURRENT_TIMESTAMP - (INTERVAL '1 second' * "retention_time")` ), }, diff --git a/packages/modules/workflow-engine-redis/src/utils/workflow-orchestrator-storage.ts b/packages/modules/workflow-engine-redis/src/utils/workflow-orchestrator-storage.ts index 6cf1295894..3e4584571c 100644 --- a/packages/modules/workflow-engine-redis/src/utils/workflow-orchestrator-storage.ts +++ b/packages/modules/workflow-engine-redis/src/utils/workflow-orchestrator-storage.ts @@ -37,22 +37,25 @@ enum JobType { const THIRTY_MINUTES_IN_MS = 1000 * 60 * 30 const REPEATABLE_CLEARER_JOB_ID = "clear-expired-executions" -const doneStates = [ +const doneStates = new Set([ TransactionStepState.DONE, TransactionStepState.REVERTED, TransactionStepState.FAILED, TransactionStepState.SKIPPED, TransactionStepState.SKIPPED_FAILURE, TransactionStepState.TIMEOUT, -] +]) -const finishedStates = [ +const finishedStates = new Set([ TransactionState.DONE, TransactionState.FAILED, TransactionState.REVERTED, -] +]) -const failedStates = [TransactionState.FAILED, TransactionState.REVERTED] +const failedStates = new Set([ + TransactionState.FAILED, + TransactionState.REVERTED, +]) export class RedisDistributedTransactionStorage implements IDistributedTransactionStorage, IDistributedSchedulerStorage { @@ -280,7 +283,7 @@ export class RedisDistributedTransactionStorage const isNotStarted = data.flow.state === TransactionState.NOT_STARTED const asyncVersion = data.flow._v - const isFinished = finishedStates.includes(data.flow.state) + const isFinished = finishedStates.has(data.flow.state) const isWaitingToCompensate = data.flow.state === TransactionState.WAITING_TO_COMPENSATE @@ -288,16 +291,16 @@ export class RedisDistributedTransactionStorage const stepsArray = Object.values(data.flow.steps) as TransactionStep[] let currentStep!: TransactionStep + let currentStepsIsAsync = false const targetStates = isFlowInvoking - ? [ + ? new Set([ TransactionStepState.INVOKING, TransactionStepState.DONE, TransactionStepState.FAILED, - ] - : [TransactionStepState.COMPENSATING] + ]) + : new Set([TransactionStepState.COMPENSATING]) - // Find the current step from the end for (let i = stepsArray.length - 1; i >= 0; i--) { const step = stepsArray[i] @@ -305,20 +308,29 @@ export class RedisDistributedTransactionStorage break } - const isTargetState = targetStates.includes(step.invoke?.state) + const isTargetState = targetStates.has(step.invoke?.state) - if (isTargetState) { + if (isTargetState && !currentStep) { currentStep = step break } } - const currentStepsIsAsync = currentStep - ? stepsArray.some( - (step) => - step?.definition?.async === true && step.depth === currentStep.depth - ) - : false + if (currentStep) { + for (const step of stepsArray) { + if (step.id === "_root") { + continue + } + + if ( + step.depth === currentStep.depth && + step?.definition?.async === true + ) { + currentStepsIsAsync = true + break + } + } + } if ( !(isNotStarted || isFinished || isWaitingToCompensate) && @@ -395,29 +407,36 @@ export class RedisDistributedTransactionStorage async get( key: string, - options?: TransactionOptions & { isCancelling?: boolean } + options?: TransactionOptions & { + isCancelling?: boolean + _cachedRawData?: string | null + } ): Promise { const [_, workflowId, transactionId] = key.split(":") - const trx = await this.workflowExecutionService_ - .list( - { - workflow_id: workflowId, - transaction_id: transactionId, - }, - { - select: ["execution", "context"], - order: { - id: "desc", + + const [trx, rawData] = await promiseAll([ + this.workflowExecutionService_ + .list( + { + workflow_id: workflowId, + transaction_id: transactionId, }, - take: 1, - } - ) - .then((trx) => trx[0]) - .catch(() => undefined) + { + select: ["execution", "context"], + order: { + id: "desc", + }, + take: 1, + } + ) + .then((trx) => trx[0]) + .catch(() => undefined), + options?._cachedRawData !== undefined + ? Promise.resolve(options._cachedRawData) + : this.redisClient.get(key), + ]) if (trx) { - const rawData = await this.redisClient.get(key) - let flow!: TransactionFlow, errors!: TransactionStepError[] if (rawData) { const data = JSON.parse(rawData) @@ -429,7 +448,7 @@ export class RedisDistributedTransactionStorage const execution = trx.execution as TransactionFlow if (!idempotent) { - const isFailedOrReverted = failedStates.includes(execution.state) + const isFailedOrReverted = failedStates.has(execution.state) const isDone = execution.state === TransactionState.DONE @@ -470,6 +489,8 @@ export class RedisDistributedTransactionStorage let lockAcquired = false + let storedData: TransactionCheckpoint | undefined + if (data.flow._v) { lockAcquired = await this.#acquireLock(key) @@ -477,7 +498,7 @@ export class RedisDistributedTransactionStorage throw new Error("Lock not acquired") } - const storedData = await this.get(key, { + storedData = await this.get(key, { isCancelling: !!data.flow.cancelledAt, } as any) @@ -485,21 +506,13 @@ export class RedisDistributedTransactionStorage } try { - const hasFinished = finishedStates.includes(data.flow.state) - - let cachedCheckpoint: TransactionCheckpoint | undefined - const getCheckpoint = async (options?: TransactionOptions) => { - if (!cachedCheckpoint) { - cachedCheckpoint = await this.get(key, options) - } - return cachedCheckpoint - } + const hasFinished = finishedStates.has(data.flow.state) await this.#preventRaceConditionExecutionIfNecessary({ data: data, key, options, - getCheckpoint, + storedData, }) // Only set if not exists @@ -514,11 +527,10 @@ export class RedisDistributedTransactionStorage } const execPipeline = () => { - const lightData_ = { + const stringifiedData = JSON.stringify({ errors: data.errors, flow: data.flow, - } - const stringifiedData = JSON.stringify(lightData_) + }) const pipeline = this.redisClient.pipeline() @@ -558,17 +570,15 @@ export class RedisDistributedTransactionStorage }) } + // Parallelize DB and Redis operations for better performance if (hasFinished && !retentionTime) { if (!data.flow.metadata?.parentStepIdempotencyKey) { - await this.deleteFromDb(data) - await execPipeline() + await promiseAll([this.deleteFromDb(data), execPipeline()]) } else { - await this.saveToDb(data, retentionTime) - await execPipeline() + await promiseAll([this.saveToDb(data, retentionTime), execPipeline()]) } } else { - await this.saveToDb(data, retentionTime) - await execPipeline() + await promiseAll([this.saveToDb(data, retentionTime), execPipeline()]) } return data as TransactionCheckpoint @@ -801,14 +811,12 @@ export class RedisDistributedTransactionStorage data, key, options, - getCheckpoint, + storedData, }: { data: TransactionCheckpoint key: string options?: TransactionOptions - getCheckpoint: ( - options: TransactionOptions - ) => Promise + storedData?: TransactionCheckpoint }) { const isInitialCheckpoint = [TransactionState.NOT_STARTED].includes( data.flow.state @@ -819,19 +827,24 @@ export class RedisDistributedTransactionStorage */ const currentFlow = data.flow - const rawData = await this.redisClient.get(key) - let data_ = {} as TransactionCheckpoint - if (rawData) { - data_ = JSON.parse(rawData) - } else { - const getOptions = { - ...options, - isCancelling: !!data.flow.cancelledAt, - } as Parameters[1] + let data_ = storedData ?? ({} as TransactionCheckpoint) - data_ = - (await getCheckpoint(getOptions as TransactionOptions)) ?? - ({ flow: {} } as TransactionCheckpoint) + if (!storedData) { + const rawData = await this.redisClient.get(key) + if (rawData) { + data_ = JSON.parse(rawData) + } else { + // Pass cached raw data to avoid redundant Redis fetch + const getOptions = { + ...options, + isCancelling: !!data.flow.cancelledAt, + _cachedRawData: rawData, + } as Parameters[1] + + data_ = + (await this.get(key, getOptions as TransactionOptions)) ?? + ({ flow: {} } as TransactionCheckpoint) + } } const { flow: latestUpdatedFlow } = data_ @@ -846,7 +859,7 @@ export class RedisDistributedTransactionStorage ? latestStep.compensate?.state : latestStep.invoke?.state - const shouldSkip = doneStates.includes(latestState) + const shouldSkip = doneStates.has(latestState) if (shouldSkip) { throw new SkipStepAlreadyFinishedError( @@ -891,7 +904,7 @@ export class RedisDistributedTransactionStorage }, updated_at: { $lte: raw( - (alias) => + (_alias) => `CURRENT_TIMESTAMP - (INTERVAL '1 second' * "retention_time")` ), },