Feat(): distributed caching (#13435)

RESOLVES CORE-1153

**What**
- This pr mainly lay the foundation the caching layer. It comes with a modules (built in memory cache) and a redis provider.
- Apply caching to few touch point to test

Co-authored-by: Carlos R. L. Rodrigues <37986729+carlos-r-l-rodrigues@users.noreply.github.com>
This commit is contained in:
Adrien de Peretti
2025-09-30 18:19:06 +02:00
committed by GitHub
parent 5b135a41fe
commit b9d6f73320
117 changed files with 5741 additions and 530 deletions
+16
View File
@@ -0,0 +1,16 @@
---
"@medusajs/medusa": patch
"@medusajs/framework": patch
"@medusajs/modules-sdk": patch
"@medusajs/types": patch
"@medusajs/utils": patch
"@medusajs/caching": patch
"@medusajs/event-bus-local": patch
"@medusajs/event-bus-redis": patch
"@medusajs/caching-redis": patch
"@medusajs/core-flows": patch
"@medusajs/pricing": patch
"@medusajs/draft-order": patch
---
Feat(): distributed caching
+22 -21
View File
@@ -13,31 +13,32 @@ packages/*
!packages/admin-next/dashboard
!packages/medusa-payment-stripe
!packages/medusa-payment-paypal
!packages/event-bus-redis
!packages/event-bus-local
!packages/modules/event-bus-redis
!packages/modules/event-bus-local
!packages/medusa-plugin-meilisearch
!packages/medusa-plugin-algolia
!packages/inventory
!packages/stock-location
!packages/cache-redis
!packages/cache-inmemory
!packages/modules/inventory
!packages/modules/stock-location
!packages/modules/cache-redis
!packages/modules/cache-inmemory
!packages/modules/caching
!packages/modules/providers/caching-redis
!packages/create-medusa-app
!packages/product
!packages/locking
!packages/orchestration
!packages/workflows-sdk
!packages/core-flows
!packages/types
!packages/modules/product
!packages/modules/locking
!packages/core/orchestration
!packages/core/workflows-sdk
!packages/core/core-flows
!packages/core/types
!packages/medusa-react
!packages/workflow-engine-redis
!packages/workflow-engine-inmemory
!packages/fulfillment
!packages/fulfillment-manual
!packages/locking-postgres
!packages/locking-redis
!packages/index
!packages/framework
!packages/modules/workflow-engine-redis
!packages/modules/workflow-engine-inmemory
!packages/modules/fulfillment
!packages/modules/providers/fulfillment-manual
!packages/modules/providers/locking-postgres
!packages/modules/providers/locking-redis
!packages/modules/index
!packages/core/framework
**/models/*
+2
View File
@@ -109,6 +109,7 @@ module.exports = {
"./packages/modules/event-bus-redis/tsconfig.spec.json",
"./packages/modules/cache-redis/tsconfig.spec.json",
"./packages/modules/cache-inmemory/tsconfig.spec.json",
"./packages/modules/caching/tsconfig.spec.json",
"./packages/modules/workflow-engine-redis/tsconfig.spec.json",
"./packages/modules/workflow-engine-inmemory/tsconfig.spec.json",
"./packages/modules/fulfillment/tsconfig.spec.json",
@@ -141,6 +142,7 @@ module.exports = {
"./packages/modules/providers/payment-stripe/tsconfig.spec.json",
"./packages/modules/providers/locking-postgres/tsconfig.spec.json",
"./packages/modules/providers/locking-redis/tsconfig.spec.json",
"./packages/modules/providers/caching-redis/tsconfig.spec.json",
"./packages/framework/tsconfig.json",
],
+1
View File
@@ -13,6 +13,7 @@
"@medusajs/api-key": "workspace:^",
"@medusajs/auth": "workspace:*",
"@medusajs/cache-inmemory": "workspace:*",
"@medusajs/caching": "workspace:*",
"@medusajs/core-flows": "workspace:^",
"@medusajs/currency": "workspace:^",
"@medusajs/customer": "workspace:^",
@@ -1,8 +1,9 @@
import {
IRegionModuleService,
IStoreModuleService,
MedusaContainer,
} from "@medusajs/framework/types"
import { MedusaError, Modules } from "@medusajs/framework/utils"
import { MedusaError, Modules, useCache } from "@medusajs/framework/utils"
import { StepResponse, createStep } from "@medusajs/framework/workflows-sdk"
/**
@@ -15,6 +16,48 @@ export type FindOneOrAnyRegionStepInput = {
regionId?: string
}
async function fetchRegionById(regionId: string, container: MedusaContainer) {
const service = container.resolve<IRegionModuleService>(Modules.REGION)
const args = [
regionId,
{
relations: ["countries"],
},
] as Parameters<IRegionModuleService["retrieveRegion"]>
return await useCache(async () => service.retrieveRegion(...args), {
container,
key: args,
})
}
async function fetchDefaultStore(container: MedusaContainer) {
const storeModule = container.resolve<IStoreModuleService>(Modules.STORE)
return await useCache(async () => storeModule.listStores(), {
container,
key: "find-one-or-any-region-default-store",
})
}
async function fetchDefaultRegion(
defaultRegionId: string,
container: MedusaContainer
) {
const service = container.resolve<IRegionModuleService>(Modules.REGION)
const args = [
{ id: defaultRegionId },
{ relations: ["countries"] },
] as Parameters<IRegionModuleService["listRegions"]>
return await useCache(async () => service.listRegions(...args), {
container,
key: args,
})
}
export const findOneOrAnyRegionStepId = "find-one-or-any-region"
/**
* This step retrieves a region either by the provided ID or the first region in the first store.
@@ -22,32 +65,24 @@ export const findOneOrAnyRegionStepId = "find-one-or-any-region"
export const findOneOrAnyRegionStep = createStep(
findOneOrAnyRegionStepId,
async (data: FindOneOrAnyRegionStepInput, { container }) => {
const service = container.resolve<IRegionModuleService>(Modules.REGION)
const storeModule = container.resolve<IStoreModuleService>(Modules.STORE)
if (data.regionId) {
try {
const region = await service.retrieveRegion(data.regionId, {
relations: ["countries"],
})
const region = await fetchRegionById(data.regionId, container)
return new StepResponse(region)
} catch (error) {
return new StepResponse(null)
}
}
const [store] = await storeModule.listStores()
const [store] = await fetchDefaultStore(container)
if (!store) {
throw new MedusaError(MedusaError.Types.NOT_FOUND, "Store not found")
}
const [region] = await service.listRegions(
{
id: store.default_region_id,
},
{ relations: ["countries"] }
const [region] = await fetchDefaultRegion(
store.default_region_id!,
container
)
if (!region) {
@@ -1,8 +1,14 @@
import type {
CustomerDTO,
ICustomerModuleService,
MedusaContainer,
} from "@medusajs/framework/types"
import { isDefined, Modules, validateEmail } from "@medusajs/framework/utils"
import {
isDefined,
Modules,
useCache,
validateEmail,
} from "@medusajs/framework/utils"
import { createStep, StepResponse } from "@medusajs/framework/workflows-sdk"
/**
@@ -39,6 +45,40 @@ interface StepCompensateInput {
customerWasCreated: boolean
}
async function fetchCustomerById(
customerId: string,
container: MedusaContainer
): Promise<CustomerDTO> {
const service = container.resolve<ICustomerModuleService>(Modules.CUSTOMER)
return await useCache<CustomerDTO>(
async () => service.retrieveCustomer(customerId),
{
container,
key: ["find-or-create-customer-by-id", customerId],
}
)
}
async function fetchCustomersByEmail(
email: string,
container: MedusaContainer,
hasAccount?: boolean
): Promise<CustomerDTO[]> {
const service = container.resolve<ICustomerModuleService>(Modules.CUSTOMER)
const filters =
hasAccount !== undefined ? { email, has_account: hasAccount } : { email }
return await useCache<CustomerDTO[]>(
async () => service.listCustomers(filters),
{
container,
key: ["find-or-create-customer-by-email", filters],
}
)
}
export const findOrCreateCustomerStepId = "find-or-create-customer"
/**
* This step finds or creates a customer based on the provided ID or email. It prioritizes finding the customer by ID, then by email.
@@ -75,7 +115,7 @@ export const findOrCreateCustomerStep = createStep(
let customerWasCreated = false
if (data.customerId) {
originalCustomer = await service.retrieveCustomer(data.customerId)
originalCustomer = await fetchCustomerById(data.customerId, container)
customerData.customer = originalCustomer
customerData.email = originalCustomer.email
}
@@ -85,9 +125,7 @@ export const findOrCreateCustomerStep = createStep(
let [customer] = originalCustomer
? [originalCustomer]
: await service.listCustomers({
email: validatedEmail,
})
: await fetchCustomersByEmail(validatedEmail, container)
// if NOT a guest customer, return it
if (customer?.has_account) {
@@ -100,10 +138,11 @@ export const findOrCreateCustomerStep = createStep(
}
if (customer && customer.email !== validatedEmail) {
;[customer] = await service.listCustomers({
email: validatedEmail,
has_account: false,
})
;[customer] = await fetchCustomersByEmail(
validatedEmail,
container,
false
)
}
if (!customer) {
@@ -1,9 +1,15 @@
import {
ISalesChannelModuleService,
IStoreModuleService,
MedusaContainer,
SalesChannelDTO,
} from "@medusajs/framework/types"
import { MedusaError, Modules, isDefined } from "@medusajs/framework/utils"
import {
MedusaError,
Modules,
isDefined,
useCache,
} from "@medusajs/framework/utils"
import { StepResponse, createStep } from "@medusajs/framework/workflows-sdk"
/**
@@ -16,6 +22,34 @@ export interface FindSalesChannelStepInput {
salesChannelId?: string | null
}
async function fetchSalesChannel(
salesChannelId: string,
container: MedusaContainer
) {
const salesChannelService = container.resolve<ISalesChannelModuleService>(
Modules.SALES_CHANNEL
)
return await useCache<
Awaited<ReturnType<typeof salesChannelService.retrieveSalesChannel>>
>(async () => salesChannelService.retrieveSalesChannel(salesChannelId), {
container,
key: ["find-sales-channel", salesChannelId],
})
}
async function fetchStore(container: MedusaContainer) {
const storeModule = container.resolve<IStoreModuleService>(Modules.STORE)
return await useCache<Awaited<ReturnType<typeof storeModule.listStores>>>(
async () =>
storeModule.listStores(
{},
{ select: ["id", "default_sales_channel_id"] }
),
{ key: "find-sales-channel-default-store", container }
)
}
export const findSalesChannelStepId = "find-sales-channel"
/**
* This step either retrieves a sales channel either using the ID provided as an input, or, if no ID
@@ -24,26 +58,17 @@ export const findSalesChannelStepId = "find-sales-channel"
export const findSalesChannelStep = createStep(
findSalesChannelStepId,
async (data: FindSalesChannelStepInput, { container }) => {
const salesChannelService = container.resolve<ISalesChannelModuleService>(
Modules.SALES_CHANNEL
)
const storeModule = container.resolve<IStoreModuleService>(Modules.STORE)
let salesChannel: SalesChannelDTO | undefined
if (data.salesChannelId) {
salesChannel = await salesChannelService.retrieveSalesChannel(
data.salesChannelId
)
salesChannel = await fetchSalesChannel(data.salesChannelId, container)
} else if (!isDefined(data.salesChannelId)) {
const [store] = await storeModule.listStores(
{},
{ select: ["default_sales_channel_id"] }
)
const [store] = await fetchStore(container)
if (store?.default_sales_channel_id) {
salesChannel = await salesChannelService.retrieveSalesChannel(
store.default_sales_channel_id
salesChannel = await fetchSalesChannel(
store.default_sales_channel_id,
container
)
}
}
@@ -1,7 +1,6 @@
import type { IPromotionModuleService } from "@medusajs/framework/types"
import {
ContainerRegistrationKeys,
MedusaError,
Modules,
PromotionActions,
} from "@medusajs/framework/utils"
import { createStep, StepResponse } from "@medusajs/framework/workflows-sdk"
@@ -72,9 +71,6 @@ export const getPromotionCodesToApply = createStep(
async (data: GetPromotionCodesToApplyStepInput, { container }) => {
const { promo_codes = [], cart, action = PromotionActions.ADD } = data
const { items = [], shipping_methods = [] } = cart
const promotionService = container.resolve<IPromotionModuleService>(
Modules.PROMOTION
)
const adjustmentCodes: string[] = []
items.concat(shipping_methods).forEach((object) => {
@@ -99,14 +95,23 @@ export const getPromotionCodesToApply = createStep(
action === PromotionActions.ADD ||
action === PromotionActions.REPLACE
) {
const query = container.resolve(ContainerRegistrationKeys.QUERY)
const validPromoCodes: Set<string> = new Set(
promo_codes.length
? (
await promotionService.listPromotions(
{ code: promo_codes },
{ select: ["code"] }
await query.graph(
{
entity: "promotion",
fields: ["id", "code"],
filters: { code: promo_codes },
},
{
cache: {
enable: true,
},
}
)
).map((p) => p.code!)
).data.map((p) => p.code!)
: []
)
@@ -1,4 +1,4 @@
import { Query } from "@medusajs/framework"
import { MedusaContainer, Query } from "@medusajs/framework"
import {
CalculatedPriceSet,
IPricingModuleService,
@@ -75,11 +75,18 @@ async function fetchVariantPriceSets(
variantIds: string[]
): Promise<VariantPriceSetData[]> {
return (
await query.graph({
entity: "variant",
fields: ["id", "price_set.id"],
filters: { id: variantIds },
})
await query.graph(
{
entity: "variant",
fields: ["id", "price_set.id"],
filters: { id: variantIds },
},
{
cache: {
enable: true,
},
}
)
).data
}
@@ -108,7 +115,8 @@ function validateVariantPriceSets(
*/
async function processVariantPriceSets(
pricingService: IPricingModuleService,
items: PriceCalculationItem[]
items: PriceCalculationItem[],
container: MedusaContainer
): Promise<GetVariantPriceSetsStepOutput> {
const result: GetVariantPriceSetsStepOutput = {}
@@ -298,7 +306,8 @@ export const getVariantPriceSetsStep = createStep(
// Use unified processing logic for both input types
const result = await processVariantPriceSets(
pricingModuleService,
calculationItems
calculationItems,
container
)
return new StepResponse(result)
@@ -1,4 +1,3 @@
import type { IPromotionModuleService } from "@medusajs/framework/types"
import {
ContainerRegistrationKeys,
Modules,
@@ -40,9 +39,6 @@ export const updateCartPromotionsStep = createStep(
const remoteQuery = container.resolve(
ContainerRegistrationKeys.REMOTE_QUERY
)
const promotionService = container.resolve<IPromotionModuleService>(
Modules.PROMOTION
)
const existingCartPromotionLinks = await remoteQuery({
entryPoint: "cart_promotion",
@@ -60,9 +56,18 @@ export const updateCartPromotionsStep = createStep(
const linksToDismiss: any[] = []
if (promo_codes?.length) {
const promotions = await promotionService.listPromotions(
{ code: promo_codes },
{ select: ["id"] }
const query = container.resolve(ContainerRegistrationKeys.QUERY)
const { data: promotions } = await query.graph(
{
entity: "promotion",
fields: ["id", "code"],
filters: { code: promo_codes },
},
{
cache: {
enable: true,
},
}
)
for (const promotion of promotions) {
@@ -205,6 +205,11 @@ export const addToCartWorkflow = createWorkflow(
filters: {
id: variantIds,
},
options: {
cache: {
enable: true,
},
},
}).config({ name: "fetch-variants" })
})
@@ -142,6 +142,11 @@ export const getVariantsAndItemsWithPrices = createWorkflow(
filters: {
id: variantIds,
},
options: {
cache: {
enable: true,
},
},
}).config({ name: "fetch-variants" })
const calculatedPriceSets = getVariantPriceSetsStep({
@@ -75,26 +75,26 @@ export const listShippingOptionsForCartWithPricingWorkflowId =
* @summary
*
* List a cart's shipping options with prices.
*
*
* @property hooks.setShippingOptionsContext - This hook is executed after the cart is retrieved and before the shipping options are queried. You can consume this hook to return any custom context useful for the shipping options retrieval.
*
* For example, you can consume the hook to add the customer Id to the context:
*
*
* ```ts
* import { listShippingOptionsForCartWithPricingWorkflow } from "@medusajs/medusa/core-flows"
* import { StepResponse } from "@medusajs/workflows-sdk"
*
*
* listShippingOptionsForCartWithPricingWorkflow.hooks.setShippingOptionsContext(
* async ({ cart }, { container }) => {
*
*
* if (cart.customer_id) {
* return new StepResponse({
* customer_id: cart.customer_id,
* })
* }
*
*
* const query = container.resolve("query")
*
*
* const { data: carts } = await query.graph({
* entity: "cart",
* filters: {
@@ -102,16 +102,16 @@ export const listShippingOptionsForCartWithPricingWorkflowId =
* },
* fields: ["customer_id"],
* })
*
*
* return new StepResponse({
* customer_id: carts[0].customer_id,
* })
* }
* )
* ```
*
*
* The `customer_id` property will be added to the context along with other properties such as `is_return` and `enabled_in_store`.
*
*
* :::note
*
* You should also consume the `setShippingOptionsContext` hook in the {@link listShippingOptionsForCartWorkflow} workflow to ensure that the context is consistent when listing shipping options across workflows.
@@ -120,7 +120,11 @@ export const listShippingOptionsForCartWithPricingWorkflowId =
*/
export const listShippingOptionsForCartWithPricingWorkflow = createWorkflow(
listShippingOptionsForCartWithPricingWorkflowId,
(input: WorkflowData<ListShippingOptionsForCartWithPricingWorkflowInput & AdditionalData>) => {
(
input: WorkflowData<
ListShippingOptionsForCartWithPricingWorkflowInput & AdditionalData
>
) => {
const optionIds = transform({ input }, ({ input }) =>
(input.options ?? []).map(({ id }) => id)
)
@@ -155,6 +159,11 @@ export const listShippingOptionsForCartWithPricingWorkflow = createWorkflow(
"stock_locations.address.*",
"stock_locations.fulfillment_sets.id",
],
options: {
cache: {
enable: true,
},
},
}).config({ name: "sales_channels-fulfillment-query" })
const scFulfillmentSets = transform(
@@ -193,13 +202,21 @@ export const listShippingOptionsForCartWithPricingWorkflow = createWorkflow(
resultValidator: shippingOptionsContextResult,
}
)
const setShippingOptionsContextResult = setShippingOptionsContext.getResult()
const setShippingOptionsContextResult =
setShippingOptionsContext.getResult()
const commonOptions = transform(
{ input, cart, fulfillmentSetIds, setShippingOptionsContextResult },
({ input, cart, fulfillmentSetIds, setShippingOptionsContextResult }) => ({
({
input,
cart,
fulfillmentSetIds,
setShippingOptionsContextResult,
}) => ({
context: {
...(setShippingOptionsContextResult ? setShippingOptionsContextResult : {}),
...(setShippingOptionsContextResult
? setShippingOptionsContextResult
: {}),
is_return: input.is_return ? "true" : "false",
enabled_in_store: !isDefined(input.enabled_in_store)
? "true"
@@ -168,6 +168,11 @@ export const listShippingOptionsForCartWorkflow = createWorkflow(
"stock_locations.name",
"stock_locations.address.*",
],
options: {
cache: {
enable: true,
},
},
}).config({ name: "sales_channels-fulfillment-query" })
const scFulfillmentSets = transform(
@@ -148,6 +148,9 @@ export const updateCartWorkflow = createWorkflow(
options: {
throwIfKeyNotFound: true,
isList: false,
cache: {
enable: true,
},
},
}).config({ name: "get-region" })
@@ -247,6 +247,11 @@ export const createOrderWorkflow = createWorkflow(
filters: {
id: variantIdsWithoutCalculatedPrice,
},
options: {
cache: {
enable: true,
},
},
}).config({ name: "query-variants-without-calculated-price" })
/**
+1
View File
@@ -16,6 +16,7 @@
"exports": {
".": "./dist/index.js",
"./config": "./dist/config/index.js",
"./caching": "./dist/caching/index.js",
"./logger": "./dist/logger/index.js",
"./database": "./dist/database/index.js",
"./subscribers": "./dist/subscribers/index.js",
@@ -30,20 +30,38 @@ export async function ensurePublishableApiKeyMiddleware(
const query = req.scope.resolve(ContainerRegistrationKeys.QUERY)
try {
const { data } = await query.graph({
entity: "api_key",
fields: ["id", "token", "sales_channels_link.sales_channel_id"],
filters: {
token: publishableApiKey,
type: ApiKeyType.PUBLISHABLE,
$or: [
{ revoked_at: { $eq: null } },
{ revoked_at: { $gt: new Date() } },
// Cache API key data and check revocation in memory
const { data } = await query.graph(
{
entity: "api_key",
fields: [
"id",
"token",
"revoked_at",
"sales_channels_link.sales_channel_id",
],
filters: {
token: publishableApiKey,
type: ApiKeyType.PUBLISHABLE,
},
},
})
{
cache: {
enable: true,
},
}
)
apiKey = data[0]
if (data.length) {
const now = new Date()
const cachedApiKey = data[0]
const isRevoked =
!!cachedApiKey.revoked_at && new Date(cachedApiKey.revoked_at) <= now
if (!isRevoked) {
apiKey = cachedApiKey
}
}
} catch (e) {
return next(e)
}
@@ -1,49 +1,88 @@
import { MedusaContainer } from "@medusajs/types"
import {
ContainerRegistrationKeys,
isString,
remoteQueryObjectFromString,
} from "@medusajs/utils"
import { MedusaRequest } from "../types"
import type {
GraphResultSet,
MedusaContainer,
RemoteJoinerOptions,
RemoteQueryEntryPoints,
RemoteQueryFunctionReturnPagination,
} from "../../types"
import { ContainerRegistrationKeys, isString } from "../../utils"
import type { MedusaRequest } from "../types"
export const refetchEntities = async (
entryPoint: string,
idOrFilter: string | object,
scope: MedusaContainer,
fields: string[],
pagination?: MedusaRequest["queryConfig"]["pagination"],
export const refetchEntities = async <TEntry extends string>({
entity,
idOrFilter,
scope,
fields,
pagination,
withDeleted,
options,
}: {
entity: TEntry
idOrFilter?: string | object
scope: MedusaContainer
fields?: string[]
pagination?: MedusaRequest["queryConfig"]["pagination"]
withDeleted?: boolean
) => {
const remoteQuery = scope.resolve(ContainerRegistrationKeys.REMOTE_QUERY)
const filters = isString(idOrFilter) ? { id: idOrFilter } : idOrFilter
let context: object = {}
options?: RemoteJoinerOptions
}): Promise<
Omit<GraphResultSet<TEntry>, "metadata"> & {
metadata: RemoteQueryFunctionReturnPagination
}
> => {
const query = scope.resolve(ContainerRegistrationKeys.QUERY)
let filters = isString(idOrFilter) ? { id: idOrFilter } : idOrFilter
let context!: Record<string, unknown>
if ("context" in filters) {
if (filters.context) {
context = filters.context!
if (filters && "context" in filters) {
const { context: context_, ...rest } = filters
if (context_) {
context = context_! as Record<string, unknown>
}
delete filters.context
filters = rest
}
const variables = { filters, ...context, ...pagination, withDeleted }
const graphOptions: Parameters<typeof query.graph>[0] = {
entity,
fields: fields ?? [],
filters,
pagination,
withDeleted,
context: context,
}
const queryObject = remoteQueryObjectFromString({
entryPoint,
variables,
const result = await query.graph(graphOptions, options)
return {
data: result.data as TEntry extends keyof RemoteQueryEntryPoints
? RemoteQueryEntryPoints[TEntry][]
: any[],
metadata: result.metadata ?? ({} as RemoteQueryFunctionReturnPagination),
}
}
export const refetchEntity = async <TEntry extends string>({
entity,
idOrFilter,
scope,
fields,
options,
}: {
entity: TEntry & string
idOrFilter: string | object
scope: MedusaContainer
fields: string[]
options?: RemoteJoinerOptions
}): Promise<
TEntry extends keyof RemoteQueryEntryPoints
? RemoteQueryEntryPoints[TEntry]
: any
> => {
const { data } = await refetchEntities<TEntry>({
entity,
idOrFilter,
scope,
fields,
options,
})
return await remoteQuery(queryObject)
}
export const refetchEntity = async (
entryPoint: string,
idOrFilter: string | object,
scope: MedusaContainer,
fields: string[]
) => {
const [entity] = await refetchEntities(entryPoint, idOrFilter, scope, fields)
return entity
return Array.isArray(data) ? data[0] : data
}
@@ -5,6 +5,7 @@ import {
IApiKeyModuleService,
IAuthModuleService,
ICacheService,
ICachingModuleService,
ICartModuleService,
ICurrencyModuleService,
ICustomerModuleService,
@@ -76,6 +77,7 @@ declare module "@medusajs/types" {
[Modules.NOTIFICATION]: INotificationModuleService
[Modules.LOCKING]: ILockingModule
[Modules.SETTINGS]: ISettingsModuleService
[Modules.CACHING]: ICachingModuleService
}
}
@@ -244,6 +244,7 @@ describe("load internal", () => {
expect(generatedJoinerConfig).toEqual({
serviceName: "module-without-joiner-config",
primaryKeys: ["id"],
idPrefixToEntityName: {},
linkableKeys: {
entity2_id: "Entity2",
entity_model_id: "EntityModel",
@@ -322,6 +323,7 @@ describe("load internal", () => {
expect(generatedJoinerConfig).toEqual({
serviceName: "module-service",
primaryKeys: ["id"],
idPrefixToEntityName: {},
linkableKeys: {},
schema: "",
alias: [
@@ -231,7 +231,8 @@ export async function loadInternalModule(args: {
ContainerRegistrationKeys.CONFIG_MODULE,
ContainerRegistrationKeys.LOGGER,
ContainerRegistrationKeys.PG_CONNECTION,
Modules.EVENT_BUS
Modules.EVENT_BUS,
Modules.CACHING
)
for (const dependency of dependencies) {
@@ -609,6 +609,7 @@ async function MedusaApp_({
query: createQuery({
remoteQuery,
indexModule,
container: sharedContainer_,
}) as any, // TODO: rm any once we remove the old RemoteQueryFunction and rely on the Query object instead,
entitiesMap,
gqlSchema: schema,
@@ -1,6 +1,7 @@
import {
GraphResultSet,
IIndexService,
MedusaContainer,
RemoteJoinerOptions,
RemoteJoinerQuery,
RemoteQueryFilters,
@@ -11,6 +12,7 @@ import {
RemoteQueryObjectFromStringResult,
} from "@medusajs/types"
import {
Cached,
MedusaError,
isObject,
remoteQueryObjectFromString,
@@ -19,12 +21,62 @@ import {
import { RemoteQuery } from "./remote-query"
import { toRemoteQuery } from "./to-remote-query"
function extractCacheOptions(option: string) {
return function extractKey(args: any[]) {
return args[1]?.cache?.[option]
}
}
function isCacheEnabled(args: any[]) {
const isEnabled = extractCacheOptions("enable")(args)
if (isEnabled === false) {
return false
}
return (
isEnabled === true ||
extractCacheOptions("key")(args) ||
extractCacheOptions("ttl")(args) ||
extractCacheOptions("tags")(args) ||
extractCacheOptions("autoInvalidate")(args) ||
extractCacheOptions("providers")(args)
)
}
const cacheDecoratorOptions = {
enable: isCacheEnabled,
key: async (args, cachingModule) => {
const key = extractCacheOptions("key")(args)
if (key) {
return key
}
const queryOptions = args[0]
const remoteJoinerOptions = args[1] ?? {}
const { initialData, cache, ...restOptions } = remoteJoinerOptions
const keyInput = {
queryOptions,
options: restOptions,
}
return await cachingModule.computeKey(keyInput)
},
ttl: extractCacheOptions("ttl"),
tags: extractCacheOptions("tags"),
autoInvalidate: extractCacheOptions("autoInvalidate"),
providers: extractCacheOptions("providers"),
container: function (this: Query) {
return this.container
},
}
/**
* API wrapper around the remoteQuery
*/
export class Query {
#remoteQuery: RemoteQuery
#indexModule: IIndexService
protected container: MedusaContainer
/**
* Method to wrap execution of the graph query for instrumentation
@@ -61,12 +113,15 @@ export class Query {
constructor({
remoteQuery,
indexModule,
container,
}: {
remoteQuery: RemoteQuery
indexModule: IIndexService
container: MedusaContainer
}) {
this.#remoteQuery = remoteQuery
this.#indexModule = indexModule
this.container = container
}
#unwrapQueryConfig(
@@ -151,6 +206,7 @@ export class Query {
* Graph function uses the remoteQuery under the hood and
* returns a result set
*/
@Cached(cacheDecoratorOptions)
async graph<const TEntry extends string>(
queryOptions: RemoteQueryInput<TEntry>,
options?: RemoteJoinerOptions
@@ -189,6 +245,7 @@ export class Query {
* Index function uses the Index module to query and hydrates the data with query.graph
* returns a result set
*/
@Cached(cacheDecoratorOptions)
async index<const TEntry extends string>(
queryOptions: RemoteQueryInput<TEntry> & {
joinFilters?: RemoteQueryFilters<TEntry>
@@ -266,13 +323,16 @@ export class Query {
export function createQuery({
remoteQuery,
indexModule,
container,
}: {
remoteQuery: RemoteQuery
indexModule: IIndexService
container: MedusaContainer
}) {
const query = new Query({
remoteQuery,
indexModule,
container,
})
function backwardCompatibleQuery(...args: any[]) {
@@ -69,8 +69,7 @@ export function toRemoteQuery<const TEntity extends string>(
}
if (QueryContext.isQueryContext(src)) {
const normalizedFilters = { ...src } as any
delete normalizedFilters.__type
const { __type, ...normalizedFilters } = src as any
const prop = "context"
@@ -100,7 +99,7 @@ export function toRemoteQuery<const TEntity extends string>(
}
// Process filters and context recursively
processNestedObjects(joinerQuery[entity], context)
processNestedObjects(joinerQuery[entity], context, true)
for (const field of fields) {
const fieldAsString = field as string
+1
View File
@@ -3,6 +3,7 @@ export * as AnalyticsTypes from "./analytics"
export * as ApiKeyTypes from "./api-key"
export * as AuthTypes from "./auth"
export * as CacheTypes from "./cache"
export * as CachingTypes from "./caching"
export * as CartTypes from "./cart"
export * as CommonTypes from "./common"
export * as CurrencyTypes from "./currency"
+161
View File
@@ -0,0 +1,161 @@
import { IModuleService, ModuleJoinerConfig } from "../modules-sdk"
type Providers = string[] | { id: string; ttl?: number }[]
export interface ICachingModuleService extends IModuleService {
// Static trace methods
// traceGet: (
// cacheGetFn: () => Promise<any>,
// key: string,
// tags: string[]
// ) => Promise<any>
// traceSet?: (
// cacheSetFn: () => Promise<any>,
// key: string,
// tags: string[],
// options: { autoInvalidate?: boolean }
// ) => Promise<any>
// traceClear?: (
// cacheClearFn: () => Promise<any>,
// key: string,
// tags: string[],
// options: { autoInvalidate?: boolean }
// ) => Promise<any>
/**
* This method retrieves data from the cache.
*
* @param key - The key of the item to retrieve.
* @param tags - The tags of the items to retrieve.
* @param providers - Array of providers to check in order of priority. If not provided,
* only the default provider will be used.
*
* @returns The item(s) that was stored in the cache. If the item(s) was not found, null will
* be returned.
*
*/
get<T>({
key,
tags,
providers,
}: {
key?: string
tags?: string[]
providers?: string[]
}): Promise<T | null>
/**
* This method stores data in the cache.
*
* @param key - The key of the item to store.
* @param data - The data to store in the cache.
* @param ttl - The time-to-live (TTL in seconds) value in seconds. If not provided, the default TTL value
* is used. The default value is based on the used Cache Module.
* @param tags - The tags of the items to store. can be used for cross invalidation.
* @param options - if specified, will be stored with the item(s).
* @param providers - The providers from which to store the item(s).
*
*/
set({
key,
data,
ttl,
tags,
options,
providers,
}: {
key: string
data: object
ttl?: number
tags?: string[]
options?: {
autoInvalidate?: boolean
}
providers?: Providers
}): Promise<void>
/**
* This method clears data from the cache.
*
* @param key - The key of the item to clear.
* @param tags - The tags of the items to clear.
* @param options - if specified, invalidate the item(s) that has the value of the given
* options stored. e.g you can invalidate the tags X if their options.autoInvalidate is false or not present.
* @param providers - The providers from which to clear the item(s).
*
*/
clear({
key,
tags,
options,
providers,
}: {
key?: string
tags?: string[]
options?: {
autoInvalidate?: boolean
}
providers?: string[]
}): Promise<void>
computeKey(input: object): Promise<string>
computeTags(input: object, options?: Record<string, any>): Promise<string[]>
}
export interface ICachingProviderService {
get({ key, tags }: { key?: string; tags?: string[] }): Promise<any>
set({
key,
data,
ttl,
tags,
options,
}: {
key: string
data: object
ttl?: number
tags?: string[]
options?: { autoInvalidate?: boolean }
}): Promise<void>
clear({
key,
tags,
options,
}: {
key?: string
tags?: string[]
options?: { autoInvalidate?: boolean }
}): Promise<void>
}
export interface EntityReference {
type: string
id: string | number
field?: string
}
export interface ICachingStrategy {
/**
* This method is called when the application starts. It can be useful to set up some auto
* invalidation logic that reacts to something.
*
* @param container MedusaContainer
* @param schema GraphQLSchema
* @param cacheModule ICachingModuleService
*/
onApplicationStart?(
schema: any,
joinerConfigs: ModuleJoinerConfig[]
): Promise<void>
onApplicationPrepareShutdown?(): Promise<void>
onApplicationShutdown?(): Promise<void>
computeKey(input: object): Promise<string>
computeTags(input: object, options?: Record<string, any>): Promise<string[]>
}
@@ -58,3 +58,8 @@ export type RawMessageFormat<TData = any> = {
context?: Pick<Context, "eventGroupId">
options?: Record<string, any>
}
export type InterceptorSubscriber<T = unknown> = (
message: Message<T>,
context?: { isGrouped?: boolean; eventGroupId?: string }
) => Promise<void> | void
@@ -1,4 +1,9 @@
import { Message, Subscriber, SubscriberContext } from "./common"
import {
InterceptorSubscriber,
Message,
Subscriber,
SubscriberContext,
} from "./common"
export interface IEventBusModuleService {
/**
@@ -86,4 +91,31 @@ export interface IEventBusModuleService {
eventNames?: string[]
}
): Promise<void>
/**
* This method adds an interceptor to the event bus. This means that the interceptor will be
* called before the event is emitted.
*
* @param interceptor - The interceptor to add.
* @returns The instance of the Event Module
*
* @example
* eventModuleService.addInterceptor((message, context) => {
* console.log("Interceptor", message, context)
* })
*/
addInterceptor?(interceptor: InterceptorSubscriber): this
/**
* This method removes an interceptor from the event bus.
*
* @param interceptor - The interceptor to remove.
* @returns The instance of the Event Module
*
* @example
* eventModuleService.removeInterceptor((message, context) => {
* console.log("Interceptor", message, context)
* })
*/
removeInterceptor?(interceptor: InterceptorSubscriber): this
}
+1
View File
@@ -5,6 +5,7 @@ export * from "./api-key"
export * from "./auth"
export * from "./bundles"
export * from "./cache"
export * from "./caching"
export * from "./cart"
export * from "./common"
export * from "./currency"
+36
View File
@@ -1,3 +1,5 @@
import { ICachingModuleService } from "../caching"
export type JoinerRelationship = {
alias: string
foreignKey: string
@@ -92,6 +94,40 @@ export interface RemoteJoinerOptions {
throwIfRelationNotFound?: boolean | string[]
initialData?: object | object[]
initialDataOnly?: boolean
cache?: {
/**
* Whether to enable the cache. This is only useful if you want to enable without providing any
* other options or if you want to enable/disable the cache based on the arguments.
*/
enable?: boolean | ((args: any[]) => boolean | undefined)
/**
* The key to use for the cache.
* If a function is provided, it will be called with the arguments as the first argument and the
* container as the second argument.
*/
key?:
| string
| ((
args: any[],
cachingModule: ICachingModuleService
) => string | Promise<string>)
/**
* The tags to use for the cache.
*/
tags?: string[] | ((args: any[]) => string[] | undefined)
/**
* The time-to-live (TTL) value in seconds.
*/
ttl?: number | ((args: any[]) => number | undefined)
/**
* Whether to auto invalidate the cache whenever it is possible.
*/
autoInvalidate?: boolean | ((args: any[]) => boolean | undefined)
/**
* The providers to use for the cache.
*/
providers?: string[] | ((args: any[]) => string[] | undefined)
}
}
export interface RemoteNestedExpands {
@@ -196,6 +196,7 @@ export type ModuleJoinerConfig = Omit<
* GraphQL schema for the all module's available entities and fields
*/
schema?: string
idPrefixToEntityName?: Record<string, string>
relationships?: ModuleJoinerRelationship[]
extends?: {
serviceName: string
@@ -48,9 +48,10 @@ export type QueryGraphFunction = {
* a normalized/consistent output.
*/
export type QueryIndexFunction = {
<const TEntry extends string>(queryOptions: IndexQueryInput<TEntry>): Promise<
Prettify<QueryResultSet<TEntry>>
>
<const TEntry extends string>(
queryOptions: IndexQueryInput<TEntry>,
options?: RemoteJoinerOptions
): Promise<Prettify<QueryResultSet<TEntry>>>
}
/*export type RemoteQueryReturnedData<TEntry extends string> =
@@ -132,7 +132,6 @@ export interface FilterablePriceRuleProps
* The IDs to filter the price rule's associated price set.
*/
price_set_id?: string | string[] | OperatorMap<string | string[]>
/**
* The IDs to filter the price rule's associated price.
*/
+1
View File
@@ -18,3 +18,4 @@ export * as PromotionUtils from "./promotion"
export * as SearchUtils from "./search"
export * as ShippingProfileUtils from "./shipping"
export * as UserUtils from "./user"
export * as CachingUtils from "./caching"
+249
View File
@@ -0,0 +1,249 @@
import { ICachingModuleService, Logger, MedusaContainer } from "@medusajs/types"
import { MedusaContextType, Modules } from "../modules-sdk"
import { FeatureFlag } from "../feature-flags"
import { ContainerRegistrationKeys, isObject } from "../common"
/**
* This function is used to cache the result of a function call.
*
* @param cb - The callback to execute.
* @param options - The options for the cache.
* @returns The result of the callback.
*/
export async function useCache<T>(
cb: (...args: any[]) => Promise<T>,
options: {
enable?: boolean
key: string | any[]
tags?: string[]
ttl?: number
/**
* Whethere the default strategy should auto invalidate the cache whenever it is possible.
*/
autoInvalidate?: boolean
providers?: string[]
container: MedusaContainer
}
): Promise<T> {
const cachingModule = options.container.resolve<ICachingModuleService>(
Modules.CACHING,
{
allowUnregistered: true,
}
)
if (
!options.enable ||
!FeatureFlag.isFeatureEnabled("caching") ||
!cachingModule
) {
return await cb()
}
let key: string
if (typeof options.key === "string") {
key = options.key
} else {
key = await cachingModule.computeKey(options.key)
}
const data = await cachingModule.get({
key,
tags: options.tags,
providers: options.providers,
})
if (data) {
return data as T
}
const result = await cb()
void cachingModule
.set({
key,
tags: options.tags,
ttl: options.ttl,
data: result as object,
options: { autoInvalidate: options.autoInvalidate },
providers: options.providers,
})
.catch((e) => {
const logger =
options.container.resolve<Logger>(ContainerRegistrationKeys.LOGGER, {
allowUnregistered: true,
}) ?? (console as unknown as Logger)
logger.error(
`An error occured while setting cache for key: ${key}\n${e.message}\n${e.stack}`
)
})
return result
}
type TargetMethodArgs<Target, PropertyKey> = Target[PropertyKey &
keyof Target] extends (...args: any[]) => any
? Parameters<Target[PropertyKey & keyof Target]>
: never
/**
* This function is used to cache the result of a method call.
*
* @param options - The options for the cache.
* @returns The original method with the cache applied.
*/
export function Cached<
const Target extends object,
const PropertyKey extends keyof Target
>(options: {
/**
* The key to use for the cache.
* If a function is provided, it will be called with the arguments as the first argument and the
* container as the second argument.
*/
key?:
| string
| ((
args: TargetMethodArgs<Target, PropertyKey>,
cachingModule: ICachingModuleService
) => string | Promise<string> | Promise<any[]> | any[])
/**
* Whether to enable the cache. This is only useful if you want to enable without providing any
* other options.
*/
enable?:
| boolean
| ((args: TargetMethodArgs<Target, PropertyKey>) => boolean | undefined)
/**
* The tags to use for the cache.
*/
tags?:
| string[]
| ((args: TargetMethodArgs<Target, PropertyKey>) => string[] | undefined)
/**
* The time-to-live (TTL) value in seconds.
*/
ttl?:
| number
| ((args: TargetMethodArgs<Target, PropertyKey>) => number | undefined)
/**
* Whether to auto invalidate the cache whenever it is possible.
*/
autoInvalidate?:
| boolean
| ((args: TargetMethodArgs<Target, PropertyKey>) => boolean | undefined)
/**
* The providers to use for the cache.
*/
providers?:
| string[]
| ((args: TargetMethodArgs<Target, PropertyKey>) => string[] | undefined)
container: MedusaContainer | ((this: Target) => MedusaContainer)
}) {
return function (
target: Target,
propertyKey: PropertyKey,
descriptor: PropertyDescriptor
) {
const originalMethod = descriptor.value
if (typeof originalMethod !== "function") {
throw new Error("@cached can only be applied to methods")
}
descriptor.value = async function (
...args: Target[PropertyKey & keyof Target] extends (
...args: any[]
) => any
? Parameters<Target[PropertyKey & keyof Target]>
: never
) {
const container: MedusaContainer =
typeof options.container === "function"
? options.container.call(this)
: options.container
const cachingModule = container.resolve<ICachingModuleService>(
Modules.CACHING,
{
allowUnregistered: true,
}
)
if (!FeatureFlag.isFeatureEnabled("caching") || !cachingModule) {
return await originalMethod.apply(this, args)
}
if (!options.key) {
options.key = await cachingModule.computeKey(
args
.map((arg) => {
if (isObject(arg)) {
// Prevent any container, manager, transactionManager, etc from being included in the key
const {
container,
manager,
transactionManager,
__type,
...rest
} = arg as any
if (__type === MedusaContextType) {
return
}
return rest
}
return arg
})
.filter(Boolean)
)
}
const resolvableKeys = [
"enable",
"key",
"tags",
"ttl",
"autoInvalidate",
"providers",
]
const cacheOptions = {} as Parameters<typeof useCache>[1]
const promises: Promise<any>[] = []
for (const key of resolvableKeys) {
if (typeof options[key] === "function") {
const res = options[key](args, cachingModule)
if (res instanceof Promise) {
promises.push(
res.then((value) => {
cacheOptions[key] = value
})
)
} else {
cacheOptions[key] = res
}
} else {
cacheOptions[key] = options[key]
}
}
await Promise.all(promises)
if (!cacheOptions.enable) {
return await originalMethod.apply(this, args)
}
Object.assign(cacheOptions, {
container,
})
return await useCache(
() => originalMethod.apply(this, args),
cacheOptions as Parameters<typeof useCache>[1]
)
}
return descriptor
}
}
@@ -2,7 +2,22 @@ import { ContainerLike } from "@medusajs/types"
export function createContainerLike(obj): ContainerLike {
return {
resolve(key: string) {
resolve(
key: string,
{
allowUnregistered = false,
}: {
allowUnregistered?: boolean
} = {}
) {
if (allowUnregistered) {
try {
return obj[key]
} catch (error) {
return undefined
}
}
return obj[key]
},
}
+1 -1
View File
@@ -14,7 +14,7 @@ export class IdProperty extends BaseProperty<string> {
return !!value?.[IsIdProperty] || value?.dataType?.name === "id"
}
protected dataType: {
dataType: {
name: "id"
options: {
prefix?: string
@@ -24,6 +24,10 @@ export class PrimaryKeyModifier<T, Schema extends PropertyType<T>>
*/
#schema: Schema
get schema() {
return this.#schema
}
constructor(schema: Schema) {
this.#schema = schema
}
+50 -1
View File
@@ -1,4 +1,8 @@
import { EventBusTypes, InternalModuleDeclaration } from "@medusajs/types"
import {
EventBusTypes,
InterceptorSubscriber,
InternalModuleDeclaration,
} from "@medusajs/types"
import { ulid } from "ulid"
export abstract class AbstractEventBusModuleService
@@ -11,6 +15,8 @@ export abstract class AbstractEventBusModuleService
EventBusTypes.SubscriberDescriptor[]
> = new Map()
protected interceptorSubscribers_: Set<InterceptorSubscriber> = new Set()
public get eventToSubscribersMap(): Map<
string | symbol,
EventBusTypes.SubscriberDescriptor[]
@@ -134,6 +140,49 @@ export abstract class AbstractEventBusModuleService
return this
}
/**
* Add an interceptor subscriber that receives all messages before they are emitted
*
* @param interceptor - Function that receives messages before emission
* @returns this for chaining
*/
public addInterceptor(interceptor: InterceptorSubscriber): this {
this.interceptorSubscribers_.add(interceptor)
return this
}
/**
* Remove an interceptor subscriber
*
* @param interceptor - Function to remove from interceptors
* @returns this for chaining
*/
public removeInterceptor(interceptor: InterceptorSubscriber): this {
this.interceptorSubscribers_.delete(interceptor)
return this
}
/**
* Call all interceptor subscribers with the message before emission
* This should be called by implementations before emitting events
*
* @param message - The message to be intercepted
* @param context - Optional context about the emission
*/
protected async callInterceptors<T = unknown>(
message: EventBusTypes.Message<T>,
context?: { isGrouped?: boolean; eventGroupId?: string }
): Promise<void> {
Array.from(this.interceptorSubscribers_).map(async (interceptor) => {
try {
await interceptor(message, context)
} catch (error) {
// Log error but don't stop other interceptors or the emission
console.error("Error in event bus interceptor:", error)
}
})
}
}
export * from "./build-event-messages"
+1
View File
@@ -29,6 +29,7 @@ export * from "./shipping"
export * from "./totals"
export * from "./totals/big-number"
export * from "./user"
export * from "./caching"
export const MedusaModuleType = Symbol.for("MedusaModule")
export const MedusaModuleProviderType = Symbol.for("MedusaModuleProvider")
@@ -48,6 +48,7 @@ describe("joiner-config-builder", () => {
serviceName: Modules.FULFILLMENT,
primaryKeys: ["id"],
schema: "",
idPrefixToEntityName: {},
linkableKeys: {
fulfillment_set_id: FulfillmentSet.name,
shipping_option_id: ShippingOption.name,
@@ -136,6 +137,7 @@ describe("joiner-config-builder", () => {
serviceName: Modules.FULFILLMENT,
primaryKeys: ["id"],
schema: "",
idPrefixToEntityName: {},
linkableKeys: {},
alias: [
{
@@ -176,6 +178,7 @@ describe("joiner-config-builder", () => {
serviceName: Modules.FULFILLMENT,
primaryKeys: ["id"],
schema: "",
idPrefixToEntityName: {},
linkableKeys: {
fulfillment_set_id: FulfillmentSet.name,
shipping_option_id: ShippingOption.name,
@@ -269,6 +272,7 @@ describe("joiner-config-builder", () => {
serviceName: Modules.FULFILLMENT,
primaryKeys: ["id"],
schema: "",
idPrefixToEntityName: {},
linkableKeys: {},
alias: [
{
@@ -300,6 +304,7 @@ describe("joiner-config-builder", () => {
serviceName: Modules.FULFILLMENT,
primaryKeys: ["id"],
schema: "",
idPrefixToEntityName: {},
linkableKeys: {
fulfillment_set_id: FulfillmentSet.name,
},
@@ -335,6 +340,7 @@ describe("joiner-config-builder", () => {
serviceName: Modules.FULFILLMENT,
primaryKeys: ["id"],
schema: expect.any(String),
idPrefixToEntityName: {},
linkableKeys: {
fulfillment_set_id: FulfillmentSet.name,
shipping_option_id: ShippingOption.name,
@@ -27,6 +27,7 @@ export const Modules = {
INDEX: "index",
LOCKING: "locking",
SETTINGS: "settings",
CACHING: "caching",
} as const
export const MODULE_PACKAGE_NAMES = {
@@ -58,6 +59,7 @@ export const MODULE_PACKAGE_NAMES = {
[Modules.INDEX]: "@medusajs/medusa/index-module",
[Modules.LOCKING]: "@medusajs/medusa/locking",
[Modules.SETTINGS]: "@medusajs/medusa/settings",
[Modules.CACHING]: "@medusajs/caching",
}
export const REVERSED_MODULE_PACKAGE_NAMES = Object.entries(
@@ -46,8 +46,10 @@ export function defineJoinerConfig(
models,
linkableKeys,
primaryKeys,
idPrefixToEntityName,
}: {
alias?: JoinerServiceConfigAlias[]
idPrefixToEntityName?: Record<string, string>
schema?: string
models?: DmlEntity<any, any>[] | { name: string }[]
linkableKeys?: ModuleJoinerConfig["linkableKeys"]
@@ -150,6 +152,12 @@ export function defineJoinerConfig(
schema = toGraphQLSchema([...modelDefinitions.values()])
}
if (!idPrefixToEntityName) {
idPrefixToEntityName = buildIdPrefixToEntityNameFromDmlObjects([
...modelDefinitions.values(),
])
}
const linkableKeysFromDml = buildLinkableKeysFromDmlObjects([
...modelDefinitions.values(),
])
@@ -199,6 +207,7 @@ export function defineJoinerConfig(
serviceName,
primaryKeys,
schema,
idPrefixToEntityName,
linkableKeys: linkableKeys,
alias: [
...[...(alias ?? ([] as any))].map((alias) => ({
@@ -230,6 +239,31 @@ export function defineJoinerConfig(
}
}
/**
* Build the id prefix to entity name map from the DML objects
* @param models
*/
export function buildIdPrefixToEntityNameFromDmlObjects(
models: DmlEntity<any, any>[]
): Record<string, string> {
return models.reduce((acc, model) => {
const id = model.parse().schema.id as
| IdProperty
| PrimaryKeyModifier<any, any>
if (
PrimaryKeyModifier.isPrimaryKeyModifier(id) &&
id.schema.dataType.options.prefix
) {
acc[id.schema.dataType.options.prefix] = model.name
} else if (IdProperty.isIdProperty(id) && id.dataType.options.prefix) {
acc[id.dataType.options.prefix] = model.name
}
return acc
}, {})
}
/**
* From a set of DML objects, build the linkable keys
*
@@ -1,6 +1,7 @@
import { Constructor, IDmlEntity, ModuleExports } from "@medusajs/types"
import { DmlEntity } from "../dml"
import {
buildIdPrefixToEntityNameFromDmlObjects,
buildLinkConfigFromLinkableKeys,
buildLinkConfigFromModelObjects,
defineJoinerConfig,
@@ -53,6 +54,10 @@ export function Module<
// TODO: Add support for non linkable modifier DML object to be skipped from the linkable generation
const linkableKeys = service.prototype.__joinerConfig().linkableKeys
service.prototype.__joinerConfig().idPrefixToEntityName =
buildIdPrefixToEntityNameFromDmlObjects(
dmlObjects.map(([, model]) => model) as DmlEntity<any, any>[]
)
if (dmlObjects.length) {
linkable = buildLinkConfigFromModelObjects<ServiceName, ModelObjects>(
@@ -155,18 +155,25 @@ const getDataForComputation = async (
query: Omit<RemoteQueryFunction, symbol>,
data: { variant_ids: string[]; sales_channel_id?: string }
) => {
const { data: variantInventoryItems } = await query.graph({
entity: "product_variant_inventory_items",
fields: [
"variant_id",
"required_quantity",
"variant.manage_inventory",
"variant.allow_backorder",
"inventory.*",
"inventory.location_levels.*",
],
filters: { variant_id: data.variant_ids },
})
const { data: variantInventoryItems } = await query.graph(
{
entity: "product_variant_inventory_items",
fields: [
"variant_id",
"required_quantity",
"variant.manage_inventory",
"variant.allow_backorder",
"inventory.*",
"inventory.location_levels.*",
],
filters: { variant_id: data.variant_ids },
},
{
cache: {
enable: true,
},
}
)
const variantInventoriesMap = new Map()
variantInventoryItems.forEach((link) => {
@@ -177,11 +184,21 @@ const getDataForComputation = async (
const locationIds = new Set<string>()
if (data.sales_channel_id) {
const { data: channelLocations } = await query.graph({
entity: "sales_channel_locations",
fields: ["stock_location_id"],
filters: { sales_channel_id: data.sales_channel_id },
})
const { data: channelLocations } = await query.graph(
{
entity: "sales_channel_locations",
fields: ["stock_location_id"],
filters: { sales_channel_id: data.sales_channel_id },
},
{
cache: {
tags: [
`SalesChannel:${data.sales_channel_id}`,
"StockLocation:list:*",
],
},
}
)
channelLocations.forEach((loc) => locationIds.add(loc.stock_location_id))
}
+2
View File
@@ -80,6 +80,8 @@
"@medusajs/auth-google": "2.10.3",
"@medusajs/cache-inmemory": "2.10.3",
"@medusajs/cache-redis": "2.10.3",
"@medusajs/caching": "2.10.3",
"@medusajs/caching-redis": "2.10.3",
"@medusajs/cart": "2.10.3",
"@medusajs/core-flows": "2.10.3",
"@medusajs/currency": "2.10.3",
@@ -67,10 +67,12 @@ export const DELETE = async (
) => {
const { id, action_id } = req.params
const claim = await refetchEntity("order_claim", id, req.scope, [
"id",
"return_id",
])
const claim = await refetchEntity({
entity: "order_claim",
idOrFilter: id,
scope: req.scope,
fields: ["id", "return_id"],
})
const { result: orderPreview } = await removeItemReturnActionWorkflow(
req.scope
@@ -81,15 +83,15 @@ export const DELETE = async (
},
})
const orderReturn = await refetchEntity(
"return",
{
const orderReturn = await refetchEntity({
entity: "return",
idOrFilter: {
...req.filterableFields,
id,
},
req.scope,
defaultAdminDetailsReturnFields
)
scope: req.scope,
fields: defaultAdminDetailsReturnFields,
})
res.json({
order_preview: orderPreview as unknown as HttpTypes.AdminOrderPreview,
@@ -10,12 +10,12 @@ export const GET = async (
req: AuthenticatedMedusaRequest,
res: MedusaResponse<AdminClaimResponse>
) => {
const claim = await refetchEntity(
"order_claim",
req.params.id,
req.scope,
req.queryConfig.fields
)
const claim = await refetchEntity({
entity: "order_claim",
idOrFilter: req.params.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
if (!claim) {
throw new MedusaError(
@@ -69,9 +69,12 @@ export const DELETE = async (
const { id, action_id } = req.params
const exchange = await refetchEntity("order_exchange", id, req.scope, [
"return_id",
])
const exchange = await refetchEntity({
entity: "order_exchange",
idOrFilter: id,
scope: req.scope,
fields: ["return_id"],
})
const { result: orderPreview } = await removeItemReturnActionWorkflow(
req.scope
@@ -10,12 +10,12 @@ export const GET = async (
req: AuthenticatedMedusaRequest,
res: MedusaResponse<AdminExchangeResponse>
) => {
const exchange = await refetchEntity(
"order_exchange",
req.params.id,
req.scope,
req.queryConfig.fields
)
const exchange = await refetchEntity({
entity: "order_exchange",
idOrFilter: req.params.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
if (!exchange) {
throw new MedusaError(
@@ -10,11 +10,12 @@ export const GET = async (
req: AuthenticatedMedusaRequest<AdminGetNotificationParamsType>,
res: MedusaResponse<HttpTypes.AdminNotificationResponse>
) => {
const notification = await refetchEntity(
"notification",
req.params.id,
req.scope,
req.queryConfig.fields
)
const notification = await refetchEntity({
entity: "notification",
idOrFilter: req.params.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ notification })
}
@@ -9,13 +9,13 @@ export const GET = async (
req: AuthenticatedMedusaRequest<HttpTypes.AdminNotificationListParams>,
res: MedusaResponse<HttpTypes.AdminNotificationListResponse>
) => {
const { rows: notifications, metadata } = await refetchEntities(
"notification",
req.filterableFields,
req.scope,
req.queryConfig.fields,
req.queryConfig.pagination
)
const { data: notifications, metadata } = await refetchEntities({
entity: "notification",
idOrFilter: req.filterableFields,
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
res.json({
notifications,
@@ -16,12 +16,12 @@ export const POST = async (
input: { orderId, fulfillmentId },
})
const order = await refetchEntity(
"order",
orderId,
req.scope,
req.queryConfig.fields
)
const order = await refetchEntity({
entity: "order",
idOrFilter: orderId,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ order })
}
@@ -21,12 +21,12 @@ export const POST = async (
},
})
const paymentCollection = await refetchEntity(
"payment_collection",
id,
req.scope,
req.queryConfig.fields
)
const paymentCollection = await refetchEntity({
entity: "payment_collection",
idOrFilter: id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ payment_collection: paymentCollection })
}
@@ -15,12 +15,12 @@ export const POST = async (
input: req.body,
})
const paymentCollection = await refetchEntity(
"payment_collection",
result[0].id,
req.scope,
req.queryConfig.fields
)
const paymentCollection = await refetchEntity({
entity: "payment_collection",
idOrFilter: result[0].id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ payment_collection: paymentCollection })
}
@@ -14,12 +14,12 @@ export const GET = async (
req: AuthenticatedMedusaRequest,
res: MedusaResponse<HttpTypes.AdminPricePreferenceResponse>
) => {
const price_preference = await refetchEntity(
"price_preference",
req.params.id,
req.scope,
req.queryConfig.fields
)
const price_preference = await refetchEntity({
entity: "price_preference",
idOrFilter: req.params.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ price_preference })
}
@@ -35,12 +35,12 @@ export const POST = async (
input: { selector: { id: [id] }, update: req.body },
})
const price_preference = await refetchEntity(
"price_preference",
id,
req.scope,
req.queryConfig.fields
)
const price_preference = await refetchEntity({
entity: "price_preference",
idOrFilter: id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ price_preference })
}
@@ -11,13 +11,14 @@ export const GET = async (
req: AuthenticatedMedusaRequest<HttpTypes.AdminPricePreferenceListParams>,
res: MedusaResponse<HttpTypes.AdminPricePreferenceListResponse>
) => {
const { rows: price_preferences, metadata } = await refetchEntities(
"price_preference",
req.filterableFields,
req.scope,
req.queryConfig.fields,
req.queryConfig.pagination
)
const { data: price_preferences, metadata } = await refetchEntities({
entity: "price_preference",
idOrFilter: req.filterableFields,
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
res.json({
price_preferences: price_preferences,
count: metadata.count,
@@ -35,12 +36,12 @@ export const POST = async (
input: [req.validatedBody],
})
const price_preference = await refetchEntity(
"price_preference",
result[0].id,
req.scope,
req.queryConfig.fields
)
const price_preference = await refetchEntity({
entity: "price_preference",
idOrFilter: result[0].id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ price_preference })
}
@@ -19,12 +19,12 @@ export const POST = async (
input: { id, ...req.validatedBody },
})
const category = await refetchEntity(
"product_category",
id,
req.scope,
req.queryConfig.fields
)
const category = await refetchEntity({
entity: "product_category",
idOrFilter: id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ product_category: category })
}
@@ -21,12 +21,15 @@ export const GET = async (
req: AuthenticatedMedusaRequest<AdminProductCategoryParamsType>,
res: MedusaResponse<AdminProductCategoryResponse>
) => {
const [category] = await refetchEntities(
"product_category",
{ id: req.params.id, ...req.filterableFields },
req.scope,
req.queryConfig.fields
)
const {
data: [category],
} = await refetchEntities({
entity: "product_category",
idOrFilter: { id: req.params.id, ...req.filterableFields },
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
if (!category) {
throw new MedusaError(
@@ -48,12 +51,15 @@ export const POST = async (
input: { selector: { id }, update: req.validatedBody },
})
const [category] = await refetchEntities(
"product_category",
{ id, ...req.filterableFields },
req.scope,
req.queryConfig.fields
)
const {
data: [category],
} = await refetchEntities({
entity: "product_category",
idOrFilter: { id, ...req.filterableFields },
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
res.status(200).json({ product_category: category })
}
@@ -10,13 +10,13 @@ export const GET = async (
req: AuthenticatedMedusaRequest<HttpTypes.AdminProductCategoryListParams>,
res: MedusaResponse<HttpTypes.AdminProductCategoryListResponse>
) => {
const { rows: product_categories, metadata } = await refetchEntities(
"product_category",
req.filterableFields,
req.scope,
req.queryConfig.fields,
req.queryConfig.pagination
)
const { data: product_categories, metadata } = await refetchEntities({
entity: "product_category",
idOrFilter: req.filterableFields,
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
res.json({
product_categories,
@@ -34,12 +34,15 @@ export const POST = async (
input: { product_categories: [req.validatedBody] },
})
const [category] = await refetchEntities(
"product_category",
{ id: result[0].id, ...req.filterableFields },
req.scope,
req.queryConfig.fields
)
const {
data: [category],
} = await refetchEntities({
entity: "product_category",
idOrFilter: { id: result[0].id, ...req.filterableFields },
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
res.status(200).json({ product_category: category })
}
@@ -19,12 +19,12 @@ export const GET = async (
req: AuthenticatedMedusaRequest<AdminGetProductTagParamsType>,
res: MedusaResponse<HttpTypes.AdminProductTagResponse>
) => {
const productTag = await refetchEntity(
"product_tag",
req.params.id,
req.scope,
req.queryConfig.fields
)
const productTag = await refetchEntity({
entity: "product_tag",
idOrFilter: req.params.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ product_tag: productTag })
}
@@ -33,12 +33,12 @@ export const POST = async (
req: AuthenticatedMedusaRequest<AdminUpdateProductTagType>,
res: MedusaResponse<HttpTypes.AdminProductTagResponse>
) => {
const existingProductTag = await refetchEntity(
"product_tag",
req.params.id,
req.scope,
["id"]
)
const existingProductTag = await refetchEntity({
entity: "product_tag",
idOrFilter: req.params.id,
scope: req.scope,
fields: ["id"],
})
if (!existingProductTag) {
throw new MedusaError(
@@ -54,12 +54,12 @@ export const POST = async (
},
})
const productTag = await refetchEntity(
"product_tag",
result[0].id,
req.scope,
req.queryConfig.fields
)
const productTag = await refetchEntity({
entity: "product_tag",
idOrFilter: result[0].id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ product_tag: productTag })
}
@@ -12,13 +12,13 @@ export const GET = async (
req: AuthenticatedMedusaRequest<HttpTypes.AdminProductTagListParams>,
res: MedusaResponse<HttpTypes.AdminProductTagListResponse>
) => {
const { rows: product_tags, metadata } = await refetchEntities(
"product_tag",
req.filterableFields,
req.scope,
req.queryConfig.fields,
req.queryConfig.pagination
)
const { data: product_tags, metadata } = await refetchEntities({
entity: "product_tag",
idOrFilter: req.filterableFields,
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
res.json({
product_tags: product_tags,
@@ -38,12 +38,12 @@ export const POST = async (
input: { product_tags: input },
})
const productTag = await refetchEntity(
"product_tag",
result[0].id,
req.scope,
req.queryConfig.fields
)
const productTag = await refetchEntity({
entity: "product_tag",
idOrFilter: result[0].id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ product_tag: productTag })
}
@@ -21,13 +21,13 @@ export const GET = async (
)
}
const { rows: variants, metadata } = await refetchEntities(
"variant",
{ ...req.filterableFields },
req.scope,
remapKeysForVariant(req.queryConfig.fields ?? []),
req.queryConfig.pagination
)
const { data: variants, metadata } = await refetchEntities({
entity: "variant",
idOrFilter: { ...req.filterableFields },
scope: req.scope,
fields: remapKeysForVariant(req.queryConfig.fields ?? []),
pagination: req.queryConfig.pagination,
})
if (withInventoryQuantity) {
await wrapVariantsWithTotalInventoryQuantity(req, variants || [])
@@ -17,12 +17,12 @@ export const GET = async (
) => {
const productId = req.params.id
const optionId = req.params.option_id
const productOption = await refetchEntity(
"product_option",
{ id: optionId, product_id: productId },
req.scope,
req.queryConfig.fields
)
const productOption = await refetchEntity({
entity: "product_option",
idOrFilter: { id: optionId, product_id: productId },
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ product_option: productOption })
}
@@ -45,12 +45,12 @@ export const POST = async (
},
})
const product = await refetchEntity(
"product",
productId,
req.scope,
remapKeysForProduct(req.queryConfig.fields ?? [])
)
const product = await refetchEntity({
entity: "product",
idOrFilter: productId,
scope: req.scope,
fields: remapKeysForProduct(req.queryConfig.fields ?? []),
})
res.status(200).json({ product: remapProductResponse(product) })
}
@@ -67,12 +67,12 @@ export const DELETE = async (
input: { ids: [optionId] /* product_id: productId */ },
})
const product = await refetchEntity(
"product",
productId,
req.scope,
remapKeysForProduct(req.queryConfig.fields ?? [])
)
const product = await refetchEntity({
entity: "product",
idOrFilter: productId,
scope: req.scope,
fields: remapKeysForProduct(req.queryConfig.fields ?? []),
})
res.status(200).json({
id: optionId,
@@ -14,13 +14,13 @@ export const GET = async (
res: MedusaResponse<HttpTypes.AdminProductOptionListResponse>
) => {
const productId = req.params.id
const { rows: product_options, metadata } = await refetchEntities(
"product_option",
{ ...req.filterableFields, product_id: productId },
req.scope,
req.queryConfig.fields,
req.queryConfig.pagination
)
const { data: product_options, metadata } = await refetchEntities({
entity: "product_option",
idOrFilter: { ...req.filterableFields, product_id: productId },
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
res.json({
product_options,
@@ -51,11 +51,11 @@ export const POST = async (
},
})
const product = await refetchEntity(
"product",
productId,
req.scope,
remapKeysForProduct(req.queryConfig.fields ?? [])
)
const product = await refetchEntity({
entity: "product",
idOrFilter: productId,
scope: req.scope,
fields: remapKeysForProduct(req.queryConfig.fields ?? []),
})
res.status(200).json({ product: remapProductResponse(product) })
}
@@ -16,12 +16,12 @@ export const GET = async (
res: MedusaResponse<HttpTypes.AdminProductResponse>
) => {
const selectFields = remapKeysForProduct(req.queryConfig.fields ?? [])
const product = await refetchEntity(
"product",
req.params.id,
req.scope,
selectFields
)
const product = await refetchEntity({
entity: "product",
idOrFilter: req.params.id,
scope: req.scope,
fields: selectFields,
})
if (!product) {
throw new MedusaError(MedusaError.Types.NOT_FOUND, "Product not found")
@@ -38,12 +38,12 @@ export const POST = async (
) => {
const { additional_data, ...update } = req.validatedBody
const existingProduct = await refetchEntity(
"product",
req.params.id,
req.scope,
["id"]
)
const existingProduct = await refetchEntity({
entity: "product",
idOrFilter: req.params.id,
scope: req.scope,
fields: ["id"],
})
/**
* Check if the product exists with the id or not before calling the workflow.
*/
@@ -62,12 +62,12 @@ export const POST = async (
},
})
const product = await refetchEntity(
"product",
result[0].id,
req.scope,
remapKeysForProduct(req.queryConfig.fields ?? [])
)
const product = await refetchEntity({
entity: "product",
idOrFilter: result[0].id,
scope: req.scope,
fields: remapKeysForProduct(req.queryConfig.fields ?? []),
})
res.status(200).json({ product: remapProductResponse(product) })
}
@@ -24,12 +24,12 @@ export const GET = async (
const variantId = req.params.variant_id
const variables = { id: variantId, product_id: productId }
const variant = await refetchEntity(
"variant",
variables,
req.scope,
remapKeysForVariant(req.queryConfig.fields ?? [])
)
const variant = await refetchEntity({
entity: "variant",
idOrFilter: variables,
scope: req.scope,
fields: remapKeysForVariant(req.queryConfig.fields ?? []),
})
res.status(200).json({ variant: remapVariantResponse(variant) })
}
@@ -52,12 +52,12 @@ export const POST = async (
},
})
const product = await refetchEntity(
"product",
productId,
req.scope,
remapKeysForProduct(req.queryConfig.fields ?? [])
)
const product = await refetchEntity({
entity: "product",
idOrFilter: productId,
scope: req.scope,
fields: remapKeysForProduct(req.queryConfig.fields ?? []),
})
res.status(200).json({ product: remapProductResponse(product) })
}
@@ -74,12 +74,12 @@ export const DELETE = async (
input: { ids: [variantId] /* product_id: productId */ },
})
const product = await refetchEntity(
"product",
productId,
req.scope,
remapKeysForProduct(req.queryConfig.fields ?? [])
)
const product = await refetchEntity({
entity: "product",
idOrFilter: productId,
scope: req.scope,
fields: remapKeysForProduct(req.queryConfig.fields ?? []),
})
res.status(200).json({
id: variantId,
@@ -29,13 +29,13 @@ export const GET = async (
)
}
const { rows: variants, metadata } = await refetchEntities(
"variant",
{ ...req.filterableFields, product_id: productId },
req.scope,
remapKeysForVariant(req.queryConfig.fields ?? []),
req.queryConfig.pagination
)
const { data: variants, metadata } = await refetchEntities({
entity: "variant",
idOrFilter: { ...req.filterableFields, product_id: productId },
scope: req.scope,
fields: remapKeysForVariant(req.queryConfig.fields ?? []),
pagination: req.queryConfig.pagination,
})
if (withInventoryQuantity) {
await wrapVariantsWithTotalInventoryQuantity(req, variants || [])
@@ -69,12 +69,12 @@ export const POST = async (
input: { product_variants: input, additional_data },
})
const product = await refetchEntity(
"product",
productId,
req.scope,
remapKeysForProduct(req.queryConfig.fields ?? [])
)
const product = await refetchEntity({
entity: "product",
idOrFilter: productId,
scope: req.scope,
fields: remapKeysForProduct(req.queryConfig.fields ?? []),
})
res.status(200).json({ product: remapProductResponse(product) })
}
+14 -14
View File
@@ -41,14 +41,14 @@ async function getProducts(
) {
const selectFields = remapKeysForProduct(req.queryConfig.fields ?? [])
const { rows: products, metadata } = await refetchEntities(
"product",
req.filterableFields,
req.scope,
selectFields,
req.queryConfig.pagination,
req.queryConfig.withDeleted
)
const { data: products, metadata } = await refetchEntities({
entity: "product",
idOrFilter: req.filterableFields,
scope: req.scope,
fields: selectFields,
pagination: req.queryConfig.pagination,
withDeleted: req.queryConfig.withDeleted,
})
res.json({
products: products.map(remapProductResponse),
@@ -103,12 +103,12 @@ export const POST = async (
input: { products: [products], additional_data },
})
const product = await refetchEntity(
"product",
result[0].id,
req.scope,
remapKeysForProduct(req.queryConfig.fields ?? [])
)
const product = await refetchEntity({
entity: "product",
idOrFilter: result[0].id,
scope: req.scope,
fields: remapKeysForProduct(req.queryConfig.fields ?? []),
})
res.status(200).json({ product: remapProductResponse(product) })
}
@@ -14,12 +14,12 @@ export const GET = async (
req: AuthenticatedMedusaRequest,
res: MedusaResponse<RefundReasonResponse>
) => {
const refund_reason = await refetchEntity(
"refund_reason",
req.params.id,
req.scope,
req.queryConfig.fields
)
const refund_reason = await refetchEntity({
entity: "refund_reason",
idOrFilter: req.params.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.json({ refund_reason })
}
@@ -39,12 +39,12 @@ export const POST = async (
],
})
const refund_reason = await refetchEntity(
"refund_reason",
req.params.id,
req.scope,
req.queryConfig.fields
)
const refund_reason = await refetchEntity({
entity: "refund_reason",
idOrFilter: req.params.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.json({ refund_reason })
}
@@ -17,13 +17,13 @@ export const GET = async (
req: AuthenticatedMedusaRequest<HttpTypes.RefundReasonFilters>,
res: MedusaResponse<PaginatedResponse<RefundReasonsResponse>>
) => {
const { rows: refund_reasons, metadata } = await refetchEntities(
"refund_reasons",
req.filterableFields,
req.scope,
req.queryConfig.fields,
req.queryConfig.pagination
)
const { data: refund_reasons, metadata } = await refetchEntities({
entity: "refund_reasons",
idOrFilter: req.filterableFields,
scope: req.scope,
fields: req.queryConfig.fields,
pagination: req.queryConfig.pagination,
})
res.json({
refund_reasons,
@@ -43,12 +43,12 @@ export const POST = async (
input: { data: [req.validatedBody] },
})
const refund_reason = await refetchEntity(
"refund_reason",
refundReason.id,
req.scope,
req.queryConfig.fields
)
const refund_reason = await refetchEntity({
entity: "refund_reason",
idOrFilter: refundReason.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
res.status(200).json({ refund_reason })
}
@@ -18,12 +18,12 @@ export const GET = async (
req: AuthenticatedMedusaRequest,
res: MedusaResponse<AdminReturnReasonResponse>
) => {
const return_reason = await refetchEntity(
"return_reason",
req.params.id,
req.scope,
req.queryConfig.fields
)
const return_reason = await refetchEntity({
entity: "return_reason",
idOrFilter: req.params.id,
scope: req.scope,
fields: req.queryConfig.fields,
})
if (!return_reason) {
throw new MedusaError(
@@ -6,5 +6,5 @@ export const refetchOrder = async (
scope: MedusaContainer,
fields: string[]
) => {
return await refetchEntity("order", idOrFilter, scope, fields)
return await refetchEntity({ entity: "order", idOrFilter, scope, fields })
}
@@ -9,5 +9,10 @@ export const refetchPaymentCollection = async (
scope: MedusaContainer,
fields: string[]
): Promise<PaymentCollectionDTO> => {
return refetchEntity("payment_collection", id, scope, fields)
return refetchEntity({
entity: "payment_collection",
idOrFilter: id,
scope,
fields,
})
}
@@ -11,12 +11,12 @@ export const GET = async (
req: AuthenticatedMedusaRequest<StoreProductCategoryParamsType>,
res: MedusaResponse<StoreProductCategoryResponse>
) => {
const category = await refetchEntity(
"product_category",
{ id: req.params.id, ...req.filterableFields },
req.scope,
req.queryConfig.fields
)
const category = await refetchEntity({
entity: "product_category",
idOrFilter: { id: req.params.id, ...req.filterableFields },
scope: req.scope,
fields: req.queryConfig.fields,
})
if (!category) {
throw new MedusaError(
@@ -1,5 +1,6 @@
import { isPresent, MedusaError } from "@medusajs/framework/utils"
import { MedusaResponse } from "@medusajs/framework/http"
import { HttpTypes } from "@medusajs/framework/types"
import { isPresent, MedusaError, QueryContext } from "@medusajs/framework/utils"
import { wrapVariantsWithInventoryQuantityForSalesChannel } from "../../../utils/middlewares"
import {
filterOutInternalProductCategories,
@@ -7,7 +8,6 @@ import {
RequestWithContext,
wrapProductsWithTaxPrices,
} from "../helpers"
import { HttpTypes } from "@medusajs/framework/types"
export const GET = async (
req: RequestWithContext<HttpTypes.StoreProductParams>,
@@ -29,9 +29,11 @@ export const GET = async (
}
if (isPresent(req.pricingContext)) {
filters["context"] = {
"variants.calculated_price": { context: req.pricingContext },
}
filters["context"] ??= {}
filters["context"]["variants"] ??= {}
filters["context"]["variants"]["calculated_price"] ??= QueryContext(
req.pricingContext!
)
}
const includesCategoriesField = req.queryConfig.fields.some((field) =>
@@ -25,7 +25,7 @@ export const refetchProduct = async (
scope: MedusaContainer,
fields: string[]
) => {
return await refetchEntity("product", idOrFilter, scope, fields)
return await refetchEntity({ entity: "product", idOrFilter, scope, fields })
}
export const filterOutInternalProductCategories = (
+37 -25
View File
@@ -5,7 +5,6 @@ import {
FeatureFlag,
isPresent,
QueryContext,
remoteQueryObjectFromString,
} from "@medusajs/framework/utils"
import IndexEngineFeatureFlag from "../../../feature-flags/index-engine"
import { wrapVariantsWithInventoryQuantityForSalesChannel } from "../../utils/middlewares"
@@ -49,7 +48,9 @@ async function getProductsWithIndexEngine(
if (isPresent(req.pricingContext)) {
context["variants"] ??= {}
context["variants"]["calculated_price"] = QueryContext(req.pricingContext!)
context["variants"]["calculated_price"] ??= QueryContext(
req.pricingContext!
)
}
const filters: Record<string, any> = req.filterableFields
@@ -62,13 +63,20 @@ async function getProductsWithIndexEngine(
delete filters.sales_channel_id
}
const { data: products = [], metadata } = await query.index({
entity: "product",
fields: req.queryConfig.fields,
filters,
pagination: req.queryConfig.pagination,
context,
})
const { data: products = [], metadata } = await query.index(
{
entity: "product",
fields: req.queryConfig.fields,
filters,
pagination: req.queryConfig.pagination,
context,
},
{
cache: {
enable: true,
},
}
)
if (withInventoryQuantity) {
await wrapVariantsWithInventoryQuantityForSalesChannel(
@@ -91,7 +99,7 @@ async function getProducts(
req: RequestWithContext<HttpTypes.StoreProductListParams>,
res: MedusaResponse<HttpTypes.StoreProductListResponse>
) {
const remoteQuery = req.scope.resolve(ContainerRegistrationKeys.REMOTE_QUERY)
const query = req.scope.resolve(ContainerRegistrationKeys.QUERY)
const context: object = {}
const withInventoryQuantity = req.queryConfig.fields.some((field) =>
field.includes("variants.inventory_quantity")
@@ -104,22 +112,26 @@ async function getProducts(
}
if (isPresent(req.pricingContext)) {
context["variants.calculated_price"] = {
context: req.pricingContext,
}
context["variants"] ??= {}
context["variants"]["calculated_price"] ??= QueryContext(
req.pricingContext!
)
}
const queryObject = remoteQueryObjectFromString({
entryPoint: "product",
variables: {
const { data: products = [], metadata } = await query.graph(
{
entity: "product",
fields: req.queryConfig.fields,
filters: req.filterableFields,
...req.queryConfig.pagination,
...context,
pagination: req.queryConfig.pagination,
context,
},
fields: req.queryConfig.fields,
})
const { rows: products, metadata } = await remoteQuery(queryObject)
{
cache: {
enable: true,
},
}
)
if (withInventoryQuantity) {
await wrapVariantsWithInventoryQuantityForSalesChannel(
@@ -131,8 +143,8 @@ async function getProducts(
await wrapProductsWithTaxPrices(req, products)
res.json({
products,
count: metadata.count,
offset: metadata.skip,
limit: metadata.take,
count: metadata!.count,
offset: metadata!.skip,
limit: metadata!.take,
})
}
@@ -39,12 +39,12 @@ export function normalizeDataForContext() {
// If the cart is passed, get the information from it
if (req.filterableFields.cart_id) {
const cart = await refetchEntity(
"cart",
req.filterableFields.cart_id,
req.scope,
["region_id", "shipping_address.*"]
)
const cart = await refetchEntity({
entity: "cart",
idOrFilter: req.filterableFields.cart_id,
scope: req.scope,
fields: ["region_id", "shipping_address.*"],
})
if (cart?.region_id) {
regionId = cart.region_id
@@ -58,9 +58,16 @@ export function normalizeDataForContext() {
// Finally, try to get it from the store defaults if not available
if (!regionId) {
const stores = await refetchEntities("store", {}, req.scope, [
"default_region_id",
])
const stores = await refetchEntities({
entity: "store",
scope: req.scope,
fields: ["id", "default_region_id"],
options: {
cache: {
enable: true,
},
},
})
regionId = stores[0]?.default_region_id
}
@@ -17,12 +17,17 @@ export function setPricingContext() {
}
// We validate the region ID in the previous middleware
const region = await refetchEntity(
"region",
req.filterableFields.region_id!,
req.scope,
["id", "currency_code"]
)
const region = await refetchEntity({
entity: "region",
idOrFilter: req.filterableFields.region_id!,
scope: req.scope,
fields: ["id", "currency_code"],
options: {
cache: {
enable: true,
},
},
})
if (!region) {
try {
@@ -42,12 +47,12 @@ export function setPricingContext() {
// Find all the customer groups the customer is a part of and set
if (req.auth_context?.actor_id) {
const customerGroups = await refetchEntities(
"customer_group",
{ customers: { id: req.auth_context.actor_id } },
req.scope,
["id"]
)
const { data: customerGroups } = await refetchEntities({
entity: "customer_group",
idOrFilter: { customers: { id: req.auth_context.actor_id } },
scope: req.scope,
fields: ["id"],
})
pricingContext.customer = { groups: [] }
customerGroups.map((cg) =>
@@ -38,12 +38,12 @@ export function setTaxContext() {
}
const getTaxInclusivityInfo = async (req: MedusaRequest) => {
const region = await refetchEntity(
"region",
req.filterableFields.region_id as string,
req.scope,
["automatic_taxes"]
)
const region = await refetchEntity({
entity: "region",
idOrFilter: req.filterableFields.region_id as string,
scope: req.scope,
fields: ["automatic_taxes"],
})
if (!region) {
throw new MedusaError(
@@ -0,0 +1,10 @@
import { FlagSettings } from "@medusajs/framework/feature-flags"
const CachingFeatureFlag: FlagSettings = {
key: "caching",
default_val: false,
env_key: "MEDUSA_FF_CACHING",
description: "[WIP] Enable core caching where applicable",
}
export default CachingFeatureFlag
+116 -4
View File
@@ -1,4 +1,3 @@
import { snakeCase } from "lodash"
import {
MedusaNextFunction,
MedusaRequest,
@@ -6,10 +5,15 @@ import {
Query,
} from "@medusajs/framework"
import { ApiLoader } from "@medusajs/framework/http"
import { Tracer } from "@medusajs/framework/telemetry"
import type { SpanExporter } from "@opentelemetry/sdk-trace-node"
import type { NodeSDKConfiguration } from "@opentelemetry/sdk-node"
import { TransactionOrchestrator } from "@medusajs/framework/orchestration"
import { Tracer } from "@medusajs/framework/telemetry"
import { FeatureFlag } from "@medusajs/framework/utils"
import { SpanStatusCode } from "@opentelemetry/api"
import type { NodeSDKConfiguration } from "@opentelemetry/sdk-node"
import type { SpanExporter } from "@opentelemetry/sdk-trace-node"
import { snakeCase } from "lodash"
import CacheModule from "../modules/caching"
import { ICachingModuleService } from "@medusajs/framework/types"
const EXCLUDED_RESOURCES = [".vite", "virtual:"]
@@ -261,6 +265,110 @@ export function instrumentWorkflows() {
}
}
export function instrumentCache() {
if (!FeatureFlag.isFeatureEnabled("caching")) {
return
}
const CacheTracer = new Tracer("@medusajs/caching", "2.0.0")
const cacheModule_ = CacheModule as unknown as {
service: ICachingModuleService & {
traceGet: (
cacheGetFn: () => Promise<any>,
key: string,
tags: string[]
) => Promise<any>
traceSet: (
cacheSetFn: () => Promise<any>,
key: string,
tags: string[],
options: { autoInvalidate?: boolean }
) => Promise<any>
traceClear: (
cacheClearFn: () => Promise<any>,
key: string,
tags: string[],
options: { autoInvalidate?: boolean }
) => Promise<any>
}
}
cacheModule_.service.traceGet = async function (cacheGetFn, key, tags) {
return await CacheTracer.trace(`cache.get`, async (span) => {
span.setAttributes({
"cache.key": key,
"cache.tags": tags,
})
try {
return await cacheGetFn()
} catch (error) {
span.setStatus({
code: SpanStatusCode.ERROR,
message: error.message,
})
throw error
} finally {
span.end()
}
})
}
cacheModule_.service.traceSet = async function (
cacheSetFn,
key,
tags,
options = {}
) {
return await CacheTracer.trace(`cache.set`, async (span) => {
span.setAttributes({
"cache.key": key,
"cache.tags": tags,
"cache.options": JSON.stringify(options),
})
try {
return await cacheSetFn()
} catch (error) {
span.setStatus({
code: SpanStatusCode.ERROR,
message: error.message,
})
throw error
} finally {
span.end()
}
})
}
cacheModule_.service.traceClear = async function (
cacheClearFn,
key,
tags,
options = {}
) {
return await CacheTracer.trace(`cache.clear`, async (span) => {
span.setAttributes({
"cache.key": key,
"cache.tags": tags,
"cache.options": JSON.stringify(options),
})
try {
return await cacheClearFn()
} catch (error) {
span.setStatus({
code: SpanStatusCode.ERROR,
message: error.message,
})
throw error
} finally {
span.end()
}
})
}
}
/**
* A helper function to configure the OpenTelemetry SDK with some defaults.
* For better/more control, please configure the SDK manually.
@@ -283,6 +391,7 @@ export function registerOtel(
query: boolean
workflows: boolean
db: boolean
cache: boolean
}>
}
) {
@@ -318,6 +427,9 @@ export function registerOtel(
if (instrument.workflows) {
instrumentWorkflows()
}
if (instrument.cache) {
instrumentCache()
}
const sdk = new NodeSDK({
serviceName,
@@ -0,0 +1,6 @@
import RedisCachingProvider from "@medusajs/caching-redis"
export * from "@medusajs/caching-redis"
export default RedisCachingProvider
export const discoveryPath = require.resolve("@medusajs/caching-redis")
+6
View File
@@ -0,0 +1,6 @@
import CacheModule from "@medusajs/caching"
export * from "@medusajs/caching"
export default CacheModule
export const discoveryPath = require.resolve("@medusajs/caching")
+6
View File
@@ -0,0 +1,6 @@
/dist
node_modules
.DS_store
.env*
.env
*.sql
+1
View File
@@ -0,0 +1 @@
# @medusajs/caching
@@ -0,0 +1,51 @@
import {
EventBusTypes,
IEventBusModuleService,
Message,
Subscriber,
} from "@medusajs/types"
export class EventBusServiceMock implements IEventBusModuleService {
protected readonly subscribers_: Map<string | symbol, Set<Subscriber>> =
new Map()
async emit<T>(
messages: Message<T> | Message<T>[],
options?: Record<string, unknown>
): Promise<void> {
const messages_ = Array.isArray(messages) ? messages : [messages]
for (const message of messages_) {
const subscribers = this.subscribers_.get(message.name)
const starSubscribers = this.subscribers_.get("*")
for (const subscriber of [
...(subscribers ?? []),
...(starSubscribers ?? []),
]) {
const { options, ...payload } = message
await subscriber(payload)
}
}
}
subscribe(event: string | symbol, subscriber: Subscriber): this {
this.subscribers_.set(event, new Set([subscriber]))
return this
}
unsubscribe(
event: string | symbol,
subscriber: Subscriber,
context?: EventBusTypes.SubscriberContext
): this {
return this
}
releaseGroupedEvents(eventGroupId: string): Promise<void> {
throw new Error("Method not implemented.")
}
clearGroupedEvents(eventGroupId: string): Promise<void> {
throw new Error("Method not implemented.")
}
}
@@ -0,0 +1,336 @@
import { Modules } from "@medusajs/framework/utils"
import { moduleIntegrationTestRunner } from "@medusajs/test-utils"
import { ICachingModuleService } from "@medusajs/framework/types"
import { MedusaModule } from "@medusajs/framework/modules-sdk"
jest.setTimeout(10000)
jest.spyOn(MedusaModule, "getAllJoinerConfigs").mockReturnValue([
{
schema: `
type Product {
id: ID
title: String
handle: String
status: String
type_id: String
collection_id: String
is_giftcard: Boolean
external_id: String
created_at: DateTime
updated_at: DateTime
variants: [ProductVariant]
sales_channels: [SalesChannel]
}
type ProductVariant {
id: ID
product_id: String
sku: String
prices: [Price]
}
type Price {
id: ID
amount: Float
currency_code: String
}
type SalesChannel {
id: ID
is_disabled: Boolean
}
`,
},
])
moduleIntegrationTestRunner<ICachingModuleService>({
moduleName: Modules.CACHING,
testSuite: ({ service }) => {
describe("Caching Module Service", () => {
beforeEach(async () => {
await service.clear({ tags: ["*"] }).catch(() => {})
})
describe("Basic Cache Operations", () => {
it("should set and get cache data with default memory provider", async () => {
const testData = { id: "test-id", name: "Test Item" }
await service.set({
key: "test-key",
data: testData,
ttl: 3600,
})
const result = await service.get({ key: "test-key" })
expect(result).toEqual(testData)
})
it("should return null for non-existent keys", async () => {
const result = await service.get({ key: "non-existent" })
expect(result).toBeNull()
})
it("should handle tags-based storage and retrieval", async () => {
const testData1 = { id: "1", name: "Item 1" }
const testData2 = { id: "2", name: "Item 2" }
await service.set({
key: "item-1",
data: testData1,
tags: ["product", "active"],
})
await service.set({
key: "item-2",
data: testData2,
tags: ["product", "inactive"],
})
const productResults = await service.get({ tags: ["product"] })
expect(productResults).toHaveLength(2)
expect(productResults).toContainEqual(testData1)
expect(productResults).toContainEqual(testData2)
const activeResults = await service.get<any[]>({ tags: ["active"] })
expect(activeResults).toHaveLength(1)
expect(activeResults?.[0]).toEqual(testData1)
})
it("should clear cache by key", async () => {
await service.set({
key: "test-key",
data: { value: "test" },
})
await service.clear({ key: "test-key" })
const result = await service.get({ key: "test-key" })
expect(result).toBeNull()
})
it("should clear cache by tags", async () => {
await service.set({
key: "item-1",
data: { id: "1" },
tags: ["category-a"],
})
await service.set({
key: "item-2",
data: { id: "2" },
tags: ["category-b"],
})
await service.clear({ tags: ["category-a"] })
const result1 = await service.get({ key: "item-1" })
const result2 = await service.get({ key: "item-2" })
expect(result1).toBeNull()
expect(result2).toEqual({ id: "2" })
})
})
describe("Provider Priority", () => {
it("should check providers in order of priority when specified", async () => {
const testData = { id: "priority-test", name: "Priority Test" }
await service.set({
key: "priority-key",
data: testData,
providers: ["cache-memory"],
})
const result = await service.get({
key: "priority-key",
providers: ["cache-memory"],
})
expect(result).toEqual(testData)
})
it("should return null when providers array is empty or invalid", async () => {
const result = await service.get({
key: "test-key",
providers: [],
})
expect(result).toBeNull()
})
})
describe("Promise Deduplication", () => {
it("should deduplicate concurrent get requests with same parameters", async () => {
const testData = { id: "concurrent-test", name: "Concurrent Test" }
await service.set({
key: "concurrent-key",
data: testData,
})
const promises = Array.from({ length: 5 }, () =>
service.get<any>({ key: "concurrent-key" })
)
const results = await Promise.all(promises)
results.forEach((result) => {
expect(result).toEqual(testData)
})
})
it("should deduplicate concurrent get requests with same tags", async () => {
const testData = { id: "tag-test", name: "Tag Test" }
await service.set({
key: "tag-key",
data: testData,
tags: ["concurrent-tag"],
})
const promises = Array.from({ length: 5 }, () =>
service.get<any[]>({ tags: ["concurrent-tag"] })
)
const results = await Promise.all(promises)
results.forEach((result) => {
expect(result).toHaveLength(1)
expect(result?.[0]).toEqual(testData)
})
})
it("should deduplicate concurrent clear requests", async () => {
await service.set({
key: "clear-test-1",
data: { id: "1" },
tags: ["clear-tag"],
})
await service.set({
key: "clear-test-2",
data: { id: "2" },
tags: ["clear-tag"],
})
const promises = Array.from({ length: 3 }, () =>
service.clear({ tags: ["clear-tag"] })
)
await Promise.all(promises)
const result1 = await service.get({ key: "clear-test-1" })
const result2 = await service.get({ key: "clear-test-2" })
expect(result1).toBeNull()
expect(result2).toBeNull()
})
it("should handle concurrent requests with different parameters separately", async () => {
const testData1 = { id: "1", name: "Item 1" }
const testData2 = { id: "2", name: "Item 2" }
await service.set({ key: "key-1", data: testData1 })
await service.set({ key: "key-2", data: testData2 })
const promises = [
service.get({ key: "key-1" }),
service.get({ key: "key-1" }),
service.get({ key: "key-2" }),
service.get({ key: "key-2" }),
]
const results = await Promise.all(promises)
expect(results[0]).toEqual(testData1)
expect(results[1]).toEqual(testData1)
expect(results[2]).toEqual(testData2)
expect(results[3]).toEqual(testData2)
})
})
describe("Memory Cache Provider Integration", () => {
it("should respect TTL settings", async () => {
const testData = { id: "ttl-test", name: "TTL Test" }
await service.set({
key: "ttl-key",
data: testData,
ttl: 1,
})
let result = await service.get({ key: "ttl-key" })
expect(result).toEqual(testData)
await new Promise((resolve) => setTimeout(resolve, 1100))
result = await service.get({ key: "ttl-key" })
expect(result).toBeNull()
})
it("should handle autoInvalidate option", async () => {
const testData = { id: "no-auto-test", name: "No Auto Test" }
await service.set({
key: "no-auto-key",
data: testData,
tags: ["no-auto-tag"],
options: { autoInvalidate: false },
})
await service.clear({
tags: ["no-auto-tag"],
options: { autoInvalidate: true },
})
const result = await service.get({ key: "no-auto-key" })
expect(result).toEqual(testData)
await service.clear({
tags: ["no-auto-tag"],
})
const result2 = await service.get({ key: "no-auto-key" })
expect(result2).toBeNull()
})
it("should generate consistent cache keys", async () => {
const testInput = { userId: "123", action: "view" }
const key1 = await service.computeKey(testInput)
const key2 = await service.computeKey(testInput)
expect(key1).toBe(key2)
expect(typeof key1).toBe("string")
expect(key1.length).toBeGreaterThan(0)
})
it("should generate cache tags", async () => {
const testInput = { id: "prod_1", title: "123", description: "456" }
const tags = await service.computeTags(testInput)
expect(Array.isArray(tags)).toBe(true)
expect(tags.length).toBeGreaterThan(0)
})
})
describe("Error Handling", () => {
it("should throw error when neither key nor tags provided to get", async () => {
await expect(service.get({})).rejects.toThrow(
"Either key or tags must be provided"
)
})
it("should throw error when neither key nor tags provided to clear", async () => {
await expect(service.clear({})).rejects.toThrow(
"Either key or tags must be provided"
)
})
})
})
},
})
@@ -0,0 +1,430 @@
import { Modules } from "@medusajs/framework/utils"
import { moduleIntegrationTestRunner } from "@medusajs/test-utils"
import { ICachingModuleService } from "@medusajs/framework/types"
import { MedusaModule } from "@medusajs/framework/modules-sdk"
import { EventBusServiceMock } from "../__fixtures__/event-bus-mock"
jest.setTimeout(30000)
jest.spyOn(MedusaModule, "getAllJoinerConfigs").mockReturnValue([
{
schema: `
type Product {
id: ID
title: String
handle: String
status: String
type_id: String
collection_id: String
is_giftcard: Boolean
external_id: String
created_at: DateTime
updated_at: DateTime
variants: [ProductVariant]
sales_channels: [SalesChannel]
}
type ProductVariant {
id: ID
product_id: String
sku: String
prices: [Price]
}
type Price {
id: ID
amount: Float
currency_code: String
variant_id: String
}
type SalesChannel {
id: ID
is_disabled: Boolean
}
type ProductCollection {
id: ID
title: String
handle: String
}
`,
},
])
const mockEventBus = new EventBusServiceMock()
moduleIntegrationTestRunner<ICachingModuleService>({
moduleName: Modules.CACHING,
injectedDependencies: {
[Modules.EVENT_BUS]: mockEventBus,
},
testSuite: ({ service }) => {
describe("Cache Invalidation with Entity Relationships", () => {
afterEach(async () => {
await service.clear({ tags: ["*"] }).catch(() => {})
})
describe("Single Entity Caching", () => {
it("should cache and retrieve a single product entity using computed keys", async () => {
const product = {
id: "prod_1",
title: "Test Product",
handle: "test-product",
status: "published",
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
}
const productKey = await service.computeKey(product)
await service.set({
key: productKey,
data: product,
})
const cachedProduct = await service.get({ key: productKey })
expect(cachedProduct).toEqual(product)
})
it("should auto-invalidate single entity when strategy clears computed tags", async () => {
const product = {
id: "prod_1",
title: "Test Product",
handle: "test-product",
}
const productKey = await service.computeKey(product)
await service.set({
key: productKey,
data: product,
})
await mockEventBus.emit(
[{ name: "product.updated", data: { id: product.id } }],
{}
)
const result = await service.get({ key: productKey })
expect(result).toBeNull()
})
it("should not auto-invalidate single entity with autoInvalidate=false", async () => {
const product = {
id: "prod_1",
title: "Test Product",
handle: "test-product",
}
const productKey = await service.computeKey(product)
await service.set({
key: productKey,
data: product,
options: { autoInvalidate: false },
})
await mockEventBus.emit(
[{ name: "product.updated", data: { id: product.id } }],
{}
)
const result = await service.get({ key: productKey })
expect(result).toEqual(product)
})
})
describe("Entity List Caching", () => {
it("should cache and retrieve lists of entities using computed keys", async () => {
const publishedProductsQuery = {
entity: "product",
filters: { status: "published" },
fields: ["id", "title", "status"],
}
const allProductsQuery = {
entity: "product",
filters: {},
fields: ["id", "title", "status"],
}
const publishedProducts = [
{ id: "prod_1", title: "Product 1", status: "published" },
{ id: "prod_2", title: "Product 2", status: "published" },
]
const allProducts = [
...publishedProducts,
{ id: "prod_3", title: "Product 3", status: "draft" },
]
const publishedProductsKey = await service.computeKey(
publishedProductsQuery
)
const allProductsKey = await service.computeKey(allProductsQuery)
await service.set({
key: publishedProductsKey,
data: publishedProducts,
})
await service.set({
key: allProductsKey,
data: allProducts,
})
const cachedPublished = await service.get({
key: publishedProductsKey,
})
const cachedAll = await service.get({ key: allProductsKey })
expect(cachedPublished).toEqual(publishedProducts)
expect(cachedAll).toEqual(allProducts)
})
it("should invalidate related lists when individual product is updated", async () => {
const listQuery = {
entity: "product",
filters: { status: "published" },
includes: ["id", "title"],
}
const products = [
{ id: "prod_1", title: "Product 1", status: "published" },
{ id: "prod_2", title: "Product 2", status: "published" },
]
const listKey = await service.computeKey(listQuery)
await service.set({
key: listKey,
data: products,
})
await mockEventBus.emit(
[
{
name: "product.updated",
data: { id: "prod_1", title: "Updated Product 1" },
},
],
{}
)
const cachedList = await service.get({ key: listKey })
expect(cachedList).toBeNull()
})
})
describe("Nested Entity Caching", () => {
it("should cache products with nested variants and prices using computed keys", async () => {
const productWithVariants = {
id: "prod_1",
title: "Complex Product",
variants: [
{
id: "var_1",
product_id: "prod_1",
sku: "SKU-001",
prices: [
{
id: "price_1",
variant_id: "var_1",
amount: 1000,
currency_code: "USD",
},
{
id: "price_2",
variant_id: "var_1",
amount: 900,
currency_code: "EUR",
},
],
},
{
id: "var_2",
product_id: "prod_1",
sku: "SKU-002",
prices: [
{
id: "price_3",
variant_id: "var_2",
amount: 1200,
currency_code: "USD",
},
],
},
],
}
const productKey = await service.computeKey(productWithVariants)
await service.set({
key: productKey,
data: productWithVariants,
})
const cached = await service.get<typeof productWithVariants>({
key: productKey,
})
expect(cached).toEqual(productWithVariants)
expect(cached!.variants).toHaveLength(2)
expect(cached!.variants[0].prices).toHaveLength(2)
})
it("should invalidate nested product when related variant is updated", async () => {
const productWithVariants = {
id: "prod_1",
title: "Complex Product",
variants: [{ id: "var_1", product_id: "prod_1", sku: "SKU-001" }],
}
const productKey = await service.computeKey(productWithVariants)
await service.set({
key: productKey,
data: productWithVariants,
})
await mockEventBus.emit(
[
{
name: "product_variant.updated",
data: {
id: "var_1",
product_id: "prod_1",
sku: "SKU-001-UPDATED",
},
},
],
{}
)
const result = await service.get({ key: productKey })
expect(result).toBeNull()
})
it("should handle price updates affecting variant and product caches", async () => {
const price = {
id: "price_1",
variant_id: "var_1",
amount: 1000,
currency_code: "USD",
}
const variant = {
id: "var_1",
product_id: "prod_1",
sku: "SKU-001",
prices: [price],
}
const product = {
id: "prod_1",
title: "Product",
variants: [variant],
}
const priceKey = await service.computeKey(price)
const variantKey = await service.computeKey(variant)
const productKey = await service.computeKey(product)
await service.set({
key: priceKey,
data: price,
})
await service.set({
key: variantKey,
data: variant,
})
await service.set({
key: productKey,
data: product,
})
await mockEventBus.emit(
[
{
name: "price.updated",
data: { id: "price_1", variant_id: "var_1", amount: 1100 },
},
],
{}
)
const cachedPrice = await service.get({ key: priceKey })
const cachedVariant = await service.get({ key: variantKey })
const cachedProduct = await service.get({ key: productKey })
expect(cachedPrice).toBeNull()
expect(cachedVariant).toBeNull()
expect(cachedProduct).toBeNull()
})
})
describe("Complex Query Caching", () => {
it("should cache complex queries and invalidate based on entity relationships", async () => {
const complexQuery = {
entity: "product",
filters: { status: "published", collection_id: "col_1" },
includes: {
variants: {
include: {
prices: true,
},
},
collection: true,
},
pagination: { limit: 10, offset: 0 },
}
const queryResult = {
data: [
{
id: "prod_1",
title: "Product 1",
collection_id: "col_1",
variants: [
{
id: "var_1",
product_id: "prod_1",
prices: [{ id: "price_1", amount: 1000 }],
},
],
collection: { id: "col_1", title: "Collection 1" },
},
],
pagination: { total: 1, limit: 10, offset: 0 },
}
const queryKey = await service.computeKey(complexQuery)
await service.set({
key: queryKey,
data: queryResult,
})
const cached = await service.get({ key: queryKey })
expect(cached).toEqual(queryResult)
await mockEventBus.emit(
[
{
name: "price.updated",
data: { id: "price_1", amount: 1100 },
},
],
{}
)
const cachedAfterUpdate = await service.get({ key: queryKey })
expect(cachedAfterUpdate).toBeNull()
})
})
})
},
})
@@ -0,0 +1,540 @@
import { MedusaModule } from "@medusajs/framework/modules-sdk"
import { ICachingModuleService } from "@medusajs/framework/types"
import { Modules } from "@medusajs/framework/utils"
import { moduleIntegrationTestRunner } from "@medusajs/test-utils"
import { setTimeout } from "timers/promises"
import { EventBusServiceMock } from "../../__fixtures__/event-bus-mock"
jest.setTimeout(300000)
jest.spyOn(MedusaModule, "getAllJoinerConfigs").mockReturnValue([
{
schema: `
type Product {
id: ID
title: String
handle: String
status: String
type_id: String
collection_id: String
is_giftcard: Boolean
external_id: String
created_at: DateTime
updated_at: DateTime
variants: [ProductVariant]
sales_channels: [SalesChannel]
}
type ProductVariant {
id: ID
product_id: String
sku: String
prices: [Price]
}
type Price {
id: ID
amount: Float
currency_code: String
variant_id: String
}
type SalesChannel {
id: ID
is_disabled: Boolean
}
type ProductCollection {
id: ID
title: String
handle: String
}
`,
},
])
const DEFAULT_WAIT_INTERVAL = 50
const mockEventBus = new EventBusServiceMock()
moduleIntegrationTestRunner<ICachingModuleService>({
moduleName: Modules.CACHING,
injectedDependencies: {
[Modules.EVENT_BUS]: mockEventBus,
},
moduleOptions: {
providers: [
{
id: "cache-redis",
resolve: require.resolve("../../../../providers/caching-redis/src"),
is_default: true,
options: {
redisUrl: "localhost:6379",
},
},
],
},
testSuite: ({ service }) => {
describe("Cache Invalidation with Entity Relationships", () => {
afterEach(async () => {
await service.clear({ tags: ["*"] }).catch(() => {})
})
describe("Single Entity Caching", () => {
it("should cache and retrieve a single product entity using computed keys", async () => {
const product = {
id: "prod_1",
title: "Test Product",
handle: "test-product",
status: "published",
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
}
const productKey = await service.computeKey(product)
await service.set({
key: productKey,
data: product,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
const cachedProduct = await service.get({ key: productKey })
expect(cachedProduct).toEqual(product)
})
it("should auto-invalidate single entity when strategy clears computed tags", async () => {
const product = {
id: "prod_1",
title: "Test Product",
handle: "test-product",
}
const productKey = await service.computeKey(product)
await service.set({
key: productKey,
data: product,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
await mockEventBus.emit(
[{ name: "product.updated", data: { id: product.id } }],
{}
)
await setTimeout(DEFAULT_WAIT_INTERVAL)
const result = await service.get({ key: productKey })
expect(result).toBeNull()
})
it("should not auto-invalidate single entity with autoInvalidate=false", async () => {
const product = {
id: "prod_1",
title: "Test Product",
handle: "test-product",
}
const productKey = await service.computeKey(product)
await service.set({
key: productKey,
data: product,
options: { autoInvalidate: false },
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
await mockEventBus.emit(
[{ name: "product.updated", data: { id: product.id } }],
{}
)
await setTimeout(DEFAULT_WAIT_INTERVAL)
const result = await service.get({ key: productKey })
expect(result).toEqual(product)
})
})
describe("Entity List Caching", () => {
it("should cache and retrieve lists of entities using computed keys", async () => {
const publishedProductsQuery = {
entity: "product",
filters: { status: "published" },
fields: ["id", "title", "status"],
}
const allProductsQuery = {
entity: "product",
filters: {},
fields: ["id", "title", "status"],
}
const publishedProducts = [
{ id: "prod_1", title: "Product 1", status: "published" },
{ id: "prod_2", title: "Product 2", status: "published" },
]
const allProducts = [
...publishedProducts,
{ id: "prod_3", title: "Product 3", status: "draft" },
]
const publishedProductsKey = await service.computeKey(
publishedProductsQuery
)
const allProductsKey = await service.computeKey(allProductsQuery)
await service.set({
key: publishedProductsKey,
data: publishedProducts,
})
await service.set({
key: allProductsKey,
data: allProducts,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
const cachedPublished = await service.get({
key: publishedProductsKey,
})
const cachedAll = await service.get({ key: allProductsKey })
expect(cachedPublished).toEqual(publishedProducts)
expect(cachedAll).toEqual(allProducts)
})
it("should invalidate related lists when individual product is updated", async () => {
const listQuery = {
entity: "product",
filters: { status: "published" },
includes: ["id", "title"],
}
const products = [
{ id: "prod_1", title: "Product 1", status: "published" },
{ id: "prod_2", title: "Product 2", status: "published" },
]
const listKey = await service.computeKey(listQuery)
await service.set({
key: listKey,
data: products,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
await mockEventBus.emit(
[
{
name: "product.updated",
data: { id: "prod_1", title: "Updated Product 1" },
},
],
{}
)
await setTimeout(DEFAULT_WAIT_INTERVAL)
const cachedList = await service.get({ key: listKey })
expect(cachedList).toBeNull()
})
})
describe("Nested Entity Caching", () => {
it("should cache products with nested variants and prices using computed keys", async () => {
const productWithVariants = {
id: "prod_1",
title: "Complex Product",
variants: [
{
id: "var_1",
product_id: "prod_1",
sku: "SKU-001",
prices: [
{
id: "price_1",
variant_id: "var_1",
amount: 1000,
currency_code: "USD",
},
{
id: "price_2",
variant_id: "var_1",
amount: 900,
currency_code: "EUR",
},
],
},
{
id: "var_2",
product_id: "prod_1",
sku: "SKU-002",
prices: [
{
id: "price_3",
variant_id: "var_2",
amount: 1500,
currency_code: "USD",
},
],
},
],
}
const productKey = await service.computeKey(productWithVariants)
await service.set({
key: productKey,
data: productWithVariants,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
const cached = await service.get<typeof productWithVariants>({
key: productKey,
})
expect(cached).toEqual(productWithVariants)
expect(cached!.variants).toHaveLength(2)
expect(cached!.variants[0].prices).toHaveLength(2)
})
it("should invalidate nested product when related variant is updated", async () => {
const productWithVariants = {
id: "prod_1",
title: "Complex Product",
variants: [{ id: "var_1", product_id: "prod_1", sku: "SKU-001" }],
}
const productKey = await service.computeKey(productWithVariants)
await service.set({
key: productKey,
data: productWithVariants,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
await mockEventBus.emit(
[
{
name: "product_variant.updated",
data: {
id: "var_1",
product_id: "prod_1",
sku: "SKU-001-UPDATED",
},
},
],
{}
)
await setTimeout(DEFAULT_WAIT_INTERVAL)
const result = await service.get({ key: productKey })
expect(result).toBeNull()
})
it("should handle price updates affecting variant and product caches", async () => {
const price = {
id: "price_1",
variant_id: "var_1",
amount: 1000,
currency_code: "USD",
}
const variant = {
id: "var_1",
product_id: "prod_1",
sku: "SKU-001",
prices: [price],
}
const product = {
id: "prod_1",
title: "Product",
variants: [variant],
}
const priceKey = await service.computeKey(price)
const variantKey = await service.computeKey(variant)
const productKey = await service.computeKey(product)
await service.set({
key: priceKey,
data: price,
})
await service.set({
key: variantKey,
data: variant,
})
await service.set({
key: productKey,
data: product,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
await mockEventBus.emit(
[
{
name: "price.updated",
data: { id: "price_1", variant_id: "var_1", amount: 1100 },
},
],
{}
)
await setTimeout(DEFAULT_WAIT_INTERVAL)
const cachedPrice = await service.get({ key: priceKey })
const cachedVariant = await service.get({ key: variantKey })
const cachedProduct = await service.get({ key: productKey })
expect(cachedPrice).toBeNull()
expect(cachedVariant).toBeNull()
expect(cachedProduct).toBeNull()
})
})
describe("Complex Query Caching", () => {
it("should cache complex queries and invalidate based on entity relationships", async () => {
const complexQuery = {
entity: "product",
filters: { status: "published", collection_id: "col_1" },
includes: {
variants: {
include: {
prices: true,
},
},
collection: true,
},
pagination: { limit: 10, offset: 0 },
}
const queryResult = {
data: [
{
id: "prod_1",
title: "Product 1",
collection_id: "col_1",
variants: [
{
id: "var_1",
product_id: "prod_1",
prices: [{ id: "price_1", amount: 1000 }],
},
],
collection: { id: "col_1", title: "Collection 1" },
},
],
pagination: { total: 1, limit: 10, offset: 0 },
}
const queryKey = await service.computeKey(complexQuery)
await service.set({
key: queryKey,
data: queryResult,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
const cached = await service.get({ key: queryKey })
expect(cached).toEqual(queryResult)
await mockEventBus.emit(
[
{
name: "price.updated",
data: { id: "price_1", amount: 1100 },
},
],
{}
)
await setTimeout(DEFAULT_WAIT_INTERVAL)
const cachedAfterUpdate = await service.get({ key: queryKey })
expect(cachedAfterUpdate).toBeNull()
})
})
it("should cache complex queries and return the correct cached data", async () => {
const complexQuery = {
entity: "product",
filters: { status: "published", collection_id: "col_1" },
includes: {
variants: {
include: {
prices: true,
},
},
collection: true,
},
pagination: { limit: 10, offset: 0 },
}
const queryResult = {
data: [
{
id: "prod_1",
title: "Product 1",
collection_id: "col_1",
variants: [
{
id: "var_1",
product_id: "prod_1",
prices: [{ id: "price_1", amount: 1000 }],
},
],
collection: { id: "col_1", title: "Collection 1" },
},
],
pagination: { total: 1, limit: 10, offset: 0 },
}
const queryKey = await service.computeKey(complexQuery)
await service.set({
key: queryKey,
data: queryResult,
})
await setTimeout(DEFAULT_WAIT_INTERVAL)
const cached = await service.get({ key: queryKey })
expect(cached).toEqual(queryResult)
const complexQueryOffset = {
entity: "product",
filters: { status: "published", collection_id: "col_1" },
includes: {
variants: {
include: {
prices: true,
},
},
collection: true,
},
pagination: { limit: 10, offset: 10 },
}
const queryKeyOffset = await service.computeKey(complexQueryOffset)
const cachedOffset = await service.get({ key: queryKeyOffset })
expect(cachedOffset).toEqual(null)
})
})
},
})
+8
View File
@@ -0,0 +1,8 @@
const defineJestConfig = require("../../../define_jest_config")
module.exports = defineJestConfig({
moduleNameMapper: {
"^@services": "<rootDir>/src/services",
"^@types": "<rootDir>/src/types",
"^@utils": "<rootDir>/src/utils",
},
})
+49
View File
@@ -0,0 +1,49 @@
{
"name": "@medusajs/caching",
"version": "2.10.3",
"description": "Caching Module for Medusa",
"main": "dist/index.js",
"repository": {
"type": "git",
"url": "https://github.com/medusajs/medusa",
"directory": "packages/modules/caching"
},
"files": [
"dist",
"!dist/**/__tests__",
"!dist/**/__mocks__",
"!dist/**/__fixtures__"
],
"publishConfig": {
"access": "public"
},
"author": "Medusa",
"license": "MIT",
"scripts": {
"watch": "tsc --build --watch",
"watch:test": "tsc --build tsconfig.spec.json --watch",
"resolve:aliases": "tsc --showConfig -p tsconfig.json > tsconfig.resolved.json && tsc-alias -p tsconfig.resolved.json && rimraf tsconfig.resolved.json",
"build": "rimraf dist && tsc --build && npm run resolve:aliases",
"test": "jest --passWithNoTests --runInBand --bail --forceExit -- src/",
"test:integration": "jest --runInBand --forceExit -- integration-tests/__tests__/**/*.ts"
},
"devDependencies": {
"@medusajs/framework": "2.10.3",
"@medusajs/test-utils": "2.10.3",
"@swc/core": "^1.7.28",
"@swc/jest": "^0.2.36",
"jest": "^29.7.0",
"rimraf": "^3.0.2",
"tsc-alias": "^1.8.6",
"typescript": "^5.6.2"
},
"peerDependencies": {
"@medusajs/framework": "2.10.3",
"awilix": "^8.0.1"
},
"dependencies": {
"fast-json-stable-stringify": "^2.1.0",
"node-cache": "^5.1.2",
"xxhash-wasm": "^1.1.0"
}
}
+12
View File
@@ -0,0 +1,12 @@
import { Module, Modules } from "@medusajs/framework/utils"
import { default as loadHash } from "./loaders/hash"
import { default as loadProviders } from "./loaders/providers"
import CachingModuleService from "./services/cache-module"
export default Module(Modules.CACHING, {
service: CachingModuleService,
loaders: [loadHash, loadProviders],
})
// Module options types
export { CachingModuleOptions } from "./types"
@@ -0,0 +1,8 @@
import { asValue } from "awilix"
export default async ({ container }) => {
const xxhashhWasm = await import("xxhash-wasm")
const { h32ToString } = await xxhashhWasm.default()
container.register("hasher", asValue(h32ToString))
}
@@ -0,0 +1,94 @@
import { moduleProviderLoader } from "@medusajs/framework/modules-sdk"
import { LoaderOptions, ModulesSdkTypes } from "@medusajs/framework/types"
import {
ContainerRegistrationKeys,
getProviderRegistrationKey,
} from "@medusajs/framework/utils"
import { CachingProviderService } from "@services"
import {
CachingDefaultProvider,
CachingIdentifiersRegistrationName,
CachingModuleOptions,
CachingProviderRegistrationPrefix,
} from "@types"
import { aliasTo, asFunction, asValue, Lifetime } from "awilix"
import { MemoryCachingProvider } from "../providers/memory-cache"
import { DefaultCacheStrategy } from "../utils/strategy"
const registrationFn = async (klass, container, { id }) => {
const key = CachingProviderService.getRegistrationIdentifier(klass)
if (!id) {
throw new Error(`No "id" provided for provider ${key}`)
}
const regKey = getProviderRegistrationKey({
providerId: id,
providerIdentifier: key,
})
container.register({
[CachingProviderRegistrationPrefix + id]: aliasTo(regKey),
})
container.registerAdd(CachingIdentifiersRegistrationName, asValue(key))
}
export default async ({
container,
options,
}: LoaderOptions<
(
| ModulesSdkTypes.ModuleServiceInitializeOptions
| ModulesSdkTypes.ModuleServiceInitializeCustomDataLayerOptions
) &
CachingModuleOptions
>): Promise<void> => {
container.registerAdd(CachingIdentifiersRegistrationName, asValue(undefined))
const strategy = DefaultCacheStrategy // Re enable custom strategy another time
container.register("strategy", asValue(strategy))
// MemoryCachingProvider - default provider
container.register({
[CachingProviderRegistrationPrefix + MemoryCachingProvider.identifier]:
asFunction(() => new MemoryCachingProvider(), {
lifetime: Lifetime.SINGLETON,
}),
})
container.registerAdd(
CachingIdentifiersRegistrationName,
asValue(MemoryCachingProvider.identifier)
)
container.register(
CachingDefaultProvider,
asValue(MemoryCachingProvider.identifier)
)
// Load other providers
await moduleProviderLoader({
container,
providers: options?.providers || [],
registerServiceFn: registrationFn,
})
const isSingleProvider = options?.providers?.length === 1
let hasDefaultProvider = false
for (const provider of options?.providers || []) {
if (provider.is_default || isSingleProvider) {
if (provider.is_default) {
hasDefaultProvider = true
}
container.register(CachingDefaultProvider, asValue(provider.id))
}
}
const logger = container.resolve(ContainerRegistrationKeys.LOGGER)
if (!hasDefaultProvider) {
logger.warn(
`[caching-module]: No default caching provider defined. Using "${container.resolve(
CachingDefaultProvider
)}" as default.`
)
}
}
@@ -0,0 +1,228 @@
import NodeCache from "node-cache"
import type { ICachingProviderService } from "@medusajs/framework/types"
export interface MemoryCacheModuleOptions {
/**
* TTL in seconds
*/
ttl?: number
/**
* Maximum number of keys to store (see node-cache documentation)
*/
maxKeys?: number
/**
* Check period for expired keys in seconds (see node-cache documentation)
*/
checkPeriod?: number
/**
* Use clones for cached data (see node-cache documentation)
*/
useClones?: boolean
}
export class MemoryCachingProvider implements ICachingProviderService {
static identifier = "cache-memory"
protected cacheClient: NodeCache
protected tagIndex: Map<string, Set<string>> = new Map() // tag -> keys
protected keyTags: Map<string, Set<string>> = new Map() // key -> tags
protected entryOptions: Map<string, { autoInvalidate?: boolean }> = new Map() // key -> options
protected options: MemoryCacheModuleOptions
constructor() {
this.options = {
ttl: 3600,
maxKeys: 25000,
checkPeriod: 60, // 10 minutes
useClones: false, // Default to false for speed, true would be slower but safer. we can discuss
}
const cacheClient = new NodeCache({
stdTTL: this.options.ttl,
maxKeys: this.options.maxKeys,
checkperiod: this.options.checkPeriod,
useClones: this.options.useClones,
})
this.cacheClient = cacheClient
// Clean up tag indices when keys expire
this.cacheClient.on("expired", (key: string, value: any) => {
this.cleanupTagReferences(key)
})
this.cacheClient.on("del", (key: string, value: any) => {
this.cleanupTagReferences(key)
})
}
private cleanupTagReferences(key: string): void {
const tags = this.keyTags.get(key)
if (tags) {
tags.forEach((tag) => {
const keysForTag = this.tagIndex.get(tag)
if (keysForTag) {
keysForTag.delete(key)
if (keysForTag.size === 0) {
this.tagIndex.delete(tag)
}
}
})
this.keyTags.delete(key)
}
// Also clean up entry options
this.entryOptions.delete(key)
}
async get({ key, tags }: { key?: string; tags?: string[] }): Promise<any> {
if (key) {
return this.cacheClient.get(key) ?? null
}
if (tags && tags.length) {
const allKeys = new Set<string>()
tags.forEach((tag) => {
const keysForTag = this.tagIndex.get(tag)
if (keysForTag) {
keysForTag.forEach((key) => allKeys.add(key))
}
})
if (allKeys.size === 0) {
return []
}
const results: any[] = []
allKeys.forEach((key) => {
const value = this.cacheClient.get(key)
if (value !== undefined) {
results.push(value)
}
})
return results
}
return null
}
async set({
key,
data,
ttl,
tags,
options,
}: {
key: string
data: object
ttl?: number
tags?: string[]
options?: {
autoInvalidate?: boolean
}
}): Promise<void> {
// Set the cache entry
const effectiveTTL = ttl ?? this.options.ttl ?? 3600
this.cacheClient.set(key, data, effectiveTTL)
// Handle tags if provided
if (tags && tags.length) {
// Clean up any existing tag references for this key
this.cleanupTagReferences(key)
const tagSet = new Set(tags)
this.keyTags.set(key, tagSet)
// Add this key to each tag's index
tags.forEach((tag) => {
if (!this.tagIndex.has(tag)) {
this.tagIndex.set(tag, new Set())
}
this.tagIndex.get(tag)!.add(key)
})
}
// Store entry options if provided
if (
Object.keys(options ?? {}).length &&
!Object.values(options ?? {}).every((value) => value === undefined)
) {
this.entryOptions.set(key, options!)
}
}
async clear({
key,
tags,
options,
}: {
key?: string
tags?: string[]
options?: {
autoInvalidate?: boolean
}
}): Promise<void> {
if (key) {
this.cacheClient.del(key)
return
}
if (tags && tags.length) {
// Handle wildcard tag to clear all cache data
if (tags.includes("*")) {
this.cacheClient.flushAll()
this.tagIndex.clear()
this.keyTags.clear()
this.entryOptions.clear()
return
}
const allKeys = new Set<string>()
tags.forEach((tag) => {
const keysForTag = this.tagIndex.get(tag)
if (keysForTag) {
keysForTag.forEach((key) => allKeys.add(key))
}
})
if (allKeys.size) {
// If no options provided (user explicit call), clear everything
if (!options) {
const keysToDelete = Array.from(allKeys)
this.cacheClient.del(keysToDelete)
// Clean up ALL tag references for deleted keys
keysToDelete.forEach((key) => {
this.cleanupTagReferences(key)
})
return
}
// If autoInvalidate is true (strategy call), only clear entries with autoInvalidate=true (default)
if (options.autoInvalidate === true) {
const keysToDelete: string[] = []
allKeys.forEach((key) => {
const entryOptions = this.entryOptions.get(key)
// Delete if entry has autoInvalidate=true or no setting (default true)
const shouldAutoInvalidate = entryOptions?.autoInvalidate ?? true
if (shouldAutoInvalidate) {
keysToDelete.push(key)
}
})
if (keysToDelete.length) {
this.cacheClient.del(keysToDelete)
// Clean up ALL tag references for deleted keys
keysToDelete.forEach((key) => {
this.cleanupTagReferences(key)
})
}
}
}
}
}
}
@@ -0,0 +1,406 @@
import { MedusaModule } from "@medusajs/framework/modules-sdk"
import type {
ICachingModuleService,
ICachingStrategy,
Logger,
} from "@medusajs/framework/types"
import { GraphQLUtils, MedusaError } from "@medusajs/framework/utils"
import { CachingDefaultProvider, InjectedDependencies } from "@types"
import CacheProviderService from "./cache-provider"
const ONE_HOUR_IN_SECOND = 60 * 60
export default class CachingModuleService implements ICachingModuleService {
protected container: InjectedDependencies
protected providerService: CacheProviderService
protected strategyCtr: new (...args: any[]) => ICachingStrategy
protected strategy: ICachingStrategy
protected defaultProviderId: string
protected logger: Logger
protected ongoingRequests: Map<string, Promise<any>> = new Map()
protected ttl: number
static traceGet?: (
cacheGetFn: () => Promise<any>,
key: string,
tags: string[]
) => Promise<any>
static traceSet?: (
cacheSetFn: () => Promise<any>,
key: string,
tags: string[],
options: { autoInvalidate?: boolean }
) => Promise<any>
static traceClear?: (
cacheClearFn: () => Promise<any>,
key: string,
tags: string[],
options: { autoInvalidate?: boolean }
) => Promise<any>
constructor(
container: InjectedDependencies,
protected readonly moduleDeclaration:
| { options: { ttl?: number } }
| { ttl?: number }
) {
this.container = container
this.providerService = container.cacheProviderService
this.defaultProviderId = container[CachingDefaultProvider]
this.strategyCtr = container.strategy as new (
...args: any[]
) => ICachingStrategy
this.strategy = new this.strategyCtr(this.container, this)
const moduleOptions =
"options" in moduleDeclaration
? moduleDeclaration.options
: moduleDeclaration
this.ttl = moduleOptions.ttl ?? ONE_HOUR_IN_SECOND
this.logger = container.logger ?? (console as unknown as Logger)
}
__hooks = {
onApplicationStart: async () => {
this.onApplicationStart()
},
onApplicationShutdown: async () => {
this.onApplicationShutdown()
},
onApplicationPrepareShutdown: async () => {
this.onApplicationPrepareShutdown()
},
}
protected onApplicationStart() {
const loadedSchema = MedusaModule.getAllJoinerConfigs()
.map((joinerConfig) => joinerConfig?.schema ?? "")
.join("\n")
const defaultMedusaSchema = `
scalar DateTime
scalar JSON
directive @enumValue(value: String) on ENUM_VALUE
`
const { schema: cleanedSchema } = GraphQLUtils.cleanGraphQLSchema(
defaultMedusaSchema + loadedSchema
)
const mergedSchema = GraphQLUtils.mergeTypeDefs(cleanedSchema)
const schema = GraphQLUtils.makeExecutableSchema({
typeDefs: mergedSchema,
})
this.strategy.onApplicationStart?.(
schema,
MedusaModule.getAllJoinerConfigs()
)
}
protected onApplicationShutdown() {
this.strategy.onApplicationShutdown?.()
}
protected onApplicationPrepareShutdown() {
this.strategy.onApplicationPrepareShutdown?.()
}
protected static normalizeProviders(
providers: string[] | { id: string; ttl?: number }[]
): { id: string; ttl?: number }[] {
const providers_ = Array.isArray(providers) ? providers : [providers]
return providers_.map((provider) => {
return typeof provider === "string" ? { id: provider } : provider
})
}
protected getRequestKey(
key?: string,
tags?: string[],
providers?: string[]
): string {
const keyPart = key || ""
const tagsPart = tags?.sort().join(",") || ""
const providersPart = providers?.join(",") || this.defaultProviderId
return `${keyPart}|${tagsPart}|${providersPart}`
}
protected getClearRequestKey(
key?: string,
tags?: string[],
providers?: string[]
): string {
const keyPart = key || ""
const tagsPart = tags?.sort().join(",") || ""
const providersPart = providers?.join(",") || this.defaultProviderId
return `clear:${keyPart}|${tagsPart}|${providersPart}`
}
async get(options: { key?: string; tags?: string[]; providers?: string[] }) {
if (CachingModuleService.traceGet) {
return await CachingModuleService.traceGet(
() => this.get_(options),
options.key ?? "",
options.tags ?? []
)
}
return await this.get_(options)
}
private async get_({
key,
tags,
providers,
}: {
key?: string
tags?: string[]
providers?: string[]
}) {
if (!key && !tags) {
throw new MedusaError(
MedusaError.Types.INVALID_ARGUMENT,
"Either key or tags must be provided"
)
}
const requestKey = this.getRequestKey(key, tags, providers)
const existingRequest = this.ongoingRequests.get(requestKey)
if (existingRequest) {
return await existingRequest
}
const requestPromise = this.performCacheGet(key, tags, providers)
this.ongoingRequests.set(requestKey, requestPromise)
try {
const result = await requestPromise
return result
} finally {
// Clean up the completed request
this.ongoingRequests.delete(requestKey)
}
}
protected async performCacheGet(
key?: string,
tags?: string[],
providers?: string[]
): Promise<any> {
const providersToCheck = providers ?? [this.defaultProviderId]
for (const providerId of providersToCheck) {
try {
const provider_ = this.providerService.retrieveProvider(providerId)
const result = await provider_.get({ key, tags })
if (result != null) {
return result
}
} catch (error) {
this.logger.warn(
`Cache provider ${providerId} failed: ${error.message}\n${error.stack}`
)
continue
}
}
return null
}
async set(options: {
key: string
data: object
ttl?: number
tags?: string[]
providers?: string[]
options?: { autoInvalidate?: boolean }
}) {
if (CachingModuleService.traceSet) {
return await CachingModuleService.traceSet(
() => this.set_(options),
options.key,
options.tags ?? [],
options.options ?? {}
)
}
return await this.set_(options)
}
private async set_({
key,
data,
ttl,
tags,
providers,
options,
}: {
key: string
data: object
tags?: string[]
ttl?: number
providers?: string[] | { id: string; ttl?: number }[]
options?: {
autoInvalidate?: boolean
}
}) {
if (!key) {
throw new MedusaError(
MedusaError.Types.INVALID_ARGUMENT,
"[CachingModuleService] Key must be provided"
)
}
const key_ = key
const tags_ = tags ?? (await this.strategy.computeTags(data))
let providers_: string[] | { id: string; ttl?: number }[] = [
this.defaultProviderId,
]
providers_ = CachingModuleService.normalizeProviders(
providers ?? providers_
)
const providerIds = providers_.map((p) => p.id)
const requestKey = this.getRequestKey(key_, tags_, providerIds)
const existingRequest = this.ongoingRequests.get(requestKey)
if (existingRequest) {
return await existingRequest
}
const requestPromise = this.performCacheSet(
key_,
tags_,
data,
ttl,
providers_,
options
)
this.ongoingRequests.set(requestKey, requestPromise)
try {
await requestPromise
} finally {
// Clean up the completed request
this.ongoingRequests.delete(requestKey)
}
}
protected async performCacheSet(
key: string,
tags: string[],
data: object,
ttl?: number,
providers?: { id: string; ttl?: number }[],
options?: {
autoInvalidate?: boolean
}
): Promise<void> {
for (const providerOptions of providers || []) {
const ttl_ = providerOptions.ttl ?? ttl ?? this.ttl
const provider = this.providerService.retrieveProvider(providerOptions.id)
void provider.set({
key,
tags,
data,
ttl: ttl_,
options,
})
}
}
async clear(options: {
key?: string
tags?: string[]
options?: { autoInvalidate?: boolean }
providers?: string[]
}) {
if (CachingModuleService.traceClear) {
return await CachingModuleService.traceClear(
() => this.clear_(options),
options.key ?? "",
options.tags ?? [],
options.options ?? {}
)
}
return await this.clear_(options)
}
private async clear_({
key,
tags,
options,
providers,
}: {
key?: string
tags?: string[]
options?: {
autoInvalidate?: boolean
}
providers?: string[]
}) {
if (!key && !tags) {
throw new MedusaError(
MedusaError.Types.INVALID_ARGUMENT,
"Either key or tags must be provided"
)
}
const requestKey = this.getClearRequestKey(key, tags, providers)
const existingRequest = this.ongoingRequests.get(requestKey)
if (existingRequest) {
return await existingRequest
}
const requestPromise = this.performCacheClear(key, tags, options, providers)
this.ongoingRequests.set(requestKey, requestPromise)
try {
await requestPromise
} finally {
// Clean up the completed request
this.ongoingRequests.delete(requestKey)
}
}
protected async performCacheClear(
key?: string,
tags?: string[],
options?: {
autoInvalidate?: boolean
},
providers?: string[]
): Promise<void> {
let providerIds_: string[] = [this.defaultProviderId]
if (providers) {
providerIds_ = Array.isArray(providers) ? providers : [providers]
}
for (const providerId of providerIds_) {
const provider = this.providerService.retrieveProvider(providerId)
void provider.clear({ key, tags, options })
}
}
async computeKey(input: object): Promise<string> {
return await this.strategy.computeKey(input)
}
async computeTags(
input: object,
options?: Record<string, any>
): Promise<string[]> {
return await this.strategy.computeTags(input, options)
}
}
@@ -0,0 +1,60 @@
import {
Constructor,
ICachingProviderService,
Logger,
} from "@medusajs/framework/types"
import { MedusaError } from "@medusajs/framework/utils"
import { CachingProviderRegistrationPrefix } from "../types"
type InjectedDependencies = {
[key: `cp_${string}`]: ICachingProviderService
logger?: Logger
}
export default class CacheProviderService {
#container: InjectedDependencies
#logger: Logger
constructor(container: InjectedDependencies) {
this.#container = container
this.#logger = container["logger"]
? container.logger
: (console as unknown as Logger)
}
static getRegistrationIdentifier(
providerClass: Constructor<ICachingProviderService>
) {
if (!(providerClass as any).identifier) {
throw new MedusaError(
MedusaError.Types.INVALID_ARGUMENT,
`Trying to register a caching provider without an identifier.`
)
}
return `${(providerClass as any).identifier}`
}
public retrieveProvider(providerId: string): ICachingProviderService {
try {
return this.#container[
`${CachingProviderRegistrationPrefix}${providerId}`
]
} catch (err) {
if (err.name === "AwilixResolutionError") {
const errMessage = `
Unable to retrieve the caching provider with id: ${providerId}
Please make sure that the provider is registered in the container and it is configured correctly in your project configuration file.`
// Log full error for debugging
this.#logger.error(`AwilixResolutionError: ${err.message}`, err)
throw new Error(errMessage)
}
const errMessage = `Unable to retrieve the caching provider with id: ${providerId}, the following error occurred: ${err.message}`
this.#logger.error(errMessage)
throw new Error(errMessage)
}
}
}

Some files were not shown because too many files have changed in this diff Show More