diff --git a/.changeset/cold-cats-cheat.md b/.changeset/cold-cats-cheat.md new file mode 100644 index 0000000000..6367cfc500 --- /dev/null +++ b/.changeset/cold-cats-cheat.md @@ -0,0 +1,8 @@ +--- +"@medusajs/medusa": patch +"@medusajs/test-utils": patch +"@medusajs/framework": patch +"@medusajs/utils": patch +--- + +chore(): Add retry strategy to database connection diff --git a/packages/core/framework/src/database/pg-connection-loader.ts b/packages/core/framework/src/database/pg-connection-loader.ts index e4d2b1e1c2..33d2279903 100644 --- a/packages/core/framework/src/database/pg-connection-loader.ts +++ b/packages/core/framework/src/database/pg-connection-loader.ts @@ -1,13 +1,19 @@ -import { ContainerRegistrationKeys, ModulesSdkUtils } from "@medusajs/utils" +import { + ContainerRegistrationKeys, + ModulesSdkUtils, + retryExecution, + stringifyCircular, +} from "@medusajs/utils" import { asValue } from "awilix" -import { container } from "../container" import { configManager } from "../config" +import { container } from "../container" +import { logger } from "../logger" /** * Initialize a knex connection that can then be shared to any resources if needed */ -export function pgConnectionLoader(): ReturnType< - typeof ModulesSdkUtils.createPgConnection +export async function pgConnectionLoader(): Promise< + ReturnType > { if (container.hasRegistration(ContainerRegistrationKeys.PG_CONNECTION)) { return container.resolve( @@ -45,6 +51,31 @@ export function pgConnectionLoader(): ReturnType< }, }) + const maxRetries = process.env.__MEDUSA_DB_CONNECTION_MAX_RETRIES + ? parseInt(process.env.__MEDUSA_DB_CONNECTION_MAX_RETRIES) + : 5 + + const retryDelay = process.env.__MEDUSA_DB_CONNECTION_RETRY_DELAY + ? parseInt(process.env.__MEDUSA_DB_CONNECTION_RETRY_DELAY) + : 1000 + + await retryExecution( + async () => { + await pgConnection.raw("SELECT 1") + }, + { + maxRetries, + retryDelay, + onRetry: (error) => { + logger.warn( + `Pg connection failed to connect to the database. Retrying...\n${stringifyCircular( + error + )}` + ) + }, + } + ) + container.register( ContainerRegistrationKeys.PG_CONNECTION, asValue(pgConnection) diff --git a/packages/core/framework/src/medusa-app-loader.ts b/packages/core/framework/src/medusa-app-loader.ts index e7cfdca6c3..f074f01a1d 100644 --- a/packages/core/framework/src/medusa-app-loader.ts +++ b/packages/core/framework/src/medusa-app-loader.ts @@ -24,7 +24,6 @@ import { isPresent, upperCaseFirst, } from "@medusajs/utils" -import { pgConnectionLoader } from "./database" import type { Knex } from "@mikro-orm/knex" import { aliasTo, asValue } from "awilix" @@ -125,11 +124,8 @@ export class MedusaAppLoader { const sharedResourcesConfig: ModuleServiceInitializeOptions = { database: { clientUrl: - ( - injectedDependencies[ - ContainerRegistrationKeys.PG_CONNECTION - ] as ReturnType - )?.client?.config?.connection?.connectionString ?? + injectedDependencies[ContainerRegistrationKeys.PG_CONNECTION]?.client + ?.config?.connection?.connectionString ?? configManager.config.projectConfig.databaseUrl, driverOptions: configManager.config.projectConfig.databaseDriverOptions, pool: pool, diff --git a/packages/core/utils/src/common/__tests__/retry-execution.spec.ts b/packages/core/utils/src/common/__tests__/retry-execution.spec.ts new file mode 100644 index 0000000000..65b30c6a61 --- /dev/null +++ b/packages/core/utils/src/common/__tests__/retry-execution.spec.ts @@ -0,0 +1,168 @@ +import { retryExecution } from "../retry-execution" +import { setTimeout } from "timers/promises" + +// Mock setTimeout to avoid waiting in tests +jest.mock("timers/promises", () => ({ + setTimeout: jest.fn().mockResolvedValue(undefined), +})) + +describe("retryExecution", () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it("should return the result of the function on the first try", async () => { + const fn = jest.fn().mockResolvedValue("success") + const result = await retryExecution(fn, { + maxRetries: 3, + retryDelay: 100, + shouldRetry: () => true, + }) + + expect(result).toBe("success") + expect(fn).toHaveBeenCalledTimes(1) + expect(setTimeout).not.toHaveBeenCalled() + }) + + it("should retry the function and succeed", async () => { + const fn = jest + .fn() + .mockRejectedValueOnce(new Error("failure")) + .mockRejectedValueOnce(new Error("failure")) + .mockResolvedValue("success") + + const result = await retryExecution(fn, { + maxRetries: 3, + retryDelay: 100, + shouldRetry: () => true, + }) + + expect(result).toBe("success") + expect(fn).toHaveBeenCalledTimes(3) + expect(setTimeout).toHaveBeenCalledTimes(2) + expect(setTimeout).toHaveBeenCalledWith(100) + }) + + it("should throw an error after max retries", async () => { + const error = new Error("failure") + const fn = jest.fn().mockRejectedValue(error) + const maxRetries = 3 + + await expect( + retryExecution(fn, { + maxRetries, + retryDelay: 100, + shouldRetry: () => true, + }) + ).rejects.toThrow(error) + + expect(fn).toHaveBeenCalledTimes(maxRetries) + expect(setTimeout).toHaveBeenCalledTimes(maxRetries - 1) + }) + + it("should not retry if shouldRetry returns false", async () => { + const error = new Error("non-retryable error") + const fn = jest.fn().mockRejectedValue(error) + const shouldRetry = jest.fn().mockReturnValue(false) + + await expect( + retryExecution(fn, { + maxRetries: 3, + retryDelay: 100, + shouldRetry, + }) + ).rejects.toThrow(error) + + expect(fn).toHaveBeenCalledTimes(1) + expect(shouldRetry).toHaveBeenCalledWith(error) + expect(setTimeout).not.toHaveBeenCalled() + }) + + it("should use default options if none are provided", async () => { + const error = new Error("failure") + const fn = jest.fn().mockRejectedValue(error) + + await expect(retryExecution(fn)).rejects.toThrow(error) + + // Default maxRetries is 5 + expect(fn).toHaveBeenCalledTimes(5) + // Default retryDelay is 1000 + expect(setTimeout).toHaveBeenCalledTimes(4) + expect(setTimeout).toHaveBeenCalledWith(1000) + }) + + it("should handle async functions correctly", async () => { + const asyncFn = jest.fn(async () => { + await new Promise((resolve) => setImmediate(resolve)) + return "async success" + }) + + const result = await retryExecution(asyncFn, { + maxRetries: 3, + retryDelay: 100, + shouldRetry: () => true, + }) + + expect(result).toBe("async success") + expect(asyncFn).toHaveBeenCalledTimes(1) + }) + + it("should retry an async function and succeed", async () => { + const asyncFn = jest + .fn() + .mockRejectedValueOnce(new Error("failure")) + .mockResolvedValue("async success") + + const result = await retryExecution(asyncFn, { + maxRetries: 3, + retryDelay: 100, + shouldRetry: () => true, + }) + + expect(result).toBe("async success") + expect(asyncFn).toHaveBeenCalledTimes(2) + expect(setTimeout).toHaveBeenCalledTimes(1) + }) + + it("should use a function for retryDelay if provided", async () => { + const fn = jest + .fn() + .mockRejectedValueOnce(new Error("failure")) + .mockRejectedValueOnce(new Error("failure")) + .mockResolvedValue("success") + + const retryDelayFn = jest.fn((retries, maxRetries) => { + return retries * 50 + }) + + const result = await retryExecution(fn, { + maxRetries: 3, + retryDelay: retryDelayFn, + shouldRetry: () => true, + }) + + expect(result).toBe("success") + expect(fn).toHaveBeenCalledTimes(3) + expect(retryDelayFn).toHaveBeenCalledTimes(2) + expect(retryDelayFn).toHaveBeenCalledWith(1, 3) + expect(retryDelayFn).toHaveBeenCalledWith(2, 3) + expect(setTimeout).toHaveBeenCalledTimes(2) + expect(setTimeout).toHaveBeenNthCalledWith(1, 50) + expect(setTimeout).toHaveBeenNthCalledWith(2, 100) + }) + + it("should throw the final error if maxRetries is 0", async () => { + const fn = jest.fn().mockResolvedValue("success") + const maxRetries = 0 + + await expect( + retryExecution(fn, { + maxRetries, + retryDelay: 100, + shouldRetry: () => true, + }) + ).rejects.toThrow("Retry execution failed. Max retries reached.") + + expect(fn).not.toHaveBeenCalled() + }) +}) diff --git a/packages/core/utils/src/common/index.ts b/packages/core/utils/src/common/index.ts index 4dc8e03db6..c593b83634 100644 --- a/packages/core/utils/src/common/index.ts +++ b/packages/core/utils/src/common/index.ts @@ -89,3 +89,4 @@ export * from "./validate-handle" export * from "./validate-module-name" export * from "./wrap-handler" export * from "./normalize-csv-value" +export * from "./retry-execution" diff --git a/packages/core/utils/src/common/retry-execution.ts b/packages/core/utils/src/common/retry-execution.ts new file mode 100644 index 0000000000..eb3ed0a9a4 --- /dev/null +++ b/packages/core/utils/src/common/retry-execution.ts @@ -0,0 +1,61 @@ +import { setTimeout } from "timers/promises" + +const ONE_SECOND = 1000 + +/** + * Retry the function to be executed until it succeeds or the max retries is reached. + * + * @param fn - The function to be executed. + * @param options - The options for the retry execution. + * @param options.shouldRetry - The function to determine if the function should be retried based on the error argument. + * @param options.maxRetries - The maximum number of retries. + * @param options.retryDelay - The delay between retries. If a function is provided, it will be called with the current retry count and the maximum number of retries and should return the delay in milliseconds. + * @param options.onRetry - The function to be called when the function fails to execute. + * @returns The result of the function. + */ +export async function retryExecution( + fn: () => Promise, + options: { + shouldRetry?: (error: any) => boolean + onRetry?: (error: any) => void + maxRetries?: number + retryDelay?: number | ((retries: number, maxRetries: number) => number) + } = { + shouldRetry: () => true, + onRetry: () => {}, + maxRetries: 5, + retryDelay: ONE_SECOND, + } +): Promise { + let { shouldRetry, onRetry, maxRetries, retryDelay } = options + shouldRetry = shouldRetry ?? (() => true) + maxRetries = maxRetries ?? 5 + onRetry = onRetry ?? (() => {}) + + const retryDelayFn = + typeof retryDelay === "function" + ? retryDelay + : (retries: number, maxRetries: number) => retryDelay + + let retries = 0 + while (retries < maxRetries) { + try { + return await fn() + } catch (error) { + if (!shouldRetry(error as Error)) { + throw error + } + retries++ + if (retries === maxRetries) { + throw error + } + + onRetry(error) + + await setTimeout(retryDelayFn(retries, maxRetries)) + } + } + + // This should never be reached + throw new Error("Retry execution failed. Max retries reached.") +} diff --git a/packages/core/utils/src/dal/mikro-orm/mikro-orm-create-connection.ts b/packages/core/utils/src/dal/mikro-orm/mikro-orm-create-connection.ts index 99a892b88f..4be5e90baf 100644 --- a/packages/core/utils/src/dal/mikro-orm/mikro-orm-create-connection.ts +++ b/packages/core/utils/src/dal/mikro-orm/mikro-orm-create-connection.ts @@ -1,7 +1,7 @@ import { ModuleServiceInitializeOptions } from "@medusajs/types" import { Filter as MikroORMFilter } from "@mikro-orm/core" import { TSMigrationGenerator } from "@mikro-orm/migrations" -import { isString } from "../../common" +import { isString, retryExecution, stringifyCircular } from "../../common" import { normalizeMigrationSQL } from "../utils" type FilterDef = Parameters[0] @@ -83,38 +83,61 @@ export async function mikroOrmCreateConnection( } const { MikroORM, defineConfig } = await import("@mikro-orm/postgresql") - return await MikroORM.init( - defineConfig({ - discovery: { disableDynamicFileAccess: true, warnWhenNoEntities: false }, - entities, - debug: database.debug ?? process.env.NODE_ENV?.startsWith("dev") ?? false, - baseDir: process.cwd(), - clientUrl, - schema, - driverOptions, - tsNode: process.env.APP_ENV === "development", - filters: database.filters ?? {}, - assign: { - convertCustomTypes: true, + const mikroOrmConfig = defineConfig({ + discovery: { disableDynamicFileAccess: true, warnWhenNoEntities: false }, + entities, + debug: database.debug ?? process.env.NODE_ENV?.startsWith("dev") ?? false, + baseDir: process.cwd(), + clientUrl, + schema, + driverOptions, + tsNode: process.env.APP_ENV === "development", + filters: database.filters ?? {}, + assign: { + convertCustomTypes: true, + }, + migrations: { + disableForeignKeys: false, + path: pathToMigrations, + snapshotName: database.snapshotName, + generator: CustomTsMigrationGenerator, + silent: !( + database.debug ?? + process.env.NODE_ENV?.startsWith("dev") ?? + false + ), + }, + schemaGenerator: { + disableForeignKeys: false, + }, + pool: { + min: 2, + ...database.pool, + }, + }) + + const maxRetries = process.env.__MEDUSA_DB_CONNECTION_MAX_RETRIES + ? parseInt(process.env.__MEDUSA_DB_CONNECTION_MAX_RETRIES) + : 5 + + const retryDelay = process.env.__MEDUSA_DB_CONNECTION_RETRY_DELAY + ? parseInt(process.env.__MEDUSA_DB_CONNECTION_RETRY_DELAY) + : 1000 + + return await retryExecution( + async () => { + return await MikroORM.init(mikroOrmConfig) + }, + { + maxRetries, + retryDelay, + onRetry: (error) => { + console.warn( + `MikroORM failed to connect to the database. Retrying...\n${stringifyCircular( + error + )}` + ) }, - migrations: { - disableForeignKeys: false, - path: pathToMigrations, - snapshotName: database.snapshotName, - generator: CustomTsMigrationGenerator, - silent: !( - database.debug ?? - process.env.NODE_ENV?.startsWith("dev") ?? - false - ), - }, - schemaGenerator: { - disableForeignKeys: false, - }, - pool: { - min: 2, - ...database.pool, - }, - }) + } ) } diff --git a/packages/medusa-test-utils/src/medusa-test-runner-utils/use-db.ts b/packages/medusa-test-utils/src/medusa-test-runner-utils/use-db.ts index 8204d840d4..57704e88f6 100644 --- a/packages/medusa-test-utils/src/medusa-test-runner-utils/use-db.ts +++ b/packages/medusa-test-utils/src/medusa-test-runner-utils/use-db.ts @@ -15,7 +15,7 @@ export async function initDb() { "@medusajs/framework" ) - const pgConnection = pgConnectionLoader() + const pgConnection = await pgConnectionLoader() await featureFlagsLoader() return pgConnection diff --git a/packages/medusa/src/loaders/index.ts b/packages/medusa/src/loaders/index.ts index 9ad7390271..3ca4368ce9 100644 --- a/packages/medusa/src/loaders/index.ts +++ b/packages/medusa/src/loaders/index.ts @@ -129,7 +129,7 @@ export async function initializeContainer( [ContainerRegistrationKeys.REMOTE_QUERY]: asValue(null), }) - pgConnectionLoader() + await pgConnectionLoader() return container }