From 1c3ef13371aab0b19d0493b84063b50ff8e4116d Mon Sep 17 00:00:00 2001 From: Stevche Radevski Date: Tue, 9 Jul 2024 11:37:13 +0200 Subject: [PATCH] feat: Add necessary middlewares for tax inclusive pricing (#7827) We are adding tax inclusive pricing calculation when listing products. Two things to keep in mind: - `region_id` will be required if you request calculated prices. - We won't accept `currency_code` anymore, as that will come from the region info (since ultimately a cart and its currency are tied to a region) REF CORE-2376 DEPENDS ON #8003 --- integration-tests/helpers/fixtures.ts | 2 +- .../__tests__/product/store/product.spec.ts | 334 +++++++++++++++++- .../core/types/src/http/pricing/common.ts | 105 ++++++ .../types/src/http/product/admin/payloads.ts | 2 + .../core/types/src/http/product/common.ts | 4 +- packages/core/types/src/product/common.ts | 8 +- packages/core/types/src/tax/common.ts | 2 +- packages/core/utils/src/totals/index.ts | 1 + packages/core/utils/src/totals/tax/index.ts | 21 ++ .../src/api/store/products/[id]/route.ts | 13 +- .../medusa/src/api/store/products/helpers.ts | 99 +++++- .../src/api/store/products/middlewares.ts | 9 + .../medusa/src/api/store/products/route.ts | 6 +- .../src/api/store/products/validators.ts | 18 +- .../common/clear-filters-by-key.ts | 12 + .../src/api/utils/middlewares/common/index.ts | 1 + .../api/utils/middlewares/products/index.ts | 3 +- .../products/normalize-data-for-context.ts | 74 ++++ .../products/set-pricing-context.ts | 93 +---- .../middlewares/products/set-tax-context.ts | 73 ++++ .../pricing-module/calculate-price.spec.ts | 33 ++ .../pricing/src/services/pricing-module.ts | 32 +- 22 files changed, 824 insertions(+), 121 deletions(-) create mode 100644 packages/core/types/src/http/pricing/common.ts create mode 100644 packages/medusa/src/api/utils/middlewares/common/clear-filters-by-key.ts create mode 100644 packages/medusa/src/api/utils/middlewares/products/normalize-data-for-context.ts create mode 100644 packages/medusa/src/api/utils/middlewares/products/set-tax-context.ts diff --git a/integration-tests/helpers/fixtures.ts b/integration-tests/helpers/fixtures.ts index 1f1c2ae4a8..6bd8c777c0 100644 --- a/integration-tests/helpers/fixtures.ts +++ b/integration-tests/helpers/fixtures.ts @@ -1,7 +1,7 @@ import { HttpTypes } from "@medusajs/types" export const getProductFixture = ( - overrides: Partial + overrides: Partial ) => ({ title: "Test fixture", description: "test-product-description", diff --git a/integration-tests/http/__tests__/product/store/product.spec.ts b/integration-tests/http/__tests__/product/store/product.spec.ts index b3ebd2ca41..22e1ebdc71 100644 --- a/integration-tests/http/__tests__/product/store/product.spec.ts +++ b/integration-tests/http/__tests__/product/store/product.spec.ts @@ -892,7 +892,7 @@ medusaIntegrationTestRunner({ expect(error.response.status).toEqual(400) expect(error.response.data).toEqual({ message: - "Missing required pricing context to calculate prices - currency_code or region_id", + "Missing required pricing context to calculate prices - region_id", type: "invalid_data", }) }) @@ -907,7 +907,7 @@ medusaIntegrationTestRunner({ ).data.region let response = await api.get( - `/store/products?fields=*variants.calculated_price¤cy_code=usd` + `/store/products?fields=*variants.calculated_price®ion_id=${region.id}` ) const expectation = expect.arrayContaining([ @@ -948,12 +948,6 @@ medusaIntegrationTestRunner({ expect(response.data.count).toEqual(3) expect(response.data.products).toEqual(expectation) - // Without calculated_price fields - response = await api.get(`/store/products?currency_code=usd`) - - expect(response.status).toEqual(200) - expect(response.data.products).toEqual(expectation) - // with only region_id response = await api.get(`/store/products?region_id=${region.id}`) @@ -1205,7 +1199,7 @@ medusaIntegrationTestRunner({ expect(error.response.status).toEqual(400) expect(error.response.data).toEqual({ message: - "Missing required pricing context to calculate prices - currency_code or region_id", + "Missing required pricing context to calculate prices - region_id", type: "invalid_data", }) }) @@ -1220,7 +1214,7 @@ medusaIntegrationTestRunner({ ).data.region let response = await api.get( - `/store/products/${product.id}?fields=*variants.calculated_price¤cy_code=usd` + `/store/products/${product.id}?fields=*variants.calculated_price®ion_id=${region.id}` ) const expectation = expect.objectContaining({ @@ -1258,14 +1252,6 @@ medusaIntegrationTestRunner({ expect(response.status).toEqual(200) expect(response.data.product).toEqual(expectation) - // Without calculated_price fields - response = await api.get( - `/store/products/${product.id}?currency_code=usd` - ) - - expect(response.status).toEqual(200) - expect(response.data.product).toEqual(expectation) - // with only region_id response = await api.get( `/store/products/${product.id}?region_id=${region.id}` @@ -1275,5 +1261,317 @@ medusaIntegrationTestRunner({ expect(response.data.product).toEqual(expectation) }) }) + + describe("Tax handling", () => { + let usRegion + let euRegion + let dkRegion + let euCart + + beforeEach(async () => { + usRegion = ( + await api.post( + "/admin/regions", + { + name: "Test Region", + currency_code: "usd", + countries: ["us"], + is_tax_inclusive: false, + automatic_taxes: false, + }, + adminHeaders + ) + ).data.region + + euRegion = ( + await api.post( + "/admin/regions", + { + name: "Test Region", + currency_code: "eur", + countries: ["it", "de"], + is_tax_inclusive: true, + automatic_taxes: true, + }, + adminHeaders + ) + ).data.region + + dkRegion = ( + await api.post( + "/admin/regions", + { + name: "Test Region", + currency_code: "dkk", + countries: ["dk"], + is_tax_inclusive: false, + automatic_taxes: true, + }, + adminHeaders + ) + ).data.region + + product1 = ( + await api.post( + "/admin/products", + getProductFixture({ + title: "test1", + status: "published", + variants: [ + { + title: "Test taxes", + prices: [ + { + amount: 45, + currency_code: "eur", + rules: { region_id: euRegion.id }, + }, + { + amount: 100, + currency_code: "usd", + rules: { region_id: usRegion.id }, + }, + { + amount: 30, + currency_code: "dkk", + rules: { region_id: dkRegion.id }, + }, + ], + }, + ], + }), + adminHeaders + ) + ).data.product + + euCart = (await api.post("/store/carts", { region_id: euRegion.id })) + .data.cart + + await api.post( + `/admin/tax-regions`, + { + country_code: "us", + default_tax_rate: { + code: "default", + rate: 5, + name: "default rate", + }, + }, + adminHeaders + ) + + await api.post( + `/admin/tax-regions`, + { + country_code: "it", + default_tax_rate: { + code: "default", + rate: 10, + name: "default rate", + }, + }, + adminHeaders + ) + + await api.post( + `/admin/tax-regions`, + { + country_code: "dk", + default_tax_rate: { + code: "default", + rate: 20, + name: "default rate", + }, + }, + adminHeaders + ) + }) + + it("should not return tax pricing if the context is not sufficient when listing products", async () => { + const products = ( + await api.get( + `/store/products?fields=id,*variants.calculated_price®ion_id=${usRegion.id}` + ) + ).data.products + + expect(products.length).toBe(1) + expect(products[0].variants[0].calculated_price).not.toHaveProperty( + "calculated_amount_with_tax" + ) + expect(products[0].variants[0].calculated_price).not.toHaveProperty( + "calculated_amount_without_tax" + ) + }) + + it("should not return tax pricing if automatic taxes are off when listing products", async () => { + const products = ( + await api.get( + `/store/products?fields=id,*variants.calculated_price®ion_id=${usRegion.id}&country_code=us` + ) + ).data.products + + expect(products.length).toBe(1) + expect(products[0].variants[0].calculated_price).not.toHaveProperty( + "calculated_amount_with_tax" + ) + expect(products[0].variants[0].calculated_price).not.toHaveProperty( + "calculated_amount_without_tax" + ) + }) + + it("should return prices with and without tax for a tax inclusive region when listing products", async () => { + const products = ( + await api.get( + `/store/products?fields=id,*variants.calculated_price®ion_id=${euRegion.id}&country_code=it` + ) + ).data.products + + expect(products.length).toBe(1) + expect(products[0].variants).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + calculated_price: expect.objectContaining({ + currency_code: "eur", + calculated_amount: 45, + calculated_amount_with_tax: 45, + }), + }), + ]) + ) + + // TODO: Return an integer instead of a float for the pricing + expect( + products[0].variants[0].calculated_price.calculated_amount_without_tax.toFixed( + 1 + ) + ).toEqual("40.9") + }) + + it("should return prices with and without tax for a tax exclusive region when listing products", async () => { + const products = ( + await api.get( + `/store/products?fields=id,*variants.calculated_price®ion_id=${dkRegion.id}&country_code=dk` + ) + ).data.products + + expect(products.length).toBe(1) + expect(products[0].variants).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + calculated_price: expect.objectContaining({ + currency_code: "dkk", + calculated_amount: 30, + calculated_amount_with_tax: 36, + calculated_amount_without_tax: 30, + }), + }), + ]) + ) + }) + + it("should return prices with and without tax when the cart is available and a country is passed when listing products", async () => { + const products = ( + await api.get( + `/store/products?fields=id,*variants.calculated_price&cart_id=${euCart.id}&country_code=it` + ) + ).data.products + + expect(products.length).toBe(1) + expect(products[0].variants).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + calculated_price: expect.objectContaining({ + currency_code: "eur", + calculated_amount: 45, + calculated_amount_with_tax: 45, + }), + }), + ]) + ) + + // TODO: Return an integer instead of a float for the pricing + expect( + products[0].variants[0].calculated_price.calculated_amount_without_tax.toFixed( + 1 + ) + ).toEqual("40.9") + }) + + it("should return prices with and without tax when the cart context is available when listing products", async () => { + await api.post(`/store/carts/${euCart.id}`, { + shipping_address: { + country_code: "it", + }, + }) + + const products = ( + await api.get( + `/store/products?fields=id,*variants.calculated_price&cart_id=${euCart.id}` + ) + ).data.products + + expect(products.length).toBe(1) + expect(products[0].variants).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + calculated_price: expect.objectContaining({ + currency_code: "eur", + calculated_amount: 45, + calculated_amount_with_tax: 45, + }), + }), + ]) + ) + + // TODO: Return an integer instead of a float for the pricing + expect( + products[0].variants[0].calculated_price.calculated_amount_without_tax.toFixed( + 1 + ) + ).toEqual("40.9") + }) + + it("should not return tax pricing if the context is not sufficient when fetching a single product", async () => { + const product = ( + await api.get( + `/store/products/${product1.id}?fields=id,*variants.calculated_price®ion_id=${usRegion.id}` + ) + ).data.product + + expect(product.variants[0].calculated_price).not.toHaveProperty( + "calculated_amount_with_tax" + ) + expect(product.variants[0].calculated_price).not.toHaveProperty( + "calculated_amount_without_tax" + ) + }) + + it("should return prices with and without tax for a tax inclusive region when fetching a single product", async () => { + const product = ( + await api.get( + `/store/products/${product1.id}?fields=id,*variants.calculated_price®ion_id=${euRegion.id}&country_code=it` + ) + ).data.product + + expect(product.variants).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + calculated_price: expect.objectContaining({ + currency_code: "eur", + calculated_amount: 45, + calculated_amount_with_tax: 45, + }), + }), + ]) + ) + + // TODO: Return an integer instead of a float for the pricing + expect( + product.variants[0].calculated_price.calculated_amount_without_tax.toFixed( + 1 + ) + ).toEqual("40.9") + }) + }) }, }) diff --git a/packages/core/types/src/http/pricing/common.ts b/packages/core/types/src/http/pricing/common.ts new file mode 100644 index 0000000000..a5aaf771b5 --- /dev/null +++ b/packages/core/types/src/http/pricing/common.ts @@ -0,0 +1,105 @@ +export interface BaseCalculatedPriceSet { + /** + * The ID of the price set. + */ + id: string + + /** + * Whether the calculated price is associated with a price list. During the calculation process, if no valid price list is found, + * the calculated price is set to the original price, which doesn't belong to a price list. In that case, the value of this property is `false`. + */ + is_calculated_price_price_list?: boolean + + /** + * Whether the calculated price is tax inclusive. + */ + is_calculated_price_tax_inclusive?: boolean + + /** + * The amount of the calculated price, or `null` if there isn't a calculated price. + */ + calculated_amount: number | null + + /** + * The amount of the calculated price with taxes included. If the calculated price is tax inclusive, this field will be the same as `calculated_amount`. + */ + calculated_amount_with_tax?: number | null + + /** + * The amount of the calculated price without taxes included. If the calculated price is tax exclusive, this field will be the same as `calculated_amount`. + */ + calculated_amount_without_tax?: number | null + + /** + * Whether the original price is associated with a price list. During the calculation process, if the price list of the calculated price is of type override, + * the original price will be the same as the calculated price. In that case, the value of this property is `true`. + */ + is_original_price_price_list?: boolean + + /** + * Whether the original price is tax inclusive. + */ + is_original_price_tax_inclusive?: boolean + + /** + * The amount of the original price, or `null` if there isn't a calculated price. + */ + original_amount: number | null + + /** + * The currency code of the calculated price, or null if there isn't a calculated price. + */ + currency_code: string | null + + /** + * The details of the calculated price. + */ + calculated_price?: { + /** + * The ID of the price selected as the calculated price. + */ + id: string | null + /** + * The ID of the associated price list, if any. + */ + price_list_id: string | null + /** + * The type of the associated price list, if any. + */ + price_list_type: string | null + /** + * The `min_quantity` field defined on a price. + */ + min_quantity: number | null + /** + * The `max_quantity` field defined on a price. + */ + max_quantity: number | null + } + + /** + * The details of the original price. + */ + original_price?: { + /** + * The ID of the price selected as the original price. + */ + id: string | null + /** + * The ID of the associated price list, if any. + */ + price_list_id: string | null + /** + * The type of the associated price list, if any. + */ + price_list_type: string | null + /** + * The `min_quantity` field defined on a price. + */ + min_quantity: number | null + /** + * The `max_quantity` field defined on a price. + */ + max_quantity: number | null + } +} diff --git a/packages/core/types/src/http/product/admin/payloads.ts b/packages/core/types/src/http/product/admin/payloads.ts index 4cef18d5df..d87cc0c6bd 100644 --- a/packages/core/types/src/http/product/admin/payloads.ts +++ b/packages/core/types/src/http/product/admin/payloads.ts @@ -22,6 +22,8 @@ export interface AdminCreateProductVariantPrice { amount: number min_quantity?: number | null max_quantity?: number | null + // Note: Although the BE is generic, we only use region_id for price rules for now, so it's better to keep the typings stricter. + rules?: { region_id: string } | null } export interface AdminCreateProductVariant { diff --git a/packages/core/types/src/http/product/common.ts b/packages/core/types/src/http/product/common.ts index 128fcc6263..8727cbcca0 100644 --- a/packages/core/types/src/http/product/common.ts +++ b/packages/core/types/src/http/product/common.ts @@ -1,6 +1,7 @@ import { BaseFilterable, OperatorMap } from "../../dal" import { BaseCollection } from "../collection/common" import { FindParams } from "../common" +import { BaseCalculatedPriceSet } from "../pricing/common" import { BaseProductCategory } from "../product-category/common" import { BaseProductType } from "../product-type/common" @@ -56,10 +57,11 @@ export interface BaseProductVariant { length: number | null height: number | null width: number | null + variant_rank?: number | null options: BaseProductOptionValue[] | null product?: BaseProduct | null product_id?: string - variant_rank?: number | null + calculated_price?: BaseCalculatedPriceSet created_at: string updated_at: string deleted_at: string | null diff --git a/packages/core/types/src/product/common.ts b/packages/core/types/src/product/common.ts index d71c9999bf..8b439210a0 100644 --- a/packages/core/types/src/product/common.ts +++ b/packages/core/types/src/product/common.ts @@ -305,11 +305,11 @@ export interface ProductCategoryDTO { * * @expandable */ - parent_category?: ProductCategoryDTO | null + parent_category: ProductCategoryDTO | null /** * The associated parent category id. */ - parent_category_id?: string | null + parent_category_id: string | null /** * The associated child categories. * @@ -330,6 +330,10 @@ export interface ProductCategoryDTO { * When the product category was updated. */ updated_at: string | Date + /** + * When the product category was deleted. + */ + deleted_at?: string | Date } /** diff --git a/packages/core/types/src/tax/common.ts b/packages/core/types/src/tax/common.ts index dcc73511bd..eae5e75874 100644 --- a/packages/core/types/src/tax/common.ts +++ b/packages/core/types/src/tax/common.ts @@ -513,7 +513,7 @@ interface TaxLineDTO { /** * The rate of the tax line. */ - rate: number | null + rate: number /** * The code of the tax line. diff --git a/packages/core/utils/src/totals/index.ts b/packages/core/utils/src/totals/index.ts index f14979d933..8600021e94 100644 --- a/packages/core/utils/src/totals/index.ts +++ b/packages/core/utils/src/totals/index.ts @@ -2,6 +2,7 @@ export * from "./cart" export * from "./create-raw-properties-from-bignumber" export * from "./line-item" export * from "./math" +export * from "./tax" export * from "./promotion" export * from "./shipping-method" export * from "./transform-properties-to-bignumber" diff --git a/packages/core/utils/src/totals/tax/index.ts b/packages/core/utils/src/totals/tax/index.ts index b477005ede..dec517981f 100644 --- a/packages/core/utils/src/totals/tax/index.ts +++ b/packages/core/utils/src/totals/tax/index.ts @@ -31,3 +31,24 @@ export function calculateTaxTotal({ return taxTotal } + +export function calculateAmountsWithTax({ + taxLines, + amount, + includesTax, +}: { + taxLines: Pick[] + amount: number + includesTax?: boolean +}) { + const tax = calculateTaxTotal({ + taxLines, + includesTax, + taxableAmount: amount, + }) + + return { + priceWithTax: includesTax ? amount : MathBN.add(tax, amount).toNumber(), + priceWithoutTax: includesTax ? MathBN.sub(amount, tax).toNumber() : amount, + } +} diff --git a/packages/medusa/src/api/store/products/[id]/route.ts b/packages/medusa/src/api/store/products/[id]/route.ts index d3488d977a..9e1f5f05ec 100644 --- a/packages/medusa/src/api/store/products/[id]/route.ts +++ b/packages/medusa/src/api/store/products/[id]/route.ts @@ -1,11 +1,15 @@ import { MedusaError, isPresent } from "@medusajs/utils" -import { MedusaRequest, MedusaResponse } from "../../../../types/routing" +import { MedusaResponse } from "../../../../types/routing" import { wrapVariantsWithInventoryQuantity } from "../../../utils/middlewares" -import { refetchProduct } from "../helpers" -import { StoreGetProductsParamsType } from "../validators" +import { + RequestWithContext, + refetchProduct, + wrapProductsWithTaxPrices, +} from "../helpers" +import { StoreGetProductParamsType } from "../validators" export const GET = async ( - req: MedusaRequest, + req: RequestWithContext, res: MedusaResponse ) => { const withInventoryQuantity = req.remoteQueryConfig.fields.some((field) => @@ -46,5 +50,6 @@ export const GET = async ( await wrapVariantsWithInventoryQuantity(req, product.variants || []) } + await wrapProductsWithTaxPrices(req, [product]) res.json({ product }) } diff --git a/packages/medusa/src/api/store/products/helpers.ts b/packages/medusa/src/api/store/products/helpers.ts index 774ffd3de1..63327593ff 100644 --- a/packages/medusa/src/api/store/products/helpers.ts +++ b/packages/medusa/src/api/store/products/helpers.ts @@ -1,6 +1,25 @@ -import { MedusaContainer } from "@medusajs/types" +import { + ModuleRegistrationName, + calculateAmountsWithTax, +} from "@medusajs/utils" import { MedusaRequest } from "../../../types/routing" import { refetchEntities, refetchEntity } from "../../utils/refetch-entity" +import { + MedusaContainer, + HttpTypes, + TaxableItemDTO, + ItemTaxLineDTO, + TaxCalculationContext, +} from "@medusajs/types" + +export type RequestWithContext = MedusaRequest & { + taxContext: { + taxLineContext?: TaxCalculationContext + taxInclusivityContext?: { + automaticTaxes: boolean + } + } +} export const refetchProduct = async ( idOrFilter: string | object, @@ -30,3 +49,81 @@ export const maybeApplyStockLocationId = async (req: MedusaRequest, ctx) => { return entities.map((entity) => entity.stock_location_id) } + +export const wrapProductsWithTaxPrices = async ( + req: RequestWithContext, + products: HttpTypes.StoreProduct[] +) => { + // If we are missing the necessary context, we can't calculate the tax, so only `calculated_amount` will be available + if ( + !req.taxContext?.taxInclusivityContext || + !req.taxContext?.taxLineContext + ) { + return + } + + // If automatic taxes are not enabled, we should skip calculating any tax + if (!req.taxContext.taxInclusivityContext.automaticTaxes) { + return + } + + const taxService = req.scope.resolve(ModuleRegistrationName.TAX) + + const taxRates = (await taxService.getTaxLines( + products.map(asTaxItem).flat(), + req.taxContext.taxLineContext + )) as unknown as ItemTaxLineDTO[] + + const taxRatesMap = new Map() + taxRates.forEach((taxRate) => { + if (!taxRatesMap.has(taxRate.line_item_id)) { + taxRatesMap.set(taxRate.line_item_id, []) + } + + taxRatesMap.get(taxRate.line_item_id)?.push(taxRate) + }) + + products.forEach((product) => { + product.variants?.forEach((variant) => { + if (!variant.calculated_price) { + return + } + + const taxRatesForVariant = taxRatesMap.get(variant.id) || [] + const { priceWithTax, priceWithoutTax } = calculateAmountsWithTax({ + taxLines: taxRatesForVariant, + amount: variant.calculated_price!.calculated_amount!, + includesTax: + variant.calculated_price!.is_calculated_price_tax_inclusive!, + }) + + variant.calculated_price.calculated_amount_with_tax = priceWithTax + variant.calculated_price.calculated_amount_without_tax = priceWithoutTax + }) + }) +} + +const asTaxItem = (product: HttpTypes.StoreProduct): TaxableItemDTO[] => { + return product.variants + ?.map((variant) => { + if (!variant.calculated_price) { + return + } + + return { + id: variant.id, + product_id: product.id, + product_name: product.title, + product_categories: product.categories?.map((c) => c.name), + // TODO: It is strange that we only accept a single category, revisit the tax module implementation + product_category_id: product.categories?.[0]?.id, + product_sku: variant.sku, + product_type: product.type, + product_type_id: product.type_id, + quantity: 1, + unit_price: variant.calculated_price.calculated_amount, + currency_code: variant.calculated_price.currency_code, + } + }) + .filter((v) => !!v) as unknown as TaxableItemDTO[] +} diff --git a/packages/medusa/src/api/store/products/middlewares.ts b/packages/medusa/src/api/store/products/middlewares.ts index bc4a46f4fe..e90b15844d 100644 --- a/packages/medusa/src/api/store/products/middlewares.ts +++ b/packages/medusa/src/api/store/products/middlewares.ts @@ -3,8 +3,11 @@ import { MiddlewareRoute } from "../../../loaders/helpers/routing/types" import { maybeApplyLinkFilter } from "../../utils/maybe-apply-link-filter" import { applyDefaultFilters, + clearFiltersByKey, filterByValidSalesChannels, + normalizeDataForContext, setPricingContext, + setTaxContext, } from "../../utils/middlewares" import { setContext } from "../../utils/middlewares/common/set-context" import { validateAndTransformQuery } from "../../utils/validate-query" @@ -47,7 +50,10 @@ export const storeProductRoutesMiddlewares: MiddlewareRoute[] = [ return { id: categoryIds, is_internal: false, is_active: true } }, }), + normalizeDataForContext(), setPricingContext(), + setTaxContext(), + clearFiltersByKey(["region_id", "country_code", "province", "cart_id"]), ], }, { @@ -78,7 +84,10 @@ export const storeProductRoutesMiddlewares: MiddlewareRoute[] = [ return { is_internal: false, is_active: true } }, }), + normalizeDataForContext(), setPricingContext(), + setTaxContext(), + clearFiltersByKey(["region_id", "country_code", "province", "cart_id"]), ], }, ] diff --git a/packages/medusa/src/api/store/products/route.ts b/packages/medusa/src/api/store/products/route.ts index 54d4af32e9..9a5f4ee52a 100644 --- a/packages/medusa/src/api/store/products/route.ts +++ b/packages/medusa/src/api/store/products/route.ts @@ -3,12 +3,13 @@ import { isPresent, remoteQueryObjectFromString, } from "@medusajs/utils" -import { MedusaRequest, MedusaResponse } from "../../../types/routing" +import { MedusaResponse } from "../../../types/routing" import { wrapVariantsWithInventoryQuantity } from "../../utils/middlewares" import { StoreGetProductsParamsType } from "./validators" +import { RequestWithContext, wrapProductsWithTaxPrices } from "./helpers" export const GET = async ( - req: MedusaRequest, + req: RequestWithContext, res: MedusaResponse ) => { const remoteQuery = req.scope.resolve(ContainerRegistrationKeys.REMOTE_QUERY) @@ -48,6 +49,7 @@ export const GET = async ( ) } + await wrapProductsWithTaxPrices(req, products) res.json({ products, count: metadata.count, diff --git a/packages/medusa/src/api/store/products/validators.ts b/packages/medusa/src/api/store/products/validators.ts index c3d8385362..2d3c4ae09d 100644 --- a/packages/medusa/src/api/store/products/validators.ts +++ b/packages/medusa/src/api/store/products/validators.ts @@ -9,7 +9,17 @@ import { createSelectParams, } from "../../utils/validators" -export const StoreGetProductParams = createSelectParams() +export type StoreGetProductParamsType = z.infer + +export const StoreGetProductParams = createSelectParams().merge( + // These are used to populate the tax and pricing context + z.object({ + region_id: z.string().optional(), + country_code: z.string().optional(), + province: z.string().optional(), + cart_id: z.string().optional(), + }) +) export type StoreGetProductVariantsParamsType = z.infer< typeof StoreGetProductVariantsParams @@ -38,8 +48,12 @@ export const StoreGetProductsParams = createFindParams({ }).merge( z .object({ + // These are used to populate the tax and pricing context region_id: z.string().optional(), - currency_code: z.string().optional(), + country_code: z.string().optional(), + province: z.string().optional(), + cart_id: z.string().optional(), + variants: z .object({ status: ProductStatusEnum.array().optional(), diff --git a/packages/medusa/src/api/utils/middlewares/common/clear-filters-by-key.ts b/packages/medusa/src/api/utils/middlewares/common/clear-filters-by-key.ts new file mode 100644 index 0000000000..804e30d2a7 --- /dev/null +++ b/packages/medusa/src/api/utils/middlewares/common/clear-filters-by-key.ts @@ -0,0 +1,12 @@ +import { NextFunction } from "express" +import { MedusaRequest } from "../../../../types/routing" + +export function clearFiltersByKey(keys: string[]) { + return async (req: MedusaRequest, _, next: NextFunction) => { + keys.forEach((key) => { + delete req.filterableFields[key] + }) + + return next() + } +} diff --git a/packages/medusa/src/api/utils/middlewares/common/index.ts b/packages/medusa/src/api/utils/middlewares/common/index.ts index a0bd86ad38..420f4ad32f 100644 --- a/packages/medusa/src/api/utils/middlewares/common/index.ts +++ b/packages/medusa/src/api/utils/middlewares/common/index.ts @@ -1 +1,2 @@ export * from "./apply-default-filters" +export * from "./clear-filters-by-key" diff --git a/packages/medusa/src/api/utils/middlewares/products/index.ts b/packages/medusa/src/api/utils/middlewares/products/index.ts index 599fbbbb9c..ef38f7eb77 100644 --- a/packages/medusa/src/api/utils/middlewares/products/index.ts +++ b/packages/medusa/src/api/utils/middlewares/products/index.ts @@ -1,4 +1,5 @@ export * from "./filter-by-valid-sales-channels" +export * from "./normalize-data-for-context" export * from "./set-pricing-context" +export * from "./set-tax-context" export * from "./variant-inventory-quantity" - diff --git a/packages/medusa/src/api/utils/middlewares/products/normalize-data-for-context.ts b/packages/medusa/src/api/utils/middlewares/products/normalize-data-for-context.ts new file mode 100644 index 0000000000..c2b90bd8f3 --- /dev/null +++ b/packages/medusa/src/api/utils/middlewares/products/normalize-data-for-context.ts @@ -0,0 +1,74 @@ +import { MedusaError } from "@medusajs/utils" +import { NextFunction } from "express" +import { AuthenticatedMedusaRequest } from "../../../../types/routing" +import { refetchEntities, refetchEntity } from "../../refetch-entity" + +export function normalizeDataForContext() { + return async (req: AuthenticatedMedusaRequest, _, next: NextFunction) => { + // If the product pricing is not requested, we don't need region information + let withCalculatedPrice = req.remoteQueryConfig.fields.some((field) => + field.startsWith("variants.calculated_price") + ) + + // If the region is passed, we calculate the prices without requesting them. + // TODO: This seems a bit messy, reconsider if we want to keep this logic. + if (!withCalculatedPrice && req.filterableFields.region_id) { + req.remoteQueryConfig.fields.push("variants.calculated_price.*") + withCalculatedPrice = true + } + + if (!withCalculatedPrice) { + return next() + } + + // Region ID is required to calculate prices correctly. + // Country code, and optionally province, are needed to calculate taxes. + let regionId = req.filterableFields.region_id + let countryCode = req.filterableFields.country_code + let province = req.filterableFields.province + + // If the cart is passed, get the information from it + if (req.filterableFields.cart_id) { + const cart = await refetchEntity( + "cart", + req.filterableFields.cart_id, + req.scope, + ["region_id", "shipping_address.*"] + ) + + if (cart?.region_id) { + regionId = cart.region_id + } + + if (cart?.shipping_address) { + countryCode = cart.shipping_address.country_code + province = cart.shipping_address.province + } + } + + // Finally, try to get it from the store defaults if not available + if (!regionId) { + const stores = await refetchEntities("store", {}, req.scope, [ + "default_region_id", + ]) + regionId = stores[0]?.default_region_id + } + + if (!regionId) { + try { + throw new MedusaError( + MedusaError.Types.INVALID_DATA, + `Missing required pricing context to calculate prices - region_id` + ) + } catch (e) { + return next(e) + } + } + + req.filterableFields.region_id = regionId + req.filterableFields.country_code = countryCode + req.filterableFields.province = province + + return next() + } +} diff --git a/packages/medusa/src/api/utils/middlewares/products/set-pricing-context.ts b/packages/medusa/src/api/utils/middlewares/products/set-pricing-context.ts index 8903359110..02a58cb962 100644 --- a/packages/medusa/src/api/utils/middlewares/products/set-pricing-context.ts +++ b/packages/medusa/src/api/utils/middlewares/products/set-pricing-context.ts @@ -1,5 +1,5 @@ import { MedusaPricingContext } from "@medusajs/types" -import { isPresent, MedusaError } from "@medusajs/utils" +import { MedusaError } from "@medusajs/utils" import { NextFunction } from "express" import { AuthenticatedMedusaRequest } from "../../../../types/routing" import { refetchEntities, refetchEntity } from "../../refetch-entity" @@ -9,106 +9,47 @@ export function setPricingContext() { const withCalculatedPrice = req.remoteQueryConfig.fields.some((field) => field.startsWith("variants.calculated_price") ) - - // If the endpoint doesn't pass region_id and currency_code, we can exit early - if ( - !withCalculatedPrice && - !req.filterableFields.region_id && - !req.filterableFields.currency_code - ) { + if (!withCalculatedPrice) { return next() } - // If the endpoint requested the field variants.calculated_price, we should throw - // an error if region or currency is not passed - if ( - withCalculatedPrice && - !req.filterableFields.region_id && - !req.filterableFields.currency_code - ) { + // We validate the region ID in the previous middleware + const region = await refetchEntity( + "region", + req.filterableFields.region_id!, + req.scope, + ["id", "currency_code"] + ) + + if (!region) { try { throw new MedusaError( MedusaError.Types.INVALID_DATA, - `Missing required pricing context to calculate prices - currency_code or region_id` + `Region with id ${req.filterableFields.region_id} not found when populating the pricing context` ) } catch (e) { return next(e) } } - const query = req.filterableFields || {} - const pricingContext: MedusaPricingContext = {} - const customerId = req.user?.customer_id - - if (query.region_id) { - const region = await refetchEntity("region", query.region_id, req.scope, [ - "id", - "currency_code", - ]) - - if (region) { - pricingContext.region_id = region.id - } - - if (region?.currency_code) { - pricingContext.currency_code = region.currency_code - } - - delete req.filterableFields.region_id - } - - // If a currency code is explicitly passed, we should be using that instead of the - // regions currency code - if (query.currency_code) { - const currency = await refetchEntity( - "currency", - { code: query.currency_code }, - req.scope, - ["code"] - ) - - if (currency) { - pricingContext.currency_code = currency.code - } - - delete req.filterableFields.currency_code + const pricingContext: MedusaPricingContext = { + region_id: region.id, + currency_code: region.currency_code, } // Find all the customer groups the customer is a part of and set - if (customerId) { + if (req.user?.customer_id) { const customerGroups = await refetchEntities( "customer_group", - { customer_id: customerId }, + { customer_id: req.user?.customer_id }, req.scope, ["id"] ) pricingContext.customer_group_id = customerGroups.map((cg) => cg.id) - - delete req.filterableFields.customer_id - } - - // If a currency_code is not present in the context, we will not be able to calculate prices - if ( - !isPresent(pricingContext) || - !isPresent(pricingContext.currency_code) - ) { - try { - throw new MedusaError( - MedusaError.Types.INVALID_DATA, - `Valid pricing parameters (currency_code or region_id) are required to calculate prices` - ) - } catch (e) { - return next(e) - } } req.pricingContext = pricingContext - - if (!withCalculatedPrice) { - req.remoteQueryConfig.fields.push("variants.calculated_price.*") - } - return next() } } diff --git a/packages/medusa/src/api/utils/middlewares/products/set-tax-context.ts b/packages/medusa/src/api/utils/middlewares/products/set-tax-context.ts new file mode 100644 index 0000000000..b87335a92d --- /dev/null +++ b/packages/medusa/src/api/utils/middlewares/products/set-tax-context.ts @@ -0,0 +1,73 @@ +import { TaxCalculationContext } from "@medusajs/types" +import { NextFunction } from "express" +import { + AuthenticatedMedusaRequest, + MedusaRequest, +} from "../../../../types/routing" +import { refetchEntity } from "../../refetch-entity" +import { MedusaError } from "@medusajs/utils" +import { RequestWithContext } from "../../../store/products/helpers" + +export function setTaxContext() { + return async (req: AuthenticatedMedusaRequest, _, next: NextFunction) => { + const withCalculatedPrice = req.remoteQueryConfig.fields.some((field) => + field.startsWith("variants.calculated_price") + ) + if (!withCalculatedPrice) { + return next() + } + + try { + const inclusivity = await getTaxInclusivityInfo(req) + if (!inclusivity || !inclusivity.automaticTaxes) { + return next() + } + + const taxLinesContext = await getTaxLinesContext(req) + + // TODO: Allow passing a context typings param to AuthenticatedMedusaRequest + ;(req as unknown as RequestWithContext).taxContext = { + taxLineContext: taxLinesContext, + taxInclusivityContext: inclusivity, + } + return next() + } catch (e) { + next(e) + } + } +} + +const getTaxInclusivityInfo = async (req: MedusaRequest) => { + const region = await refetchEntity( + "region", + req.filterableFields.region_id as string, + req.scope, + ["automatic_taxes"] + ) + + if (!region) { + throw new MedusaError( + MedusaError.Types.INVALID_DATA, + `Region with id ${req.filterableFields.region_id} not found when populating the tax context` + ) + } + + return { + automaticTaxes: region.automatic_taxes, + } +} + +const getTaxLinesContext = async (req: MedusaRequest) => { + if (!req.filterableFields.country_code) { + return + } + + const taxContext = { + address: { + country_code: req.filterableFields.country_code as string, + province_code: req.filterableFields.province as string, + }, + } as TaxCalculationContext + + return taxContext +} diff --git a/packages/modules/pricing/integration-tests/__tests__/services/pricing-module/calculate-price.spec.ts b/packages/modules/pricing/integration-tests/__tests__/services/pricing-module/calculate-price.spec.ts index dbcb92992f..8c5e3116d1 100644 --- a/packages/modules/pricing/integration-tests/__tests__/services/pricing-module/calculate-price.spec.ts +++ b/packages/modules/pricing/integration-tests/__tests__/services/pricing-module/calculate-price.spec.ts @@ -1449,6 +1449,39 @@ moduleIntegrationTestRunner({ ]) }) + it("should return the region tax inclusivity for the selected price when there are multiple region preferences", async () => { + await (service as any).createPricePreferences([ + { + attribute: "region_id", + value: "DE", + is_tax_inclusive: false, + }, + { + attribute: "region_id", + value: "PL", + is_tax_inclusive: true, + }, + ]) + + const priceSetsResult = await service.calculatePrices( + { id: ["price-set-PLN"] }, + { + context: { currency_code: "PLN", region_id: "PL" }, + } + ) + + expect(priceSetsResult).toEqual([ + expect.objectContaining({ + id: "price-set-PLN", + is_calculated_price_tax_inclusive: true, + calculated_amount: 300, + is_original_price_tax_inclusive: true, + original_amount: 300, + currency_code: "PLN", + }), + ]) + }) + it("should return the appropriate tax inclusive setting for each calculated and original price", async () => { await createPriceLists(service, {}, {}) await (service as any).createPricePreferences([ diff --git a/packages/modules/pricing/src/services/pricing-module.ts b/packages/modules/pricing/src/services/pricing-module.ts index 89ae2ea8cb..0101d2497b 100644 --- a/packages/modules/pricing/src/services/pricing-module.ts +++ b/packages/modules/pricing/src/services/pricing-module.ts @@ -334,7 +334,9 @@ export default class PricingModuleService is_calculated_price_price_list: !!calculatedPrice?.price_list_id, is_calculated_price_tax_inclusive: isTaxInclusive( priceRulesPriceMap.get(calculatedPrice.id), - pricingPreferences + pricingPreferences, + calculatedPrice.currency_code!, + pricingContext.context?.region_id as string ), calculated_amount: parseInt(calculatedPrice?.amount || "") || null, @@ -342,7 +344,9 @@ export default class PricingModuleService is_original_price_tax_inclusive: originalPrice?.id ? isTaxInclusive( priceRulesPriceMap.get(originalPrice.id), - pricingPreferences + pricingPreferences, + originalPrice.currency_code || calculatedPrice.currency_code!, + pricingContext.context?.region_id as string ) : false, original_amount: parseInt(originalPrice?.amount || "") || null, @@ -1451,19 +1455,23 @@ export default class PricingModuleService const isTaxInclusive = ( priceRules: PriceRule[], - preferences: PricePreference[] + preferences: PricePreference[], + currencyCode: string, + regionId?: string ) => { - const regionPreference = preferences.find((p) => p.attribute === "region_id") - const currencyPreference = preferences.find( - (p) => p.attribute === "currency_code" + const regionRule = priceRules?.find( + (rule) => rule.attribute === "region_id" && rule.value === regionId ) - const regionRule = priceRules?.find((rule) => rule.attribute === "region_id") - if ( - regionRule && - regionPreference && - regionRule.value === regionPreference.value - ) { + const regionPreference = preferences.find( + (p) => p.attribute === "region_id" && p.value === regionId + ) + + const currencyPreference = preferences.find( + (p) => p.attribute === "currency_code" && p.value === currencyCode + ) + + if (regionRule && regionPreference) { return regionPreference.is_tax_inclusive }