chore(): Add retry strategy to database connection (#12713)

RESOLVES FRMW-2978

**What**
Add retry mechanism to database connection management to prevent failing when the server start faster than what makes the connection available
This commit is contained in:
Adrien de Peretti
2025-06-13 09:18:49 +02:00
committed by GitHub
parent 44d1d18689
commit cbf3644eb7
9 changed files with 333 additions and 45 deletions

View File

@@ -0,0 +1,8 @@
---
"@medusajs/medusa": patch
"@medusajs/test-utils": patch
"@medusajs/framework": patch
"@medusajs/utils": patch
---
chore(): Add retry strategy to database connection

View File

@@ -1,13 +1,19 @@
import { ContainerRegistrationKeys, ModulesSdkUtils } from "@medusajs/utils"
import {
ContainerRegistrationKeys,
ModulesSdkUtils,
retryExecution,
stringifyCircular,
} from "@medusajs/utils"
import { asValue } from "awilix"
import { container } from "../container"
import { configManager } from "../config"
import { container } from "../container"
import { logger } from "../logger"
/**
* Initialize a knex connection that can then be shared to any resources if needed
*/
export function pgConnectionLoader(): ReturnType<
typeof ModulesSdkUtils.createPgConnection
export async function pgConnectionLoader(): Promise<
ReturnType<typeof ModulesSdkUtils.createPgConnection>
> {
if (container.hasRegistration(ContainerRegistrationKeys.PG_CONNECTION)) {
return container.resolve(
@@ -45,6 +51,31 @@ export function pgConnectionLoader(): ReturnType<
},
})
const maxRetries = process.env.__MEDUSA_DB_CONNECTION_MAX_RETRIES
? parseInt(process.env.__MEDUSA_DB_CONNECTION_MAX_RETRIES)
: 5
const retryDelay = process.env.__MEDUSA_DB_CONNECTION_RETRY_DELAY
? parseInt(process.env.__MEDUSA_DB_CONNECTION_RETRY_DELAY)
: 1000
await retryExecution(
async () => {
await pgConnection.raw("SELECT 1")
},
{
maxRetries,
retryDelay,
onRetry: (error) => {
logger.warn(
`Pg connection failed to connect to the database. Retrying...\n${stringifyCircular(
error
)}`
)
},
}
)
container.register(
ContainerRegistrationKeys.PG_CONNECTION,
asValue(pgConnection)

View File

@@ -24,7 +24,6 @@ import {
isPresent,
upperCaseFirst,
} from "@medusajs/utils"
import { pgConnectionLoader } from "./database"
import type { Knex } from "@mikro-orm/knex"
import { aliasTo, asValue } from "awilix"
@@ -125,11 +124,8 @@ export class MedusaAppLoader {
const sharedResourcesConfig: ModuleServiceInitializeOptions = {
database: {
clientUrl:
(
injectedDependencies[
ContainerRegistrationKeys.PG_CONNECTION
] as ReturnType<typeof pgConnectionLoader>
)?.client?.config?.connection?.connectionString ??
injectedDependencies[ContainerRegistrationKeys.PG_CONNECTION]?.client
?.config?.connection?.connectionString ??
configManager.config.projectConfig.databaseUrl,
driverOptions: configManager.config.projectConfig.databaseDriverOptions,
pool: pool,

View File

@@ -0,0 +1,168 @@
import { retryExecution } from "../retry-execution"
import { setTimeout } from "timers/promises"
// Mock setTimeout to avoid waiting in tests
jest.mock("timers/promises", () => ({
setTimeout: jest.fn().mockResolvedValue(undefined),
}))
describe("retryExecution", () => {
beforeEach(() => {
jest.clearAllMocks()
})
it("should return the result of the function on the first try", async () => {
const fn = jest.fn().mockResolvedValue("success")
const result = await retryExecution(fn, {
maxRetries: 3,
retryDelay: 100,
shouldRetry: () => true,
})
expect(result).toBe("success")
expect(fn).toHaveBeenCalledTimes(1)
expect(setTimeout).not.toHaveBeenCalled()
})
it("should retry the function and succeed", async () => {
const fn = jest
.fn()
.mockRejectedValueOnce(new Error("failure"))
.mockRejectedValueOnce(new Error("failure"))
.mockResolvedValue("success")
const result = await retryExecution(fn, {
maxRetries: 3,
retryDelay: 100,
shouldRetry: () => true,
})
expect(result).toBe("success")
expect(fn).toHaveBeenCalledTimes(3)
expect(setTimeout).toHaveBeenCalledTimes(2)
expect(setTimeout).toHaveBeenCalledWith(100)
})
it("should throw an error after max retries", async () => {
const error = new Error("failure")
const fn = jest.fn().mockRejectedValue(error)
const maxRetries = 3
await expect(
retryExecution(fn, {
maxRetries,
retryDelay: 100,
shouldRetry: () => true,
})
).rejects.toThrow(error)
expect(fn).toHaveBeenCalledTimes(maxRetries)
expect(setTimeout).toHaveBeenCalledTimes(maxRetries - 1)
})
it("should not retry if shouldRetry returns false", async () => {
const error = new Error("non-retryable error")
const fn = jest.fn().mockRejectedValue(error)
const shouldRetry = jest.fn().mockReturnValue(false)
await expect(
retryExecution(fn, {
maxRetries: 3,
retryDelay: 100,
shouldRetry,
})
).rejects.toThrow(error)
expect(fn).toHaveBeenCalledTimes(1)
expect(shouldRetry).toHaveBeenCalledWith(error)
expect(setTimeout).not.toHaveBeenCalled()
})
it("should use default options if none are provided", async () => {
const error = new Error("failure")
const fn = jest.fn().mockRejectedValue(error)
await expect(retryExecution(fn)).rejects.toThrow(error)
// Default maxRetries is 5
expect(fn).toHaveBeenCalledTimes(5)
// Default retryDelay is 1000
expect(setTimeout).toHaveBeenCalledTimes(4)
expect(setTimeout).toHaveBeenCalledWith(1000)
})
it("should handle async functions correctly", async () => {
const asyncFn = jest.fn(async () => {
await new Promise((resolve) => setImmediate(resolve))
return "async success"
})
const result = await retryExecution(asyncFn, {
maxRetries: 3,
retryDelay: 100,
shouldRetry: () => true,
})
expect(result).toBe("async success")
expect(asyncFn).toHaveBeenCalledTimes(1)
})
it("should retry an async function and succeed", async () => {
const asyncFn = jest
.fn()
.mockRejectedValueOnce(new Error("failure"))
.mockResolvedValue("async success")
const result = await retryExecution(asyncFn, {
maxRetries: 3,
retryDelay: 100,
shouldRetry: () => true,
})
expect(result).toBe("async success")
expect(asyncFn).toHaveBeenCalledTimes(2)
expect(setTimeout).toHaveBeenCalledTimes(1)
})
it("should use a function for retryDelay if provided", async () => {
const fn = jest
.fn()
.mockRejectedValueOnce(new Error("failure"))
.mockRejectedValueOnce(new Error("failure"))
.mockResolvedValue("success")
const retryDelayFn = jest.fn((retries, maxRetries) => {
return retries * 50
})
const result = await retryExecution(fn, {
maxRetries: 3,
retryDelay: retryDelayFn,
shouldRetry: () => true,
})
expect(result).toBe("success")
expect(fn).toHaveBeenCalledTimes(3)
expect(retryDelayFn).toHaveBeenCalledTimes(2)
expect(retryDelayFn).toHaveBeenCalledWith(1, 3)
expect(retryDelayFn).toHaveBeenCalledWith(2, 3)
expect(setTimeout).toHaveBeenCalledTimes(2)
expect(setTimeout).toHaveBeenNthCalledWith(1, 50)
expect(setTimeout).toHaveBeenNthCalledWith(2, 100)
})
it("should throw the final error if maxRetries is 0", async () => {
const fn = jest.fn().mockResolvedValue("success")
const maxRetries = 0
await expect(
retryExecution(fn, {
maxRetries,
retryDelay: 100,
shouldRetry: () => true,
})
).rejects.toThrow("Retry execution failed. Max retries reached.")
expect(fn).not.toHaveBeenCalled()
})
})

View File

@@ -89,3 +89,4 @@ export * from "./validate-handle"
export * from "./validate-module-name"
export * from "./wrap-handler"
export * from "./normalize-csv-value"
export * from "./retry-execution"

View File

@@ -0,0 +1,61 @@
import { setTimeout } from "timers/promises"
const ONE_SECOND = 1000
/**
* Retry the function to be executed until it succeeds or the max retries is reached.
*
* @param fn - The function to be executed.
* @param options - The options for the retry execution.
* @param options.shouldRetry - The function to determine if the function should be retried based on the error argument.
* @param options.maxRetries - The maximum number of retries.
* @param options.retryDelay - The delay between retries. If a function is provided, it will be called with the current retry count and the maximum number of retries and should return the delay in milliseconds.
* @param options.onRetry - The function to be called when the function fails to execute.
* @returns The result of the function.
*/
export async function retryExecution<T>(
fn: () => Promise<T>,
options: {
shouldRetry?: (error: any) => boolean
onRetry?: (error: any) => void
maxRetries?: number
retryDelay?: number | ((retries: number, maxRetries: number) => number)
} = {
shouldRetry: () => true,
onRetry: () => {},
maxRetries: 5,
retryDelay: ONE_SECOND,
}
): Promise<T> {
let { shouldRetry, onRetry, maxRetries, retryDelay } = options
shouldRetry = shouldRetry ?? (() => true)
maxRetries = maxRetries ?? 5
onRetry = onRetry ?? (() => {})
const retryDelayFn =
typeof retryDelay === "function"
? retryDelay
: (retries: number, maxRetries: number) => retryDelay
let retries = 0
while (retries < maxRetries) {
try {
return await fn()
} catch (error) {
if (!shouldRetry(error as Error)) {
throw error
}
retries++
if (retries === maxRetries) {
throw error
}
onRetry(error)
await setTimeout(retryDelayFn(retries, maxRetries))
}
}
// This should never be reached
throw new Error("Retry execution failed. Max retries reached.")
}

View File

@@ -1,7 +1,7 @@
import { ModuleServiceInitializeOptions } from "@medusajs/types"
import { Filter as MikroORMFilter } from "@mikro-orm/core"
import { TSMigrationGenerator } from "@mikro-orm/migrations"
import { isString } from "../../common"
import { isString, retryExecution, stringifyCircular } from "../../common"
import { normalizeMigrationSQL } from "../utils"
type FilterDef = Parameters<typeof MikroORMFilter>[0]
@@ -83,38 +83,61 @@ export async function mikroOrmCreateConnection(
}
const { MikroORM, defineConfig } = await import("@mikro-orm/postgresql")
return await MikroORM.init(
defineConfig({
discovery: { disableDynamicFileAccess: true, warnWhenNoEntities: false },
entities,
debug: database.debug ?? process.env.NODE_ENV?.startsWith("dev") ?? false,
baseDir: process.cwd(),
clientUrl,
schema,
driverOptions,
tsNode: process.env.APP_ENV === "development",
filters: database.filters ?? {},
assign: {
convertCustomTypes: true,
const mikroOrmConfig = defineConfig({
discovery: { disableDynamicFileAccess: true, warnWhenNoEntities: false },
entities,
debug: database.debug ?? process.env.NODE_ENV?.startsWith("dev") ?? false,
baseDir: process.cwd(),
clientUrl,
schema,
driverOptions,
tsNode: process.env.APP_ENV === "development",
filters: database.filters ?? {},
assign: {
convertCustomTypes: true,
},
migrations: {
disableForeignKeys: false,
path: pathToMigrations,
snapshotName: database.snapshotName,
generator: CustomTsMigrationGenerator,
silent: !(
database.debug ??
process.env.NODE_ENV?.startsWith("dev") ??
false
),
},
schemaGenerator: {
disableForeignKeys: false,
},
pool: {
min: 2,
...database.pool,
},
})
const maxRetries = process.env.__MEDUSA_DB_CONNECTION_MAX_RETRIES
? parseInt(process.env.__MEDUSA_DB_CONNECTION_MAX_RETRIES)
: 5
const retryDelay = process.env.__MEDUSA_DB_CONNECTION_RETRY_DELAY
? parseInt(process.env.__MEDUSA_DB_CONNECTION_RETRY_DELAY)
: 1000
return await retryExecution(
async () => {
return await MikroORM.init(mikroOrmConfig)
},
{
maxRetries,
retryDelay,
onRetry: (error) => {
console.warn(
`MikroORM failed to connect to the database. Retrying...\n${stringifyCircular(
error
)}`
)
},
migrations: {
disableForeignKeys: false,
path: pathToMigrations,
snapshotName: database.snapshotName,
generator: CustomTsMigrationGenerator,
silent: !(
database.debug ??
process.env.NODE_ENV?.startsWith("dev") ??
false
),
},
schemaGenerator: {
disableForeignKeys: false,
},
pool: {
min: 2,
...database.pool,
},
})
}
)
}

View File

@@ -15,7 +15,7 @@ export async function initDb() {
"@medusajs/framework"
)
const pgConnection = pgConnectionLoader()
const pgConnection = await pgConnectionLoader()
await featureFlagsLoader()
return pgConnection

View File

@@ -129,7 +129,7 @@ export async function initializeContainer(
[ContainerRegistrationKeys.REMOTE_QUERY]: asValue(null),
})
pgConnectionLoader()
await pgConnectionLoader()
return container
}