From b9d6f73320c36c53235b12fb8397b30a448917f0 Mon Sep 17 00:00:00 2001 From: Adrien de Peretti Date: Tue, 30 Sep 2025 18:19:06 +0200 Subject: [PATCH] Feat(): distributed caching (#13435) RESOLVES CORE-1153 **What** - This pr mainly lay the foundation the caching layer. It comes with a modules (built in memory cache) and a redis provider. - Apply caching to few touch point to test Co-authored-by: Carlos R. L. Rodrigues <37986729+carlos-r-l-rodrigues@users.noreply.github.com> --- .changeset/light-lions-invent.md | 16 + .eslintignore | 43 +- .eslintrc.js | 2 + integration-tests/modules/package.json | 1 + .../src/cart/steps/find-one-or-any-region.ts | 63 +- .../src/cart/steps/find-or-create-customer.ts | 57 +- .../src/cart/steps/find-sales-channel.ts | 55 +- .../steps/get-promotion-codes-to-apply.ts | 23 +- .../src/cart/steps/get-variant-price-sets.ts | 25 +- .../src/cart/steps/update-cart-promotions.ts | 19 +- .../src/cart/workflows/add-to-cart.ts | 5 + .../get-variants-and-items-with-prices.ts | 5 + ...-shipping-options-for-cart-with-pricing.ts | 43 +- .../list-shipping-options-for-cart.ts | 5 + .../src/cart/workflows/update-cart.ts | 3 + .../src/order/workflows/create-order.ts | 5 + packages/core/framework/package.json | 1 + .../middlewares/ensure-publishable-api-key.ts | 40 +- .../src/http/utils/refetch-entities.ts | 115 ++- .../core/framework/src/types/container.ts | 2 + .../utils/__tests__/load-internal.spec.ts | 2 + .../src/loaders/utils/load-internal.ts | 3 +- packages/core/modules-sdk/src/medusa-app.ts | 1 + .../modules-sdk/src/remote-query/query.ts | 60 ++ .../src/remote-query/to-remote-query.ts | 5 +- packages/core/types/src/bundles.ts | 1 + packages/core/types/src/caching/index.ts | 161 +++++ packages/core/types/src/event-bus/common.ts | 5 + .../types/src/event-bus/event-bus-module.ts | 34 +- packages/core/types/src/index.ts | 1 + packages/core/types/src/joiner/index.ts | 36 + packages/core/types/src/modules-sdk/index.ts | 1 + .../types/src/modules-sdk/remote-query.ts | 7 +- .../types/src/pricing/common/price-rule.ts | 1 - packages/core/utils/src/bundles.ts | 1 + packages/core/utils/src/caching/index.ts | 249 +++++++ .../utils/src/common/create-container-like.ts | 17 +- packages/core/utils/src/dml/properties/id.ts | 2 +- .../utils/src/dml/properties/primary-key.ts | 4 + packages/core/utils/src/event-bus/index.ts | 51 +- packages/core/utils/src/index.ts | 1 + .../__tests__/joiner-config-builder.spec.ts | 6 + .../core/utils/src/modules-sdk/definition.ts | 2 + .../src/modules-sdk/joiner-config-builder.ts | 34 + packages/core/utils/src/modules-sdk/module.ts | 5 + .../src/product/get-variant-availability.ts | 51 +- packages/medusa/package.json | 2 + .../[id]/inbound/items/[action_id]/route.ts | 22 +- .../medusa/src/api/admin/claims/[id]/route.ts | 12 +- .../[id]/inbound/items/[action_id]/route.ts | 9 +- .../src/api/admin/exchanges/[id]/route.ts | 12 +- .../src/api/admin/notifications/[id]/route.ts | 13 +- .../src/api/admin/notifications/route.ts | 14 +- .../mark-as-delivered/route.ts | 12 +- .../[id]/mark-as-paid/route.ts | 12 +- .../api/admin/payment-collections/route.ts | 12 +- .../api/admin/price-preferences/[id]/route.ts | 24 +- .../src/api/admin/price-preferences/route.ts | 27 +- .../product-categories/[id]/products/route.ts | 12 +- .../admin/product-categories/[id]/route.ts | 30 +- .../src/api/admin/product-categories/route.ts | 29 +- .../src/api/admin/product-tags/[id]/route.ts | 36 +- .../src/api/admin/product-tags/route.ts | 26 +- .../src/api/admin/product-variants/route.ts | 14 +- .../[id]/options/[option_id]/route.ts | 36 +- .../api/admin/products/[id]/options/route.ts | 26 +- .../src/api/admin/products/[id]/route.ts | 36 +- .../[id]/variants/[variant_id]/route.ts | 36 +- .../api/admin/products/[id]/variants/route.ts | 26 +- .../medusa/src/api/admin/products/route.ts | 28 +- .../api/admin/refund-reasons/[id]/route.ts | 24 +- .../src/api/admin/refund-reasons/route.ts | 26 +- .../api/admin/return-reasons/[id]/route.ts | 12 +- .../medusa/src/api/store/orders/helpers.ts | 2 +- .../api/store/payment-collections/helpers.ts | 7 +- .../store/product-categories/[id]/route.ts | 12 +- .../src/api/store/products/[id]/route.ts | 12 +- .../medusa/src/api/store/products/helpers.ts | 2 +- .../medusa/src/api/store/products/route.ts | 62 +- .../products/normalize-data-for-context.ts | 25 +- .../products/set-pricing-context.ts | 29 +- .../middlewares/products/set-tax-context.ts | 12 +- packages/medusa/src/feature-flags/caching.ts | 10 + packages/medusa/src/instrumentation/index.ts | 120 ++- packages/medusa/src/modules/caching-redis.ts | 6 + packages/medusa/src/modules/caching.ts | 6 + packages/modules/caching/.gitignore | 6 + packages/modules/caching/CHANGELOG.md | 1 + .../__fixtures__/event-bus-mock.ts | 51 ++ .../integration-tests/__tests__/index.spec.ts | 336 +++++++++ .../__tests__/invalidation.spec.ts | 430 +++++++++++ .../__tests__/redis/invalidation.spec.ts | 540 ++++++++++++++ packages/modules/caching/jest.config.js | 8 + packages/modules/caching/package.json | 49 ++ packages/modules/caching/src/index.ts | 12 + packages/modules/caching/src/loaders/hash.ts | 8 + .../modules/caching/src/loaders/providers.ts | 94 +++ .../caching/src/providers/memory-cache.ts | 228 ++++++ .../caching/src/services/cache-module.ts | 406 +++++++++++ .../caching/src/services/cache-provider.ts | 60 ++ .../modules/caching/src/services/index.ts | 2 + packages/modules/caching/src/types/index.ts | 56 ++ .../src/utils/__tests__/parser.test.ts | 487 +++++++++++++ packages/modules/caching/src/utils/parser.ts | 242 +++++++ .../modules/caching/src/utils/strategy.ts | 133 ++++ packages/modules/caching/tsconfig.json | 10 + .../src/services/event-bus-local.ts | 25 +- .../src/services/event-bus-redis.ts | 19 + .../modules/locking/mikro-orm.config.dev.ts | 6 - .../providers/caching-redis/package.json | 49 ++ .../providers/caching-redis/src/index.ts | 12 + .../caching-redis/src/loaders/connection.ts | 66 ++ .../caching-redis/src/loaders/hash.ts | 8 + .../caching-redis/src/services/redis-cache.ts | 683 ++++++++++++++++++ .../caching-redis/src/types/index.ts | 26 + .../providers/caching-redis/tsconfig.json | 12 + yarn.lock | 64 ++ 117 files changed, 5741 insertions(+), 530 deletions(-) create mode 100644 .changeset/light-lions-invent.md create mode 100644 packages/core/types/src/caching/index.ts create mode 100644 packages/core/utils/src/caching/index.ts create mode 100644 packages/medusa/src/feature-flags/caching.ts create mode 100644 packages/medusa/src/modules/caching-redis.ts create mode 100644 packages/medusa/src/modules/caching.ts create mode 100644 packages/modules/caching/.gitignore create mode 100644 packages/modules/caching/CHANGELOG.md create mode 100644 packages/modules/caching/integration-tests/__fixtures__/event-bus-mock.ts create mode 100644 packages/modules/caching/integration-tests/__tests__/index.spec.ts create mode 100644 packages/modules/caching/integration-tests/__tests__/invalidation.spec.ts create mode 100644 packages/modules/caching/integration-tests/__tests__/redis/invalidation.spec.ts create mode 100644 packages/modules/caching/jest.config.js create mode 100644 packages/modules/caching/package.json create mode 100644 packages/modules/caching/src/index.ts create mode 100644 packages/modules/caching/src/loaders/hash.ts create mode 100644 packages/modules/caching/src/loaders/providers.ts create mode 100644 packages/modules/caching/src/providers/memory-cache.ts create mode 100644 packages/modules/caching/src/services/cache-module.ts create mode 100644 packages/modules/caching/src/services/cache-provider.ts create mode 100644 packages/modules/caching/src/services/index.ts create mode 100644 packages/modules/caching/src/types/index.ts create mode 100644 packages/modules/caching/src/utils/__tests__/parser.test.ts create mode 100644 packages/modules/caching/src/utils/parser.ts create mode 100644 packages/modules/caching/src/utils/strategy.ts create mode 100644 packages/modules/caching/tsconfig.json delete mode 100644 packages/modules/locking/mikro-orm.config.dev.ts create mode 100644 packages/modules/providers/caching-redis/package.json create mode 100644 packages/modules/providers/caching-redis/src/index.ts create mode 100644 packages/modules/providers/caching-redis/src/loaders/connection.ts create mode 100644 packages/modules/providers/caching-redis/src/loaders/hash.ts create mode 100644 packages/modules/providers/caching-redis/src/services/redis-cache.ts create mode 100644 packages/modules/providers/caching-redis/src/types/index.ts create mode 100644 packages/modules/providers/caching-redis/tsconfig.json diff --git a/.changeset/light-lions-invent.md b/.changeset/light-lions-invent.md new file mode 100644 index 0000000000..df38fd450e --- /dev/null +++ b/.changeset/light-lions-invent.md @@ -0,0 +1,16 @@ +--- +"@medusajs/medusa": patch +"@medusajs/framework": patch +"@medusajs/modules-sdk": patch +"@medusajs/types": patch +"@medusajs/utils": patch +"@medusajs/caching": patch +"@medusajs/event-bus-local": patch +"@medusajs/event-bus-redis": patch +"@medusajs/caching-redis": patch +"@medusajs/core-flows": patch +"@medusajs/pricing": patch +"@medusajs/draft-order": patch +--- + +Feat(): distributed caching diff --git a/.eslintignore b/.eslintignore index 43e9a8fa72..27fb3ea52b 100644 --- a/.eslintignore +++ b/.eslintignore @@ -13,31 +13,32 @@ packages/* !packages/admin-next/dashboard !packages/medusa-payment-stripe !packages/medusa-payment-paypal -!packages/event-bus-redis -!packages/event-bus-local +!packages/modules/event-bus-redis +!packages/modules/event-bus-local !packages/medusa-plugin-meilisearch !packages/medusa-plugin-algolia -!packages/inventory -!packages/stock-location -!packages/cache-redis -!packages/cache-inmemory +!packages/modules/inventory +!packages/modules/stock-location +!packages/modules/cache-redis +!packages/modules/cache-inmemory +!packages/modules/caching +!packages/modules/providers/caching-redis !packages/create-medusa-app -!packages/product -!packages/locking -!packages/orchestration -!packages/workflows-sdk -!packages/core-flows -!packages/types +!packages/modules/product +!packages/modules/locking +!packages/core/orchestration +!packages/core/workflows-sdk +!packages/core/core-flows +!packages/core/types !packages/medusa-react -!packages/workflow-engine-redis -!packages/workflow-engine-inmemory -!packages/fulfillment -!packages/fulfillment-manual -!packages/locking-postgres -!packages/locking-redis -!packages/index - -!packages/framework +!packages/modules/workflow-engine-redis +!packages/modules/workflow-engine-inmemory +!packages/modules/fulfillment +!packages/modules/providers/fulfillment-manual +!packages/modules/providers/locking-postgres +!packages/modules/providers/locking-redis +!packages/modules/index +!packages/core/framework **/models/* diff --git a/.eslintrc.js b/.eslintrc.js index 7220b54576..308ae874dc 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -109,6 +109,7 @@ module.exports = { "./packages/modules/event-bus-redis/tsconfig.spec.json", "./packages/modules/cache-redis/tsconfig.spec.json", "./packages/modules/cache-inmemory/tsconfig.spec.json", + "./packages/modules/caching/tsconfig.spec.json", "./packages/modules/workflow-engine-redis/tsconfig.spec.json", "./packages/modules/workflow-engine-inmemory/tsconfig.spec.json", "./packages/modules/fulfillment/tsconfig.spec.json", @@ -141,6 +142,7 @@ module.exports = { "./packages/modules/providers/payment-stripe/tsconfig.spec.json", "./packages/modules/providers/locking-postgres/tsconfig.spec.json", "./packages/modules/providers/locking-redis/tsconfig.spec.json", + "./packages/modules/providers/caching-redis/tsconfig.spec.json", "./packages/framework/tsconfig.json", ], diff --git a/integration-tests/modules/package.json b/integration-tests/modules/package.json index 2c7d81feee..94b83ed00e 100644 --- a/integration-tests/modules/package.json +++ b/integration-tests/modules/package.json @@ -13,6 +13,7 @@ "@medusajs/api-key": "workspace:^", "@medusajs/auth": "workspace:*", "@medusajs/cache-inmemory": "workspace:*", + "@medusajs/caching": "workspace:*", "@medusajs/core-flows": "workspace:^", "@medusajs/currency": "workspace:^", "@medusajs/customer": "workspace:^", diff --git a/packages/core/core-flows/src/cart/steps/find-one-or-any-region.ts b/packages/core/core-flows/src/cart/steps/find-one-or-any-region.ts index 41321c7b00..d6e64602fa 100644 --- a/packages/core/core-flows/src/cart/steps/find-one-or-any-region.ts +++ b/packages/core/core-flows/src/cart/steps/find-one-or-any-region.ts @@ -1,8 +1,9 @@ import { IRegionModuleService, IStoreModuleService, + MedusaContainer, } from "@medusajs/framework/types" -import { MedusaError, Modules } from "@medusajs/framework/utils" +import { MedusaError, Modules, useCache } from "@medusajs/framework/utils" import { StepResponse, createStep } from "@medusajs/framework/workflows-sdk" /** @@ -15,6 +16,48 @@ export type FindOneOrAnyRegionStepInput = { regionId?: string } +async function fetchRegionById(regionId: string, container: MedusaContainer) { + const service = container.resolve(Modules.REGION) + + const args = [ + regionId, + { + relations: ["countries"], + }, + ] as Parameters + + return await useCache(async () => service.retrieveRegion(...args), { + container, + key: args, + }) +} + +async function fetchDefaultStore(container: MedusaContainer) { + const storeModule = container.resolve(Modules.STORE) + + return await useCache(async () => storeModule.listStores(), { + container, + key: "find-one-or-any-region-default-store", + }) +} + +async function fetchDefaultRegion( + defaultRegionId: string, + container: MedusaContainer +) { + const service = container.resolve(Modules.REGION) + + const args = [ + { id: defaultRegionId }, + { relations: ["countries"] }, + ] as Parameters + + return await useCache(async () => service.listRegions(...args), { + container, + key: args, + }) +} + export const findOneOrAnyRegionStepId = "find-one-or-any-region" /** * This step retrieves a region either by the provided ID or the first region in the first store. @@ -22,32 +65,24 @@ export const findOneOrAnyRegionStepId = "find-one-or-any-region" export const findOneOrAnyRegionStep = createStep( findOneOrAnyRegionStepId, async (data: FindOneOrAnyRegionStepInput, { container }) => { - const service = container.resolve(Modules.REGION) - - const storeModule = container.resolve(Modules.STORE) - if (data.regionId) { try { - const region = await service.retrieveRegion(data.regionId, { - relations: ["countries"], - }) + const region = await fetchRegionById(data.regionId, container) return new StepResponse(region) } catch (error) { return new StepResponse(null) } } - const [store] = await storeModule.listStores() + const [store] = await fetchDefaultStore(container) if (!store) { throw new MedusaError(MedusaError.Types.NOT_FOUND, "Store not found") } - const [region] = await service.listRegions( - { - id: store.default_region_id, - }, - { relations: ["countries"] } + const [region] = await fetchDefaultRegion( + store.default_region_id!, + container ) if (!region) { diff --git a/packages/core/core-flows/src/cart/steps/find-or-create-customer.ts b/packages/core/core-flows/src/cart/steps/find-or-create-customer.ts index 81ed2fccdc..22defa4ba2 100644 --- a/packages/core/core-flows/src/cart/steps/find-or-create-customer.ts +++ b/packages/core/core-flows/src/cart/steps/find-or-create-customer.ts @@ -1,8 +1,14 @@ import type { CustomerDTO, ICustomerModuleService, + MedusaContainer, } from "@medusajs/framework/types" -import { isDefined, Modules, validateEmail } from "@medusajs/framework/utils" +import { + isDefined, + Modules, + useCache, + validateEmail, +} from "@medusajs/framework/utils" import { createStep, StepResponse } from "@medusajs/framework/workflows-sdk" /** @@ -39,6 +45,40 @@ interface StepCompensateInput { customerWasCreated: boolean } +async function fetchCustomerById( + customerId: string, + container: MedusaContainer +): Promise { + const service = container.resolve(Modules.CUSTOMER) + + return await useCache( + async () => service.retrieveCustomer(customerId), + { + container, + key: ["find-or-create-customer-by-id", customerId], + } + ) +} + +async function fetchCustomersByEmail( + email: string, + container: MedusaContainer, + hasAccount?: boolean +): Promise { + const service = container.resolve(Modules.CUSTOMER) + + const filters = + hasAccount !== undefined ? { email, has_account: hasAccount } : { email } + + return await useCache( + async () => service.listCustomers(filters), + { + container, + key: ["find-or-create-customer-by-email", filters], + } + ) +} + export const findOrCreateCustomerStepId = "find-or-create-customer" /** * This step finds or creates a customer based on the provided ID or email. It prioritizes finding the customer by ID, then by email. @@ -75,7 +115,7 @@ export const findOrCreateCustomerStep = createStep( let customerWasCreated = false if (data.customerId) { - originalCustomer = await service.retrieveCustomer(data.customerId) + originalCustomer = await fetchCustomerById(data.customerId, container) customerData.customer = originalCustomer customerData.email = originalCustomer.email } @@ -85,9 +125,7 @@ export const findOrCreateCustomerStep = createStep( let [customer] = originalCustomer ? [originalCustomer] - : await service.listCustomers({ - email: validatedEmail, - }) + : await fetchCustomersByEmail(validatedEmail, container) // if NOT a guest customer, return it if (customer?.has_account) { @@ -100,10 +138,11 @@ export const findOrCreateCustomerStep = createStep( } if (customer && customer.email !== validatedEmail) { - ;[customer] = await service.listCustomers({ - email: validatedEmail, - has_account: false, - }) + ;[customer] = await fetchCustomersByEmail( + validatedEmail, + container, + false + ) } if (!customer) { diff --git a/packages/core/core-flows/src/cart/steps/find-sales-channel.ts b/packages/core/core-flows/src/cart/steps/find-sales-channel.ts index 6f80d2da1f..b400281fa2 100644 --- a/packages/core/core-flows/src/cart/steps/find-sales-channel.ts +++ b/packages/core/core-flows/src/cart/steps/find-sales-channel.ts @@ -1,9 +1,15 @@ import { ISalesChannelModuleService, IStoreModuleService, + MedusaContainer, SalesChannelDTO, } from "@medusajs/framework/types" -import { MedusaError, Modules, isDefined } from "@medusajs/framework/utils" +import { + MedusaError, + Modules, + isDefined, + useCache, +} from "@medusajs/framework/utils" import { StepResponse, createStep } from "@medusajs/framework/workflows-sdk" /** @@ -16,6 +22,34 @@ export interface FindSalesChannelStepInput { salesChannelId?: string | null } +async function fetchSalesChannel( + salesChannelId: string, + container: MedusaContainer +) { + const salesChannelService = container.resolve( + Modules.SALES_CHANNEL + ) + + return await useCache< + Awaited> + >(async () => salesChannelService.retrieveSalesChannel(salesChannelId), { + container, + key: ["find-sales-channel", salesChannelId], + }) +} + +async function fetchStore(container: MedusaContainer) { + const storeModule = container.resolve(Modules.STORE) + return await useCache>>( + async () => + storeModule.listStores( + {}, + { select: ["id", "default_sales_channel_id"] } + ), + { key: "find-sales-channel-default-store", container } + ) +} + export const findSalesChannelStepId = "find-sales-channel" /** * This step either retrieves a sales channel either using the ID provided as an input, or, if no ID @@ -24,26 +58,17 @@ export const findSalesChannelStepId = "find-sales-channel" export const findSalesChannelStep = createStep( findSalesChannelStepId, async (data: FindSalesChannelStepInput, { container }) => { - const salesChannelService = container.resolve( - Modules.SALES_CHANNEL - ) - const storeModule = container.resolve(Modules.STORE) - let salesChannel: SalesChannelDTO | undefined if (data.salesChannelId) { - salesChannel = await salesChannelService.retrieveSalesChannel( - data.salesChannelId - ) + salesChannel = await fetchSalesChannel(data.salesChannelId, container) } else if (!isDefined(data.salesChannelId)) { - const [store] = await storeModule.listStores( - {}, - { select: ["default_sales_channel_id"] } - ) + const [store] = await fetchStore(container) if (store?.default_sales_channel_id) { - salesChannel = await salesChannelService.retrieveSalesChannel( - store.default_sales_channel_id + salesChannel = await fetchSalesChannel( + store.default_sales_channel_id, + container ) } } diff --git a/packages/core/core-flows/src/cart/steps/get-promotion-codes-to-apply.ts b/packages/core/core-flows/src/cart/steps/get-promotion-codes-to-apply.ts index 882de39e65..3086f6e788 100644 --- a/packages/core/core-flows/src/cart/steps/get-promotion-codes-to-apply.ts +++ b/packages/core/core-flows/src/cart/steps/get-promotion-codes-to-apply.ts @@ -1,7 +1,6 @@ -import type { IPromotionModuleService } from "@medusajs/framework/types" import { + ContainerRegistrationKeys, MedusaError, - Modules, PromotionActions, } from "@medusajs/framework/utils" import { createStep, StepResponse } from "@medusajs/framework/workflows-sdk" @@ -72,9 +71,6 @@ export const getPromotionCodesToApply = createStep( async (data: GetPromotionCodesToApplyStepInput, { container }) => { const { promo_codes = [], cart, action = PromotionActions.ADD } = data const { items = [], shipping_methods = [] } = cart - const promotionService = container.resolve( - Modules.PROMOTION - ) const adjustmentCodes: string[] = [] items.concat(shipping_methods).forEach((object) => { @@ -99,14 +95,23 @@ export const getPromotionCodesToApply = createStep( action === PromotionActions.ADD || action === PromotionActions.REPLACE ) { + const query = container.resolve(ContainerRegistrationKeys.QUERY) const validPromoCodes: Set = new Set( promo_codes.length ? ( - await promotionService.listPromotions( - { code: promo_codes }, - { select: ["code"] } + await query.graph( + { + entity: "promotion", + fields: ["id", "code"], + filters: { code: promo_codes }, + }, + { + cache: { + enable: true, + }, + } ) - ).map((p) => p.code!) + ).data.map((p) => p.code!) : [] ) diff --git a/packages/core/core-flows/src/cart/steps/get-variant-price-sets.ts b/packages/core/core-flows/src/cart/steps/get-variant-price-sets.ts index edcc68f780..d3641f1654 100644 --- a/packages/core/core-flows/src/cart/steps/get-variant-price-sets.ts +++ b/packages/core/core-flows/src/cart/steps/get-variant-price-sets.ts @@ -1,4 +1,4 @@ -import { Query } from "@medusajs/framework" +import { MedusaContainer, Query } from "@medusajs/framework" import { CalculatedPriceSet, IPricingModuleService, @@ -75,11 +75,18 @@ async function fetchVariantPriceSets( variantIds: string[] ): Promise { return ( - await query.graph({ - entity: "variant", - fields: ["id", "price_set.id"], - filters: { id: variantIds }, - }) + await query.graph( + { + entity: "variant", + fields: ["id", "price_set.id"], + filters: { id: variantIds }, + }, + { + cache: { + enable: true, + }, + } + ) ).data } @@ -108,7 +115,8 @@ function validateVariantPriceSets( */ async function processVariantPriceSets( pricingService: IPricingModuleService, - items: PriceCalculationItem[] + items: PriceCalculationItem[], + container: MedusaContainer ): Promise { const result: GetVariantPriceSetsStepOutput = {} @@ -298,7 +306,8 @@ export const getVariantPriceSetsStep = createStep( // Use unified processing logic for both input types const result = await processVariantPriceSets( pricingModuleService, - calculationItems + calculationItems, + container ) return new StepResponse(result) diff --git a/packages/core/core-flows/src/cart/steps/update-cart-promotions.ts b/packages/core/core-flows/src/cart/steps/update-cart-promotions.ts index 461f830f09..114a8e4956 100644 --- a/packages/core/core-flows/src/cart/steps/update-cart-promotions.ts +++ b/packages/core/core-flows/src/cart/steps/update-cart-promotions.ts @@ -1,4 +1,3 @@ -import type { IPromotionModuleService } from "@medusajs/framework/types" import { ContainerRegistrationKeys, Modules, @@ -40,9 +39,6 @@ export const updateCartPromotionsStep = createStep( const remoteQuery = container.resolve( ContainerRegistrationKeys.REMOTE_QUERY ) - const promotionService = container.resolve( - Modules.PROMOTION - ) const existingCartPromotionLinks = await remoteQuery({ entryPoint: "cart_promotion", @@ -60,9 +56,18 @@ export const updateCartPromotionsStep = createStep( const linksToDismiss: any[] = [] if (promo_codes?.length) { - const promotions = await promotionService.listPromotions( - { code: promo_codes }, - { select: ["id"] } + const query = container.resolve(ContainerRegistrationKeys.QUERY) + const { data: promotions } = await query.graph( + { + entity: "promotion", + fields: ["id", "code"], + filters: { code: promo_codes }, + }, + { + cache: { + enable: true, + }, + } ) for (const promotion of promotions) { diff --git a/packages/core/core-flows/src/cart/workflows/add-to-cart.ts b/packages/core/core-flows/src/cart/workflows/add-to-cart.ts index 8c31844951..ecf61d74cf 100644 --- a/packages/core/core-flows/src/cart/workflows/add-to-cart.ts +++ b/packages/core/core-flows/src/cart/workflows/add-to-cart.ts @@ -205,6 +205,11 @@ export const addToCartWorkflow = createWorkflow( filters: { id: variantIds, }, + options: { + cache: { + enable: true, + }, + }, }).config({ name: "fetch-variants" }) }) diff --git a/packages/core/core-flows/src/cart/workflows/get-variants-and-items-with-prices.ts b/packages/core/core-flows/src/cart/workflows/get-variants-and-items-with-prices.ts index 4ab79168b4..df6ecfdef0 100644 --- a/packages/core/core-flows/src/cart/workflows/get-variants-and-items-with-prices.ts +++ b/packages/core/core-flows/src/cart/workflows/get-variants-and-items-with-prices.ts @@ -142,6 +142,11 @@ export const getVariantsAndItemsWithPrices = createWorkflow( filters: { id: variantIds, }, + options: { + cache: { + enable: true, + }, + }, }).config({ name: "fetch-variants" }) const calculatedPriceSets = getVariantPriceSetsStep({ diff --git a/packages/core/core-flows/src/cart/workflows/list-shipping-options-for-cart-with-pricing.ts b/packages/core/core-flows/src/cart/workflows/list-shipping-options-for-cart-with-pricing.ts index 86ec154daf..d1a56af580 100644 --- a/packages/core/core-flows/src/cart/workflows/list-shipping-options-for-cart-with-pricing.ts +++ b/packages/core/core-flows/src/cart/workflows/list-shipping-options-for-cart-with-pricing.ts @@ -75,26 +75,26 @@ export const listShippingOptionsForCartWithPricingWorkflowId = * @summary * * List a cart's shipping options with prices. - * + * * @property hooks.setShippingOptionsContext - This hook is executed after the cart is retrieved and before the shipping options are queried. You can consume this hook to return any custom context useful for the shipping options retrieval. * * For example, you can consume the hook to add the customer Id to the context: - * + * * ```ts * import { listShippingOptionsForCartWithPricingWorkflow } from "@medusajs/medusa/core-flows" * import { StepResponse } from "@medusajs/workflows-sdk" - * + * * listShippingOptionsForCartWithPricingWorkflow.hooks.setShippingOptionsContext( * async ({ cart }, { container }) => { - * + * * if (cart.customer_id) { * return new StepResponse({ * customer_id: cart.customer_id, * }) * } - * + * * const query = container.resolve("query") - * + * * const { data: carts } = await query.graph({ * entity: "cart", * filters: { @@ -102,16 +102,16 @@ export const listShippingOptionsForCartWithPricingWorkflowId = * }, * fields: ["customer_id"], * }) - * + * * return new StepResponse({ * customer_id: carts[0].customer_id, * }) * } * ) * ``` - * + * * The `customer_id` property will be added to the context along with other properties such as `is_return` and `enabled_in_store`. - * + * * :::note * * You should also consume the `setShippingOptionsContext` hook in the {@link listShippingOptionsForCartWorkflow} workflow to ensure that the context is consistent when listing shipping options across workflows. @@ -120,7 +120,11 @@ export const listShippingOptionsForCartWithPricingWorkflowId = */ export const listShippingOptionsForCartWithPricingWorkflow = createWorkflow( listShippingOptionsForCartWithPricingWorkflowId, - (input: WorkflowData) => { + ( + input: WorkflowData< + ListShippingOptionsForCartWithPricingWorkflowInput & AdditionalData + > + ) => { const optionIds = transform({ input }, ({ input }) => (input.options ?? []).map(({ id }) => id) ) @@ -155,6 +159,11 @@ export const listShippingOptionsForCartWithPricingWorkflow = createWorkflow( "stock_locations.address.*", "stock_locations.fulfillment_sets.id", ], + options: { + cache: { + enable: true, + }, + }, }).config({ name: "sales_channels-fulfillment-query" }) const scFulfillmentSets = transform( @@ -193,13 +202,21 @@ export const listShippingOptionsForCartWithPricingWorkflow = createWorkflow( resultValidator: shippingOptionsContextResult, } ) - const setShippingOptionsContextResult = setShippingOptionsContext.getResult() + const setShippingOptionsContextResult = + setShippingOptionsContext.getResult() const commonOptions = transform( { input, cart, fulfillmentSetIds, setShippingOptionsContextResult }, - ({ input, cart, fulfillmentSetIds, setShippingOptionsContextResult }) => ({ + ({ + input, + cart, + fulfillmentSetIds, + setShippingOptionsContextResult, + }) => ({ context: { - ...(setShippingOptionsContextResult ? setShippingOptionsContextResult : {}), + ...(setShippingOptionsContextResult + ? setShippingOptionsContextResult + : {}), is_return: input.is_return ? "true" : "false", enabled_in_store: !isDefined(input.enabled_in_store) ? "true" diff --git a/packages/core/core-flows/src/cart/workflows/list-shipping-options-for-cart.ts b/packages/core/core-flows/src/cart/workflows/list-shipping-options-for-cart.ts index 8cb1570f57..d752128f07 100644 --- a/packages/core/core-flows/src/cart/workflows/list-shipping-options-for-cart.ts +++ b/packages/core/core-flows/src/cart/workflows/list-shipping-options-for-cart.ts @@ -168,6 +168,11 @@ export const listShippingOptionsForCartWorkflow = createWorkflow( "stock_locations.name", "stock_locations.address.*", ], + options: { + cache: { + enable: true, + }, + }, }).config({ name: "sales_channels-fulfillment-query" }) const scFulfillmentSets = transform( diff --git a/packages/core/core-flows/src/cart/workflows/update-cart.ts b/packages/core/core-flows/src/cart/workflows/update-cart.ts index 3d8e044dd0..647b06cc53 100644 --- a/packages/core/core-flows/src/cart/workflows/update-cart.ts +++ b/packages/core/core-flows/src/cart/workflows/update-cart.ts @@ -148,6 +148,9 @@ export const updateCartWorkflow = createWorkflow( options: { throwIfKeyNotFound: true, isList: false, + cache: { + enable: true, + }, }, }).config({ name: "get-region" }) diff --git a/packages/core/core-flows/src/order/workflows/create-order.ts b/packages/core/core-flows/src/order/workflows/create-order.ts index eeeb8a3432..f5296bb11c 100644 --- a/packages/core/core-flows/src/order/workflows/create-order.ts +++ b/packages/core/core-flows/src/order/workflows/create-order.ts @@ -247,6 +247,11 @@ export const createOrderWorkflow = createWorkflow( filters: { id: variantIdsWithoutCalculatedPrice, }, + options: { + cache: { + enable: true, + }, + }, }).config({ name: "query-variants-without-calculated-price" }) /** diff --git a/packages/core/framework/package.json b/packages/core/framework/package.json index 1c1782eadd..37a7ca3656 100644 --- a/packages/core/framework/package.json +++ b/packages/core/framework/package.json @@ -16,6 +16,7 @@ "exports": { ".": "./dist/index.js", "./config": "./dist/config/index.js", + "./caching": "./dist/caching/index.js", "./logger": "./dist/logger/index.js", "./database": "./dist/database/index.js", "./subscribers": "./dist/subscribers/index.js", diff --git a/packages/core/framework/src/http/middlewares/ensure-publishable-api-key.ts b/packages/core/framework/src/http/middlewares/ensure-publishable-api-key.ts index 3ce2918f9a..7d8af09d1a 100644 --- a/packages/core/framework/src/http/middlewares/ensure-publishable-api-key.ts +++ b/packages/core/framework/src/http/middlewares/ensure-publishable-api-key.ts @@ -30,20 +30,38 @@ export async function ensurePublishableApiKeyMiddleware( const query = req.scope.resolve(ContainerRegistrationKeys.QUERY) try { - const { data } = await query.graph({ - entity: "api_key", - fields: ["id", "token", "sales_channels_link.sales_channel_id"], - filters: { - token: publishableApiKey, - type: ApiKeyType.PUBLISHABLE, - $or: [ - { revoked_at: { $eq: null } }, - { revoked_at: { $gt: new Date() } }, + // Cache API key data and check revocation in memory + const { data } = await query.graph( + { + entity: "api_key", + fields: [ + "id", + "token", + "revoked_at", + "sales_channels_link.sales_channel_id", ], + filters: { + token: publishableApiKey, + type: ApiKeyType.PUBLISHABLE, + }, }, - }) + { + cache: { + enable: true, + }, + } + ) - apiKey = data[0] + if (data.length) { + const now = new Date() + const cachedApiKey = data[0] + const isRevoked = + !!cachedApiKey.revoked_at && new Date(cachedApiKey.revoked_at) <= now + + if (!isRevoked) { + apiKey = cachedApiKey + } + } } catch (e) { return next(e) } diff --git a/packages/core/framework/src/http/utils/refetch-entities.ts b/packages/core/framework/src/http/utils/refetch-entities.ts index 2e92cf9aa4..d18c7415e8 100644 --- a/packages/core/framework/src/http/utils/refetch-entities.ts +++ b/packages/core/framework/src/http/utils/refetch-entities.ts @@ -1,49 +1,88 @@ -import { MedusaContainer } from "@medusajs/types" -import { - ContainerRegistrationKeys, - isString, - remoteQueryObjectFromString, -} from "@medusajs/utils" -import { MedusaRequest } from "../types" +import type { + GraphResultSet, + MedusaContainer, + RemoteJoinerOptions, + RemoteQueryEntryPoints, + RemoteQueryFunctionReturnPagination, +} from "../../types" +import { ContainerRegistrationKeys, isString } from "../../utils" +import type { MedusaRequest } from "../types" -export const refetchEntities = async ( - entryPoint: string, - idOrFilter: string | object, - scope: MedusaContainer, - fields: string[], - pagination?: MedusaRequest["queryConfig"]["pagination"], +export const refetchEntities = async ({ + entity, + idOrFilter, + scope, + fields, + pagination, + withDeleted, + options, +}: { + entity: TEntry + idOrFilter?: string | object + scope: MedusaContainer + fields?: string[] + pagination?: MedusaRequest["queryConfig"]["pagination"] withDeleted?: boolean -) => { - const remoteQuery = scope.resolve(ContainerRegistrationKeys.REMOTE_QUERY) - const filters = isString(idOrFilter) ? { id: idOrFilter } : idOrFilter - let context: object = {} + options?: RemoteJoinerOptions +}): Promise< + Omit, "metadata"> & { + metadata: RemoteQueryFunctionReturnPagination + } +> => { + const query = scope.resolve(ContainerRegistrationKeys.QUERY) + let filters = isString(idOrFilter) ? { id: idOrFilter } : idOrFilter + let context!: Record - if ("context" in filters) { - if (filters.context) { - context = filters.context! + if (filters && "context" in filters) { + const { context: context_, ...rest } = filters + if (context_) { + context = context_! as Record } - - delete filters.context + filters = rest } - const variables = { filters, ...context, ...pagination, withDeleted } + const graphOptions: Parameters[0] = { + entity, + fields: fields ?? [], + filters, + pagination, + withDeleted, + context: context, + } - const queryObject = remoteQueryObjectFromString({ - entryPoint, - variables, + const result = await query.graph(graphOptions, options) + return { + data: result.data as TEntry extends keyof RemoteQueryEntryPoints + ? RemoteQueryEntryPoints[TEntry][] + : any[], + metadata: result.metadata ?? ({} as RemoteQueryFunctionReturnPagination), + } +} + +export const refetchEntity = async ({ + entity, + idOrFilter, + scope, + fields, + options, +}: { + entity: TEntry & string + idOrFilter: string | object + scope: MedusaContainer + fields: string[] + options?: RemoteJoinerOptions +}): Promise< + TEntry extends keyof RemoteQueryEntryPoints + ? RemoteQueryEntryPoints[TEntry] + : any +> => { + const { data } = await refetchEntities({ + entity, + idOrFilter, + scope, fields, + options, }) - return await remoteQuery(queryObject) -} - -export const refetchEntity = async ( - entryPoint: string, - idOrFilter: string | object, - scope: MedusaContainer, - fields: string[] -) => { - const [entity] = await refetchEntities(entryPoint, idOrFilter, scope, fields) - - return entity + return Array.isArray(data) ? data[0] : data } diff --git a/packages/core/framework/src/types/container.ts b/packages/core/framework/src/types/container.ts index fbe1ea8df9..3ff653324b 100644 --- a/packages/core/framework/src/types/container.ts +++ b/packages/core/framework/src/types/container.ts @@ -5,6 +5,7 @@ import { IApiKeyModuleService, IAuthModuleService, ICacheService, + ICachingModuleService, ICartModuleService, ICurrencyModuleService, ICustomerModuleService, @@ -76,6 +77,7 @@ declare module "@medusajs/types" { [Modules.NOTIFICATION]: INotificationModuleService [Modules.LOCKING]: ILockingModule [Modules.SETTINGS]: ISettingsModuleService + [Modules.CACHING]: ICachingModuleService } } diff --git a/packages/core/modules-sdk/src/loaders/utils/__tests__/load-internal.spec.ts b/packages/core/modules-sdk/src/loaders/utils/__tests__/load-internal.spec.ts index e6474149f9..a3d81a657c 100644 --- a/packages/core/modules-sdk/src/loaders/utils/__tests__/load-internal.spec.ts +++ b/packages/core/modules-sdk/src/loaders/utils/__tests__/load-internal.spec.ts @@ -244,6 +244,7 @@ describe("load internal", () => { expect(generatedJoinerConfig).toEqual({ serviceName: "module-without-joiner-config", primaryKeys: ["id"], + idPrefixToEntityName: {}, linkableKeys: { entity2_id: "Entity2", entity_model_id: "EntityModel", @@ -322,6 +323,7 @@ describe("load internal", () => { expect(generatedJoinerConfig).toEqual({ serviceName: "module-service", primaryKeys: ["id"], + idPrefixToEntityName: {}, linkableKeys: {}, schema: "", alias: [ diff --git a/packages/core/modules-sdk/src/loaders/utils/load-internal.ts b/packages/core/modules-sdk/src/loaders/utils/load-internal.ts index 5f2f0f8d77..f37d2c7f8a 100644 --- a/packages/core/modules-sdk/src/loaders/utils/load-internal.ts +++ b/packages/core/modules-sdk/src/loaders/utils/load-internal.ts @@ -231,7 +231,8 @@ export async function loadInternalModule(args: { ContainerRegistrationKeys.CONFIG_MODULE, ContainerRegistrationKeys.LOGGER, ContainerRegistrationKeys.PG_CONNECTION, - Modules.EVENT_BUS + Modules.EVENT_BUS, + Modules.CACHING ) for (const dependency of dependencies) { diff --git a/packages/core/modules-sdk/src/medusa-app.ts b/packages/core/modules-sdk/src/medusa-app.ts index 5a6848e0ee..9672dd995d 100644 --- a/packages/core/modules-sdk/src/medusa-app.ts +++ b/packages/core/modules-sdk/src/medusa-app.ts @@ -609,6 +609,7 @@ async function MedusaApp_({ query: createQuery({ remoteQuery, indexModule, + container: sharedContainer_, }) as any, // TODO: rm any once we remove the old RemoteQueryFunction and rely on the Query object instead, entitiesMap, gqlSchema: schema, diff --git a/packages/core/modules-sdk/src/remote-query/query.ts b/packages/core/modules-sdk/src/remote-query/query.ts index 3be1be2391..b7dbddce1c 100644 --- a/packages/core/modules-sdk/src/remote-query/query.ts +++ b/packages/core/modules-sdk/src/remote-query/query.ts @@ -1,6 +1,7 @@ import { GraphResultSet, IIndexService, + MedusaContainer, RemoteJoinerOptions, RemoteJoinerQuery, RemoteQueryFilters, @@ -11,6 +12,7 @@ import { RemoteQueryObjectFromStringResult, } from "@medusajs/types" import { + Cached, MedusaError, isObject, remoteQueryObjectFromString, @@ -19,12 +21,62 @@ import { import { RemoteQuery } from "./remote-query" import { toRemoteQuery } from "./to-remote-query" +function extractCacheOptions(option: string) { + return function extractKey(args: any[]) { + return args[1]?.cache?.[option] + } +} + +function isCacheEnabled(args: any[]) { + const isEnabled = extractCacheOptions("enable")(args) + if (isEnabled === false) { + return false + } + + return ( + isEnabled === true || + extractCacheOptions("key")(args) || + extractCacheOptions("ttl")(args) || + extractCacheOptions("tags")(args) || + extractCacheOptions("autoInvalidate")(args) || + extractCacheOptions("providers")(args) + ) +} + +const cacheDecoratorOptions = { + enable: isCacheEnabled, + key: async (args, cachingModule) => { + const key = extractCacheOptions("key")(args) + if (key) { + return key + } + + const queryOptions = args[0] + const remoteJoinerOptions = args[1] ?? {} + const { initialData, cache, ...restOptions } = remoteJoinerOptions + + const keyInput = { + queryOptions, + options: restOptions, + } + return await cachingModule.computeKey(keyInput) + }, + ttl: extractCacheOptions("ttl"), + tags: extractCacheOptions("tags"), + autoInvalidate: extractCacheOptions("autoInvalidate"), + providers: extractCacheOptions("providers"), + container: function (this: Query) { + return this.container + }, +} + /** * API wrapper around the remoteQuery */ export class Query { #remoteQuery: RemoteQuery #indexModule: IIndexService + protected container: MedusaContainer /** * Method to wrap execution of the graph query for instrumentation @@ -61,12 +113,15 @@ export class Query { constructor({ remoteQuery, indexModule, + container, }: { remoteQuery: RemoteQuery indexModule: IIndexService + container: MedusaContainer }) { this.#remoteQuery = remoteQuery this.#indexModule = indexModule + this.container = container } #unwrapQueryConfig( @@ -151,6 +206,7 @@ export class Query { * Graph function uses the remoteQuery under the hood and * returns a result set */ + @Cached(cacheDecoratorOptions) async graph( queryOptions: RemoteQueryInput, options?: RemoteJoinerOptions @@ -189,6 +245,7 @@ export class Query { * Index function uses the Index module to query and hydrates the data with query.graph * returns a result set */ + @Cached(cacheDecoratorOptions) async index( queryOptions: RemoteQueryInput & { joinFilters?: RemoteQueryFilters @@ -266,13 +323,16 @@ export class Query { export function createQuery({ remoteQuery, indexModule, + container, }: { remoteQuery: RemoteQuery indexModule: IIndexService + container: MedusaContainer }) { const query = new Query({ remoteQuery, indexModule, + container, }) function backwardCompatibleQuery(...args: any[]) { diff --git a/packages/core/modules-sdk/src/remote-query/to-remote-query.ts b/packages/core/modules-sdk/src/remote-query/to-remote-query.ts index 4877ce07f4..e179ecfe20 100644 --- a/packages/core/modules-sdk/src/remote-query/to-remote-query.ts +++ b/packages/core/modules-sdk/src/remote-query/to-remote-query.ts @@ -69,8 +69,7 @@ export function toRemoteQuery( } if (QueryContext.isQueryContext(src)) { - const normalizedFilters = { ...src } as any - delete normalizedFilters.__type + const { __type, ...normalizedFilters } = src as any const prop = "context" @@ -100,7 +99,7 @@ export function toRemoteQuery( } // Process filters and context recursively - processNestedObjects(joinerQuery[entity], context) + processNestedObjects(joinerQuery[entity], context, true) for (const field of fields) { const fieldAsString = field as string diff --git a/packages/core/types/src/bundles.ts b/packages/core/types/src/bundles.ts index ebc17f67b2..ab16f34014 100644 --- a/packages/core/types/src/bundles.ts +++ b/packages/core/types/src/bundles.ts @@ -3,6 +3,7 @@ export * as AnalyticsTypes from "./analytics" export * as ApiKeyTypes from "./api-key" export * as AuthTypes from "./auth" export * as CacheTypes from "./cache" +export * as CachingTypes from "./caching" export * as CartTypes from "./cart" export * as CommonTypes from "./common" export * as CurrencyTypes from "./currency" diff --git a/packages/core/types/src/caching/index.ts b/packages/core/types/src/caching/index.ts new file mode 100644 index 0000000000..7a1274f483 --- /dev/null +++ b/packages/core/types/src/caching/index.ts @@ -0,0 +1,161 @@ +import { IModuleService, ModuleJoinerConfig } from "../modules-sdk" + +type Providers = string[] | { id: string; ttl?: number }[] + +export interface ICachingModuleService extends IModuleService { + // Static trace methods + // traceGet: ( + // cacheGetFn: () => Promise, + // key: string, + // tags: string[] + // ) => Promise + + // traceSet?: ( + // cacheSetFn: () => Promise, + // key: string, + // tags: string[], + // options: { autoInvalidate?: boolean } + // ) => Promise + + // traceClear?: ( + // cacheClearFn: () => Promise, + // key: string, + // tags: string[], + // options: { autoInvalidate?: boolean } + // ) => Promise + + /** + * This method retrieves data from the cache. + * + * @param key - The key of the item to retrieve. + * @param tags - The tags of the items to retrieve. + * @param providers - Array of providers to check in order of priority. If not provided, + * only the default provider will be used. + * + * @returns The item(s) that was stored in the cache. If the item(s) was not found, null will + * be returned. + * + */ + get({ + key, + tags, + providers, + }: { + key?: string + tags?: string[] + providers?: string[] + }): Promise + + /** + * This method stores data in the cache. + * + * @param key - The key of the item to store. + * @param data - The data to store in the cache. + * @param ttl - The time-to-live (TTL in seconds) value in seconds. If not provided, the default TTL value + * is used. The default value is based on the used Cache Module. + * @param tags - The tags of the items to store. can be used for cross invalidation. + * @param options - if specified, will be stored with the item(s). + * @param providers - The providers from which to store the item(s). + * + */ + set({ + key, + data, + ttl, + tags, + options, + providers, + }: { + key: string + data: object + ttl?: number + tags?: string[] + options?: { + autoInvalidate?: boolean + } + providers?: Providers + }): Promise + + /** + * This method clears data from the cache. + * + * @param key - The key of the item to clear. + * @param tags - The tags of the items to clear. + * @param options - if specified, invalidate the item(s) that has the value of the given + * options stored. e.g you can invalidate the tags X if their options.autoInvalidate is false or not present. + * @param providers - The providers from which to clear the item(s). + * + */ + clear({ + key, + tags, + options, + providers, + }: { + key?: string + tags?: string[] + options?: { + autoInvalidate?: boolean + } + providers?: string[] + }): Promise + + computeKey(input: object): Promise + + computeTags(input: object, options?: Record): Promise +} + +export interface ICachingProviderService { + get({ key, tags }: { key?: string; tags?: string[] }): Promise + set({ + key, + data, + ttl, + tags, + options, + }: { + key: string + data: object + ttl?: number + tags?: string[] + options?: { autoInvalidate?: boolean } + }): Promise + clear({ + key, + tags, + options, + }: { + key?: string + tags?: string[] + options?: { autoInvalidate?: boolean } + }): Promise +} + +export interface EntityReference { + type: string + id: string | number + field?: string +} + +export interface ICachingStrategy { + /** + * This method is called when the application starts. It can be useful to set up some auto + * invalidation logic that reacts to something. + * + * @param container MedusaContainer + * @param schema GraphQLSchema + * @param cacheModule ICachingModuleService + */ + onApplicationStart?( + schema: any, + joinerConfigs: ModuleJoinerConfig[] + ): Promise + + onApplicationPrepareShutdown?(): Promise + + onApplicationShutdown?(): Promise + + computeKey(input: object): Promise + + computeTags(input: object, options?: Record): Promise +} diff --git a/packages/core/types/src/event-bus/common.ts b/packages/core/types/src/event-bus/common.ts index ba8e2ec06b..0cc3953b6b 100644 --- a/packages/core/types/src/event-bus/common.ts +++ b/packages/core/types/src/event-bus/common.ts @@ -58,3 +58,8 @@ export type RawMessageFormat = { context?: Pick options?: Record } + +export type InterceptorSubscriber = ( + message: Message, + context?: { isGrouped?: boolean; eventGroupId?: string } +) => Promise | void diff --git a/packages/core/types/src/event-bus/event-bus-module.ts b/packages/core/types/src/event-bus/event-bus-module.ts index 27604e8442..5b8c99df7a 100644 --- a/packages/core/types/src/event-bus/event-bus-module.ts +++ b/packages/core/types/src/event-bus/event-bus-module.ts @@ -1,4 +1,9 @@ -import { Message, Subscriber, SubscriberContext } from "./common" +import { + InterceptorSubscriber, + Message, + Subscriber, + SubscriberContext, +} from "./common" export interface IEventBusModuleService { /** @@ -86,4 +91,31 @@ export interface IEventBusModuleService { eventNames?: string[] } ): Promise + + /** + * This method adds an interceptor to the event bus. This means that the interceptor will be + * called before the event is emitted. + * + * @param interceptor - The interceptor to add. + * @returns The instance of the Event Module + * + * @example + * eventModuleService.addInterceptor((message, context) => { + * console.log("Interceptor", message, context) + * }) + */ + addInterceptor?(interceptor: InterceptorSubscriber): this + + /** + * This method removes an interceptor from the event bus. + * + * @param interceptor - The interceptor to remove. + * @returns The instance of the Event Module + * + * @example + * eventModuleService.removeInterceptor((message, context) => { + * console.log("Interceptor", message, context) + * }) + */ + removeInterceptor?(interceptor: InterceptorSubscriber): this } diff --git a/packages/core/types/src/index.ts b/packages/core/types/src/index.ts index ec4b483c91..0bb67735f5 100644 --- a/packages/core/types/src/index.ts +++ b/packages/core/types/src/index.ts @@ -5,6 +5,7 @@ export * from "./api-key" export * from "./auth" export * from "./bundles" export * from "./cache" +export * from "./caching" export * from "./cart" export * from "./common" export * from "./currency" diff --git a/packages/core/types/src/joiner/index.ts b/packages/core/types/src/joiner/index.ts index 641e3a979e..c4f04b5183 100644 --- a/packages/core/types/src/joiner/index.ts +++ b/packages/core/types/src/joiner/index.ts @@ -1,3 +1,5 @@ +import { ICachingModuleService } from "../caching" + export type JoinerRelationship = { alias: string foreignKey: string @@ -92,6 +94,40 @@ export interface RemoteJoinerOptions { throwIfRelationNotFound?: boolean | string[] initialData?: object | object[] initialDataOnly?: boolean + cache?: { + /** + * Whether to enable the cache. This is only useful if you want to enable without providing any + * other options or if you want to enable/disable the cache based on the arguments. + */ + enable?: boolean | ((args: any[]) => boolean | undefined) + /** + * The key to use for the cache. + * If a function is provided, it will be called with the arguments as the first argument and the + * container as the second argument. + */ + key?: + | string + | (( + args: any[], + cachingModule: ICachingModuleService + ) => string | Promise) + /** + * The tags to use for the cache. + */ + tags?: string[] | ((args: any[]) => string[] | undefined) + /** + * The time-to-live (TTL) value in seconds. + */ + ttl?: number | ((args: any[]) => number | undefined) + /** + * Whether to auto invalidate the cache whenever it is possible. + */ + autoInvalidate?: boolean | ((args: any[]) => boolean | undefined) + /** + * The providers to use for the cache. + */ + providers?: string[] | ((args: any[]) => string[] | undefined) + } } export interface RemoteNestedExpands { diff --git a/packages/core/types/src/modules-sdk/index.ts b/packages/core/types/src/modules-sdk/index.ts index 5a8c442695..3166a552d3 100644 --- a/packages/core/types/src/modules-sdk/index.ts +++ b/packages/core/types/src/modules-sdk/index.ts @@ -196,6 +196,7 @@ export type ModuleJoinerConfig = Omit< * GraphQL schema for the all module's available entities and fields */ schema?: string + idPrefixToEntityName?: Record relationships?: ModuleJoinerRelationship[] extends?: { serviceName: string diff --git a/packages/core/types/src/modules-sdk/remote-query.ts b/packages/core/types/src/modules-sdk/remote-query.ts index a368681749..34aa76f593 100644 --- a/packages/core/types/src/modules-sdk/remote-query.ts +++ b/packages/core/types/src/modules-sdk/remote-query.ts @@ -48,9 +48,10 @@ export type QueryGraphFunction = { * a normalized/consistent output. */ export type QueryIndexFunction = { - (queryOptions: IndexQueryInput): Promise< - Prettify> - > + ( + queryOptions: IndexQueryInput, + options?: RemoteJoinerOptions + ): Promise>> } /*export type RemoteQueryReturnedData = diff --git a/packages/core/types/src/pricing/common/price-rule.ts b/packages/core/types/src/pricing/common/price-rule.ts index 1b216f7a2b..22c3a38dae 100644 --- a/packages/core/types/src/pricing/common/price-rule.ts +++ b/packages/core/types/src/pricing/common/price-rule.ts @@ -132,7 +132,6 @@ export interface FilterablePriceRuleProps * The IDs to filter the price rule's associated price set. */ price_set_id?: string | string[] | OperatorMap - /** * The IDs to filter the price rule's associated price. */ diff --git a/packages/core/utils/src/bundles.ts b/packages/core/utils/src/bundles.ts index cee955c1de..6e7d41ad29 100644 --- a/packages/core/utils/src/bundles.ts +++ b/packages/core/utils/src/bundles.ts @@ -18,3 +18,4 @@ export * as PromotionUtils from "./promotion" export * as SearchUtils from "./search" export * as ShippingProfileUtils from "./shipping" export * as UserUtils from "./user" +export * as CachingUtils from "./caching" diff --git a/packages/core/utils/src/caching/index.ts b/packages/core/utils/src/caching/index.ts new file mode 100644 index 0000000000..c9fc045ecc --- /dev/null +++ b/packages/core/utils/src/caching/index.ts @@ -0,0 +1,249 @@ +import { ICachingModuleService, Logger, MedusaContainer } from "@medusajs/types" +import { MedusaContextType, Modules } from "../modules-sdk" +import { FeatureFlag } from "../feature-flags" +import { ContainerRegistrationKeys, isObject } from "../common" + +/** + * This function is used to cache the result of a function call. + * + * @param cb - The callback to execute. + * @param options - The options for the cache. + * @returns The result of the callback. + */ +export async function useCache( + cb: (...args: any[]) => Promise, + options: { + enable?: boolean + key: string | any[] + tags?: string[] + ttl?: number + /** + * Whethere the default strategy should auto invalidate the cache whenever it is possible. + */ + autoInvalidate?: boolean + providers?: string[] + container: MedusaContainer + } +): Promise { + const cachingModule = options.container.resolve( + Modules.CACHING, + { + allowUnregistered: true, + } + ) + + if ( + !options.enable || + !FeatureFlag.isFeatureEnabled("caching") || + !cachingModule + ) { + return await cb() + } + + let key: string + if (typeof options.key === "string") { + key = options.key + } else { + key = await cachingModule.computeKey(options.key) + } + + const data = await cachingModule.get({ + key, + tags: options.tags, + providers: options.providers, + }) + + if (data) { + return data as T + } + + const result = await cb() + + void cachingModule + .set({ + key, + tags: options.tags, + ttl: options.ttl, + data: result as object, + options: { autoInvalidate: options.autoInvalidate }, + providers: options.providers, + }) + .catch((e) => { + const logger = + options.container.resolve(ContainerRegistrationKeys.LOGGER, { + allowUnregistered: true, + }) ?? (console as unknown as Logger) + logger.error( + `An error occured while setting cache for key: ${key}\n${e.message}\n${e.stack}` + ) + }) + + return result +} + +type TargetMethodArgs = Target[PropertyKey & + keyof Target] extends (...args: any[]) => any + ? Parameters + : never + +/** + * This function is used to cache the result of a method call. + * + * @param options - The options for the cache. + * @returns The original method with the cache applied. + */ +export function Cached< + const Target extends object, + const PropertyKey extends keyof Target +>(options: { + /** + * The key to use for the cache. + * If a function is provided, it will be called with the arguments as the first argument and the + * container as the second argument. + */ + key?: + | string + | (( + args: TargetMethodArgs, + cachingModule: ICachingModuleService + ) => string | Promise | Promise | any[]) + /** + * Whether to enable the cache. This is only useful if you want to enable without providing any + * other options. + */ + enable?: + | boolean + | ((args: TargetMethodArgs) => boolean | undefined) + /** + * The tags to use for the cache. + */ + tags?: + | string[] + | ((args: TargetMethodArgs) => string[] | undefined) + /** + * The time-to-live (TTL) value in seconds. + */ + ttl?: + | number + | ((args: TargetMethodArgs) => number | undefined) + /** + * Whether to auto invalidate the cache whenever it is possible. + */ + autoInvalidate?: + | boolean + | ((args: TargetMethodArgs) => boolean | undefined) + /** + * The providers to use for the cache. + */ + providers?: + | string[] + | ((args: TargetMethodArgs) => string[] | undefined) + + container: MedusaContainer | ((this: Target) => MedusaContainer) +}) { + return function ( + target: Target, + propertyKey: PropertyKey, + descriptor: PropertyDescriptor + ) { + const originalMethod = descriptor.value + + if (typeof originalMethod !== "function") { + throw new Error("@cached can only be applied to methods") + } + + descriptor.value = async function ( + ...args: Target[PropertyKey & keyof Target] extends ( + ...args: any[] + ) => any + ? Parameters + : never + ) { + const container: MedusaContainer = + typeof options.container === "function" + ? options.container.call(this) + : options.container + + const cachingModule = container.resolve( + Modules.CACHING, + { + allowUnregistered: true, + } + ) + + if (!FeatureFlag.isFeatureEnabled("caching") || !cachingModule) { + return await originalMethod.apply(this, args) + } + + if (!options.key) { + options.key = await cachingModule.computeKey( + args + .map((arg) => { + if (isObject(arg)) { + // Prevent any container, manager, transactionManager, etc from being included in the key + const { + container, + manager, + transactionManager, + __type, + ...rest + } = arg as any + if (__type === MedusaContextType) { + return + } + return rest + } + return arg + }) + .filter(Boolean) + ) + } + + const resolvableKeys = [ + "enable", + "key", + "tags", + "ttl", + "autoInvalidate", + "providers", + ] + + const cacheOptions = {} as Parameters[1] + + const promises: Promise[] = [] + for (const key of resolvableKeys) { + if (typeof options[key] === "function") { + const res = options[key](args, cachingModule) + if (res instanceof Promise) { + promises.push( + res.then((value) => { + cacheOptions[key] = value + }) + ) + } else { + cacheOptions[key] = res + } + } else { + cacheOptions[key] = options[key] + } + } + + await Promise.all(promises) + + if (!cacheOptions.enable) { + return await originalMethod.apply(this, args) + } + + Object.assign(cacheOptions, { + container, + }) + + return await useCache( + () => originalMethod.apply(this, args), + cacheOptions as Parameters[1] + ) + } + + return descriptor + } +} diff --git a/packages/core/utils/src/common/create-container-like.ts b/packages/core/utils/src/common/create-container-like.ts index 445c6adbfe..480191a587 100644 --- a/packages/core/utils/src/common/create-container-like.ts +++ b/packages/core/utils/src/common/create-container-like.ts @@ -2,7 +2,22 @@ import { ContainerLike } from "@medusajs/types" export function createContainerLike(obj): ContainerLike { return { - resolve(key: string) { + resolve( + key: string, + { + allowUnregistered = false, + }: { + allowUnregistered?: boolean + } = {} + ) { + if (allowUnregistered) { + try { + return obj[key] + } catch (error) { + return undefined + } + } + return obj[key] }, } diff --git a/packages/core/utils/src/dml/properties/id.ts b/packages/core/utils/src/dml/properties/id.ts index a2d7bbf30d..73f710e6a5 100644 --- a/packages/core/utils/src/dml/properties/id.ts +++ b/packages/core/utils/src/dml/properties/id.ts @@ -14,7 +14,7 @@ export class IdProperty extends BaseProperty { return !!value?.[IsIdProperty] || value?.dataType?.name === "id" } - protected dataType: { + dataType: { name: "id" options: { prefix?: string diff --git a/packages/core/utils/src/dml/properties/primary-key.ts b/packages/core/utils/src/dml/properties/primary-key.ts index 77ed360fee..51b931ee46 100644 --- a/packages/core/utils/src/dml/properties/primary-key.ts +++ b/packages/core/utils/src/dml/properties/primary-key.ts @@ -24,6 +24,10 @@ export class PrimaryKeyModifier> */ #schema: Schema + get schema() { + return this.#schema + } + constructor(schema: Schema) { this.#schema = schema } diff --git a/packages/core/utils/src/event-bus/index.ts b/packages/core/utils/src/event-bus/index.ts index cb9b10ec4f..5cdb889bd4 100644 --- a/packages/core/utils/src/event-bus/index.ts +++ b/packages/core/utils/src/event-bus/index.ts @@ -1,4 +1,8 @@ -import { EventBusTypes, InternalModuleDeclaration } from "@medusajs/types" +import { + EventBusTypes, + InterceptorSubscriber, + InternalModuleDeclaration, +} from "@medusajs/types" import { ulid } from "ulid" export abstract class AbstractEventBusModuleService @@ -11,6 +15,8 @@ export abstract class AbstractEventBusModuleService EventBusTypes.SubscriberDescriptor[] > = new Map() + protected interceptorSubscribers_: Set = new Set() + public get eventToSubscribersMap(): Map< string | symbol, EventBusTypes.SubscriberDescriptor[] @@ -134,6 +140,49 @@ export abstract class AbstractEventBusModuleService return this } + + /** + * Add an interceptor subscriber that receives all messages before they are emitted + * + * @param interceptor - Function that receives messages before emission + * @returns this for chaining + */ + public addInterceptor(interceptor: InterceptorSubscriber): this { + this.interceptorSubscribers_.add(interceptor) + return this + } + + /** + * Remove an interceptor subscriber + * + * @param interceptor - Function to remove from interceptors + * @returns this for chaining + */ + public removeInterceptor(interceptor: InterceptorSubscriber): this { + this.interceptorSubscribers_.delete(interceptor) + return this + } + + /** + * Call all interceptor subscribers with the message before emission + * This should be called by implementations before emitting events + * + * @param message - The message to be intercepted + * @param context - Optional context about the emission + */ + protected async callInterceptors( + message: EventBusTypes.Message, + context?: { isGrouped?: boolean; eventGroupId?: string } + ): Promise { + Array.from(this.interceptorSubscribers_).map(async (interceptor) => { + try { + await interceptor(message, context) + } catch (error) { + // Log error but don't stop other interceptors or the emission + console.error("Error in event bus interceptor:", error) + } + }) + } } export * from "./build-event-messages" diff --git a/packages/core/utils/src/index.ts b/packages/core/utils/src/index.ts index 44f1308c83..e53e8b53be 100644 --- a/packages/core/utils/src/index.ts +++ b/packages/core/utils/src/index.ts @@ -29,6 +29,7 @@ export * from "./shipping" export * from "./totals" export * from "./totals/big-number" export * from "./user" +export * from "./caching" export const MedusaModuleType = Symbol.for("MedusaModule") export const MedusaModuleProviderType = Symbol.for("MedusaModuleProvider") diff --git a/packages/core/utils/src/modules-sdk/__tests__/joiner-config-builder.spec.ts b/packages/core/utils/src/modules-sdk/__tests__/joiner-config-builder.spec.ts index ed0f50c1f3..c9334f28dc 100644 --- a/packages/core/utils/src/modules-sdk/__tests__/joiner-config-builder.spec.ts +++ b/packages/core/utils/src/modules-sdk/__tests__/joiner-config-builder.spec.ts @@ -48,6 +48,7 @@ describe("joiner-config-builder", () => { serviceName: Modules.FULFILLMENT, primaryKeys: ["id"], schema: "", + idPrefixToEntityName: {}, linkableKeys: { fulfillment_set_id: FulfillmentSet.name, shipping_option_id: ShippingOption.name, @@ -136,6 +137,7 @@ describe("joiner-config-builder", () => { serviceName: Modules.FULFILLMENT, primaryKeys: ["id"], schema: "", + idPrefixToEntityName: {}, linkableKeys: {}, alias: [ { @@ -176,6 +178,7 @@ describe("joiner-config-builder", () => { serviceName: Modules.FULFILLMENT, primaryKeys: ["id"], schema: "", + idPrefixToEntityName: {}, linkableKeys: { fulfillment_set_id: FulfillmentSet.name, shipping_option_id: ShippingOption.name, @@ -269,6 +272,7 @@ describe("joiner-config-builder", () => { serviceName: Modules.FULFILLMENT, primaryKeys: ["id"], schema: "", + idPrefixToEntityName: {}, linkableKeys: {}, alias: [ { @@ -300,6 +304,7 @@ describe("joiner-config-builder", () => { serviceName: Modules.FULFILLMENT, primaryKeys: ["id"], schema: "", + idPrefixToEntityName: {}, linkableKeys: { fulfillment_set_id: FulfillmentSet.name, }, @@ -335,6 +340,7 @@ describe("joiner-config-builder", () => { serviceName: Modules.FULFILLMENT, primaryKeys: ["id"], schema: expect.any(String), + idPrefixToEntityName: {}, linkableKeys: { fulfillment_set_id: FulfillmentSet.name, shipping_option_id: ShippingOption.name, diff --git a/packages/core/utils/src/modules-sdk/definition.ts b/packages/core/utils/src/modules-sdk/definition.ts index 44f79d6b70..a642f3917a 100644 --- a/packages/core/utils/src/modules-sdk/definition.ts +++ b/packages/core/utils/src/modules-sdk/definition.ts @@ -27,6 +27,7 @@ export const Modules = { INDEX: "index", LOCKING: "locking", SETTINGS: "settings", + CACHING: "caching", } as const export const MODULE_PACKAGE_NAMES = { @@ -58,6 +59,7 @@ export const MODULE_PACKAGE_NAMES = { [Modules.INDEX]: "@medusajs/medusa/index-module", [Modules.LOCKING]: "@medusajs/medusa/locking", [Modules.SETTINGS]: "@medusajs/medusa/settings", + [Modules.CACHING]: "@medusajs/caching", } export const REVERSED_MODULE_PACKAGE_NAMES = Object.entries( diff --git a/packages/core/utils/src/modules-sdk/joiner-config-builder.ts b/packages/core/utils/src/modules-sdk/joiner-config-builder.ts index c71a28403d..75d74cbb98 100644 --- a/packages/core/utils/src/modules-sdk/joiner-config-builder.ts +++ b/packages/core/utils/src/modules-sdk/joiner-config-builder.ts @@ -46,8 +46,10 @@ export function defineJoinerConfig( models, linkableKeys, primaryKeys, + idPrefixToEntityName, }: { alias?: JoinerServiceConfigAlias[] + idPrefixToEntityName?: Record schema?: string models?: DmlEntity[] | { name: string }[] linkableKeys?: ModuleJoinerConfig["linkableKeys"] @@ -150,6 +152,12 @@ export function defineJoinerConfig( schema = toGraphQLSchema([...modelDefinitions.values()]) } + if (!idPrefixToEntityName) { + idPrefixToEntityName = buildIdPrefixToEntityNameFromDmlObjects([ + ...modelDefinitions.values(), + ]) + } + const linkableKeysFromDml = buildLinkableKeysFromDmlObjects([ ...modelDefinitions.values(), ]) @@ -199,6 +207,7 @@ export function defineJoinerConfig( serviceName, primaryKeys, schema, + idPrefixToEntityName, linkableKeys: linkableKeys, alias: [ ...[...(alias ?? ([] as any))].map((alias) => ({ @@ -230,6 +239,31 @@ export function defineJoinerConfig( } } +/** + * Build the id prefix to entity name map from the DML objects + * @param models + */ +export function buildIdPrefixToEntityNameFromDmlObjects( + models: DmlEntity[] +): Record { + return models.reduce((acc, model) => { + const id = model.parse().schema.id as + | IdProperty + | PrimaryKeyModifier + + if ( + PrimaryKeyModifier.isPrimaryKeyModifier(id) && + id.schema.dataType.options.prefix + ) { + acc[id.schema.dataType.options.prefix] = model.name + } else if (IdProperty.isIdProperty(id) && id.dataType.options.prefix) { + acc[id.dataType.options.prefix] = model.name + } + + return acc + }, {}) +} + /** * From a set of DML objects, build the linkable keys * diff --git a/packages/core/utils/src/modules-sdk/module.ts b/packages/core/utils/src/modules-sdk/module.ts index d2347a00d1..c154d245fe 100644 --- a/packages/core/utils/src/modules-sdk/module.ts +++ b/packages/core/utils/src/modules-sdk/module.ts @@ -1,6 +1,7 @@ import { Constructor, IDmlEntity, ModuleExports } from "@medusajs/types" import { DmlEntity } from "../dml" import { + buildIdPrefixToEntityNameFromDmlObjects, buildLinkConfigFromLinkableKeys, buildLinkConfigFromModelObjects, defineJoinerConfig, @@ -53,6 +54,10 @@ export function Module< // TODO: Add support for non linkable modifier DML object to be skipped from the linkable generation const linkableKeys = service.prototype.__joinerConfig().linkableKeys + service.prototype.__joinerConfig().idPrefixToEntityName = + buildIdPrefixToEntityNameFromDmlObjects( + dmlObjects.map(([, model]) => model) as DmlEntity[] + ) if (dmlObjects.length) { linkable = buildLinkConfigFromModelObjects( diff --git a/packages/core/utils/src/product/get-variant-availability.ts b/packages/core/utils/src/product/get-variant-availability.ts index aba2b170f8..f17091ffc5 100644 --- a/packages/core/utils/src/product/get-variant-availability.ts +++ b/packages/core/utils/src/product/get-variant-availability.ts @@ -155,18 +155,25 @@ const getDataForComputation = async ( query: Omit, data: { variant_ids: string[]; sales_channel_id?: string } ) => { - const { data: variantInventoryItems } = await query.graph({ - entity: "product_variant_inventory_items", - fields: [ - "variant_id", - "required_quantity", - "variant.manage_inventory", - "variant.allow_backorder", - "inventory.*", - "inventory.location_levels.*", - ], - filters: { variant_id: data.variant_ids }, - }) + const { data: variantInventoryItems } = await query.graph( + { + entity: "product_variant_inventory_items", + fields: [ + "variant_id", + "required_quantity", + "variant.manage_inventory", + "variant.allow_backorder", + "inventory.*", + "inventory.location_levels.*", + ], + filters: { variant_id: data.variant_ids }, + }, + { + cache: { + enable: true, + }, + } + ) const variantInventoriesMap = new Map() variantInventoryItems.forEach((link) => { @@ -177,11 +184,21 @@ const getDataForComputation = async ( const locationIds = new Set() if (data.sales_channel_id) { - const { data: channelLocations } = await query.graph({ - entity: "sales_channel_locations", - fields: ["stock_location_id"], - filters: { sales_channel_id: data.sales_channel_id }, - }) + const { data: channelLocations } = await query.graph( + { + entity: "sales_channel_locations", + fields: ["stock_location_id"], + filters: { sales_channel_id: data.sales_channel_id }, + }, + { + cache: { + tags: [ + `SalesChannel:${data.sales_channel_id}`, + "StockLocation:list:*", + ], + }, + } + ) channelLocations.forEach((loc) => locationIds.add(loc.stock_location_id)) } diff --git a/packages/medusa/package.json b/packages/medusa/package.json index 58171be6fc..8757e8b6d4 100644 --- a/packages/medusa/package.json +++ b/packages/medusa/package.json @@ -80,6 +80,8 @@ "@medusajs/auth-google": "2.10.3", "@medusajs/cache-inmemory": "2.10.3", "@medusajs/cache-redis": "2.10.3", + "@medusajs/caching": "2.10.3", + "@medusajs/caching-redis": "2.10.3", "@medusajs/cart": "2.10.3", "@medusajs/core-flows": "2.10.3", "@medusajs/currency": "2.10.3", diff --git a/packages/medusa/src/api/admin/claims/[id]/inbound/items/[action_id]/route.ts b/packages/medusa/src/api/admin/claims/[id]/inbound/items/[action_id]/route.ts index b5e07bfd97..bce69d57e1 100644 --- a/packages/medusa/src/api/admin/claims/[id]/inbound/items/[action_id]/route.ts +++ b/packages/medusa/src/api/admin/claims/[id]/inbound/items/[action_id]/route.ts @@ -67,10 +67,12 @@ export const DELETE = async ( ) => { const { id, action_id } = req.params - const claim = await refetchEntity("order_claim", id, req.scope, [ - "id", - "return_id", - ]) + const claim = await refetchEntity({ + entity: "order_claim", + idOrFilter: id, + scope: req.scope, + fields: ["id", "return_id"], + }) const { result: orderPreview } = await removeItemReturnActionWorkflow( req.scope @@ -81,15 +83,15 @@ export const DELETE = async ( }, }) - const orderReturn = await refetchEntity( - "return", - { + const orderReturn = await refetchEntity({ + entity: "return", + idOrFilter: { ...req.filterableFields, id, }, - req.scope, - defaultAdminDetailsReturnFields - ) + scope: req.scope, + fields: defaultAdminDetailsReturnFields, + }) res.json({ order_preview: orderPreview as unknown as HttpTypes.AdminOrderPreview, diff --git a/packages/medusa/src/api/admin/claims/[id]/route.ts b/packages/medusa/src/api/admin/claims/[id]/route.ts index 9fd5c72282..ee18e84e57 100644 --- a/packages/medusa/src/api/admin/claims/[id]/route.ts +++ b/packages/medusa/src/api/admin/claims/[id]/route.ts @@ -10,12 +10,12 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const claim = await refetchEntity( - "order_claim", - req.params.id, - req.scope, - req.queryConfig.fields - ) + const claim = await refetchEntity({ + entity: "order_claim", + idOrFilter: req.params.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) if (!claim) { throw new MedusaError( diff --git a/packages/medusa/src/api/admin/exchanges/[id]/inbound/items/[action_id]/route.ts b/packages/medusa/src/api/admin/exchanges/[id]/inbound/items/[action_id]/route.ts index b0d7405a32..211ae75b88 100644 --- a/packages/medusa/src/api/admin/exchanges/[id]/inbound/items/[action_id]/route.ts +++ b/packages/medusa/src/api/admin/exchanges/[id]/inbound/items/[action_id]/route.ts @@ -69,9 +69,12 @@ export const DELETE = async ( const { id, action_id } = req.params - const exchange = await refetchEntity("order_exchange", id, req.scope, [ - "return_id", - ]) + const exchange = await refetchEntity({ + entity: "order_exchange", + idOrFilter: id, + scope: req.scope, + fields: ["return_id"], + }) const { result: orderPreview } = await removeItemReturnActionWorkflow( req.scope diff --git a/packages/medusa/src/api/admin/exchanges/[id]/route.ts b/packages/medusa/src/api/admin/exchanges/[id]/route.ts index 95b7d5c118..5ee3f8a3ca 100644 --- a/packages/medusa/src/api/admin/exchanges/[id]/route.ts +++ b/packages/medusa/src/api/admin/exchanges/[id]/route.ts @@ -10,12 +10,12 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const exchange = await refetchEntity( - "order_exchange", - req.params.id, - req.scope, - req.queryConfig.fields - ) + const exchange = await refetchEntity({ + entity: "order_exchange", + idOrFilter: req.params.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) if (!exchange) { throw new MedusaError( diff --git a/packages/medusa/src/api/admin/notifications/[id]/route.ts b/packages/medusa/src/api/admin/notifications/[id]/route.ts index 42909a810a..19f8f66c0f 100644 --- a/packages/medusa/src/api/admin/notifications/[id]/route.ts +++ b/packages/medusa/src/api/admin/notifications/[id]/route.ts @@ -10,11 +10,12 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const notification = await refetchEntity( - "notification", - req.params.id, - req.scope, - req.queryConfig.fields - ) + const notification = await refetchEntity({ + entity: "notification", + idOrFilter: req.params.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) + res.status(200).json({ notification }) } diff --git a/packages/medusa/src/api/admin/notifications/route.ts b/packages/medusa/src/api/admin/notifications/route.ts index 76f51801c9..15cd2e607c 100644 --- a/packages/medusa/src/api/admin/notifications/route.ts +++ b/packages/medusa/src/api/admin/notifications/route.ts @@ -9,13 +9,13 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const { rows: notifications, metadata } = await refetchEntities( - "notification", - req.filterableFields, - req.scope, - req.queryConfig.fields, - req.queryConfig.pagination - ) + const { data: notifications, metadata } = await refetchEntities({ + entity: "notification", + idOrFilter: req.filterableFields, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) res.json({ notifications, diff --git a/packages/medusa/src/api/admin/orders/[id]/fulfillments/[fulfillment_id]/mark-as-delivered/route.ts b/packages/medusa/src/api/admin/orders/[id]/fulfillments/[fulfillment_id]/mark-as-delivered/route.ts index a4e7667561..0d808f72b9 100644 --- a/packages/medusa/src/api/admin/orders/[id]/fulfillments/[fulfillment_id]/mark-as-delivered/route.ts +++ b/packages/medusa/src/api/admin/orders/[id]/fulfillments/[fulfillment_id]/mark-as-delivered/route.ts @@ -16,12 +16,12 @@ export const POST = async ( input: { orderId, fulfillmentId }, }) - const order = await refetchEntity( - "order", - orderId, - req.scope, - req.queryConfig.fields - ) + const order = await refetchEntity({ + entity: "order", + idOrFilter: orderId, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ order }) } diff --git a/packages/medusa/src/api/admin/payment-collections/[id]/mark-as-paid/route.ts b/packages/medusa/src/api/admin/payment-collections/[id]/mark-as-paid/route.ts index 1d7725dbc4..bf117ed2b3 100644 --- a/packages/medusa/src/api/admin/payment-collections/[id]/mark-as-paid/route.ts +++ b/packages/medusa/src/api/admin/payment-collections/[id]/mark-as-paid/route.ts @@ -21,12 +21,12 @@ export const POST = async ( }, }) - const paymentCollection = await refetchEntity( - "payment_collection", - id, - req.scope, - req.queryConfig.fields - ) + const paymentCollection = await refetchEntity({ + entity: "payment_collection", + idOrFilter: id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ payment_collection: paymentCollection }) } diff --git a/packages/medusa/src/api/admin/payment-collections/route.ts b/packages/medusa/src/api/admin/payment-collections/route.ts index 3968c1b68c..68c789a31f 100644 --- a/packages/medusa/src/api/admin/payment-collections/route.ts +++ b/packages/medusa/src/api/admin/payment-collections/route.ts @@ -15,12 +15,12 @@ export const POST = async ( input: req.body, }) - const paymentCollection = await refetchEntity( - "payment_collection", - result[0].id, - req.scope, - req.queryConfig.fields - ) + const paymentCollection = await refetchEntity({ + entity: "payment_collection", + idOrFilter: result[0].id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ payment_collection: paymentCollection }) } diff --git a/packages/medusa/src/api/admin/price-preferences/[id]/route.ts b/packages/medusa/src/api/admin/price-preferences/[id]/route.ts index 5333787227..75e6527ade 100644 --- a/packages/medusa/src/api/admin/price-preferences/[id]/route.ts +++ b/packages/medusa/src/api/admin/price-preferences/[id]/route.ts @@ -14,12 +14,12 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const price_preference = await refetchEntity( - "price_preference", - req.params.id, - req.scope, - req.queryConfig.fields - ) + const price_preference = await refetchEntity({ + entity: "price_preference", + idOrFilter: req.params.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ price_preference }) } @@ -35,12 +35,12 @@ export const POST = async ( input: { selector: { id: [id] }, update: req.body }, }) - const price_preference = await refetchEntity( - "price_preference", - id, - req.scope, - req.queryConfig.fields - ) + const price_preference = await refetchEntity({ + entity: "price_preference", + idOrFilter: id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ price_preference }) } diff --git a/packages/medusa/src/api/admin/price-preferences/route.ts b/packages/medusa/src/api/admin/price-preferences/route.ts index da521120f9..4e444fb137 100644 --- a/packages/medusa/src/api/admin/price-preferences/route.ts +++ b/packages/medusa/src/api/admin/price-preferences/route.ts @@ -11,13 +11,14 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const { rows: price_preferences, metadata } = await refetchEntities( - "price_preference", - req.filterableFields, - req.scope, - req.queryConfig.fields, - req.queryConfig.pagination - ) + const { data: price_preferences, metadata } = await refetchEntities({ + entity: "price_preference", + idOrFilter: req.filterableFields, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) + res.json({ price_preferences: price_preferences, count: metadata.count, @@ -35,12 +36,12 @@ export const POST = async ( input: [req.validatedBody], }) - const price_preference = await refetchEntity( - "price_preference", - result[0].id, - req.scope, - req.queryConfig.fields - ) + const price_preference = await refetchEntity({ + entity: "price_preference", + idOrFilter: result[0].id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ price_preference }) } diff --git a/packages/medusa/src/api/admin/product-categories/[id]/products/route.ts b/packages/medusa/src/api/admin/product-categories/[id]/products/route.ts index f21bb7997f..774eff6243 100644 --- a/packages/medusa/src/api/admin/product-categories/[id]/products/route.ts +++ b/packages/medusa/src/api/admin/product-categories/[id]/products/route.ts @@ -19,12 +19,12 @@ export const POST = async ( input: { id, ...req.validatedBody }, }) - const category = await refetchEntity( - "product_category", - id, - req.scope, - req.queryConfig.fields - ) + const category = await refetchEntity({ + entity: "product_category", + idOrFilter: id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ product_category: category }) } diff --git a/packages/medusa/src/api/admin/product-categories/[id]/route.ts b/packages/medusa/src/api/admin/product-categories/[id]/route.ts index afacf03574..4c7f8122be 100644 --- a/packages/medusa/src/api/admin/product-categories/[id]/route.ts +++ b/packages/medusa/src/api/admin/product-categories/[id]/route.ts @@ -21,12 +21,15 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const [category] = await refetchEntities( - "product_category", - { id: req.params.id, ...req.filterableFields }, - req.scope, - req.queryConfig.fields - ) + const { + data: [category], + } = await refetchEntities({ + entity: "product_category", + idOrFilter: { id: req.params.id, ...req.filterableFields }, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) if (!category) { throw new MedusaError( @@ -48,12 +51,15 @@ export const POST = async ( input: { selector: { id }, update: req.validatedBody }, }) - const [category] = await refetchEntities( - "product_category", - { id, ...req.filterableFields }, - req.scope, - req.queryConfig.fields - ) + const { + data: [category], + } = await refetchEntities({ + entity: "product_category", + idOrFilter: { id, ...req.filterableFields }, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) res.status(200).json({ product_category: category }) } diff --git a/packages/medusa/src/api/admin/product-categories/route.ts b/packages/medusa/src/api/admin/product-categories/route.ts index 91e9d7e2f9..1f90757f4a 100644 --- a/packages/medusa/src/api/admin/product-categories/route.ts +++ b/packages/medusa/src/api/admin/product-categories/route.ts @@ -10,13 +10,13 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const { rows: product_categories, metadata } = await refetchEntities( - "product_category", - req.filterableFields, - req.scope, - req.queryConfig.fields, - req.queryConfig.pagination - ) + const { data: product_categories, metadata } = await refetchEntities({ + entity: "product_category", + idOrFilter: req.filterableFields, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) res.json({ product_categories, @@ -34,12 +34,15 @@ export const POST = async ( input: { product_categories: [req.validatedBody] }, }) - const [category] = await refetchEntities( - "product_category", - { id: result[0].id, ...req.filterableFields }, - req.scope, - req.queryConfig.fields - ) + const { + data: [category], + } = await refetchEntities({ + entity: "product_category", + idOrFilter: { id: result[0].id, ...req.filterableFields }, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) res.status(200).json({ product_category: category }) } diff --git a/packages/medusa/src/api/admin/product-tags/[id]/route.ts b/packages/medusa/src/api/admin/product-tags/[id]/route.ts index f657ba1c50..d264c08010 100644 --- a/packages/medusa/src/api/admin/product-tags/[id]/route.ts +++ b/packages/medusa/src/api/admin/product-tags/[id]/route.ts @@ -19,12 +19,12 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const productTag = await refetchEntity( - "product_tag", - req.params.id, - req.scope, - req.queryConfig.fields - ) + const productTag = await refetchEntity({ + entity: "product_tag", + idOrFilter: req.params.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ product_tag: productTag }) } @@ -33,12 +33,12 @@ export const POST = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const existingProductTag = await refetchEntity( - "product_tag", - req.params.id, - req.scope, - ["id"] - ) + const existingProductTag = await refetchEntity({ + entity: "product_tag", + idOrFilter: req.params.id, + scope: req.scope, + fields: ["id"], + }) if (!existingProductTag) { throw new MedusaError( @@ -54,12 +54,12 @@ export const POST = async ( }, }) - const productTag = await refetchEntity( - "product_tag", - result[0].id, - req.scope, - req.queryConfig.fields - ) + const productTag = await refetchEntity({ + entity: "product_tag", + idOrFilter: result[0].id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ product_tag: productTag }) } diff --git a/packages/medusa/src/api/admin/product-tags/route.ts b/packages/medusa/src/api/admin/product-tags/route.ts index 5e8d671062..1caaafae60 100644 --- a/packages/medusa/src/api/admin/product-tags/route.ts +++ b/packages/medusa/src/api/admin/product-tags/route.ts @@ -12,13 +12,13 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const { rows: product_tags, metadata } = await refetchEntities( - "product_tag", - req.filterableFields, - req.scope, - req.queryConfig.fields, - req.queryConfig.pagination - ) + const { data: product_tags, metadata } = await refetchEntities({ + entity: "product_tag", + idOrFilter: req.filterableFields, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) res.json({ product_tags: product_tags, @@ -38,12 +38,12 @@ export const POST = async ( input: { product_tags: input }, }) - const productTag = await refetchEntity( - "product_tag", - result[0].id, - req.scope, - req.queryConfig.fields - ) + const productTag = await refetchEntity({ + entity: "product_tag", + idOrFilter: result[0].id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ product_tag: productTag }) } diff --git a/packages/medusa/src/api/admin/product-variants/route.ts b/packages/medusa/src/api/admin/product-variants/route.ts index 1b70c26f16..635dddf3e2 100644 --- a/packages/medusa/src/api/admin/product-variants/route.ts +++ b/packages/medusa/src/api/admin/product-variants/route.ts @@ -21,13 +21,13 @@ export const GET = async ( ) } - const { rows: variants, metadata } = await refetchEntities( - "variant", - { ...req.filterableFields }, - req.scope, - remapKeysForVariant(req.queryConfig.fields ?? []), - req.queryConfig.pagination - ) + const { data: variants, metadata } = await refetchEntities({ + entity: "variant", + idOrFilter: { ...req.filterableFields }, + scope: req.scope, + fields: remapKeysForVariant(req.queryConfig.fields ?? []), + pagination: req.queryConfig.pagination, + }) if (withInventoryQuantity) { await wrapVariantsWithTotalInventoryQuantity(req, variants || []) diff --git a/packages/medusa/src/api/admin/products/[id]/options/[option_id]/route.ts b/packages/medusa/src/api/admin/products/[id]/options/[option_id]/route.ts index 2d9aa05776..acfb3cbc24 100644 --- a/packages/medusa/src/api/admin/products/[id]/options/[option_id]/route.ts +++ b/packages/medusa/src/api/admin/products/[id]/options/[option_id]/route.ts @@ -17,12 +17,12 @@ export const GET = async ( ) => { const productId = req.params.id const optionId = req.params.option_id - const productOption = await refetchEntity( - "product_option", - { id: optionId, product_id: productId }, - req.scope, - req.queryConfig.fields - ) + const productOption = await refetchEntity({ + entity: "product_option", + idOrFilter: { id: optionId, product_id: productId }, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ product_option: productOption }) } @@ -45,12 +45,12 @@ export const POST = async ( }, }) - const product = await refetchEntity( - "product", - productId, - req.scope, - remapKeysForProduct(req.queryConfig.fields ?? []) - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: productId, + scope: req.scope, + fields: remapKeysForProduct(req.queryConfig.fields ?? []), + }) res.status(200).json({ product: remapProductResponse(product) }) } @@ -67,12 +67,12 @@ export const DELETE = async ( input: { ids: [optionId] /* product_id: productId */ }, }) - const product = await refetchEntity( - "product", - productId, - req.scope, - remapKeysForProduct(req.queryConfig.fields ?? []) - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: productId, + scope: req.scope, + fields: remapKeysForProduct(req.queryConfig.fields ?? []), + }) res.status(200).json({ id: optionId, diff --git a/packages/medusa/src/api/admin/products/[id]/options/route.ts b/packages/medusa/src/api/admin/products/[id]/options/route.ts index b552fe5322..f0ff19de0b 100644 --- a/packages/medusa/src/api/admin/products/[id]/options/route.ts +++ b/packages/medusa/src/api/admin/products/[id]/options/route.ts @@ -14,13 +14,13 @@ export const GET = async ( res: MedusaResponse ) => { const productId = req.params.id - const { rows: product_options, metadata } = await refetchEntities( - "product_option", - { ...req.filterableFields, product_id: productId }, - req.scope, - req.queryConfig.fields, - req.queryConfig.pagination - ) + const { data: product_options, metadata } = await refetchEntities({ + entity: "product_option", + idOrFilter: { ...req.filterableFields, product_id: productId }, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) res.json({ product_options, @@ -51,11 +51,11 @@ export const POST = async ( }, }) - const product = await refetchEntity( - "product", - productId, - req.scope, - remapKeysForProduct(req.queryConfig.fields ?? []) - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: productId, + scope: req.scope, + fields: remapKeysForProduct(req.queryConfig.fields ?? []), + }) res.status(200).json({ product: remapProductResponse(product) }) } diff --git a/packages/medusa/src/api/admin/products/[id]/route.ts b/packages/medusa/src/api/admin/products/[id]/route.ts index c39c0e63ed..c7afa11410 100644 --- a/packages/medusa/src/api/admin/products/[id]/route.ts +++ b/packages/medusa/src/api/admin/products/[id]/route.ts @@ -16,12 +16,12 @@ export const GET = async ( res: MedusaResponse ) => { const selectFields = remapKeysForProduct(req.queryConfig.fields ?? []) - const product = await refetchEntity( - "product", - req.params.id, - req.scope, - selectFields - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: req.params.id, + scope: req.scope, + fields: selectFields, + }) if (!product) { throw new MedusaError(MedusaError.Types.NOT_FOUND, "Product not found") @@ -38,12 +38,12 @@ export const POST = async ( ) => { const { additional_data, ...update } = req.validatedBody - const existingProduct = await refetchEntity( - "product", - req.params.id, - req.scope, - ["id"] - ) + const existingProduct = await refetchEntity({ + entity: "product", + idOrFilter: req.params.id, + scope: req.scope, + fields: ["id"], + }) /** * Check if the product exists with the id or not before calling the workflow. */ @@ -62,12 +62,12 @@ export const POST = async ( }, }) - const product = await refetchEntity( - "product", - result[0].id, - req.scope, - remapKeysForProduct(req.queryConfig.fields ?? []) - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: result[0].id, + scope: req.scope, + fields: remapKeysForProduct(req.queryConfig.fields ?? []), + }) res.status(200).json({ product: remapProductResponse(product) }) } diff --git a/packages/medusa/src/api/admin/products/[id]/variants/[variant_id]/route.ts b/packages/medusa/src/api/admin/products/[id]/variants/[variant_id]/route.ts index 91a819fde5..a4f9ac399b 100644 --- a/packages/medusa/src/api/admin/products/[id]/variants/[variant_id]/route.ts +++ b/packages/medusa/src/api/admin/products/[id]/variants/[variant_id]/route.ts @@ -24,12 +24,12 @@ export const GET = async ( const variantId = req.params.variant_id const variables = { id: variantId, product_id: productId } - const variant = await refetchEntity( - "variant", - variables, - req.scope, - remapKeysForVariant(req.queryConfig.fields ?? []) - ) + const variant = await refetchEntity({ + entity: "variant", + idOrFilter: variables, + scope: req.scope, + fields: remapKeysForVariant(req.queryConfig.fields ?? []), + }) res.status(200).json({ variant: remapVariantResponse(variant) }) } @@ -52,12 +52,12 @@ export const POST = async ( }, }) - const product = await refetchEntity( - "product", - productId, - req.scope, - remapKeysForProduct(req.queryConfig.fields ?? []) - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: productId, + scope: req.scope, + fields: remapKeysForProduct(req.queryConfig.fields ?? []), + }) res.status(200).json({ product: remapProductResponse(product) }) } @@ -74,12 +74,12 @@ export const DELETE = async ( input: { ids: [variantId] /* product_id: productId */ }, }) - const product = await refetchEntity( - "product", - productId, - req.scope, - remapKeysForProduct(req.queryConfig.fields ?? []) - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: productId, + scope: req.scope, + fields: remapKeysForProduct(req.queryConfig.fields ?? []), + }) res.status(200).json({ id: variantId, diff --git a/packages/medusa/src/api/admin/products/[id]/variants/route.ts b/packages/medusa/src/api/admin/products/[id]/variants/route.ts index 830d0a651a..e2411f74a4 100644 --- a/packages/medusa/src/api/admin/products/[id]/variants/route.ts +++ b/packages/medusa/src/api/admin/products/[id]/variants/route.ts @@ -29,13 +29,13 @@ export const GET = async ( ) } - const { rows: variants, metadata } = await refetchEntities( - "variant", - { ...req.filterableFields, product_id: productId }, - req.scope, - remapKeysForVariant(req.queryConfig.fields ?? []), - req.queryConfig.pagination - ) + const { data: variants, metadata } = await refetchEntities({ + entity: "variant", + idOrFilter: { ...req.filterableFields, product_id: productId }, + scope: req.scope, + fields: remapKeysForVariant(req.queryConfig.fields ?? []), + pagination: req.queryConfig.pagination, + }) if (withInventoryQuantity) { await wrapVariantsWithTotalInventoryQuantity(req, variants || []) @@ -69,12 +69,12 @@ export const POST = async ( input: { product_variants: input, additional_data }, }) - const product = await refetchEntity( - "product", - productId, - req.scope, - remapKeysForProduct(req.queryConfig.fields ?? []) - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: productId, + scope: req.scope, + fields: remapKeysForProduct(req.queryConfig.fields ?? []), + }) res.status(200).json({ product: remapProductResponse(product) }) } diff --git a/packages/medusa/src/api/admin/products/route.ts b/packages/medusa/src/api/admin/products/route.ts index 7ff8a4954e..127f79d299 100644 --- a/packages/medusa/src/api/admin/products/route.ts +++ b/packages/medusa/src/api/admin/products/route.ts @@ -41,14 +41,14 @@ async function getProducts( ) { const selectFields = remapKeysForProduct(req.queryConfig.fields ?? []) - const { rows: products, metadata } = await refetchEntities( - "product", - req.filterableFields, - req.scope, - selectFields, - req.queryConfig.pagination, - req.queryConfig.withDeleted - ) + const { data: products, metadata } = await refetchEntities({ + entity: "product", + idOrFilter: req.filterableFields, + scope: req.scope, + fields: selectFields, + pagination: req.queryConfig.pagination, + withDeleted: req.queryConfig.withDeleted, + }) res.json({ products: products.map(remapProductResponse), @@ -103,12 +103,12 @@ export const POST = async ( input: { products: [products], additional_data }, }) - const product = await refetchEntity( - "product", - result[0].id, - req.scope, - remapKeysForProduct(req.queryConfig.fields ?? []) - ) + const product = await refetchEntity({ + entity: "product", + idOrFilter: result[0].id, + scope: req.scope, + fields: remapKeysForProduct(req.queryConfig.fields ?? []), + }) res.status(200).json({ product: remapProductResponse(product) }) } diff --git a/packages/medusa/src/api/admin/refund-reasons/[id]/route.ts b/packages/medusa/src/api/admin/refund-reasons/[id]/route.ts index 4f32700992..6629cbd4df 100644 --- a/packages/medusa/src/api/admin/refund-reasons/[id]/route.ts +++ b/packages/medusa/src/api/admin/refund-reasons/[id]/route.ts @@ -14,12 +14,12 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const refund_reason = await refetchEntity( - "refund_reason", - req.params.id, - req.scope, - req.queryConfig.fields - ) + const refund_reason = await refetchEntity({ + entity: "refund_reason", + idOrFilter: req.params.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.json({ refund_reason }) } @@ -39,12 +39,12 @@ export const POST = async ( ], }) - const refund_reason = await refetchEntity( - "refund_reason", - req.params.id, - req.scope, - req.queryConfig.fields - ) + const refund_reason = await refetchEntity({ + entity: "refund_reason", + idOrFilter: req.params.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.json({ refund_reason }) } diff --git a/packages/medusa/src/api/admin/refund-reasons/route.ts b/packages/medusa/src/api/admin/refund-reasons/route.ts index 2b296b918a..68e988841b 100644 --- a/packages/medusa/src/api/admin/refund-reasons/route.ts +++ b/packages/medusa/src/api/admin/refund-reasons/route.ts @@ -17,13 +17,13 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse> ) => { - const { rows: refund_reasons, metadata } = await refetchEntities( - "refund_reasons", - req.filterableFields, - req.scope, - req.queryConfig.fields, - req.queryConfig.pagination - ) + const { data: refund_reasons, metadata } = await refetchEntities({ + entity: "refund_reasons", + idOrFilter: req.filterableFields, + scope: req.scope, + fields: req.queryConfig.fields, + pagination: req.queryConfig.pagination, + }) res.json({ refund_reasons, @@ -43,12 +43,12 @@ export const POST = async ( input: { data: [req.validatedBody] }, }) - const refund_reason = await refetchEntity( - "refund_reason", - refundReason.id, - req.scope, - req.queryConfig.fields - ) + const refund_reason = await refetchEntity({ + entity: "refund_reason", + idOrFilter: refundReason.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) res.status(200).json({ refund_reason }) } diff --git a/packages/medusa/src/api/admin/return-reasons/[id]/route.ts b/packages/medusa/src/api/admin/return-reasons/[id]/route.ts index 893e55a310..816be19f3d 100644 --- a/packages/medusa/src/api/admin/return-reasons/[id]/route.ts +++ b/packages/medusa/src/api/admin/return-reasons/[id]/route.ts @@ -18,12 +18,12 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const return_reason = await refetchEntity( - "return_reason", - req.params.id, - req.scope, - req.queryConfig.fields - ) + const return_reason = await refetchEntity({ + entity: "return_reason", + idOrFilter: req.params.id, + scope: req.scope, + fields: req.queryConfig.fields, + }) if (!return_reason) { throw new MedusaError( diff --git a/packages/medusa/src/api/store/orders/helpers.ts b/packages/medusa/src/api/store/orders/helpers.ts index 0224cdce90..32740b2f36 100644 --- a/packages/medusa/src/api/store/orders/helpers.ts +++ b/packages/medusa/src/api/store/orders/helpers.ts @@ -6,5 +6,5 @@ export const refetchOrder = async ( scope: MedusaContainer, fields: string[] ) => { - return await refetchEntity("order", idOrFilter, scope, fields) + return await refetchEntity({ entity: "order", idOrFilter, scope, fields }) } diff --git a/packages/medusa/src/api/store/payment-collections/helpers.ts b/packages/medusa/src/api/store/payment-collections/helpers.ts index 3f5e7ce664..a5d6bf6387 100644 --- a/packages/medusa/src/api/store/payment-collections/helpers.ts +++ b/packages/medusa/src/api/store/payment-collections/helpers.ts @@ -9,5 +9,10 @@ export const refetchPaymentCollection = async ( scope: MedusaContainer, fields: string[] ): Promise => { - return refetchEntity("payment_collection", id, scope, fields) + return refetchEntity({ + entity: "payment_collection", + idOrFilter: id, + scope, + fields, + }) } diff --git a/packages/medusa/src/api/store/product-categories/[id]/route.ts b/packages/medusa/src/api/store/product-categories/[id]/route.ts index c08cbd9e67..c8ca9f674a 100644 --- a/packages/medusa/src/api/store/product-categories/[id]/route.ts +++ b/packages/medusa/src/api/store/product-categories/[id]/route.ts @@ -11,12 +11,12 @@ export const GET = async ( req: AuthenticatedMedusaRequest, res: MedusaResponse ) => { - const category = await refetchEntity( - "product_category", - { id: req.params.id, ...req.filterableFields }, - req.scope, - req.queryConfig.fields - ) + const category = await refetchEntity({ + entity: "product_category", + idOrFilter: { id: req.params.id, ...req.filterableFields }, + scope: req.scope, + fields: req.queryConfig.fields, + }) if (!category) { throw new MedusaError( diff --git a/packages/medusa/src/api/store/products/[id]/route.ts b/packages/medusa/src/api/store/products/[id]/route.ts index 4cf22bc7b2..38c9e9caee 100644 --- a/packages/medusa/src/api/store/products/[id]/route.ts +++ b/packages/medusa/src/api/store/products/[id]/route.ts @@ -1,5 +1,6 @@ -import { isPresent, MedusaError } from "@medusajs/framework/utils" import { MedusaResponse } from "@medusajs/framework/http" +import { HttpTypes } from "@medusajs/framework/types" +import { isPresent, MedusaError, QueryContext } from "@medusajs/framework/utils" import { wrapVariantsWithInventoryQuantityForSalesChannel } from "../../../utils/middlewares" import { filterOutInternalProductCategories, @@ -7,7 +8,6 @@ import { RequestWithContext, wrapProductsWithTaxPrices, } from "../helpers" -import { HttpTypes } from "@medusajs/framework/types" export const GET = async ( req: RequestWithContext, @@ -29,9 +29,11 @@ export const GET = async ( } if (isPresent(req.pricingContext)) { - filters["context"] = { - "variants.calculated_price": { context: req.pricingContext }, - } + filters["context"] ??= {} + filters["context"]["variants"] ??= {} + filters["context"]["variants"]["calculated_price"] ??= QueryContext( + req.pricingContext! + ) } const includesCategoriesField = req.queryConfig.fields.some((field) => diff --git a/packages/medusa/src/api/store/products/helpers.ts b/packages/medusa/src/api/store/products/helpers.ts index c2f2dbd9f8..f66f529ae2 100644 --- a/packages/medusa/src/api/store/products/helpers.ts +++ b/packages/medusa/src/api/store/products/helpers.ts @@ -25,7 +25,7 @@ export const refetchProduct = async ( scope: MedusaContainer, fields: string[] ) => { - return await refetchEntity("product", idOrFilter, scope, fields) + return await refetchEntity({ entity: "product", idOrFilter, scope, fields }) } export const filterOutInternalProductCategories = ( diff --git a/packages/medusa/src/api/store/products/route.ts b/packages/medusa/src/api/store/products/route.ts index d99d18f4cb..edf17588de 100644 --- a/packages/medusa/src/api/store/products/route.ts +++ b/packages/medusa/src/api/store/products/route.ts @@ -5,7 +5,6 @@ import { FeatureFlag, isPresent, QueryContext, - remoteQueryObjectFromString, } from "@medusajs/framework/utils" import IndexEngineFeatureFlag from "../../../feature-flags/index-engine" import { wrapVariantsWithInventoryQuantityForSalesChannel } from "../../utils/middlewares" @@ -49,7 +48,9 @@ async function getProductsWithIndexEngine( if (isPresent(req.pricingContext)) { context["variants"] ??= {} - context["variants"]["calculated_price"] = QueryContext(req.pricingContext!) + context["variants"]["calculated_price"] ??= QueryContext( + req.pricingContext! + ) } const filters: Record = req.filterableFields @@ -62,13 +63,20 @@ async function getProductsWithIndexEngine( delete filters.sales_channel_id } - const { data: products = [], metadata } = await query.index({ - entity: "product", - fields: req.queryConfig.fields, - filters, - pagination: req.queryConfig.pagination, - context, - }) + const { data: products = [], metadata } = await query.index( + { + entity: "product", + fields: req.queryConfig.fields, + filters, + pagination: req.queryConfig.pagination, + context, + }, + { + cache: { + enable: true, + }, + } + ) if (withInventoryQuantity) { await wrapVariantsWithInventoryQuantityForSalesChannel( @@ -91,7 +99,7 @@ async function getProducts( req: RequestWithContext, res: MedusaResponse ) { - const remoteQuery = req.scope.resolve(ContainerRegistrationKeys.REMOTE_QUERY) + const query = req.scope.resolve(ContainerRegistrationKeys.QUERY) const context: object = {} const withInventoryQuantity = req.queryConfig.fields.some((field) => field.includes("variants.inventory_quantity") @@ -104,22 +112,26 @@ async function getProducts( } if (isPresent(req.pricingContext)) { - context["variants.calculated_price"] = { - context: req.pricingContext, - } + context["variants"] ??= {} + context["variants"]["calculated_price"] ??= QueryContext( + req.pricingContext! + ) } - const queryObject = remoteQueryObjectFromString({ - entryPoint: "product", - variables: { + const { data: products = [], metadata } = await query.graph( + { + entity: "product", + fields: req.queryConfig.fields, filters: req.filterableFields, - ...req.queryConfig.pagination, - ...context, + pagination: req.queryConfig.pagination, + context, }, - fields: req.queryConfig.fields, - }) - - const { rows: products, metadata } = await remoteQuery(queryObject) + { + cache: { + enable: true, + }, + } + ) if (withInventoryQuantity) { await wrapVariantsWithInventoryQuantityForSalesChannel( @@ -131,8 +143,8 @@ async function getProducts( await wrapProductsWithTaxPrices(req, products) res.json({ products, - count: metadata.count, - offset: metadata.skip, - limit: metadata.take, + count: metadata!.count, + offset: metadata!.skip, + limit: metadata!.take, }) } diff --git a/packages/medusa/src/api/utils/middlewares/products/normalize-data-for-context.ts b/packages/medusa/src/api/utils/middlewares/products/normalize-data-for-context.ts index 9d583d1745..18e1f2839e 100644 --- a/packages/medusa/src/api/utils/middlewares/products/normalize-data-for-context.ts +++ b/packages/medusa/src/api/utils/middlewares/products/normalize-data-for-context.ts @@ -39,12 +39,12 @@ export function normalizeDataForContext() { // If the cart is passed, get the information from it if (req.filterableFields.cart_id) { - const cart = await refetchEntity( - "cart", - req.filterableFields.cart_id, - req.scope, - ["region_id", "shipping_address.*"] - ) + const cart = await refetchEntity({ + entity: "cart", + idOrFilter: req.filterableFields.cart_id, + scope: req.scope, + fields: ["region_id", "shipping_address.*"], + }) if (cart?.region_id) { regionId = cart.region_id @@ -58,9 +58,16 @@ export function normalizeDataForContext() { // Finally, try to get it from the store defaults if not available if (!regionId) { - const stores = await refetchEntities("store", {}, req.scope, [ - "default_region_id", - ]) + const stores = await refetchEntities({ + entity: "store", + scope: req.scope, + fields: ["id", "default_region_id"], + options: { + cache: { + enable: true, + }, + }, + }) regionId = stores[0]?.default_region_id } diff --git a/packages/medusa/src/api/utils/middlewares/products/set-pricing-context.ts b/packages/medusa/src/api/utils/middlewares/products/set-pricing-context.ts index 98b1a12bb5..1ee14c3c7d 100644 --- a/packages/medusa/src/api/utils/middlewares/products/set-pricing-context.ts +++ b/packages/medusa/src/api/utils/middlewares/products/set-pricing-context.ts @@ -17,12 +17,17 @@ export function setPricingContext() { } // We validate the region ID in the previous middleware - const region = await refetchEntity( - "region", - req.filterableFields.region_id!, - req.scope, - ["id", "currency_code"] - ) + const region = await refetchEntity({ + entity: "region", + idOrFilter: req.filterableFields.region_id!, + scope: req.scope, + fields: ["id", "currency_code"], + options: { + cache: { + enable: true, + }, + }, + }) if (!region) { try { @@ -42,12 +47,12 @@ export function setPricingContext() { // Find all the customer groups the customer is a part of and set if (req.auth_context?.actor_id) { - const customerGroups = await refetchEntities( - "customer_group", - { customers: { id: req.auth_context.actor_id } }, - req.scope, - ["id"] - ) + const { data: customerGroups } = await refetchEntities({ + entity: "customer_group", + idOrFilter: { customers: { id: req.auth_context.actor_id } }, + scope: req.scope, + fields: ["id"], + }) pricingContext.customer = { groups: [] } customerGroups.map((cg) => diff --git a/packages/medusa/src/api/utils/middlewares/products/set-tax-context.ts b/packages/medusa/src/api/utils/middlewares/products/set-tax-context.ts index 3c76c8f4c3..693ee54339 100644 --- a/packages/medusa/src/api/utils/middlewares/products/set-tax-context.ts +++ b/packages/medusa/src/api/utils/middlewares/products/set-tax-context.ts @@ -38,12 +38,12 @@ export function setTaxContext() { } const getTaxInclusivityInfo = async (req: MedusaRequest) => { - const region = await refetchEntity( - "region", - req.filterableFields.region_id as string, - req.scope, - ["automatic_taxes"] - ) + const region = await refetchEntity({ + entity: "region", + idOrFilter: req.filterableFields.region_id as string, + scope: req.scope, + fields: ["automatic_taxes"], + }) if (!region) { throw new MedusaError( diff --git a/packages/medusa/src/feature-flags/caching.ts b/packages/medusa/src/feature-flags/caching.ts new file mode 100644 index 0000000000..04eb21182a --- /dev/null +++ b/packages/medusa/src/feature-flags/caching.ts @@ -0,0 +1,10 @@ +import { FlagSettings } from "@medusajs/framework/feature-flags" + +const CachingFeatureFlag: FlagSettings = { + key: "caching", + default_val: false, + env_key: "MEDUSA_FF_CACHING", + description: "[WIP] Enable core caching where applicable", +} + +export default CachingFeatureFlag diff --git a/packages/medusa/src/instrumentation/index.ts b/packages/medusa/src/instrumentation/index.ts index 04e2bebd01..92c6e955f4 100644 --- a/packages/medusa/src/instrumentation/index.ts +++ b/packages/medusa/src/instrumentation/index.ts @@ -1,4 +1,3 @@ -import { snakeCase } from "lodash" import { MedusaNextFunction, MedusaRequest, @@ -6,10 +5,15 @@ import { Query, } from "@medusajs/framework" import { ApiLoader } from "@medusajs/framework/http" -import { Tracer } from "@medusajs/framework/telemetry" -import type { SpanExporter } from "@opentelemetry/sdk-trace-node" -import type { NodeSDKConfiguration } from "@opentelemetry/sdk-node" import { TransactionOrchestrator } from "@medusajs/framework/orchestration" +import { Tracer } from "@medusajs/framework/telemetry" +import { FeatureFlag } from "@medusajs/framework/utils" +import { SpanStatusCode } from "@opentelemetry/api" +import type { NodeSDKConfiguration } from "@opentelemetry/sdk-node" +import type { SpanExporter } from "@opentelemetry/sdk-trace-node" +import { snakeCase } from "lodash" +import CacheModule from "../modules/caching" +import { ICachingModuleService } from "@medusajs/framework/types" const EXCLUDED_RESOURCES = [".vite", "virtual:"] @@ -261,6 +265,110 @@ export function instrumentWorkflows() { } } +export function instrumentCache() { + if (!FeatureFlag.isFeatureEnabled("caching")) { + return + } + + const CacheTracer = new Tracer("@medusajs/caching", "2.0.0") + const cacheModule_ = CacheModule as unknown as { + service: ICachingModuleService & { + traceGet: ( + cacheGetFn: () => Promise, + key: string, + tags: string[] + ) => Promise + traceSet: ( + cacheSetFn: () => Promise, + key: string, + tags: string[], + options: { autoInvalidate?: boolean } + ) => Promise + traceClear: ( + cacheClearFn: () => Promise, + key: string, + tags: string[], + options: { autoInvalidate?: boolean } + ) => Promise + } + } + + cacheModule_.service.traceGet = async function (cacheGetFn, key, tags) { + return await CacheTracer.trace(`cache.get`, async (span) => { + span.setAttributes({ + "cache.key": key, + "cache.tags": tags, + }) + + try { + return await cacheGetFn() + } catch (error) { + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message, + }) + throw error + } finally { + span.end() + } + }) + } + + cacheModule_.service.traceSet = async function ( + cacheSetFn, + key, + tags, + options = {} + ) { + return await CacheTracer.trace(`cache.set`, async (span) => { + span.setAttributes({ + "cache.key": key, + "cache.tags": tags, + "cache.options": JSON.stringify(options), + }) + + try { + return await cacheSetFn() + } catch (error) { + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message, + }) + throw error + } finally { + span.end() + } + }) + } + + cacheModule_.service.traceClear = async function ( + cacheClearFn, + key, + tags, + options = {} + ) { + return await CacheTracer.trace(`cache.clear`, async (span) => { + span.setAttributes({ + "cache.key": key, + "cache.tags": tags, + "cache.options": JSON.stringify(options), + }) + + try { + return await cacheClearFn() + } catch (error) { + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error.message, + }) + throw error + } finally { + span.end() + } + }) + } +} + /** * A helper function to configure the OpenTelemetry SDK with some defaults. * For better/more control, please configure the SDK manually. @@ -283,6 +391,7 @@ export function registerOtel( query: boolean workflows: boolean db: boolean + cache: boolean }> } ) { @@ -318,6 +427,9 @@ export function registerOtel( if (instrument.workflows) { instrumentWorkflows() } + if (instrument.cache) { + instrumentCache() + } const sdk = new NodeSDK({ serviceName, diff --git a/packages/medusa/src/modules/caching-redis.ts b/packages/medusa/src/modules/caching-redis.ts new file mode 100644 index 0000000000..8b6f9334a8 --- /dev/null +++ b/packages/medusa/src/modules/caching-redis.ts @@ -0,0 +1,6 @@ +import RedisCachingProvider from "@medusajs/caching-redis" + +export * from "@medusajs/caching-redis" + +export default RedisCachingProvider +export const discoveryPath = require.resolve("@medusajs/caching-redis") diff --git a/packages/medusa/src/modules/caching.ts b/packages/medusa/src/modules/caching.ts new file mode 100644 index 0000000000..06c7ec0084 --- /dev/null +++ b/packages/medusa/src/modules/caching.ts @@ -0,0 +1,6 @@ +import CacheModule from "@medusajs/caching" + +export * from "@medusajs/caching" + +export default CacheModule +export const discoveryPath = require.resolve("@medusajs/caching") diff --git a/packages/modules/caching/.gitignore b/packages/modules/caching/.gitignore new file mode 100644 index 0000000000..874c6c69d3 --- /dev/null +++ b/packages/modules/caching/.gitignore @@ -0,0 +1,6 @@ +/dist +node_modules +.DS_store +.env* +.env +*.sql diff --git a/packages/modules/caching/CHANGELOG.md b/packages/modules/caching/CHANGELOG.md new file mode 100644 index 0000000000..1396da4e39 --- /dev/null +++ b/packages/modules/caching/CHANGELOG.md @@ -0,0 +1 @@ +# @medusajs/caching diff --git a/packages/modules/caching/integration-tests/__fixtures__/event-bus-mock.ts b/packages/modules/caching/integration-tests/__fixtures__/event-bus-mock.ts new file mode 100644 index 0000000000..da364786a4 --- /dev/null +++ b/packages/modules/caching/integration-tests/__fixtures__/event-bus-mock.ts @@ -0,0 +1,51 @@ +import { + EventBusTypes, + IEventBusModuleService, + Message, + Subscriber, +} from "@medusajs/types" + +export class EventBusServiceMock implements IEventBusModuleService { + protected readonly subscribers_: Map> = + new Map() + + async emit( + messages: Message | Message[], + options?: Record + ): Promise { + const messages_ = Array.isArray(messages) ? messages : [messages] + + for (const message of messages_) { + const subscribers = this.subscribers_.get(message.name) + const starSubscribers = this.subscribers_.get("*") + + for (const subscriber of [ + ...(subscribers ?? []), + ...(starSubscribers ?? []), + ]) { + const { options, ...payload } = message + await subscriber(payload) + } + } + } + + subscribe(event: string | symbol, subscriber: Subscriber): this { + this.subscribers_.set(event, new Set([subscriber])) + return this + } + + unsubscribe( + event: string | symbol, + subscriber: Subscriber, + context?: EventBusTypes.SubscriberContext + ): this { + return this + } + + releaseGroupedEvents(eventGroupId: string): Promise { + throw new Error("Method not implemented.") + } + clearGroupedEvents(eventGroupId: string): Promise { + throw new Error("Method not implemented.") + } +} diff --git a/packages/modules/caching/integration-tests/__tests__/index.spec.ts b/packages/modules/caching/integration-tests/__tests__/index.spec.ts new file mode 100644 index 0000000000..898c055255 --- /dev/null +++ b/packages/modules/caching/integration-tests/__tests__/index.spec.ts @@ -0,0 +1,336 @@ +import { Modules } from "@medusajs/framework/utils" +import { moduleIntegrationTestRunner } from "@medusajs/test-utils" +import { ICachingModuleService } from "@medusajs/framework/types" +import { MedusaModule } from "@medusajs/framework/modules-sdk" + +jest.setTimeout(10000) + +jest.spyOn(MedusaModule, "getAllJoinerConfigs").mockReturnValue([ + { + schema: ` + type Product { + id: ID + title: String + handle: String + status: String + type_id: String + collection_id: String + is_giftcard: Boolean + external_id: String + created_at: DateTime + updated_at: DateTime + + variants: [ProductVariant] + sales_channels: [SalesChannel] + } + + type ProductVariant { + id: ID + product_id: String + sku: String + + prices: [Price] + } + + type Price { + id: ID + amount: Float + currency_code: String + } + + type SalesChannel { + id: ID + is_disabled: Boolean + } +`, + }, +]) + +moduleIntegrationTestRunner({ + moduleName: Modules.CACHING, + testSuite: ({ service }) => { + describe("Caching Module Service", () => { + beforeEach(async () => { + await service.clear({ tags: ["*"] }).catch(() => {}) + }) + + describe("Basic Cache Operations", () => { + it("should set and get cache data with default memory provider", async () => { + const testData = { id: "test-id", name: "Test Item" } + + await service.set({ + key: "test-key", + data: testData, + ttl: 3600, + }) + + const result = await service.get({ key: "test-key" }) + expect(result).toEqual(testData) + }) + + it("should return null for non-existent keys", async () => { + const result = await service.get({ key: "non-existent" }) + expect(result).toBeNull() + }) + + it("should handle tags-based storage and retrieval", async () => { + const testData1 = { id: "1", name: "Item 1" } + const testData2 = { id: "2", name: "Item 2" } + + await service.set({ + key: "item-1", + data: testData1, + tags: ["product", "active"], + }) + + await service.set({ + key: "item-2", + data: testData2, + tags: ["product", "inactive"], + }) + + const productResults = await service.get({ tags: ["product"] }) + expect(productResults).toHaveLength(2) + expect(productResults).toContainEqual(testData1) + expect(productResults).toContainEqual(testData2) + + const activeResults = await service.get({ tags: ["active"] }) + expect(activeResults).toHaveLength(1) + expect(activeResults?.[0]).toEqual(testData1) + }) + + it("should clear cache by key", async () => { + await service.set({ + key: "test-key", + data: { value: "test" }, + }) + + await service.clear({ key: "test-key" }) + + const result = await service.get({ key: "test-key" }) + expect(result).toBeNull() + }) + + it("should clear cache by tags", async () => { + await service.set({ + key: "item-1", + data: { id: "1" }, + tags: ["category-a"], + }) + + await service.set({ + key: "item-2", + data: { id: "2" }, + tags: ["category-b"], + }) + + await service.clear({ tags: ["category-a"] }) + + const result1 = await service.get({ key: "item-1" }) + const result2 = await service.get({ key: "item-2" }) + + expect(result1).toBeNull() + expect(result2).toEqual({ id: "2" }) + }) + }) + + describe("Provider Priority", () => { + it("should check providers in order of priority when specified", async () => { + const testData = { id: "priority-test", name: "Priority Test" } + + await service.set({ + key: "priority-key", + data: testData, + providers: ["cache-memory"], + }) + + const result = await service.get({ + key: "priority-key", + providers: ["cache-memory"], + }) + + expect(result).toEqual(testData) + }) + + it("should return null when providers array is empty or invalid", async () => { + const result = await service.get({ + key: "test-key", + providers: [], + }) + + expect(result).toBeNull() + }) + }) + + describe("Promise Deduplication", () => { + it("should deduplicate concurrent get requests with same parameters", async () => { + const testData = { id: "concurrent-test", name: "Concurrent Test" } + + await service.set({ + key: "concurrent-key", + data: testData, + }) + + const promises = Array.from({ length: 5 }, () => + service.get({ key: "concurrent-key" }) + ) + + const results = await Promise.all(promises) + + results.forEach((result) => { + expect(result).toEqual(testData) + }) + }) + + it("should deduplicate concurrent get requests with same tags", async () => { + const testData = { id: "tag-test", name: "Tag Test" } + + await service.set({ + key: "tag-key", + data: testData, + tags: ["concurrent-tag"], + }) + + const promises = Array.from({ length: 5 }, () => + service.get({ tags: ["concurrent-tag"] }) + ) + + const results = await Promise.all(promises) + + results.forEach((result) => { + expect(result).toHaveLength(1) + expect(result?.[0]).toEqual(testData) + }) + }) + + it("should deduplicate concurrent clear requests", async () => { + await service.set({ + key: "clear-test-1", + data: { id: "1" }, + tags: ["clear-tag"], + }) + + await service.set({ + key: "clear-test-2", + data: { id: "2" }, + tags: ["clear-tag"], + }) + + const promises = Array.from({ length: 3 }, () => + service.clear({ tags: ["clear-tag"] }) + ) + + await Promise.all(promises) + + const result1 = await service.get({ key: "clear-test-1" }) + const result2 = await service.get({ key: "clear-test-2" }) + + expect(result1).toBeNull() + expect(result2).toBeNull() + }) + + it("should handle concurrent requests with different parameters separately", async () => { + const testData1 = { id: "1", name: "Item 1" } + const testData2 = { id: "2", name: "Item 2" } + + await service.set({ key: "key-1", data: testData1 }) + await service.set({ key: "key-2", data: testData2 }) + + const promises = [ + service.get({ key: "key-1" }), + service.get({ key: "key-1" }), + service.get({ key: "key-2" }), + service.get({ key: "key-2" }), + ] + + const results = await Promise.all(promises) + + expect(results[0]).toEqual(testData1) + expect(results[1]).toEqual(testData1) + expect(results[2]).toEqual(testData2) + expect(results[3]).toEqual(testData2) + }) + }) + + describe("Memory Cache Provider Integration", () => { + it("should respect TTL settings", async () => { + const testData = { id: "ttl-test", name: "TTL Test" } + + await service.set({ + key: "ttl-key", + data: testData, + ttl: 1, + }) + + let result = await service.get({ key: "ttl-key" }) + expect(result).toEqual(testData) + + await new Promise((resolve) => setTimeout(resolve, 1100)) + + result = await service.get({ key: "ttl-key" }) + expect(result).toBeNull() + }) + + it("should handle autoInvalidate option", async () => { + const testData = { id: "no-auto-test", name: "No Auto Test" } + + await service.set({ + key: "no-auto-key", + data: testData, + tags: ["no-auto-tag"], + options: { autoInvalidate: false }, + }) + + await service.clear({ + tags: ["no-auto-tag"], + options: { autoInvalidate: true }, + }) + + const result = await service.get({ key: "no-auto-key" }) + expect(result).toEqual(testData) + + await service.clear({ + tags: ["no-auto-tag"], + }) + + const result2 = await service.get({ key: "no-auto-key" }) + expect(result2).toBeNull() + }) + + it("should generate consistent cache keys", async () => { + const testInput = { userId: "123", action: "view" } + + const key1 = await service.computeKey(testInput) + const key2 = await service.computeKey(testInput) + + expect(key1).toBe(key2) + expect(typeof key1).toBe("string") + expect(key1.length).toBeGreaterThan(0) + }) + + it("should generate cache tags", async () => { + const testInput = { id: "prod_1", title: "123", description: "456" } + + const tags = await service.computeTags(testInput) + + expect(Array.isArray(tags)).toBe(true) + expect(tags.length).toBeGreaterThan(0) + }) + }) + + describe("Error Handling", () => { + it("should throw error when neither key nor tags provided to get", async () => { + await expect(service.get({})).rejects.toThrow( + "Either key or tags must be provided" + ) + }) + + it("should throw error when neither key nor tags provided to clear", async () => { + await expect(service.clear({})).rejects.toThrow( + "Either key or tags must be provided" + ) + }) + }) + }) + }, +}) diff --git a/packages/modules/caching/integration-tests/__tests__/invalidation.spec.ts b/packages/modules/caching/integration-tests/__tests__/invalidation.spec.ts new file mode 100644 index 0000000000..1354dc98d1 --- /dev/null +++ b/packages/modules/caching/integration-tests/__tests__/invalidation.spec.ts @@ -0,0 +1,430 @@ +import { Modules } from "@medusajs/framework/utils" +import { moduleIntegrationTestRunner } from "@medusajs/test-utils" +import { ICachingModuleService } from "@medusajs/framework/types" +import { MedusaModule } from "@medusajs/framework/modules-sdk" +import { EventBusServiceMock } from "../__fixtures__/event-bus-mock" + +jest.setTimeout(30000) + +jest.spyOn(MedusaModule, "getAllJoinerConfigs").mockReturnValue([ + { + schema: ` + type Product { + id: ID + title: String + handle: String + status: String + type_id: String + collection_id: String + is_giftcard: Boolean + external_id: String + created_at: DateTime + updated_at: DateTime + + variants: [ProductVariant] + sales_channels: [SalesChannel] + } + + type ProductVariant { + id: ID + product_id: String + sku: String + + prices: [Price] + } + + type Price { + id: ID + amount: Float + currency_code: String + variant_id: String + } + + type SalesChannel { + id: ID + is_disabled: Boolean + } + + type ProductCollection { + id: ID + title: String + handle: String + } +`, + }, +]) + +const mockEventBus = new EventBusServiceMock() + +moduleIntegrationTestRunner({ + moduleName: Modules.CACHING, + injectedDependencies: { + [Modules.EVENT_BUS]: mockEventBus, + }, + testSuite: ({ service }) => { + describe("Cache Invalidation with Entity Relationships", () => { + afterEach(async () => { + await service.clear({ tags: ["*"] }).catch(() => {}) + }) + + describe("Single Entity Caching", () => { + it("should cache and retrieve a single product entity using computed keys", async () => { + const product = { + id: "prod_1", + title: "Test Product", + handle: "test-product", + status: "published", + created_at: new Date().toISOString(), + updated_at: new Date().toISOString(), + } + + const productKey = await service.computeKey(product) + + await service.set({ + key: productKey, + data: product, + }) + + const cachedProduct = await service.get({ key: productKey }) + expect(cachedProduct).toEqual(product) + }) + + it("should auto-invalidate single entity when strategy clears computed tags", async () => { + const product = { + id: "prod_1", + title: "Test Product", + handle: "test-product", + } + + const productKey = await service.computeKey(product) + + await service.set({ + key: productKey, + data: product, + }) + + await mockEventBus.emit( + [{ name: "product.updated", data: { id: product.id } }], + {} + ) + + const result = await service.get({ key: productKey }) + expect(result).toBeNull() + }) + + it("should not auto-invalidate single entity with autoInvalidate=false", async () => { + const product = { + id: "prod_1", + title: "Test Product", + handle: "test-product", + } + + const productKey = await service.computeKey(product) + + await service.set({ + key: productKey, + data: product, + options: { autoInvalidate: false }, + }) + + await mockEventBus.emit( + [{ name: "product.updated", data: { id: product.id } }], + {} + ) + + const result = await service.get({ key: productKey }) + expect(result).toEqual(product) + }) + }) + + describe("Entity List Caching", () => { + it("should cache and retrieve lists of entities using computed keys", async () => { + const publishedProductsQuery = { + entity: "product", + filters: { status: "published" }, + fields: ["id", "title", "status"], + } + + const allProductsQuery = { + entity: "product", + filters: {}, + fields: ["id", "title", "status"], + } + + const publishedProducts = [ + { id: "prod_1", title: "Product 1", status: "published" }, + { id: "prod_2", title: "Product 2", status: "published" }, + ] + + const allProducts = [ + ...publishedProducts, + { id: "prod_3", title: "Product 3", status: "draft" }, + ] + + const publishedProductsKey = await service.computeKey( + publishedProductsQuery + ) + const allProductsKey = await service.computeKey(allProductsQuery) + + await service.set({ + key: publishedProductsKey, + data: publishedProducts, + }) + + await service.set({ + key: allProductsKey, + data: allProducts, + }) + + const cachedPublished = await service.get({ + key: publishedProductsKey, + }) + const cachedAll = await service.get({ key: allProductsKey }) + + expect(cachedPublished).toEqual(publishedProducts) + expect(cachedAll).toEqual(allProducts) + }) + + it("should invalidate related lists when individual product is updated", async () => { + const listQuery = { + entity: "product", + filters: { status: "published" }, + includes: ["id", "title"], + } + + const products = [ + { id: "prod_1", title: "Product 1", status: "published" }, + { id: "prod_2", title: "Product 2", status: "published" }, + ] + + const listKey = await service.computeKey(listQuery) + + await service.set({ + key: listKey, + data: products, + }) + + await mockEventBus.emit( + [ + { + name: "product.updated", + data: { id: "prod_1", title: "Updated Product 1" }, + }, + ], + {} + ) + + const cachedList = await service.get({ key: listKey }) + + expect(cachedList).toBeNull() + }) + }) + + describe("Nested Entity Caching", () => { + it("should cache products with nested variants and prices using computed keys", async () => { + const productWithVariants = { + id: "prod_1", + title: "Complex Product", + variants: [ + { + id: "var_1", + product_id: "prod_1", + sku: "SKU-001", + prices: [ + { + id: "price_1", + variant_id: "var_1", + amount: 1000, + currency_code: "USD", + }, + { + id: "price_2", + variant_id: "var_1", + amount: 900, + currency_code: "EUR", + }, + ], + }, + { + id: "var_2", + product_id: "prod_1", + sku: "SKU-002", + prices: [ + { + id: "price_3", + variant_id: "var_2", + amount: 1200, + currency_code: "USD", + }, + ], + }, + ], + } + + const productKey = await service.computeKey(productWithVariants) + + await service.set({ + key: productKey, + data: productWithVariants, + }) + + const cached = await service.get({ + key: productKey, + }) + expect(cached).toEqual(productWithVariants) + expect(cached!.variants).toHaveLength(2) + expect(cached!.variants[0].prices).toHaveLength(2) + }) + + it("should invalidate nested product when related variant is updated", async () => { + const productWithVariants = { + id: "prod_1", + title: "Complex Product", + variants: [{ id: "var_1", product_id: "prod_1", sku: "SKU-001" }], + } + + const productKey = await service.computeKey(productWithVariants) + + await service.set({ + key: productKey, + data: productWithVariants, + }) + + await mockEventBus.emit( + [ + { + name: "product_variant.updated", + data: { + id: "var_1", + product_id: "prod_1", + sku: "SKU-001-UPDATED", + }, + }, + ], + {} + ) + + const result = await service.get({ key: productKey }) + expect(result).toBeNull() + }) + + it("should handle price updates affecting variant and product caches", async () => { + const price = { + id: "price_1", + variant_id: "var_1", + amount: 1000, + currency_code: "USD", + } + + const variant = { + id: "var_1", + product_id: "prod_1", + sku: "SKU-001", + prices: [price], + } + + const product = { + id: "prod_1", + title: "Product", + variants: [variant], + } + + const priceKey = await service.computeKey(price) + const variantKey = await service.computeKey(variant) + const productKey = await service.computeKey(product) + + await service.set({ + key: priceKey, + data: price, + }) + await service.set({ + key: variantKey, + data: variant, + }) + await service.set({ + key: productKey, + data: product, + }) + + await mockEventBus.emit( + [ + { + name: "price.updated", + data: { id: "price_1", variant_id: "var_1", amount: 1100 }, + }, + ], + {} + ) + + const cachedPrice = await service.get({ key: priceKey }) + const cachedVariant = await service.get({ key: variantKey }) + const cachedProduct = await service.get({ key: productKey }) + + expect(cachedPrice).toBeNull() + expect(cachedVariant).toBeNull() + expect(cachedProduct).toBeNull() + }) + }) + + describe("Complex Query Caching", () => { + it("should cache complex queries and invalidate based on entity relationships", async () => { + const complexQuery = { + entity: "product", + filters: { status: "published", collection_id: "col_1" }, + includes: { + variants: { + include: { + prices: true, + }, + }, + collection: true, + }, + pagination: { limit: 10, offset: 0 }, + } + + const queryResult = { + data: [ + { + id: "prod_1", + title: "Product 1", + collection_id: "col_1", + variants: [ + { + id: "var_1", + product_id: "prod_1", + prices: [{ id: "price_1", amount: 1000 }], + }, + ], + collection: { id: "col_1", title: "Collection 1" }, + }, + ], + pagination: { total: 1, limit: 10, offset: 0 }, + } + + const queryKey = await service.computeKey(complexQuery) + + await service.set({ + key: queryKey, + data: queryResult, + }) + + const cached = await service.get({ key: queryKey }) + expect(cached).toEqual(queryResult) + + await mockEventBus.emit( + [ + { + name: "price.updated", + data: { id: "price_1", amount: 1100 }, + }, + ], + {} + ) + + const cachedAfterUpdate = await service.get({ key: queryKey }) + expect(cachedAfterUpdate).toBeNull() + }) + }) + }) + }, +}) diff --git a/packages/modules/caching/integration-tests/__tests__/redis/invalidation.spec.ts b/packages/modules/caching/integration-tests/__tests__/redis/invalidation.spec.ts new file mode 100644 index 0000000000..1147ca5b27 --- /dev/null +++ b/packages/modules/caching/integration-tests/__tests__/redis/invalidation.spec.ts @@ -0,0 +1,540 @@ +import { MedusaModule } from "@medusajs/framework/modules-sdk" +import { ICachingModuleService } from "@medusajs/framework/types" +import { Modules } from "@medusajs/framework/utils" +import { moduleIntegrationTestRunner } from "@medusajs/test-utils" +import { setTimeout } from "timers/promises" +import { EventBusServiceMock } from "../../__fixtures__/event-bus-mock" + +jest.setTimeout(300000) + +jest.spyOn(MedusaModule, "getAllJoinerConfigs").mockReturnValue([ + { + schema: ` + type Product { + id: ID + title: String + handle: String + status: String + type_id: String + collection_id: String + is_giftcard: Boolean + external_id: String + created_at: DateTime + updated_at: DateTime + + variants: [ProductVariant] + sales_channels: [SalesChannel] + } + + type ProductVariant { + id: ID + product_id: String + sku: String + + prices: [Price] + } + + type Price { + id: ID + amount: Float + currency_code: String + variant_id: String + } + + type SalesChannel { + id: ID + is_disabled: Boolean + } + + type ProductCollection { + id: ID + title: String + handle: String + } +`, + }, +]) + +const DEFAULT_WAIT_INTERVAL = 50 +const mockEventBus = new EventBusServiceMock() + +moduleIntegrationTestRunner({ + moduleName: Modules.CACHING, + injectedDependencies: { + [Modules.EVENT_BUS]: mockEventBus, + }, + moduleOptions: { + providers: [ + { + id: "cache-redis", + resolve: require.resolve("../../../../providers/caching-redis/src"), + is_default: true, + options: { + redisUrl: "localhost:6379", + }, + }, + ], + }, + testSuite: ({ service }) => { + describe("Cache Invalidation with Entity Relationships", () => { + afterEach(async () => { + await service.clear({ tags: ["*"] }).catch(() => {}) + }) + + describe("Single Entity Caching", () => { + it("should cache and retrieve a single product entity using computed keys", async () => { + const product = { + id: "prod_1", + title: "Test Product", + handle: "test-product", + status: "published", + created_at: new Date().toISOString(), + updated_at: new Date().toISOString(), + } + + const productKey = await service.computeKey(product) + + await service.set({ + key: productKey, + data: product, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const cachedProduct = await service.get({ key: productKey }) + expect(cachedProduct).toEqual(product) + }) + + it("should auto-invalidate single entity when strategy clears computed tags", async () => { + const product = { + id: "prod_1", + title: "Test Product", + handle: "test-product", + } + + const productKey = await service.computeKey(product) + + await service.set({ + key: productKey, + data: product, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + await mockEventBus.emit( + [{ name: "product.updated", data: { id: product.id } }], + {} + ) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const result = await service.get({ key: productKey }) + expect(result).toBeNull() + }) + + it("should not auto-invalidate single entity with autoInvalidate=false", async () => { + const product = { + id: "prod_1", + title: "Test Product", + handle: "test-product", + } + + const productKey = await service.computeKey(product) + + await service.set({ + key: productKey, + data: product, + options: { autoInvalidate: false }, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + await mockEventBus.emit( + [{ name: "product.updated", data: { id: product.id } }], + {} + ) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const result = await service.get({ key: productKey }) + expect(result).toEqual(product) + }) + }) + + describe("Entity List Caching", () => { + it("should cache and retrieve lists of entities using computed keys", async () => { + const publishedProductsQuery = { + entity: "product", + filters: { status: "published" }, + fields: ["id", "title", "status"], + } + + const allProductsQuery = { + entity: "product", + filters: {}, + fields: ["id", "title", "status"], + } + + const publishedProducts = [ + { id: "prod_1", title: "Product 1", status: "published" }, + { id: "prod_2", title: "Product 2", status: "published" }, + ] + + const allProducts = [ + ...publishedProducts, + { id: "prod_3", title: "Product 3", status: "draft" }, + ] + + const publishedProductsKey = await service.computeKey( + publishedProductsQuery + ) + const allProductsKey = await service.computeKey(allProductsQuery) + + await service.set({ + key: publishedProductsKey, + data: publishedProducts, + }) + + await service.set({ + key: allProductsKey, + data: allProducts, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const cachedPublished = await service.get({ + key: publishedProductsKey, + }) + const cachedAll = await service.get({ key: allProductsKey }) + + expect(cachedPublished).toEqual(publishedProducts) + expect(cachedAll).toEqual(allProducts) + }) + + it("should invalidate related lists when individual product is updated", async () => { + const listQuery = { + entity: "product", + filters: { status: "published" }, + includes: ["id", "title"], + } + + const products = [ + { id: "prod_1", title: "Product 1", status: "published" }, + { id: "prod_2", title: "Product 2", status: "published" }, + ] + + const listKey = await service.computeKey(listQuery) + + await service.set({ + key: listKey, + data: products, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + await mockEventBus.emit( + [ + { + name: "product.updated", + data: { id: "prod_1", title: "Updated Product 1" }, + }, + ], + {} + ) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const cachedList = await service.get({ key: listKey }) + + expect(cachedList).toBeNull() + }) + }) + + describe("Nested Entity Caching", () => { + it("should cache products with nested variants and prices using computed keys", async () => { + const productWithVariants = { + id: "prod_1", + title: "Complex Product", + variants: [ + { + id: "var_1", + product_id: "prod_1", + sku: "SKU-001", + prices: [ + { + id: "price_1", + variant_id: "var_1", + amount: 1000, + currency_code: "USD", + }, + { + id: "price_2", + variant_id: "var_1", + amount: 900, + currency_code: "EUR", + }, + ], + }, + { + id: "var_2", + product_id: "prod_1", + sku: "SKU-002", + prices: [ + { + id: "price_3", + variant_id: "var_2", + amount: 1500, + currency_code: "USD", + }, + ], + }, + ], + } + + const productKey = await service.computeKey(productWithVariants) + + await service.set({ + key: productKey, + data: productWithVariants, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const cached = await service.get({ + key: productKey, + }) + expect(cached).toEqual(productWithVariants) + expect(cached!.variants).toHaveLength(2) + expect(cached!.variants[0].prices).toHaveLength(2) + }) + + it("should invalidate nested product when related variant is updated", async () => { + const productWithVariants = { + id: "prod_1", + title: "Complex Product", + variants: [{ id: "var_1", product_id: "prod_1", sku: "SKU-001" }], + } + + const productKey = await service.computeKey(productWithVariants) + + await service.set({ + key: productKey, + data: productWithVariants, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + await mockEventBus.emit( + [ + { + name: "product_variant.updated", + data: { + id: "var_1", + product_id: "prod_1", + sku: "SKU-001-UPDATED", + }, + }, + ], + {} + ) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const result = await service.get({ key: productKey }) + expect(result).toBeNull() + }) + + it("should handle price updates affecting variant and product caches", async () => { + const price = { + id: "price_1", + variant_id: "var_1", + amount: 1000, + currency_code: "USD", + } + + const variant = { + id: "var_1", + product_id: "prod_1", + sku: "SKU-001", + prices: [price], + } + + const product = { + id: "prod_1", + title: "Product", + variants: [variant], + } + + const priceKey = await service.computeKey(price) + const variantKey = await service.computeKey(variant) + const productKey = await service.computeKey(product) + + await service.set({ + key: priceKey, + data: price, + }) + await service.set({ + key: variantKey, + data: variant, + }) + await service.set({ + key: productKey, + data: product, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + await mockEventBus.emit( + [ + { + name: "price.updated", + data: { id: "price_1", variant_id: "var_1", amount: 1100 }, + }, + ], + {} + ) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const cachedPrice = await service.get({ key: priceKey }) + const cachedVariant = await service.get({ key: variantKey }) + const cachedProduct = await service.get({ key: productKey }) + + expect(cachedPrice).toBeNull() + expect(cachedVariant).toBeNull() + expect(cachedProduct).toBeNull() + }) + }) + + describe("Complex Query Caching", () => { + it("should cache complex queries and invalidate based on entity relationships", async () => { + const complexQuery = { + entity: "product", + filters: { status: "published", collection_id: "col_1" }, + includes: { + variants: { + include: { + prices: true, + }, + }, + collection: true, + }, + pagination: { limit: 10, offset: 0 }, + } + + const queryResult = { + data: [ + { + id: "prod_1", + title: "Product 1", + collection_id: "col_1", + variants: [ + { + id: "var_1", + product_id: "prod_1", + prices: [{ id: "price_1", amount: 1000 }], + }, + ], + collection: { id: "col_1", title: "Collection 1" }, + }, + ], + pagination: { total: 1, limit: 10, offset: 0 }, + } + + const queryKey = await service.computeKey(complexQuery) + + await service.set({ + key: queryKey, + data: queryResult, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const cached = await service.get({ key: queryKey }) + expect(cached).toEqual(queryResult) + + await mockEventBus.emit( + [ + { + name: "price.updated", + data: { id: "price_1", amount: 1100 }, + }, + ], + {} + ) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const cachedAfterUpdate = await service.get({ key: queryKey }) + expect(cachedAfterUpdate).toBeNull() + }) + }) + + it("should cache complex queries and return the correct cached data", async () => { + const complexQuery = { + entity: "product", + filters: { status: "published", collection_id: "col_1" }, + includes: { + variants: { + include: { + prices: true, + }, + }, + collection: true, + }, + pagination: { limit: 10, offset: 0 }, + } + + const queryResult = { + data: [ + { + id: "prod_1", + title: "Product 1", + collection_id: "col_1", + variants: [ + { + id: "var_1", + product_id: "prod_1", + prices: [{ id: "price_1", amount: 1000 }], + }, + ], + collection: { id: "col_1", title: "Collection 1" }, + }, + ], + pagination: { total: 1, limit: 10, offset: 0 }, + } + + const queryKey = await service.computeKey(complexQuery) + + await service.set({ + key: queryKey, + data: queryResult, + }) + + await setTimeout(DEFAULT_WAIT_INTERVAL) + + const cached = await service.get({ key: queryKey }) + expect(cached).toEqual(queryResult) + + const complexQueryOffset = { + entity: "product", + filters: { status: "published", collection_id: "col_1" }, + includes: { + variants: { + include: { + prices: true, + }, + }, + collection: true, + }, + pagination: { limit: 10, offset: 10 }, + } + + const queryKeyOffset = await service.computeKey(complexQueryOffset) + + const cachedOffset = await service.get({ key: queryKeyOffset }) + expect(cachedOffset).toEqual(null) + }) + }) + }, +}) diff --git a/packages/modules/caching/jest.config.js b/packages/modules/caching/jest.config.js new file mode 100644 index 0000000000..0ce72cf61f --- /dev/null +++ b/packages/modules/caching/jest.config.js @@ -0,0 +1,8 @@ +const defineJestConfig = require("../../../define_jest_config") +module.exports = defineJestConfig({ + moduleNameMapper: { + "^@services": "/src/services", + "^@types": "/src/types", + "^@utils": "/src/utils", + }, +}) diff --git a/packages/modules/caching/package.json b/packages/modules/caching/package.json new file mode 100644 index 0000000000..8beee6bed3 --- /dev/null +++ b/packages/modules/caching/package.json @@ -0,0 +1,49 @@ +{ + "name": "@medusajs/caching", + "version": "2.10.3", + "description": "Caching Module for Medusa", + "main": "dist/index.js", + "repository": { + "type": "git", + "url": "https://github.com/medusajs/medusa", + "directory": "packages/modules/caching" + }, + "files": [ + "dist", + "!dist/**/__tests__", + "!dist/**/__mocks__", + "!dist/**/__fixtures__" + ], + "publishConfig": { + "access": "public" + }, + "author": "Medusa", + "license": "MIT", + "scripts": { + "watch": "tsc --build --watch", + "watch:test": "tsc --build tsconfig.spec.json --watch", + "resolve:aliases": "tsc --showConfig -p tsconfig.json > tsconfig.resolved.json && tsc-alias -p tsconfig.resolved.json && rimraf tsconfig.resolved.json", + "build": "rimraf dist && tsc --build && npm run resolve:aliases", + "test": "jest --passWithNoTests --runInBand --bail --forceExit -- src/", + "test:integration": "jest --runInBand --forceExit -- integration-tests/__tests__/**/*.ts" + }, + "devDependencies": { + "@medusajs/framework": "2.10.3", + "@medusajs/test-utils": "2.10.3", + "@swc/core": "^1.7.28", + "@swc/jest": "^0.2.36", + "jest": "^29.7.0", + "rimraf": "^3.0.2", + "tsc-alias": "^1.8.6", + "typescript": "^5.6.2" + }, + "peerDependencies": { + "@medusajs/framework": "2.10.3", + "awilix": "^8.0.1" + }, + "dependencies": { + "fast-json-stable-stringify": "^2.1.0", + "node-cache": "^5.1.2", + "xxhash-wasm": "^1.1.0" + } +} diff --git a/packages/modules/caching/src/index.ts b/packages/modules/caching/src/index.ts new file mode 100644 index 0000000000..95d8045a12 --- /dev/null +++ b/packages/modules/caching/src/index.ts @@ -0,0 +1,12 @@ +import { Module, Modules } from "@medusajs/framework/utils" +import { default as loadHash } from "./loaders/hash" +import { default as loadProviders } from "./loaders/providers" +import CachingModuleService from "./services/cache-module" + +export default Module(Modules.CACHING, { + service: CachingModuleService, + loaders: [loadHash, loadProviders], +}) + +// Module options types +export { CachingModuleOptions } from "./types" diff --git a/packages/modules/caching/src/loaders/hash.ts b/packages/modules/caching/src/loaders/hash.ts new file mode 100644 index 0000000000..194f142c97 --- /dev/null +++ b/packages/modules/caching/src/loaders/hash.ts @@ -0,0 +1,8 @@ +import { asValue } from "awilix" + +export default async ({ container }) => { + const xxhashhWasm = await import("xxhash-wasm") + const { h32ToString } = await xxhashhWasm.default() + + container.register("hasher", asValue(h32ToString)) +} diff --git a/packages/modules/caching/src/loaders/providers.ts b/packages/modules/caching/src/loaders/providers.ts new file mode 100644 index 0000000000..e91b9a175d --- /dev/null +++ b/packages/modules/caching/src/loaders/providers.ts @@ -0,0 +1,94 @@ +import { moduleProviderLoader } from "@medusajs/framework/modules-sdk" +import { LoaderOptions, ModulesSdkTypes } from "@medusajs/framework/types" +import { + ContainerRegistrationKeys, + getProviderRegistrationKey, +} from "@medusajs/framework/utils" +import { CachingProviderService } from "@services" +import { + CachingDefaultProvider, + CachingIdentifiersRegistrationName, + CachingModuleOptions, + CachingProviderRegistrationPrefix, +} from "@types" +import { aliasTo, asFunction, asValue, Lifetime } from "awilix" +import { MemoryCachingProvider } from "../providers/memory-cache" +import { DefaultCacheStrategy } from "../utils/strategy" + +const registrationFn = async (klass, container, { id }) => { + const key = CachingProviderService.getRegistrationIdentifier(klass) + + if (!id) { + throw new Error(`No "id" provided for provider ${key}`) + } + + const regKey = getProviderRegistrationKey({ + providerId: id, + providerIdentifier: key, + }) + + container.register({ + [CachingProviderRegistrationPrefix + id]: aliasTo(regKey), + }) + + container.registerAdd(CachingIdentifiersRegistrationName, asValue(key)) +} + +export default async ({ + container, + options, +}: LoaderOptions< + ( + | ModulesSdkTypes.ModuleServiceInitializeOptions + | ModulesSdkTypes.ModuleServiceInitializeCustomDataLayerOptions + ) & + CachingModuleOptions +>): Promise => { + container.registerAdd(CachingIdentifiersRegistrationName, asValue(undefined)) + + const strategy = DefaultCacheStrategy // Re enable custom strategy another time + container.register("strategy", asValue(strategy)) + + // MemoryCachingProvider - default provider + container.register({ + [CachingProviderRegistrationPrefix + MemoryCachingProvider.identifier]: + asFunction(() => new MemoryCachingProvider(), { + lifetime: Lifetime.SINGLETON, + }), + }) + container.registerAdd( + CachingIdentifiersRegistrationName, + asValue(MemoryCachingProvider.identifier) + ) + container.register( + CachingDefaultProvider, + asValue(MemoryCachingProvider.identifier) + ) + + // Load other providers + await moduleProviderLoader({ + container, + providers: options?.providers || [], + registerServiceFn: registrationFn, + }) + + const isSingleProvider = options?.providers?.length === 1 + let hasDefaultProvider = false + for (const provider of options?.providers || []) { + if (provider.is_default || isSingleProvider) { + if (provider.is_default) { + hasDefaultProvider = true + } + container.register(CachingDefaultProvider, asValue(provider.id)) + } + } + + const logger = container.resolve(ContainerRegistrationKeys.LOGGER) + if (!hasDefaultProvider) { + logger.warn( + `[caching-module]: No default caching provider defined. Using "${container.resolve( + CachingDefaultProvider + )}" as default.` + ) + } +} diff --git a/packages/modules/caching/src/providers/memory-cache.ts b/packages/modules/caching/src/providers/memory-cache.ts new file mode 100644 index 0000000000..ef93483e04 --- /dev/null +++ b/packages/modules/caching/src/providers/memory-cache.ts @@ -0,0 +1,228 @@ +import NodeCache from "node-cache" +import type { ICachingProviderService } from "@medusajs/framework/types" + +export interface MemoryCacheModuleOptions { + /** + * TTL in seconds + */ + ttl?: number + /** + * Maximum number of keys to store (see node-cache documentation) + */ + maxKeys?: number + /** + * Check period for expired keys in seconds (see node-cache documentation) + */ + checkPeriod?: number + /** + * Use clones for cached data (see node-cache documentation) + */ + useClones?: boolean +} + +export class MemoryCachingProvider implements ICachingProviderService { + static identifier = "cache-memory" + + protected cacheClient: NodeCache + protected tagIndex: Map> = new Map() // tag -> keys + protected keyTags: Map> = new Map() // key -> tags + protected entryOptions: Map = new Map() // key -> options + protected options: MemoryCacheModuleOptions + + constructor() { + this.options = { + ttl: 3600, + maxKeys: 25000, + checkPeriod: 60, // 10 minutes + useClones: false, // Default to false for speed, true would be slower but safer. we can discuss + } + + const cacheClient = new NodeCache({ + stdTTL: this.options.ttl, + maxKeys: this.options.maxKeys, + checkperiod: this.options.checkPeriod, + useClones: this.options.useClones, + }) + + this.cacheClient = cacheClient + + // Clean up tag indices when keys expire + this.cacheClient.on("expired", (key: string, value: any) => { + this.cleanupTagReferences(key) + }) + + this.cacheClient.on("del", (key: string, value: any) => { + this.cleanupTagReferences(key) + }) + } + + private cleanupTagReferences(key: string): void { + const tags = this.keyTags.get(key) + if (tags) { + tags.forEach((tag) => { + const keysForTag = this.tagIndex.get(tag) + if (keysForTag) { + keysForTag.delete(key) + if (keysForTag.size === 0) { + this.tagIndex.delete(tag) + } + } + }) + this.keyTags.delete(key) + } + // Also clean up entry options + this.entryOptions.delete(key) + } + + async get({ key, tags }: { key?: string; tags?: string[] }): Promise { + if (key) { + return this.cacheClient.get(key) ?? null + } + + if (tags && tags.length) { + const allKeys = new Set() + + tags.forEach((tag) => { + const keysForTag = this.tagIndex.get(tag) + if (keysForTag) { + keysForTag.forEach((key) => allKeys.add(key)) + } + }) + + if (allKeys.size === 0) { + return [] + } + + const results: any[] = [] + allKeys.forEach((key) => { + const value = this.cacheClient.get(key) + if (value !== undefined) { + results.push(value) + } + }) + + return results + } + + return null + } + + async set({ + key, + data, + ttl, + tags, + options, + }: { + key: string + data: object + ttl?: number + tags?: string[] + options?: { + autoInvalidate?: boolean + } + }): Promise { + // Set the cache entry + const effectiveTTL = ttl ?? this.options.ttl ?? 3600 + this.cacheClient.set(key, data, effectiveTTL) + + // Handle tags if provided + if (tags && tags.length) { + // Clean up any existing tag references for this key + this.cleanupTagReferences(key) + + const tagSet = new Set(tags) + this.keyTags.set(key, tagSet) + + // Add this key to each tag's index + tags.forEach((tag) => { + if (!this.tagIndex.has(tag)) { + this.tagIndex.set(tag, new Set()) + } + this.tagIndex.get(tag)!.add(key) + }) + } + + // Store entry options if provided + if ( + Object.keys(options ?? {}).length && + !Object.values(options ?? {}).every((value) => value === undefined) + ) { + this.entryOptions.set(key, options!) + } + } + + async clear({ + key, + tags, + options, + }: { + key?: string + tags?: string[] + options?: { + autoInvalidate?: boolean + } + }): Promise { + if (key) { + this.cacheClient.del(key) + return + } + + if (tags && tags.length) { + // Handle wildcard tag to clear all cache data + if (tags.includes("*")) { + this.cacheClient.flushAll() + this.tagIndex.clear() + this.keyTags.clear() + this.entryOptions.clear() + return + } + + const allKeys = new Set() + + tags.forEach((tag) => { + const keysForTag = this.tagIndex.get(tag) + if (keysForTag) { + keysForTag.forEach((key) => allKeys.add(key)) + } + }) + + if (allKeys.size) { + // If no options provided (user explicit call), clear everything + if (!options) { + const keysToDelete = Array.from(allKeys) + this.cacheClient.del(keysToDelete) + + // Clean up ALL tag references for deleted keys + keysToDelete.forEach((key) => { + this.cleanupTagReferences(key) + }) + return + } + + // If autoInvalidate is true (strategy call), only clear entries with autoInvalidate=true (default) + if (options.autoInvalidate === true) { + const keysToDelete: string[] = [] + + allKeys.forEach((key) => { + const entryOptions = this.entryOptions.get(key) + // Delete if entry has autoInvalidate=true or no setting (default true) + const shouldAutoInvalidate = entryOptions?.autoInvalidate ?? true + if (shouldAutoInvalidate) { + keysToDelete.push(key) + } + }) + + if (keysToDelete.length) { + this.cacheClient.del(keysToDelete) + + // Clean up ALL tag references for deleted keys + keysToDelete.forEach((key) => { + this.cleanupTagReferences(key) + }) + } + } + } + } + } +} diff --git a/packages/modules/caching/src/services/cache-module.ts b/packages/modules/caching/src/services/cache-module.ts new file mode 100644 index 0000000000..73f04ccb1c --- /dev/null +++ b/packages/modules/caching/src/services/cache-module.ts @@ -0,0 +1,406 @@ +import { MedusaModule } from "@medusajs/framework/modules-sdk" +import type { + ICachingModuleService, + ICachingStrategy, + Logger, +} from "@medusajs/framework/types" +import { GraphQLUtils, MedusaError } from "@medusajs/framework/utils" +import { CachingDefaultProvider, InjectedDependencies } from "@types" +import CacheProviderService from "./cache-provider" + +const ONE_HOUR_IN_SECOND = 60 * 60 + +export default class CachingModuleService implements ICachingModuleService { + protected container: InjectedDependencies + protected providerService: CacheProviderService + protected strategyCtr: new (...args: any[]) => ICachingStrategy + protected strategy: ICachingStrategy + protected defaultProviderId: string + + protected logger: Logger + protected ongoingRequests: Map> = new Map() + + protected ttl: number + + static traceGet?: ( + cacheGetFn: () => Promise, + key: string, + tags: string[] + ) => Promise + + static traceSet?: ( + cacheSetFn: () => Promise, + key: string, + tags: string[], + options: { autoInvalidate?: boolean } + ) => Promise + + static traceClear?: ( + cacheClearFn: () => Promise, + key: string, + tags: string[], + options: { autoInvalidate?: boolean } + ) => Promise + + constructor( + container: InjectedDependencies, + protected readonly moduleDeclaration: + | { options: { ttl?: number } } + | { ttl?: number } + ) { + this.container = container + this.providerService = container.cacheProviderService + this.defaultProviderId = container[CachingDefaultProvider] + this.strategyCtr = container.strategy as new ( + ...args: any[] + ) => ICachingStrategy + this.strategy = new this.strategyCtr(this.container, this) + + const moduleOptions = + "options" in moduleDeclaration + ? moduleDeclaration.options + : moduleDeclaration + + this.ttl = moduleOptions.ttl ?? ONE_HOUR_IN_SECOND + + this.logger = container.logger ?? (console as unknown as Logger) + } + + __hooks = { + onApplicationStart: async () => { + this.onApplicationStart() + }, + onApplicationShutdown: async () => { + this.onApplicationShutdown() + }, + onApplicationPrepareShutdown: async () => { + this.onApplicationPrepareShutdown() + }, + } + + protected onApplicationStart() { + const loadedSchema = MedusaModule.getAllJoinerConfigs() + .map((joinerConfig) => joinerConfig?.schema ?? "") + .join("\n") + + const defaultMedusaSchema = ` + scalar DateTime + scalar JSON + directive @enumValue(value: String) on ENUM_VALUE + ` + + const { schema: cleanedSchema } = GraphQLUtils.cleanGraphQLSchema( + defaultMedusaSchema + loadedSchema + ) + const mergedSchema = GraphQLUtils.mergeTypeDefs(cleanedSchema) + const schema = GraphQLUtils.makeExecutableSchema({ + typeDefs: mergedSchema, + }) + + this.strategy.onApplicationStart?.( + schema, + MedusaModule.getAllJoinerConfigs() + ) + } + + protected onApplicationShutdown() { + this.strategy.onApplicationShutdown?.() + } + + protected onApplicationPrepareShutdown() { + this.strategy.onApplicationPrepareShutdown?.() + } + + protected static normalizeProviders( + providers: string[] | { id: string; ttl?: number }[] + ): { id: string; ttl?: number }[] { + const providers_ = Array.isArray(providers) ? providers : [providers] + return providers_.map((provider) => { + return typeof provider === "string" ? { id: provider } : provider + }) + } + + protected getRequestKey( + key?: string, + tags?: string[], + providers?: string[] + ): string { + const keyPart = key || "" + const tagsPart = tags?.sort().join(",") || "" + const providersPart = providers?.join(",") || this.defaultProviderId + return `${keyPart}|${tagsPart}|${providersPart}` + } + + protected getClearRequestKey( + key?: string, + tags?: string[], + providers?: string[] + ): string { + const keyPart = key || "" + const tagsPart = tags?.sort().join(",") || "" + const providersPart = providers?.join(",") || this.defaultProviderId + return `clear:${keyPart}|${tagsPart}|${providersPart}` + } + + async get(options: { key?: string; tags?: string[]; providers?: string[] }) { + if (CachingModuleService.traceGet) { + return await CachingModuleService.traceGet( + () => this.get_(options), + options.key ?? "", + options.tags ?? [] + ) + } + + return await this.get_(options) + } + + private async get_({ + key, + tags, + providers, + }: { + key?: string + tags?: string[] + providers?: string[] + }) { + if (!key && !tags) { + throw new MedusaError( + MedusaError.Types.INVALID_ARGUMENT, + "Either key or tags must be provided" + ) + } + + const requestKey = this.getRequestKey(key, tags, providers) + + const existingRequest = this.ongoingRequests.get(requestKey) + if (existingRequest) { + return await existingRequest + } + + const requestPromise = this.performCacheGet(key, tags, providers) + this.ongoingRequests.set(requestKey, requestPromise) + + try { + const result = await requestPromise + return result + } finally { + // Clean up the completed request + this.ongoingRequests.delete(requestKey) + } + } + + protected async performCacheGet( + key?: string, + tags?: string[], + providers?: string[] + ): Promise { + const providersToCheck = providers ?? [this.defaultProviderId] + + for (const providerId of providersToCheck) { + try { + const provider_ = this.providerService.retrieveProvider(providerId) + const result = await provider_.get({ key, tags }) + + if (result != null) { + return result + } + } catch (error) { + this.logger.warn( + `Cache provider ${providerId} failed: ${error.message}\n${error.stack}` + ) + continue + } + } + + return null + } + + async set(options: { + key: string + data: object + ttl?: number + tags?: string[] + providers?: string[] + options?: { autoInvalidate?: boolean } + }) { + if (CachingModuleService.traceSet) { + return await CachingModuleService.traceSet( + () => this.set_(options), + options.key, + options.tags ?? [], + options.options ?? {} + ) + } + + return await this.set_(options) + } + + private async set_({ + key, + data, + ttl, + tags, + providers, + options, + }: { + key: string + data: object + tags?: string[] + ttl?: number + providers?: string[] | { id: string; ttl?: number }[] + options?: { + autoInvalidate?: boolean + } + }) { + if (!key) { + throw new MedusaError( + MedusaError.Types.INVALID_ARGUMENT, + "[CachingModuleService] Key must be provided" + ) + } + + const key_ = key + const tags_ = tags ?? (await this.strategy.computeTags(data)) + + let providers_: string[] | { id: string; ttl?: number }[] = [ + this.defaultProviderId, + ] + providers_ = CachingModuleService.normalizeProviders( + providers ?? providers_ + ) + + const providerIds = providers_.map((p) => p.id) + const requestKey = this.getRequestKey(key_, tags_, providerIds) + + const existingRequest = this.ongoingRequests.get(requestKey) + if (existingRequest) { + return await existingRequest + } + + const requestPromise = this.performCacheSet( + key_, + tags_, + data, + ttl, + providers_, + options + ) + this.ongoingRequests.set(requestKey, requestPromise) + + try { + await requestPromise + } finally { + // Clean up the completed request + this.ongoingRequests.delete(requestKey) + } + } + + protected async performCacheSet( + key: string, + tags: string[], + data: object, + ttl?: number, + providers?: { id: string; ttl?: number }[], + options?: { + autoInvalidate?: boolean + } + ): Promise { + for (const providerOptions of providers || []) { + const ttl_ = providerOptions.ttl ?? ttl ?? this.ttl + const provider = this.providerService.retrieveProvider(providerOptions.id) + void provider.set({ + key, + tags, + data, + ttl: ttl_, + options, + }) + } + } + + async clear(options: { + key?: string + tags?: string[] + options?: { autoInvalidate?: boolean } + providers?: string[] + }) { + if (CachingModuleService.traceClear) { + return await CachingModuleService.traceClear( + () => this.clear_(options), + options.key ?? "", + options.tags ?? [], + options.options ?? {} + ) + } + + return await this.clear_(options) + } + + private async clear_({ + key, + tags, + options, + providers, + }: { + key?: string + tags?: string[] + options?: { + autoInvalidate?: boolean + } + providers?: string[] + }) { + if (!key && !tags) { + throw new MedusaError( + MedusaError.Types.INVALID_ARGUMENT, + "Either key or tags must be provided" + ) + } + + const requestKey = this.getClearRequestKey(key, tags, providers) + + const existingRequest = this.ongoingRequests.get(requestKey) + if (existingRequest) { + return await existingRequest + } + + const requestPromise = this.performCacheClear(key, tags, options, providers) + this.ongoingRequests.set(requestKey, requestPromise) + + try { + await requestPromise + } finally { + // Clean up the completed request + this.ongoingRequests.delete(requestKey) + } + } + + protected async performCacheClear( + key?: string, + tags?: string[], + options?: { + autoInvalidate?: boolean + }, + providers?: string[] + ): Promise { + let providerIds_: string[] = [this.defaultProviderId] + if (providers) { + providerIds_ = Array.isArray(providers) ? providers : [providers] + } + + for (const providerId of providerIds_) { + const provider = this.providerService.retrieveProvider(providerId) + void provider.clear({ key, tags, options }) + } + } + + async computeKey(input: object): Promise { + return await this.strategy.computeKey(input) + } + + async computeTags( + input: object, + options?: Record + ): Promise { + return await this.strategy.computeTags(input, options) + } +} diff --git a/packages/modules/caching/src/services/cache-provider.ts b/packages/modules/caching/src/services/cache-provider.ts new file mode 100644 index 0000000000..5b473b162d --- /dev/null +++ b/packages/modules/caching/src/services/cache-provider.ts @@ -0,0 +1,60 @@ +import { + Constructor, + ICachingProviderService, + Logger, +} from "@medusajs/framework/types" +import { MedusaError } from "@medusajs/framework/utils" +import { CachingProviderRegistrationPrefix } from "../types" + +type InjectedDependencies = { + [key: `cp_${string}`]: ICachingProviderService + logger?: Logger +} + +export default class CacheProviderService { + #container: InjectedDependencies + #logger: Logger + + constructor(container: InjectedDependencies) { + this.#container = container + this.#logger = container["logger"] + ? container.logger + : (console as unknown as Logger) + } + + static getRegistrationIdentifier( + providerClass: Constructor + ) { + if (!(providerClass as any).identifier) { + throw new MedusaError( + MedusaError.Types.INVALID_ARGUMENT, + `Trying to register a caching provider without an identifier.` + ) + } + return `${(providerClass as any).identifier}` + } + + public retrieveProvider(providerId: string): ICachingProviderService { + try { + return this.#container[ + `${CachingProviderRegistrationPrefix}${providerId}` + ] + } catch (err) { + if (err.name === "AwilixResolutionError") { + const errMessage = ` + Unable to retrieve the caching provider with id: ${providerId} +Please make sure that the provider is registered in the container and it is configured correctly in your project configuration file.` + + // Log full error for debugging + this.#logger.error(`AwilixResolutionError: ${err.message}`, err) + + throw new Error(errMessage) + } + + const errMessage = `Unable to retrieve the caching provider with id: ${providerId}, the following error occurred: ${err.message}` + this.#logger.error(errMessage) + + throw new Error(errMessage) + } + } +} diff --git a/packages/modules/caching/src/services/index.ts b/packages/modules/caching/src/services/index.ts new file mode 100644 index 0000000000..a319f114d6 --- /dev/null +++ b/packages/modules/caching/src/services/index.ts @@ -0,0 +1,2 @@ +export { default as CachingModuleService } from "./cache-module" +export { default as CachingProviderService } from "./cache-provider" diff --git a/packages/modules/caching/src/types/index.ts b/packages/modules/caching/src/types/index.ts new file mode 100644 index 0000000000..5773f18f50 --- /dev/null +++ b/packages/modules/caching/src/types/index.ts @@ -0,0 +1,56 @@ +import type { + Constructor, + ICachingStrategy, + IEventBusModuleService, + Logger, + ModuleProviderExports, + ModuleServiceInitializeOptions, +} from "@medusajs/framework/types" +import { Modules } from "@medusajs/framework/utils" +import { default as CacheProviderService } from "../services/cache-provider" + +export const CachingDefaultProvider = "default_provider" +export const CachingIdentifiersRegistrationName = "caching_providers_identifier" + +export const CachingProviderRegistrationPrefix = "lp_" + +export type InjectedDependencies = { + cacheProviderService: CacheProviderService + hasher: (data: string) => string + logger?: Logger + strategy: Constructor + [CachingDefaultProvider]: string + [Modules.EVENT_BUS]: IEventBusModuleService +} + +export type CachingModuleOptions = Partial & { + /** + * The strategy to be used. Default to the inbuilt default strategy. + */ + // strategy?: ICachingStrategy + /** + * Time to keep data in cache (in seconds) + */ + ttl?: number + /** + * Providers to be registered + */ + providers?: { + /** + * The module provider to be registered + */ + resolve: string | ModuleProviderExports + /** + * If the provider is the default + */ + is_default?: boolean + /** + * The id of the provider + */ + id: string + /** + * key value pair of the configuration to be passed to the provider constructor + */ + options?: Record + }[] +} diff --git a/packages/modules/caching/src/utils/__tests__/parser.test.ts b/packages/modules/caching/src/utils/__tests__/parser.test.ts new file mode 100644 index 0000000000..384010530c --- /dev/null +++ b/packages/modules/caching/src/utils/__tests__/parser.test.ts @@ -0,0 +1,487 @@ +import { GraphQLSchema, buildSchema } from "graphql" +import { CacheInvalidationParser, EntityReference } from "../parser" + +describe("CacheInvalidationParser", () => { + let parser: CacheInvalidationParser + let schema: GraphQLSchema + + beforeEach(() => { + const schemaDefinition = ` + type Product { + id: ID! + title: String + description: String + collection: ProductCollection + categories: [ProductCategory!] + variants: [ProductVariant!] + created_at: String + updated_at: String + } + + type ProductCollection { + id: ID! + title: String + products: [Product!] + created_at: String + updated_at: String + } + + type ProductCategory { + id: ID! + name: String + products: [Product!] + parent: ProductCategory + children: [ProductCategory!] + created_at: String + updated_at: String + } + + type ProductVariant { + id: ID! + title: String + sku: String + product: Product! + prices: [Price!] + created_at: String + updated_at: String + } + + type Price { + id: ID! + amount: Int + currency_code: String + variant: ProductVariant! + created_at: String + updated_at: String + } + + type Order { + id: ID! + status: String + items: [OrderItem!] + customer: Customer + created_at: String + updated_at: String + } + + type OrderItem { + id: ID! + quantity: Int + order: Order! + variant: ProductVariant! + created_at: String + updated_at: String + } + + type Customer { + id: ID! + first_name: String + last_name: String + email: String + orders: [Order!] + created_at: String + updated_at: String + } + ` + + schema = buildSchema(schemaDefinition) + parser = new CacheInvalidationParser(schema, [ + // Partially populate this record ro force the test to match from both id prefix or type + // detection + { + idPrefixToEntityName: { + prod: "Product", + col: "ProductCollection", + cat: "ProductCategory", + }, + }, + ]) + }) + + describe("parseObjectForEntities", () => { + it("should identify a simple product entity", () => { + const product = { + id: "prod_123", + title: "Test Product", + description: "A test product", + } + + const entities = parser.parseObjectForEntities(product) + + expect(entities).toHaveLength(1) + expect(entities[0]).toEqual({ + type: "Product", + id: "prod_123", + isInArray: false, + }) + }) + + it("should identify nested entities in a product with collection", () => { + const product = { + id: "prod_123", + title: "Test Product", + collection: { + id: "col_456", + title: "Test Collection", + }, + } + + const entities = parser.parseObjectForEntities(product) + + expect(entities).toHaveLength(2) + expect(entities).toContainEqual({ + type: "Product", + id: "prod_123", + isInArray: false, + }) + expect(entities).toContainEqual({ + type: "ProductCollection", + id: "col_456", + isInArray: false, + }) + }) + + it("should identify entities in arrays", () => { + const product = { + id: "prod_123", + title: "Test Product", + variants: [ + { + id: "var_789", + title: "Variant 1", + sku: "SKU-001", + }, + { + id: "var_790", + title: "Variant 2", + sku: "SKU-002", + }, + ], + } + + const entities = parser.parseObjectForEntities(product) + + expect(entities).toHaveLength(3) + expect(entities).toContainEqual({ + type: "Product", + id: "prod_123", + isInArray: false, + }) + expect(entities).toContainEqual({ + type: "ProductVariant", + id: "var_789", + isInArray: true, + }) + expect(entities).toContainEqual({ + type: "ProductVariant", + id: "var_790", + isInArray: true, + }) + }) + + it("should handle deeply nested entities", () => { + const order = { + id: "order_123", + status: "completed", + items: [ + { + id: "item_456", + quantity: 2, + variant: { + id: "var_789", + title: "Variant 1", + product: { + id: "prod_123", + title: "Test Product", + collection: { + id: "col_456", + title: "Test Collection", + }, + }, + }, + }, + ], + customer: { + id: "cus_789", + email: "test@example.com", + first_name: "John", + }, + } + + const entities = parser.parseObjectForEntities(order) + + expect(entities).toHaveLength(6) + expect(entities).toContainEqual({ + type: "Order", + id: "order_123", + isInArray: false, + }) + expect(entities).toContainEqual({ + type: "OrderItem", + id: "item_456", + isInArray: true, + }) + expect(entities).toContainEqual({ + type: "ProductVariant", + id: "var_789", + isInArray: false, + }) + expect(entities).toContainEqual({ + type: "Product", + id: "prod_123", + isInArray: false, + }) + expect(entities).toContainEqual({ + type: "ProductCollection", + id: "col_456", + isInArray: false, + }) + expect(entities).toContainEqual({ + type: "Customer", + id: "cus_789", + isInArray: false, + }) + }) + + it("should return empty array for null or primitive values", () => { + expect(parser.parseObjectForEntities(null)).toEqual([]) + expect(parser.parseObjectForEntities(undefined)).toEqual([]) + expect(parser.parseObjectForEntities("string")).toEqual([]) + expect(parser.parseObjectForEntities(123)).toEqual([]) + expect(parser.parseObjectForEntities(true)).toEqual([]) + }) + + it("should ignore objects without id field", () => { + const invalidObject = { + title: "No ID Object", + description: "This object has no ID", + } + + const entities = parser.parseObjectForEntities(invalidObject) + expect(entities).toEqual([]) + }) + + it("should handle objects with partial field matches", () => { + const partialProduct = { + id: "prod_123", + title: "Test Product", + unknown_field: "Should still work", + } + + const entities = parser.parseObjectForEntities(partialProduct) + + expect(entities).toHaveLength(1) + expect(entities[0]).toEqual({ + type: "Product", + id: "prod_123", + isInArray: false, + }) + }) + }) + + describe("buildInvalidationEvents", () => { + it("should build invalidation events for a single entity", () => { + const entities: EntityReference[] = [{ type: "Product", id: "prod_123" }] + + const events = parser.buildInvalidationEvents(entities) + + expect(events).toHaveLength(1) + expect(events[0]).toMatchObject({ + entityType: "Product", + entityId: "prod_123", + relatedEntities: [], + }) + + expect(events[0].cacheKeys).toEqual(["Product:prod_123"]) + }) + + it("should build invalidation events with related entities", () => { + const entities: EntityReference[] = [ + { type: "Product", id: "prod_123" }, + { type: "ProductCollection", id: "col_456" }, + { type: "ProductVariant", id: "var_789" }, + ] + + const events = parser.buildInvalidationEvents(entities) + + expect(events).toHaveLength(3) + + const productEvent = events.find((e) => e.entityType === "Product") + expect(productEvent).toBeDefined() + expect(productEvent!.relatedEntities).toHaveLength(2) + expect(productEvent!.cacheKeys).toEqual(["Product:prod_123"]) + }) + + it("should avoid duplicate entities in events", () => { + const entities: EntityReference[] = [ + { type: "Product", id: "prod_123" }, + { type: "Product", id: "prod_123" }, // Duplicate + { type: "ProductCollection", id: "col_456" }, + ] + + const events = parser.buildInvalidationEvents(entities) + + expect(events).toHaveLength(2) // Should only have Product and ProductCollection events + expect(events.map((e) => e.entityType).sort()).toEqual([ + "Product", + "ProductCollection", + ]) + }) + + it("should generate comprehensive cache keys", () => { + const entities: EntityReference[] = [ + { type: "Product", id: "prod_123" }, + { type: "ProductCollection", id: "col_456" }, + ] + + const events = parser.buildInvalidationEvents(entities) + const productEvent = events.find((e) => e.entityType === "Product")! + + expect(productEvent.cacheKeys).toEqual(["Product:prod_123"]) + }) + }) + + describe("integration scenarios", () => { + it("should handle a complete product updated scenario", () => { + const productData = { + id: "prod_123", + title: "Updated Product Title", + collection: { + id: "col_456", + title: "Fashion Collection", + }, + categories: [ + { id: "cat_789", name: "Shirts" }, + { id: "cat_790", name: "Casual" }, + ], + variants: [ + { + id: "var_111", + title: "Size S", + prices: [{ id: "price_222", amount: 2999, currency_code: "USD" }], + }, + ], + } + + const entities = parser.parseObjectForEntities(productData) + const events = parser.buildInvalidationEvents(entities) + + // Should identify all nested entities + expect(entities).toHaveLength(6) // Product, Collection, 2 Categories, Variant, Price + + // Should created events for each entity type + expect(events).toHaveLength(6) + + // Validate cache keys for each entity type + const productEvent = events.find((e) => e.entityType === "Product")! + expect(productEvent.cacheKeys).toEqual(["Product:prod_123"]) + + const collectionEvent = events.find( + (e) => e.entityType === "ProductCollection" + )! + expect(collectionEvent.cacheKeys).toEqual(["ProductCollection:col_456"]) + + const categoryEvents = events.filter( + (e) => e.entityType === "ProductCategory" + ) + expect(categoryEvents).toHaveLength(2) + expect(categoryEvents[0].cacheKeys).toEqual([ + "ProductCategory:cat_789", + "ProductCategory:list:*", + ]) + expect(categoryEvents[1].cacheKeys).toEqual([ + "ProductCategory:cat_790", + "ProductCategory:list:*", + ]) + + const variantEvent = events.find( + (e) => e.entityType === "ProductVariant" + )! + expect(variantEvent.cacheKeys).toEqual([ + "ProductVariant:var_111", + "ProductVariant:list:*", + ]) + + const priceEvent = events.find((e) => e.entityType === "Price")! + expect(priceEvent.cacheKeys).toEqual(["Price:price_222", "Price:list:*"]) + }) + + it("should handle order with customer and items scenario", () => { + const orderData = { + id: "order_123", + status: "completed", + customer: { + id: "cus_456", + email: "customer@example.com", + }, + items: [ + { + id: "item_789", + quantity: 2, + variant: { + id: "var_111", + sku: "SHIRT-S-BLUE", + }, + }, + ], + } + + const entities = parser.parseObjectForEntities(orderData) + const events = parser.buildInvalidationEvents(entities) + + expect(entities).toHaveLength(4) // Order, Customer, OrderItem, ProductVariant + expect(events).toHaveLength(4) + + // Validate cache keys for each entity type + const orderEvent = events.find((e) => e.entityType === "Order")! + expect(orderEvent.cacheKeys).toEqual(["Order:order_123"]) + + const customerEvent = events.find((e) => e.entityType === "Customer")! + expect(customerEvent.cacheKeys).toEqual(["Customer:cus_456"]) + + const itemEvent = events.find((e) => e.entityType === "OrderItem")! + expect(itemEvent.cacheKeys).toEqual([ + "OrderItem:item_789", + "OrderItem:list:*", + ]) + + const variantEvent = events.find( + (e) => e.entityType === "ProductVariant" + )! + expect(variantEvent.cacheKeys).toEqual(["ProductVariant:var_111"]) + }) + + it("should include simplified cache keys for created operation", () => { + const entities: EntityReference[] = [{ type: "Product", id: "prod_123" }] + + const events = parser.buildInvalidationEvents(entities, "created") + + const productEvent = events[0] + expect(productEvent.cacheKeys).toEqual([ + "Product:prod_123", + "Product:list:*", + ]) + }) + + it("should include simplified cache keys for deleted operation", () => { + const entities: EntityReference[] = [{ type: "Product", id: "prod_123" }] + + const events = parser.buildInvalidationEvents(entities, "deleted") + + const productEvent = events[0] + expect(productEvent.cacheKeys).toEqual([ + "Product:prod_123", + "Product:list:*", + ]) + }) + + it("should include simplified cache keys for updated operation", () => { + const entities: EntityReference[] = [{ type: "Product", id: "prod_123" }] + + const events = parser.buildInvalidationEvents(entities, "updated") + + const productEvent = events[0] + expect(productEvent.cacheKeys).toEqual(["Product:prod_123"]) + }) + }) +}) diff --git a/packages/modules/caching/src/utils/parser.ts b/packages/modules/caching/src/utils/parser.ts new file mode 100644 index 0000000000..3c16c7cd8d --- /dev/null +++ b/packages/modules/caching/src/utils/parser.ts @@ -0,0 +1,242 @@ +import { ModuleJoinerConfig } from "@medusajs/framework/types" +import { isObject } from "@medusajs/framework/utils" +import { + GraphQLObjectType, + GraphQLSchema, + isListType, + isNonNullType, + isObjectType, +} from "graphql" + +export interface EntityReference { + type: string + id: string | number + field?: string + isInArray?: boolean +} + +export interface InvalidationEvent { + entityType: string + entityId: string | number + relatedEntities: EntityReference[] + cacheKeys: string[] +} + +export class CacheInvalidationParser { + private typeMap: Map + private idPrefixToEntityName: Record + + constructor(schema: GraphQLSchema, joinerConfigs: ModuleJoinerConfig[]) { + this.typeMap = new Map() + + // Build type map for quick lookups + const schemaTypeMap = schema.getTypeMap() + Object.keys(schemaTypeMap).forEach((typeName) => { + const type = schemaTypeMap[typeName] + if (isObjectType(type) && !typeName.startsWith("__")) { + this.typeMap.set(typeName, type) + } + }) + + this.idPrefixToEntityName = joinerConfigs.reduce((acc, joinerConfig) => { + if (joinerConfig.idPrefixToEntityName) { + Object.entries(joinerConfig.idPrefixToEntityName).forEach( + ([idPrefix, entityName]) => { + acc[idPrefix] = entityName + } + ) + } + return acc + }, {} as Record) + } + + /** + * Parse an object to identify entities and their relationships + */ + parseObjectForEntities( + obj: any, + parentType?: string, + isInArray: boolean = false + ): EntityReference[] { + const entities: EntityReference[] = [] + + if (!obj || typeof obj !== "object") { + return entities + } + + // Check if this object matches any known GraphQL types + const detectedType = this.detectEntityType(obj, parentType) + if (detectedType && obj.id) { + entities.push({ + type: detectedType, + id: obj.id, + isInArray, + }) + } + + // Recursively parse nested objects and arrays + Object.keys(obj).forEach((key) => { + const value = obj[key] + + if (Array.isArray(value)) { + value.forEach((item) => { + entities.push( + ...this.parseObjectForEntities( + item, + this.getRelationshipType(detectedType, key), + true + ) + ) + }) + } else if (isObject(value)) { + entities.push( + ...this.parseObjectForEntities( + value, + this.getRelationshipType(detectedType, key), + false + ) + ) + } + }) + + return entities + } + + /** + * Detect entity type based on object structure and GraphQL type map + */ + private detectEntityType(obj: any, suggestedType?: string): string | null { + if (obj.id) { + const idParts = obj.id.split("_") + if (idParts.length > 1 && this.idPrefixToEntityName[idParts[0]]) { + return this.idPrefixToEntityName[idParts[0]] + } + } + + if (suggestedType && this.typeMap.has(suggestedType)) { + const type = this.typeMap.get(suggestedType)! + if (this.objectMatchesType(obj, type)) { + return suggestedType + } + } + + // Try to match against all known types + for (const [typeName, type] of this.typeMap) { + if (this.objectMatchesType(obj, type)) { + return typeName + } + } + + return null + } + + /** + * Check if object structure matches GraphQL type fields + */ + private objectMatchesType(obj: any, type: GraphQLObjectType): boolean { + const fields = type.getFields() + const objKeys = Object.keys(obj) + + // Must have id field for entities + if (!obj.id || !fields.id) { + return false + } + + // Check if at least 50% of non-null object fields match type fields + const matchingFields = objKeys.filter((key) => fields[key]).length + return matchingFields >= Math.max(1, objKeys.length * 0.5) + } + + /** + * Get the expected type for a relationship field + */ + private getRelationshipType( + parentType: string | null, + fieldName: string + ): string | undefined { + if (!parentType || !this.typeMap.has(parentType)) { + return undefined + } + + const type = this.typeMap.get(parentType)! + const field = type.getFields()[fieldName] + + if (!field) { + return undefined + } + + let fieldType = field.type + + // Unwrap NonNull and List wrappers + if (isNonNullType(fieldType)) { + fieldType = fieldType.ofType + } + if (isListType(fieldType)) { + fieldType = fieldType.ofType + } + if (isNonNullType(fieldType)) { + fieldType = fieldType.ofType + } + + if (isObjectType(fieldType)) { + return fieldType.name + } + + return undefined + } + + /** + * Build invalidation events based on parsed entities + */ + buildInvalidationEvents( + entities: EntityReference[], + operation: "created" | "updated" | "deleted" = "updated" + ): InvalidationEvent[] { + const events: InvalidationEvent[] = [] + const processedEntities = new Set() + + entities.forEach((entity) => { + const entityKey = `${entity.type}:${entity.id}` + + if (processedEntities.has(entityKey)) { + return + } + processedEntities.add(entityKey) + + const relatedEntities = entities.filter( + (e) => e.type !== entity.type || e.id !== entity.id + ) + + const affectedKeys = this.buildAffectedCacheKeys(entity, operation) + + events.push({ + entityType: entity.type, + entityId: entity.id, + relatedEntities, + cacheKeys: affectedKeys, + }) + }) + + return events + } + + /** + * Build list of cache keys that should be invalidated + */ + private buildAffectedCacheKeys( + entity: EntityReference, + operation: "created" | "updated" | "deleted" = "updated" + ): string[] { + const keys = new Set() + + keys.add(`${entity.type}:${entity.id}`) + + // Add list key only if entity was found in an array context or if an event of type created or + // deleted is triggered + if (entity.isInArray || ["created", "deleted"].includes(operation)) { + keys.add(`${entity.type}:list:*`) + } + + return Array.from(keys) + } +} diff --git a/packages/modules/caching/src/utils/strategy.ts b/packages/modules/caching/src/utils/strategy.ts new file mode 100644 index 0000000000..b0ae97a573 --- /dev/null +++ b/packages/modules/caching/src/utils/strategy.ts @@ -0,0 +1,133 @@ +import type { + Event, + ICachingModuleService, + ICachingStrategy, + ModuleJoinerConfig, +} from "@medusajs/framework/types" +import { + type GraphQLSchema, + Modules, + toCamelCase, + upperCaseFirst, +} from "@medusajs/framework/utils" +import { type CachingModuleService } from "@services" +import type { InjectedDependencies } from "@types" +import stringify from "fast-json-stable-stringify" +import { CacheInvalidationParser, EntityReference } from "./parser" + +export class DefaultCacheStrategy implements ICachingStrategy { + #cacheInvalidationParser: CacheInvalidationParser + #cacheModule: ICachingModuleService + #container: InjectedDependencies + #hasher: (data: string) => string + + constructor( + container: InjectedDependencies, + cacheModule: CachingModuleService + ) { + this.#cacheModule = cacheModule + this.#container = container + this.#hasher = container.hasher + } + + objectHash(input: any): string { + const str = stringify(input) + return this.#hasher(str) + } + + async onApplicationStart( + schema: GraphQLSchema, + joinerConfigs: ModuleJoinerConfig[] + ) { + this.#cacheInvalidationParser = new CacheInvalidationParser( + schema, + joinerConfigs + ) + + const eventBus = this.#container[Modules.EVENT_BUS] + + const handleEvent = async (data: Event) => { + try { + // We dont have to await anything here and the rest can be done in the background + return + } finally { + const eventName = data.name + const operation = eventName.split(".").pop() as + | "created" + | "updated" + | "deleted" + const entityType = eventName.split(".").slice(-2).shift()! + + const eventData = data.data as + | { id: string | string[] } + | { id: string | string[] }[] + + const normalizedEventData = Array.isArray(eventData) + ? eventData + : [eventData] + + const tags: string[] = [] + for (const item of normalizedEventData) { + const ids = Array.isArray(item.id) ? item.id : [item.id] + + for (const id of ids) { + const entityReference: EntityReference = { + type: upperCaseFirst(toCamelCase(entityType)), + id, + } + + const tags_ = await this.computeTags(item, { + entities: [entityReference], + operation, + }) + tags.push(...tags_) + } + } + + void this.#cacheModule.clear({ + tags, + options: { autoInvalidate: true }, + }) + } + } + + eventBus.subscribe("*", handleEvent) + eventBus.addInterceptor?.(handleEvent) + } + + async computeKey(input: object) { + return this.objectHash(input) + } + + async computeTags( + input: object, + options?: { + entities?: EntityReference[] + operation?: "created" | "updated" | "deleted" + } + ): Promise { + // Parse the input object to identify entities + const entities_ = + options?.entities || + this.#cacheInvalidationParser.parseObjectForEntities(input) + + if (entities_.length === 0) { + return [] + } + + // Build invalidation events to get comprehensive cache keys + const events = this.#cacheInvalidationParser.buildInvalidationEvents( + entities_, + options?.operation + ) + + // Collect all unique cache keys from all events as tags + const tags = new Set() + + events.forEach((event) => { + event.cacheKeys.forEach((key) => tags.add(key)) + }) + + return Array.from(tags) + } +} diff --git a/packages/modules/caching/tsconfig.json b/packages/modules/caching/tsconfig.json new file mode 100644 index 0000000000..389d61c6a3 --- /dev/null +++ b/packages/modules/caching/tsconfig.json @@ -0,0 +1,10 @@ +{ + "extends": "../../../_tsconfig.base.json", + "compilerOptions": { + "paths": { + "@services": ["./src/services"], + "@types": ["./src/types"], + "@utils": ["./src/utils"] + } + } +} diff --git a/packages/modules/event-bus-local/src/services/event-bus-local.ts b/packages/modules/event-bus-local/src/services/event-bus-local.ts index a4aeb2f56c..75364b9fa6 100644 --- a/packages/modules/event-bus-local/src/services/event-bus-local.ts +++ b/packages/modules/event-bus-local/src/services/event-bus-local.ts @@ -58,8 +58,9 @@ export default class LocalEventBusService extends AbstractEventBusModuleService const eventListenersCount = this.eventEmitter_.listenerCount( eventData.name ) + const startSubscribersCount = this.eventEmitter_.listenerCount("*") - if (eventListenersCount === 0) { + if (eventListenersCount === 0 && startSubscribersCount === 0) { continue } @@ -84,6 +85,7 @@ export default class LocalEventBusService extends AbstractEventBusModuleService private async groupOrEmitEvent(eventData: Message) { const { options, ...eventBody } = eventData const eventGroupId = eventBody.metadata?.eventGroupId + const hasStarSubscriber = this.eventEmitter_.listenerCount("*") > 0 if (eventGroupId) { await this.groupEvent(eventGroupId, eventData) @@ -91,9 +93,15 @@ export default class LocalEventBusService extends AbstractEventBusModuleService const options_ = eventData.options as { delay: number } const delay = (ms?: number) => (ms ? setTimeout(ms) : Promise.resolve()) - delay(options_?.delay).then(() => + delay(options_?.delay).then(async () => { + // Call interceptors before emitting + void this.callInterceptors(eventData, { isGrouped: false }) + this.eventEmitter_.emit(eventData.name, eventBody) - ) + if (hasStarSubscriber) { + this.eventEmitter_.emit("*", eventBody) + } + }) } } @@ -112,6 +120,7 @@ export default class LocalEventBusService extends AbstractEventBusModuleService async releaseGroupedEvents(eventGroupId: string) { let groupedEvents = this.groupedEventsMap_.get(eventGroupId) || [] groupedEvents = JSON.parse(JSON.stringify(groupedEvents)) + const hasStarSubscriber = this.eventEmitter_.listenerCount("*") > 0 for (const event of groupedEvents) { const { options, ...eventBody } = event @@ -119,9 +128,15 @@ export default class LocalEventBusService extends AbstractEventBusModuleService const options_ = options as { delay: number } const delay = (ms?: number) => (ms ? setTimeout(ms) : Promise.resolve()) - delay(options_?.delay).then(() => + delay(options_?.delay).then(async () => { + // Call interceptors before emitting grouped events + void this.callInterceptors(event, { isGrouped: true, eventGroupId }) + this.eventEmitter_.emit(event.name, eventBody) - ) + if (hasStarSubscriber) { + this.eventEmitter_.emit("*", eventBody) + } + }) } await this.clearGroupedEvents(eventGroupId) diff --git a/packages/modules/event-bus-redis/src/services/event-bus-redis.ts b/packages/modules/event-bus-redis/src/services/event-bus-redis.ts index 958bec654b..98397b72c4 100644 --- a/packages/modules/event-bus-redis/src/services/event-bus-redis.ts +++ b/packages/modules/event-bus-redis/src/services/event-bus-redis.ts @@ -158,6 +158,10 @@ export default class RedisEventBusService extends AbstractEventBusModuleService const promises: Promise[] = [] if (eventsToEmit.length) { + eventsToEmit.map((eventData) => + this.callInterceptors(eventData, { isGrouped: false }) + ) + const emitData = this.buildEvents(eventsToEmit, options) promises.push(this.queue_.addBulk(emitData)) @@ -213,6 +217,21 @@ export default class RedisEventBusService extends AbstractEventBusModuleService async releaseGroupedEvents(eventGroupId: string) { const groupedEvents = await this.getGroupedEvents(eventGroupId) + // Call interceptors before emitting grouped events + // Extract the original messages from the job data structure + groupedEvents.map((jobData) => { + // Reconstruct the message from the job data + const message = { + name: jobData.name, + data: jobData.data, + metadata: jobData.data.metadata, + } + this.callInterceptors(message as any, { + isGrouped: true, + eventGroupId, + }) + }) + await this.queue_.addBulk(groupedEvents) await this.clearGroupedEvents(eventGroupId) diff --git a/packages/modules/locking/mikro-orm.config.dev.ts b/packages/modules/locking/mikro-orm.config.dev.ts deleted file mode 100644 index e30d36a644..0000000000 --- a/packages/modules/locking/mikro-orm.config.dev.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { defineMikroOrmCliConfig, Modules } from "@medusajs/framework/utils" -import * as entities from "./src/models" - -export default defineMikroOrmCliConfig(Modules.LOCKING, { - entities: Object.values(entities), -}) diff --git a/packages/modules/providers/caching-redis/package.json b/packages/modules/providers/caching-redis/package.json new file mode 100644 index 0000000000..580acd1e11 --- /dev/null +++ b/packages/modules/providers/caching-redis/package.json @@ -0,0 +1,49 @@ +{ + "name": "@medusajs/caching-redis", + "version": "2.10.3", + "description": "Redis Caching for Medusa", + "main": "dist/index.js", + "repository": { + "type": "git", + "url": "https://github.com/medusajs/medusa", + "directory": "packages/modules/providers/caching-redis" + }, + "files": [ + "dist", + "!dist/**/__tests__", + "!dist/**/__mocks__", + "!dist/**/__fixtures__" + ], + "engines": { + "node": ">=20" + }, + "author": "Medusa", + "license": "MIT", + "devDependencies": { + "@medusajs/framework": "2.10.3", + "@swc/core": "^1.7.28", + "@swc/jest": "^0.2.36", + "jest": "^29.7.0", + "rimraf": "^5.0.1", + "typescript": "^5.6.2" + }, + "peerDependencies": { + "@medusajs/framework": "2.10.3" + }, + "dependencies": { + "ioredis": "^5.4.1", + "xxhash-wasm": "^1.1.0" + }, + "scripts": { + "watch": "tsc --build --watch", + "watch:test": "tsc --build tsconfig.spec.json --watch", + "resolve:aliases": "tsc --showConfig -p tsconfig.json > tsconfig.resolved.json && tsc-alias -p tsconfig.resolved.json && rimraf tsconfig.resolved.json", + "build": "rimraf dist && tsc --build && npm run resolve:aliases", + "test": "jest --passWithNoTests src", + "test:integration": "jest --forceExit --passWithNoTests" + }, + "keywords": [ + "medusa-providers", + "medusa-providers-cache" + ] +} diff --git a/packages/modules/providers/caching-redis/src/index.ts b/packages/modules/providers/caching-redis/src/index.ts new file mode 100644 index 0000000000..7618a64c71 --- /dev/null +++ b/packages/modules/providers/caching-redis/src/index.ts @@ -0,0 +1,12 @@ +import { ModuleProvider, Modules } from "@medusajs/framework/utils" +import Connection from "./loaders/connection" +import Hash from "./loaders/hash" +import { RedisCachingProvider } from "./services/redis-cache" + +const services = [RedisCachingProvider] +const loaders = [Connection, Hash] + +export default ModuleProvider(Modules.CACHING, { + services, + loaders, +}) diff --git a/packages/modules/providers/caching-redis/src/loaders/connection.ts b/packages/modules/providers/caching-redis/src/loaders/connection.ts new file mode 100644 index 0000000000..5e1590a4db --- /dev/null +++ b/packages/modules/providers/caching-redis/src/loaders/connection.ts @@ -0,0 +1,66 @@ +import type { + InternalModuleDeclaration, + LoaderOptions, + ModulesSdkTypes, +} from "@medusajs/framework/types" +import { RedisCacheModuleOptions } from "@types" +import Redis from "ioredis" + +export default async ( + { + container, + logger, + options, + }: LoaderOptions< + ( + | ModulesSdkTypes.ModuleServiceInitializeOptions + | ModulesSdkTypes.ModuleServiceInitializeCustomDataLayerOptions + ) & { logger?: any } + >, + moduleDeclaration?: InternalModuleDeclaration +): Promise => { + const logger_ = logger || console + + const moduleOptions = (options ?? + moduleDeclaration?.options ?? + {}) as RedisCacheModuleOptions & { + redisUrl?: string + } + + if (!moduleOptions.redisUrl) { + throw new Error("[caching-redis] redisUrl is required") + } + + let redisClient: Redis + + try { + redisClient = new Redis(moduleOptions.redisUrl!, { + connectTimeout: 10000, + lazyConnect: true, + retryDelayOnFailover: 100, + connectionName: "medusa-cache-redis", + ...moduleOptions, + }) + + // Test connection + await redisClient.ping() + logger_.info("Redis cache connection established successfully") + } catch (error) { + logger_.error(`Failed to connect to Redis cache: ${error.message}`) + redisClient = new Redis(moduleOptions.redisUrl!, { + connectTimeout: 10000, + lazyConnect: true, + retryDelayOnFailover: 100, + ...moduleOptions, + }) + } + + container.register({ + redisClient: { + resolve: () => redisClient, + }, + prefix: { + resolve: () => moduleOptions.prefix ?? "mc:", + }, + }) +} diff --git a/packages/modules/providers/caching-redis/src/loaders/hash.ts b/packages/modules/providers/caching-redis/src/loaders/hash.ts new file mode 100644 index 0000000000..194f142c97 --- /dev/null +++ b/packages/modules/providers/caching-redis/src/loaders/hash.ts @@ -0,0 +1,8 @@ +import { asValue } from "awilix" + +export default async ({ container }) => { + const xxhashhWasm = await import("xxhash-wasm") + const { h32ToString } = await xxhashhWasm.default() + + container.register("hasher", asValue(h32ToString)) +} diff --git a/packages/modules/providers/caching-redis/src/services/redis-cache.ts b/packages/modules/providers/caching-redis/src/services/redis-cache.ts new file mode 100644 index 0000000000..3c0959dbf0 --- /dev/null +++ b/packages/modules/providers/caching-redis/src/services/redis-cache.ts @@ -0,0 +1,683 @@ +import { RedisCacheModuleOptions } from "@types" +import { Redis } from "ioredis" +import { createGunzip, createGzip } from "zlib" + +export class RedisCachingProvider { + static identifier = "cache-redis" + + protected redisClient: Redis + protected keyNamePrefix: string + protected defaultTTL: number + protected compressionThreshold: number + protected hasher: (key: string) => string + + constructor( + { + redisClient, + prefix, + hasher, + }: { redisClient: Redis; prefix: string; hasher: (key: string) => string }, + options?: RedisCacheModuleOptions + ) { + this.redisClient = redisClient + this.keyNamePrefix = prefix + this.defaultTTL = options?.ttl ?? 3600 // 1 hour default + this.compressionThreshold = options?.compressionThreshold ?? 2048 // 2KB default + this.hasher = hasher + } + + #getKeyName(key: string): string { + return `${this.keyNamePrefix}${key}` + } + + #getTagKey( + tag: string, + { isHashed = false }: { isHashed?: boolean } = {} + ): string { + return `${this.keyNamePrefix}tag:${isHashed ? tag : this.hasher(tag)}` + } + + #getTagsKey(key: string): string { + return `${this.keyNamePrefix}tags:${key}` + } + + #getTagDictionaryKey(): string { + return `${this.keyNamePrefix}tag:dictionary` + } + + #getTagNextIdKey(): string { + return `${this.keyNamePrefix}tag:next_id` + } + + #getTagRefCountKey(): string { + return `${this.keyNamePrefix}tag:refs` + } + + #getTagReverseDictionaryKey(): string { + return `${this.keyNamePrefix}tag:reverse_dict` + } + + async #internTags(tags: string[]): Promise { + const pipeline = this.redisClient.pipeline() + const dictionaryKey = this.#getTagDictionaryKey() + + const hashedTags = tags.map((tag) => this.hasher(tag)) + + // Get existing tag IDs + hashedTags.forEach((tag) => { + pipeline.hget(dictionaryKey, tag) + }) + + const results = await pipeline.exec() + const tagIds: number[] = [] + const newTags: string[] = [] + + for (let i = 0; i < hashedTags.length; i++) { + const result = results?.[i] + if (result && result[1]) { + tagIds[i] = parseInt(result[1] as string) + } else { + const hashedTag = hashedTags[i] + newTags.push(hashedTag) + tagIds[i] = -1 // Placeholder for new tags + } + } + + // Create IDs for new tags + if (newTags.length) { + const nextIdKey = this.#getTagNextIdKey() + const reverseDictKey = this.#getTagReverseDictionaryKey() + const refCountKey = this.#getTagRefCountKey() + const startId = await this.redisClient.incrby(nextIdKey, newTags.length) + + const batchPipeline = this.redisClient.pipeline() + newTags.forEach((tag, index) => { + const newId = startId - newTags.length + index + 1 + + // Store in both forward and reverse dictionaries + batchPipeline.hset(dictionaryKey, tag, newId.toString()) + batchPipeline.hset(reverseDictKey, newId.toString(), tag) + + // Update the tagIds array + const originalIndex = hashedTags.indexOf(tag) + tagIds[originalIndex] = newId + }) + + // Add reference count increments to the same pipeline + tagIds.forEach((id) => { + if (id !== -1) { + batchPipeline.hincrby(refCountKey, id.toString(), 1) + } + }) + + await batchPipeline.exec() + } else { + // Only increment reference count for existing tags + const refCountKey = this.#getTagRefCountKey() + const refPipeline = this.redisClient.pipeline() + tagIds.forEach((id) => { + refPipeline.hincrby(refCountKey, id.toString(), 1) + }) + await refPipeline.exec() + } + + return tagIds + } + + async #resolveTagIds(tagIds: number[]): Promise { + if (tagIds.length === 0) return [] + + const reverseDictKey = this.#getTagReverseDictionaryKey() + const pipeline = this.redisClient.pipeline() + + tagIds.forEach((id) => { + pipeline.hget(reverseDictKey, id.toString()) + }) + + const results = await pipeline.exec() + return results?.map((result) => result?.[1] as string).filter(Boolean) || [] + } + + async #decrementTagRefs(tagIds: number[]): Promise { + if (tagIds.length === 0) return + + const refCountKey = this.#getTagRefCountKey() + const dictionaryKey = this.#getTagDictionaryKey() + + // Decrement reference counts and collect tags with zero refs + const pipeline = this.redisClient.pipeline() + tagIds.forEach((id) => { + pipeline.hincrby(refCountKey, id.toString(), -1) + }) + + const results = await pipeline.exec() + const tagsToCleanup: number[] = [] + + // Find tags that now have zero references + results?.forEach((result, index) => { + if (result && result[1] === 0) { + tagsToCleanup.push(tagIds[index]) + } + }) + + // Clean up tags with zero references + if (tagsToCleanup.length) { + const cleanupPipeline = this.redisClient.pipeline() + const reverseDictKey = this.#getTagReverseDictionaryKey() + + // Get tag names before deleting them + const tagNames = await this.#resolveTagIds(tagsToCleanup) + + tagsToCleanup.forEach((id, index) => { + const idStr = id.toString() + + // Remove from reference count hash + cleanupPipeline.hdel(refCountKey, idStr) + // Remove from reverse dictionary + cleanupPipeline.hdel(reverseDictKey, idStr) + // Remove from forward dictionary + if (tagNames[index]) { + cleanupPipeline.hdel(dictionaryKey, tagNames[index]) + } + }) + + await cleanupPipeline.exec() + } + } + + async #compressData(data: string): Promise { + if (data.length <= this.compressionThreshold) { + const buffer = Buffer.from(data, "utf8") + const prefix = Buffer.from([0]) // 0 = uncompressed + return Buffer.concat([prefix, buffer]) + } + + return new Promise((resolve, reject) => { + const chunks: Buffer[] = [] + const gzip = createGzip() + + gzip.on("data", (chunk) => chunks.push(chunk)) + gzip.on("end", () => { + const compressedBuffer = Buffer.concat(chunks) + const prefix = Buffer.from([1]) // 1 = compressed + resolve(Buffer.concat([prefix, compressedBuffer])) + }) + gzip.on("error", (error) => { + const buffer = Buffer.from(data, "utf8") + const prefix = Buffer.from([0]) + resolve(Buffer.concat([prefix, buffer])) + }) + + gzip.write(data, "utf8") + gzip.end() + }) + } + + async #decompressData(buffer: Buffer): Promise { + if (buffer.length === 0) { + return "" + } + + const formatByte = buffer[0] + const dataBuffer = buffer.subarray(1) + + if (formatByte === 0) { + // Uncompressed + return dataBuffer.toString("utf8") + } + + if (formatByte === 1) { + // Compressed with gzip + return new Promise((resolve, reject) => { + const chunks: Buffer[] = [] + const gunzip = createGunzip() + + gunzip.on("data", (chunk) => chunks.push(chunk)) + gunzip.on("end", () => { + const decompressed = Buffer.concat(chunks).toString("utf8") + resolve(decompressed) + }) + gunzip.on("error", (error) => { + // Fallback: return as-is if decompression fails + resolve(dataBuffer.toString("utf8")) + }) + + gunzip.write(dataBuffer) + gunzip.end() + }) + } + + // Unknown format, return as UTF-8 + return buffer.toString("utf8") + } + + async get({ key, tags }: { key?: string; tags?: string[] }): Promise { + if (key) { + const keyName = this.#getKeyName(key) + const buffer = await this.redisClient.hgetBuffer(keyName, "data") + if (!buffer) { + return null + } + + const finalData = await this.#decompressData(buffer) + return JSON.parse(finalData) + } + + if (tags?.length) { + // Get all keys associated with the tags + const pipeline = this.redisClient.pipeline() + tags.forEach((tag) => { + const tagKey = this.#getTagKey(tag) + pipeline.smembers(tagKey) + }) + + const tagResults = await pipeline.exec() + const allKeys = new Set() + + tagResults?.forEach((result, index) => { + if (result && result[1]) { + ;(result[1] as string[]).forEach((key) => allKeys.add(key)) + } + }) + + if (allKeys.size === 0) { + return [] + } + + // Get all hash data for the keys + const valuePipeline = this.redisClient.pipeline() + Array.from(allKeys).forEach((key) => { + valuePipeline.hgetBuffer(key, "data") + }) + + const valueResults = await valuePipeline.exec() + const results: any[] = [] + + const decompressionPromises = (valueResults || []).map(async (result) => { + if (result && result[1]) { + const buffer = result[1] as Buffer + try { + const finalData = await this.#decompressData(buffer) + return JSON.parse(finalData) + } catch (e) { + // If JSON parsing fails, skip this entry (corrupted data) + console.warn(`Skipping corrupted cache entry: ${e.message}`) + return null + } + } + return null + }) + + const decompressionResults = await Promise.all(decompressionPromises) + results.push(...decompressionResults.filter(Boolean)) + + return results + } + + return null + } + + async set({ + key, + data, + ttl, + tags, + options, + }: { + key: string + data: object + ttl?: number + tags?: string[] + options?: { + autoInvalidate?: boolean + } + }): Promise { + const keyName = this.#getKeyName(key) + const serializedData = JSON.stringify(data) + const effectiveTTL = ttl ?? this.defaultTTL + + const finalData = await this.#compressData(serializedData) + + let tagIds: number[] = [] + if (tags?.length) { + tagIds = await this.#internTags(tags) + } + + const setPipeline = this.redisClient.pipeline() + + // Main data with conditional operations + setPipeline.hsetnx(keyName, "data", finalData) + if (options && Object.keys(options).length) { + setPipeline.hset(keyName, "options", JSON.stringify(options)) + } + if (effectiveTTL) { + setPipeline.expire(keyName, effectiveTTL) + } + + // Store tag IDs if present + if (tags?.length && tagIds.length) { + const tagsKey = this.#getTagsKey(key) + const buffer = Buffer.alloc(tagIds.length * 4) + tagIds.forEach((id, index) => { + buffer.writeUInt32LE(id, index * 4) + }) + + if (effectiveTTL) { + setPipeline.set(tagsKey, buffer, "EX", effectiveTTL + 60, "NX") + } else { + setPipeline.setnx(tagsKey, buffer) + } + + // Add tag operations to the same pipeline + tags.forEach((tag) => { + const tagKey = this.#getTagKey(tag) + setPipeline.sadd(tagKey, keyName) + if (effectiveTTL) { + setPipeline.expire(tagKey, effectiveTTL + 60) + } + }) + } + + await setPipeline.exec() + } + + async clear({ + key, + tags, + options, + }: { + key?: string + tags?: string[] + options?: { + autoInvalidate?: boolean + } + }): Promise { + if (key) { + const keyName = this.#getKeyName(key) + const tagsKey = this.#getTagsKey(key) + + const clearPipeline = this.redisClient.pipeline() + + // Get tags for cleanup and delete main key in same pipeline + clearPipeline.getBuffer(tagsKey) + clearPipeline.unlink(keyName) + + const results = await clearPipeline.exec() + const tagsBuffer = results?.[0]?.[1] as Buffer + + if (tagsBuffer?.length) { + try { + // Binary format: array of 32-bit integers + const tagIds: number[] = [] + for (let i = 0; i < tagsBuffer.length; i += 4) { + tagIds.push(tagsBuffer.readUInt32LE(i)) + } + + if (tagIds.length) { + const entryTags = await this.#resolveTagIds(tagIds) + + const tagCleanupPipeline = this.redisClient.pipeline() + entryTags.forEach((tag) => { + const tagKey = this.#getTagKey(tag, { isHashed: true }) + tagCleanupPipeline.srem(tagKey, keyName) + }) + tagCleanupPipeline.unlink(tagsKey) + await tagCleanupPipeline.exec() + + // Decrement reference counts and cleanup unused tags + await this.#decrementTagRefs(tagIds) + } + } catch (e) { + // noop - corrupted tag data, skip cleanup + } + } + + return + } + + if (tags?.length) { + // Handle wildcard tag to clear all cache data + if (tags.includes("*")) { + await this.flush() + return + } + + // Get all keys associated with the tags + const pipeline = this.redisClient.pipeline() + tags.forEach((tag) => { + const tagKey = this.#getTagKey(tag) + pipeline.smembers(tagKey) + }) + + const tagResults = await pipeline.exec() + + const allKeys = new Set() + + tagResults?.forEach((result) => { + if (result && result[1]) { + ;(result[1] as string[]).forEach((key) => allKeys.add(key)) + } + }) + + if (allKeys.size) { + // If no options provided (user explicit call), clear everything + if (!options) { + const deletePipeline = this.redisClient.pipeline() + + // Delete main keys and options + Array.from(allKeys).forEach((key) => { + deletePipeline.unlink(key) + }) + + // Clean up tag references for each key + const tagDataPromises = Array.from(allKeys).map(async (key) => { + const keyWithoutPrefix = key.replace(this.keyNamePrefix, "") + const tagsKey = this.#getTagsKey(keyWithoutPrefix) + const tagsData = await this.redisClient.getBuffer(tagsKey) + return { key, tagsKey, tagsData } + }) + + const tagResults = await Promise.all(tagDataPromises) + + // Build single pipeline for all tag cleanup operations + const tagCleanupPipeline = this.redisClient.pipeline() + const cleanupPromises = tagResults.map( + async ({ key, tagsKey, tagsData }) => { + if (tagsData) { + try { + // Binary format: array of 32-bit integers + const tagIds: number[] = [] + for (let i = 0; i < tagsData.length; i += 4) { + tagIds.push(tagsData.readUInt32LE(i)) + } + + if (tagIds.length) { + const entryTags = await this.#resolveTagIds(tagIds) + entryTags.forEach((tag) => { + const tagKey = this.#getTagKey(tag, { isHashed: true }) + tagCleanupPipeline.srem(tagKey, key) + }) + tagCleanupPipeline.unlink(tagsKey) + + // Decrement reference counts and cleanup unused tags + await this.#decrementTagRefs(tagIds) + } + } catch (e) { + // noop + } + } + } + ) + + await Promise.all(cleanupPromises) + await tagCleanupPipeline.exec() + await deletePipeline.exec() + + // Clean up empty tag sets + const allTagKeys = await this.redisClient.keys( + `${this.keyNamePrefix}tag:*` + ) + if (allTagKeys.length) { + const cardinalityPipeline = this.redisClient.pipeline() + allTagKeys.forEach((tagKey) => { + cardinalityPipeline.scard(tagKey) + }) + + const cardinalityResults = await cardinalityPipeline.exec() + + // Delete empty tag keys + const emptyTagPipeline = this.redisClient.pipeline() + cardinalityResults?.forEach((result, index) => { + if (result && result[1] === 0) { + emptyTagPipeline.unlink(allTagKeys[index]) + } + }) + + await emptyTagPipeline.exec() + } + + return + } + + // If autoInvalidate is true (strategy call), only clear entries with autoInvalidate=true (default) + if (options.autoInvalidate === true) { + const optionsPipeline = this.redisClient.pipeline() + + Array.from(allKeys).forEach((key) => { + optionsPipeline.hget(key, "options") + }) + + const optionsResults = await optionsPipeline.exec() + const keysToDelete: string[] = [] + + Array.from(allKeys).forEach((key, index) => { + const optionsResult = optionsResults?.[index] + + if (optionsResult && optionsResult[1]) { + try { + const entryOptions = JSON.parse(optionsResult[1] as string) + + // Delete if entry has autoInvalidate=true or no setting (default true) + const shouldAutoInvalidate = entryOptions.autoInvalidate ?? true + + if (shouldAutoInvalidate) { + keysToDelete.push(key) + } + } catch (e) { + // If can't parse options, assume it's safe to delete (default true) + keysToDelete.push(key) + } + } else { + // No options stored, default to true + keysToDelete.push(key) + } + }) + + if (keysToDelete.length) { + const deletePipeline = this.redisClient.pipeline() + + keysToDelete.forEach((key) => { + deletePipeline.unlink(key) + }) + + // Clean up tag references for each key to delete + const tagDataPromises = keysToDelete.map(async (key) => { + const keyWithoutPrefix = key.replace(this.keyNamePrefix, "") + const tagsKey = this.#getTagsKey(keyWithoutPrefix) + const tagsData = await this.redisClient.getBuffer(tagsKey) + return { key, tagsKey, tagsData } + }) + + // Wait for all tag data fetches + const tagResults = await Promise.all(tagDataPromises) + + // Build single pipeline for all tag cleanup operations + const tagCleanupPipeline = this.redisClient.pipeline() + + const cleanupPromises = tagResults.map( + async ({ key, tagsKey, tagsData }) => { + if (tagsData) { + try { + // Binary format: array of 32-bit integers + const tagIds: number[] = [] + for (let i = 0; i < tagsData.length; i += 4) { + tagIds.push(tagsData.readUInt32LE(i)) + } + + if (tagIds.length) { + const entryTags = await this.#resolveTagIds(tagIds) + entryTags.forEach((tag) => { + const tagKey = this.#getTagKey(tag, { isHashed: true }) + tagCleanupPipeline.srem(tagKey, key) + }) + tagCleanupPipeline.unlink(tagsKey) // Delete the tags key + + // Decrement reference counts and cleanup unused tags + await this.#decrementTagRefs(tagIds) + } + } catch (e) { + // noop + } + } + } + ) + + await Promise.all(cleanupPromises) + await tagCleanupPipeline.exec() + + await deletePipeline.exec() + + // Clean up empty tag sets + const allTagKeys = await this.redisClient.keys( + `${this.keyNamePrefix}tag:*` + ) + if (allTagKeys.length) { + const cleanupPipeline = this.redisClient.pipeline() + + allTagKeys.forEach((tagKey) => { + cleanupPipeline.scard(tagKey) + }) + + const cardinalityResults = await cleanupPipeline.exec() + + // Delete tag keys that are now empty + const emptyTagDeletePipeline = this.redisClient.pipeline() + cardinalityResults?.forEach((result, index) => { + if (result && result[1] === 0) { + emptyTagDeletePipeline.unlink(allTagKeys[index]) + } + }) + + await emptyTagDeletePipeline.exec() + } + + return + } + } + } + } + } + + async flush(): Promise { + // Use SCAN to find ALL keys with our prefix and delete them + // This includes main cache keys, tag keys (tag:*), and tags keys (tags:*) + const pattern = `${this.keyNamePrefix}*` + let cursor = "0" + + do { + const result = await this.redisClient.scan( + cursor, + "MATCH", + pattern, + "COUNT", + 1000 + ) + cursor = result[0] + const keys = result[1] + + if (keys.length) { + await this.redisClient.unlink(...keys) + } + } while (cursor !== "0") + } +} diff --git a/packages/modules/providers/caching-redis/src/types/index.ts b/packages/modules/providers/caching-redis/src/types/index.ts new file mode 100644 index 0000000000..c22fcf493a --- /dev/null +++ b/packages/modules/providers/caching-redis/src/types/index.ts @@ -0,0 +1,26 @@ +export interface RedisCacheModuleOptions { + /** + * TTL in milliseconds + */ + ttl?: number + /** + * Connection timeout in milliseconds + */ + connectTimeout?: number + /** + * Lazyload connections + */ + lazyConnect?: boolean + /** + * Connection retries + */ + retryDelayOnFailover?: number + /** + * Key prefix for all cache keys + */ + prefix?: string + /** + * Minimum size in bytes to compress data (default: 1024) + */ + compressionThreshold?: number +} diff --git a/packages/modules/providers/caching-redis/tsconfig.json b/packages/modules/providers/caching-redis/tsconfig.json new file mode 100644 index 0000000000..90f3a70b38 --- /dev/null +++ b/packages/modules/providers/caching-redis/tsconfig.json @@ -0,0 +1,12 @@ +{ + "extends": "../../../../_tsconfig.base.json", + "compilerOptions": { + "paths": { + "@models": ["./src/models"], + "@services": ["./src/services"], + "@repositories": ["./src/repositories"], + "@types": ["./src/types"], + "@utils": ["./src/utils"] + } + } +} diff --git a/yarn.lock b/yarn.lock index c8d2161f09..a1691aeb47 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6217,6 +6217,44 @@ __metadata: languageName: unknown linkType: soft +"@medusajs/caching-redis@2.10.3, @medusajs/caching-redis@workspace:packages/modules/providers/caching-redis": + version: 0.0.0-use.local + resolution: "@medusajs/caching-redis@workspace:packages/modules/providers/caching-redis" + dependencies: + "@medusajs/framework": 2.10.3 + "@swc/core": ^1.7.28 + "@swc/jest": ^0.2.36 + ioredis: ^5.4.1 + jest: ^29.7.0 + rimraf: ^5.0.1 + typescript: ^5.6.2 + xxhash-wasm: ^1.1.0 + peerDependencies: + "@medusajs/framework": 2.10.3 + languageName: unknown + linkType: soft + +"@medusajs/caching@2.10.3, @medusajs/caching@workspace:*, @medusajs/caching@workspace:packages/modules/caching": + version: 0.0.0-use.local + resolution: "@medusajs/caching@workspace:packages/modules/caching" + dependencies: + "@medusajs/framework": 2.10.3 + "@medusajs/test-utils": 2.10.3 + "@swc/core": ^1.7.28 + "@swc/jest": ^0.2.36 + fast-json-stable-stringify: ^2.1.0 + jest: ^29.7.0 + node-cache: ^5.1.2 + rimraf: ^3.0.2 + tsc-alias: ^1.8.6 + typescript: ^5.6.2 + xxhash-wasm: ^1.1.0 + peerDependencies: + "@medusajs/framework": 2.10.3 + awilix: ^8.0.1 + languageName: unknown + linkType: soft + "@medusajs/cart@2.10.3, @medusajs/cart@workspace:packages/modules/cart": version: 0.0.0-use.local resolution: "@medusajs/cart@workspace:packages/modules/cart" @@ -6823,6 +6861,8 @@ __metadata: "@medusajs/auth-google": 2.10.3 "@medusajs/cache-inmemory": 2.10.3 "@medusajs/cache-redis": 2.10.3 + "@medusajs/caching": 2.10.3 + "@medusajs/caching-redis": 2.10.3 "@medusajs/cart": 2.10.3 "@medusajs/core-flows": 2.10.3 "@medusajs/currency": 2.10.3 @@ -18598,6 +18638,13 @@ __metadata: languageName: node linkType: hard +"clone@npm:2.x": + version: 2.1.2 + resolution: "clone@npm:2.1.2" + checksum: ed0601cd0b1606bc7d82ee7175b97e68d1dd9b91fd1250a3617b38d34a095f8ee0431d40a1a611122dcccb4f93295b4fdb94942aa763392b5fe44effa50c2d5e + languageName: node + linkType: hard + "clone@npm:^1.0.2": version: 1.0.4 resolution: "clone@npm:1.0.4" @@ -24172,6 +24219,7 @@ __metadata: "@medusajs/api-key": "workspace:^" "@medusajs/auth": "workspace:*" "@medusajs/cache-inmemory": "workspace:*" + "@medusajs/caching": "workspace:*" "@medusajs/core-flows": "workspace:^" "@medusajs/currency": "workspace:^" "@medusajs/customer": "workspace:^" @@ -27809,6 +27857,15 @@ __metadata: languageName: node linkType: hard +"node-cache@npm:^5.1.2": + version: 5.1.2 + resolution: "node-cache@npm:5.1.2" + dependencies: + clone: 2.x + checksum: 2f91907510a1276415ae5898269d0765934d5a4f3682c8b1b19964694a9b841c8bd791e1a125d1f89050f412e1da5dd982179d714252b3a7223abb05b8cb24d5 + languageName: node + linkType: hard + "node-domexception@npm:^1.0.0": version: 1.0.0 resolution: "node-domexception@npm:1.0.0" @@ -36325,6 +36382,13 @@ __metadata: languageName: node linkType: hard +"xxhash-wasm@npm:^1.1.0": + version: 1.1.0 + resolution: "xxhash-wasm@npm:1.1.0" + checksum: 35aa152fc7d775ae13364fe4fb20ebd89c6ac1f56cdb6060a6d2f1ed68d15180694467e63a4adb3d11936a4798ccd75a540979070e70d9b911e9981bbdd9cea6 + languageName: node + linkType: hard + "y18n@npm:^4.0.0": version: 4.0.3 resolution: "y18n@npm:4.0.3"