feat: Add necessary middlewares for tax inclusive pricing (#7827)

We are adding tax inclusive pricing calculation when listing products.

Two things to keep in mind:
- `region_id` will be required if you request calculated prices.
- We won't accept `currency_code` anymore, as that will come from the region info (since ultimately a cart and its currency are tied to a region)

REF CORE-2376
DEPENDS ON #8003
This commit is contained in:
Stevche Radevski
2024-07-09 11:37:13 +02:00
committed by GitHub
parent db6969578f
commit 1c3ef13371
22 changed files with 824 additions and 121 deletions

View File

@@ -1,7 +1,7 @@
import { HttpTypes } from "@medusajs/types"
export const getProductFixture = (
overrides: Partial<HttpTypes.AdminProduct>
overrides: Partial<HttpTypes.AdminCreateProduct>
) => ({
title: "Test fixture",
description: "test-product-description",

View File

@@ -892,7 +892,7 @@ medusaIntegrationTestRunner({
expect(error.response.status).toEqual(400)
expect(error.response.data).toEqual({
message:
"Missing required pricing context to calculate prices - currency_code or region_id",
"Missing required pricing context to calculate prices - region_id",
type: "invalid_data",
})
})
@@ -907,7 +907,7 @@ medusaIntegrationTestRunner({
).data.region
let response = await api.get(
`/store/products?fields=*variants.calculated_price&currency_code=usd`
`/store/products?fields=*variants.calculated_price&region_id=${region.id}`
)
const expectation = expect.arrayContaining([
@@ -948,12 +948,6 @@ medusaIntegrationTestRunner({
expect(response.data.count).toEqual(3)
expect(response.data.products).toEqual(expectation)
// Without calculated_price fields
response = await api.get(`/store/products?currency_code=usd`)
expect(response.status).toEqual(200)
expect(response.data.products).toEqual(expectation)
// with only region_id
response = await api.get(`/store/products?region_id=${region.id}`)
@@ -1205,7 +1199,7 @@ medusaIntegrationTestRunner({
expect(error.response.status).toEqual(400)
expect(error.response.data).toEqual({
message:
"Missing required pricing context to calculate prices - currency_code or region_id",
"Missing required pricing context to calculate prices - region_id",
type: "invalid_data",
})
})
@@ -1220,7 +1214,7 @@ medusaIntegrationTestRunner({
).data.region
let response = await api.get(
`/store/products/${product.id}?fields=*variants.calculated_price&currency_code=usd`
`/store/products/${product.id}?fields=*variants.calculated_price&region_id=${region.id}`
)
const expectation = expect.objectContaining({
@@ -1258,14 +1252,6 @@ medusaIntegrationTestRunner({
expect(response.status).toEqual(200)
expect(response.data.product).toEqual(expectation)
// Without calculated_price fields
response = await api.get(
`/store/products/${product.id}?currency_code=usd`
)
expect(response.status).toEqual(200)
expect(response.data.product).toEqual(expectation)
// with only region_id
response = await api.get(
`/store/products/${product.id}?region_id=${region.id}`
@@ -1275,5 +1261,317 @@ medusaIntegrationTestRunner({
expect(response.data.product).toEqual(expectation)
})
})
describe("Tax handling", () => {
let usRegion
let euRegion
let dkRegion
let euCart
beforeEach(async () => {
usRegion = (
await api.post(
"/admin/regions",
{
name: "Test Region",
currency_code: "usd",
countries: ["us"],
is_tax_inclusive: false,
automatic_taxes: false,
},
adminHeaders
)
).data.region
euRegion = (
await api.post(
"/admin/regions",
{
name: "Test Region",
currency_code: "eur",
countries: ["it", "de"],
is_tax_inclusive: true,
automatic_taxes: true,
},
adminHeaders
)
).data.region
dkRegion = (
await api.post(
"/admin/regions",
{
name: "Test Region",
currency_code: "dkk",
countries: ["dk"],
is_tax_inclusive: false,
automatic_taxes: true,
},
adminHeaders
)
).data.region
product1 = (
await api.post(
"/admin/products",
getProductFixture({
title: "test1",
status: "published",
variants: [
{
title: "Test taxes",
prices: [
{
amount: 45,
currency_code: "eur",
rules: { region_id: euRegion.id },
},
{
amount: 100,
currency_code: "usd",
rules: { region_id: usRegion.id },
},
{
amount: 30,
currency_code: "dkk",
rules: { region_id: dkRegion.id },
},
],
},
],
}),
adminHeaders
)
).data.product
euCart = (await api.post("/store/carts", { region_id: euRegion.id }))
.data.cart
await api.post(
`/admin/tax-regions`,
{
country_code: "us",
default_tax_rate: {
code: "default",
rate: 5,
name: "default rate",
},
},
adminHeaders
)
await api.post(
`/admin/tax-regions`,
{
country_code: "it",
default_tax_rate: {
code: "default",
rate: 10,
name: "default rate",
},
},
adminHeaders
)
await api.post(
`/admin/tax-regions`,
{
country_code: "dk",
default_tax_rate: {
code: "default",
rate: 20,
name: "default rate",
},
},
adminHeaders
)
})
it("should not return tax pricing if the context is not sufficient when listing products", async () => {
const products = (
await api.get(
`/store/products?fields=id,*variants.calculated_price&region_id=${usRegion.id}`
)
).data.products
expect(products.length).toBe(1)
expect(products[0].variants[0].calculated_price).not.toHaveProperty(
"calculated_amount_with_tax"
)
expect(products[0].variants[0].calculated_price).not.toHaveProperty(
"calculated_amount_without_tax"
)
})
it("should not return tax pricing if automatic taxes are off when listing products", async () => {
const products = (
await api.get(
`/store/products?fields=id,*variants.calculated_price&region_id=${usRegion.id}&country_code=us`
)
).data.products
expect(products.length).toBe(1)
expect(products[0].variants[0].calculated_price).not.toHaveProperty(
"calculated_amount_with_tax"
)
expect(products[0].variants[0].calculated_price).not.toHaveProperty(
"calculated_amount_without_tax"
)
})
it("should return prices with and without tax for a tax inclusive region when listing products", async () => {
const products = (
await api.get(
`/store/products?fields=id,*variants.calculated_price&region_id=${euRegion.id}&country_code=it`
)
).data.products
expect(products.length).toBe(1)
expect(products[0].variants).toEqual(
expect.arrayContaining([
expect.objectContaining({
calculated_price: expect.objectContaining({
currency_code: "eur",
calculated_amount: 45,
calculated_amount_with_tax: 45,
}),
}),
])
)
// TODO: Return an integer instead of a float for the pricing
expect(
products[0].variants[0].calculated_price.calculated_amount_without_tax.toFixed(
1
)
).toEqual("40.9")
})
it("should return prices with and without tax for a tax exclusive region when listing products", async () => {
const products = (
await api.get(
`/store/products?fields=id,*variants.calculated_price&region_id=${dkRegion.id}&country_code=dk`
)
).data.products
expect(products.length).toBe(1)
expect(products[0].variants).toEqual(
expect.arrayContaining([
expect.objectContaining({
calculated_price: expect.objectContaining({
currency_code: "dkk",
calculated_amount: 30,
calculated_amount_with_tax: 36,
calculated_amount_without_tax: 30,
}),
}),
])
)
})
it("should return prices with and without tax when the cart is available and a country is passed when listing products", async () => {
const products = (
await api.get(
`/store/products?fields=id,*variants.calculated_price&cart_id=${euCart.id}&country_code=it`
)
).data.products
expect(products.length).toBe(1)
expect(products[0].variants).toEqual(
expect.arrayContaining([
expect.objectContaining({
calculated_price: expect.objectContaining({
currency_code: "eur",
calculated_amount: 45,
calculated_amount_with_tax: 45,
}),
}),
])
)
// TODO: Return an integer instead of a float for the pricing
expect(
products[0].variants[0].calculated_price.calculated_amount_without_tax.toFixed(
1
)
).toEqual("40.9")
})
it("should return prices with and without tax when the cart context is available when listing products", async () => {
await api.post(`/store/carts/${euCart.id}`, {
shipping_address: {
country_code: "it",
},
})
const products = (
await api.get(
`/store/products?fields=id,*variants.calculated_price&cart_id=${euCart.id}`
)
).data.products
expect(products.length).toBe(1)
expect(products[0].variants).toEqual(
expect.arrayContaining([
expect.objectContaining({
calculated_price: expect.objectContaining({
currency_code: "eur",
calculated_amount: 45,
calculated_amount_with_tax: 45,
}),
}),
])
)
// TODO: Return an integer instead of a float for the pricing
expect(
products[0].variants[0].calculated_price.calculated_amount_without_tax.toFixed(
1
)
).toEqual("40.9")
})
it("should not return tax pricing if the context is not sufficient when fetching a single product", async () => {
const product = (
await api.get(
`/store/products/${product1.id}?fields=id,*variants.calculated_price&region_id=${usRegion.id}`
)
).data.product
expect(product.variants[0].calculated_price).not.toHaveProperty(
"calculated_amount_with_tax"
)
expect(product.variants[0].calculated_price).not.toHaveProperty(
"calculated_amount_without_tax"
)
})
it("should return prices with and without tax for a tax inclusive region when fetching a single product", async () => {
const product = (
await api.get(
`/store/products/${product1.id}?fields=id,*variants.calculated_price&region_id=${euRegion.id}&country_code=it`
)
).data.product
expect(product.variants).toEqual(
expect.arrayContaining([
expect.objectContaining({
calculated_price: expect.objectContaining({
currency_code: "eur",
calculated_amount: 45,
calculated_amount_with_tax: 45,
}),
}),
])
)
// TODO: Return an integer instead of a float for the pricing
expect(
product.variants[0].calculated_price.calculated_amount_without_tax.toFixed(
1
)
).toEqual("40.9")
})
})
},
})

View File

@@ -0,0 +1,105 @@
export interface BaseCalculatedPriceSet {
/**
* The ID of the price set.
*/
id: string
/**
* Whether the calculated price is associated with a price list. During the calculation process, if no valid price list is found,
* the calculated price is set to the original price, which doesn't belong to a price list. In that case, the value of this property is `false`.
*/
is_calculated_price_price_list?: boolean
/**
* Whether the calculated price is tax inclusive.
*/
is_calculated_price_tax_inclusive?: boolean
/**
* The amount of the calculated price, or `null` if there isn't a calculated price.
*/
calculated_amount: number | null
/**
* The amount of the calculated price with taxes included. If the calculated price is tax inclusive, this field will be the same as `calculated_amount`.
*/
calculated_amount_with_tax?: number | null
/**
* The amount of the calculated price without taxes included. If the calculated price is tax exclusive, this field will be the same as `calculated_amount`.
*/
calculated_amount_without_tax?: number | null
/**
* Whether the original price is associated with a price list. During the calculation process, if the price list of the calculated price is of type override,
* the original price will be the same as the calculated price. In that case, the value of this property is `true`.
*/
is_original_price_price_list?: boolean
/**
* Whether the original price is tax inclusive.
*/
is_original_price_tax_inclusive?: boolean
/**
* The amount of the original price, or `null` if there isn't a calculated price.
*/
original_amount: number | null
/**
* The currency code of the calculated price, or null if there isn't a calculated price.
*/
currency_code: string | null
/**
* The details of the calculated price.
*/
calculated_price?: {
/**
* The ID of the price selected as the calculated price.
*/
id: string | null
/**
* The ID of the associated price list, if any.
*/
price_list_id: string | null
/**
* The type of the associated price list, if any.
*/
price_list_type: string | null
/**
* The `min_quantity` field defined on a price.
*/
min_quantity: number | null
/**
* The `max_quantity` field defined on a price.
*/
max_quantity: number | null
}
/**
* The details of the original price.
*/
original_price?: {
/**
* The ID of the price selected as the original price.
*/
id: string | null
/**
* The ID of the associated price list, if any.
*/
price_list_id: string | null
/**
* The type of the associated price list, if any.
*/
price_list_type: string | null
/**
* The `min_quantity` field defined on a price.
*/
min_quantity: number | null
/**
* The `max_quantity` field defined on a price.
*/
max_quantity: number | null
}
}

View File

@@ -22,6 +22,8 @@ export interface AdminCreateProductVariantPrice {
amount: number
min_quantity?: number | null
max_quantity?: number | null
// Note: Although the BE is generic, we only use region_id for price rules for now, so it's better to keep the typings stricter.
rules?: { region_id: string } | null
}
export interface AdminCreateProductVariant {

View File

@@ -1,6 +1,7 @@
import { BaseFilterable, OperatorMap } from "../../dal"
import { BaseCollection } from "../collection/common"
import { FindParams } from "../common"
import { BaseCalculatedPriceSet } from "../pricing/common"
import { BaseProductCategory } from "../product-category/common"
import { BaseProductType } from "../product-type/common"
@@ -56,10 +57,11 @@ export interface BaseProductVariant {
length: number | null
height: number | null
width: number | null
variant_rank?: number | null
options: BaseProductOptionValue[] | null
product?: BaseProduct | null
product_id?: string
variant_rank?: number | null
calculated_price?: BaseCalculatedPriceSet
created_at: string
updated_at: string
deleted_at: string | null

View File

@@ -305,11 +305,11 @@ export interface ProductCategoryDTO {
*
* @expandable
*/
parent_category?: ProductCategoryDTO | null
parent_category: ProductCategoryDTO | null
/**
* The associated parent category id.
*/
parent_category_id?: string | null
parent_category_id: string | null
/**
* The associated child categories.
*
@@ -330,6 +330,10 @@ export interface ProductCategoryDTO {
* When the product category was updated.
*/
updated_at: string | Date
/**
* When the product category was deleted.
*/
deleted_at?: string | Date
}
/**

View File

@@ -513,7 +513,7 @@ interface TaxLineDTO {
/**
* The rate of the tax line.
*/
rate: number | null
rate: number
/**
* The code of the tax line.

View File

@@ -2,6 +2,7 @@ export * from "./cart"
export * from "./create-raw-properties-from-bignumber"
export * from "./line-item"
export * from "./math"
export * from "./tax"
export * from "./promotion"
export * from "./shipping-method"
export * from "./transform-properties-to-bignumber"

View File

@@ -31,3 +31,24 @@ export function calculateTaxTotal({
return taxTotal
}
export function calculateAmountsWithTax({
taxLines,
amount,
includesTax,
}: {
taxLines: Pick<TaxLineDTO, "rate">[]
amount: number
includesTax?: boolean
}) {
const tax = calculateTaxTotal({
taxLines,
includesTax,
taxableAmount: amount,
})
return {
priceWithTax: includesTax ? amount : MathBN.add(tax, amount).toNumber(),
priceWithoutTax: includesTax ? MathBN.sub(amount, tax).toNumber() : amount,
}
}

View File

@@ -1,11 +1,15 @@
import { MedusaError, isPresent } from "@medusajs/utils"
import { MedusaRequest, MedusaResponse } from "../../../../types/routing"
import { MedusaResponse } from "../../../../types/routing"
import { wrapVariantsWithInventoryQuantity } from "../../../utils/middlewares"
import { refetchProduct } from "../helpers"
import { StoreGetProductsParamsType } from "../validators"
import {
RequestWithContext,
refetchProduct,
wrapProductsWithTaxPrices,
} from "../helpers"
import { StoreGetProductParamsType } from "../validators"
export const GET = async (
req: MedusaRequest<StoreGetProductsParamsType>,
req: RequestWithContext<StoreGetProductParamsType>,
res: MedusaResponse
) => {
const withInventoryQuantity = req.remoteQueryConfig.fields.some((field) =>
@@ -46,5 +50,6 @@ export const GET = async (
await wrapVariantsWithInventoryQuantity(req, product.variants || [])
}
await wrapProductsWithTaxPrices(req, [product])
res.json({ product })
}

View File

@@ -1,6 +1,25 @@
import { MedusaContainer } from "@medusajs/types"
import {
ModuleRegistrationName,
calculateAmountsWithTax,
} from "@medusajs/utils"
import { MedusaRequest } from "../../../types/routing"
import { refetchEntities, refetchEntity } from "../../utils/refetch-entity"
import {
MedusaContainer,
HttpTypes,
TaxableItemDTO,
ItemTaxLineDTO,
TaxCalculationContext,
} from "@medusajs/types"
export type RequestWithContext<T> = MedusaRequest<T> & {
taxContext: {
taxLineContext?: TaxCalculationContext
taxInclusivityContext?: {
automaticTaxes: boolean
}
}
}
export const refetchProduct = async (
idOrFilter: string | object,
@@ -30,3 +49,81 @@ export const maybeApplyStockLocationId = async (req: MedusaRequest, ctx) => {
return entities.map((entity) => entity.stock_location_id)
}
export const wrapProductsWithTaxPrices = async <T>(
req: RequestWithContext<T>,
products: HttpTypes.StoreProduct[]
) => {
// If we are missing the necessary context, we can't calculate the tax, so only `calculated_amount` will be available
if (
!req.taxContext?.taxInclusivityContext ||
!req.taxContext?.taxLineContext
) {
return
}
// If automatic taxes are not enabled, we should skip calculating any tax
if (!req.taxContext.taxInclusivityContext.automaticTaxes) {
return
}
const taxService = req.scope.resolve(ModuleRegistrationName.TAX)
const taxRates = (await taxService.getTaxLines(
products.map(asTaxItem).flat(),
req.taxContext.taxLineContext
)) as unknown as ItemTaxLineDTO[]
const taxRatesMap = new Map<string, ItemTaxLineDTO[]>()
taxRates.forEach((taxRate) => {
if (!taxRatesMap.has(taxRate.line_item_id)) {
taxRatesMap.set(taxRate.line_item_id, [])
}
taxRatesMap.get(taxRate.line_item_id)?.push(taxRate)
})
products.forEach((product) => {
product.variants?.forEach((variant) => {
if (!variant.calculated_price) {
return
}
const taxRatesForVariant = taxRatesMap.get(variant.id) || []
const { priceWithTax, priceWithoutTax } = calculateAmountsWithTax({
taxLines: taxRatesForVariant,
amount: variant.calculated_price!.calculated_amount!,
includesTax:
variant.calculated_price!.is_calculated_price_tax_inclusive!,
})
variant.calculated_price.calculated_amount_with_tax = priceWithTax
variant.calculated_price.calculated_amount_without_tax = priceWithoutTax
})
})
}
const asTaxItem = (product: HttpTypes.StoreProduct): TaxableItemDTO[] => {
return product.variants
?.map((variant) => {
if (!variant.calculated_price) {
return
}
return {
id: variant.id,
product_id: product.id,
product_name: product.title,
product_categories: product.categories?.map((c) => c.name),
// TODO: It is strange that we only accept a single category, revisit the tax module implementation
product_category_id: product.categories?.[0]?.id,
product_sku: variant.sku,
product_type: product.type,
product_type_id: product.type_id,
quantity: 1,
unit_price: variant.calculated_price.calculated_amount,
currency_code: variant.calculated_price.currency_code,
}
})
.filter((v) => !!v) as unknown as TaxableItemDTO[]
}

View File

@@ -3,8 +3,11 @@ import { MiddlewareRoute } from "../../../loaders/helpers/routing/types"
import { maybeApplyLinkFilter } from "../../utils/maybe-apply-link-filter"
import {
applyDefaultFilters,
clearFiltersByKey,
filterByValidSalesChannels,
normalizeDataForContext,
setPricingContext,
setTaxContext,
} from "../../utils/middlewares"
import { setContext } from "../../utils/middlewares/common/set-context"
import { validateAndTransformQuery } from "../../utils/validate-query"
@@ -47,7 +50,10 @@ export const storeProductRoutesMiddlewares: MiddlewareRoute[] = [
return { id: categoryIds, is_internal: false, is_active: true }
},
}),
normalizeDataForContext(),
setPricingContext(),
setTaxContext(),
clearFiltersByKey(["region_id", "country_code", "province", "cart_id"]),
],
},
{
@@ -78,7 +84,10 @@ export const storeProductRoutesMiddlewares: MiddlewareRoute[] = [
return { is_internal: false, is_active: true }
},
}),
normalizeDataForContext(),
setPricingContext(),
setTaxContext(),
clearFiltersByKey(["region_id", "country_code", "province", "cart_id"]),
],
},
]

View File

@@ -3,12 +3,13 @@ import {
isPresent,
remoteQueryObjectFromString,
} from "@medusajs/utils"
import { MedusaRequest, MedusaResponse } from "../../../types/routing"
import { MedusaResponse } from "../../../types/routing"
import { wrapVariantsWithInventoryQuantity } from "../../utils/middlewares"
import { StoreGetProductsParamsType } from "./validators"
import { RequestWithContext, wrapProductsWithTaxPrices } from "./helpers"
export const GET = async (
req: MedusaRequest<StoreGetProductsParamsType>,
req: RequestWithContext<StoreGetProductsParamsType>,
res: MedusaResponse
) => {
const remoteQuery = req.scope.resolve(ContainerRegistrationKeys.REMOTE_QUERY)
@@ -48,6 +49,7 @@ export const GET = async (
)
}
await wrapProductsWithTaxPrices(req, products)
res.json({
products,
count: metadata.count,

View File

@@ -9,7 +9,17 @@ import {
createSelectParams,
} from "../../utils/validators"
export const StoreGetProductParams = createSelectParams()
export type StoreGetProductParamsType = z.infer<typeof StoreGetProductParams>
export const StoreGetProductParams = createSelectParams().merge(
// These are used to populate the tax and pricing context
z.object({
region_id: z.string().optional(),
country_code: z.string().optional(),
province: z.string().optional(),
cart_id: z.string().optional(),
})
)
export type StoreGetProductVariantsParamsType = z.infer<
typeof StoreGetProductVariantsParams
@@ -38,8 +48,12 @@ export const StoreGetProductsParams = createFindParams({
}).merge(
z
.object({
// These are used to populate the tax and pricing context
region_id: z.string().optional(),
currency_code: z.string().optional(),
country_code: z.string().optional(),
province: z.string().optional(),
cart_id: z.string().optional(),
variants: z
.object({
status: ProductStatusEnum.array().optional(),

View File

@@ -0,0 +1,12 @@
import { NextFunction } from "express"
import { MedusaRequest } from "../../../../types/routing"
export function clearFiltersByKey(keys: string[]) {
return async (req: MedusaRequest, _, next: NextFunction) => {
keys.forEach((key) => {
delete req.filterableFields[key]
})
return next()
}
}

View File

@@ -1 +1,2 @@
export * from "./apply-default-filters"
export * from "./clear-filters-by-key"

View File

@@ -1,4 +1,5 @@
export * from "./filter-by-valid-sales-channels"
export * from "./normalize-data-for-context"
export * from "./set-pricing-context"
export * from "./set-tax-context"
export * from "./variant-inventory-quantity"

View File

@@ -0,0 +1,74 @@
import { MedusaError } from "@medusajs/utils"
import { NextFunction } from "express"
import { AuthenticatedMedusaRequest } from "../../../../types/routing"
import { refetchEntities, refetchEntity } from "../../refetch-entity"
export function normalizeDataForContext() {
return async (req: AuthenticatedMedusaRequest, _, next: NextFunction) => {
// If the product pricing is not requested, we don't need region information
let withCalculatedPrice = req.remoteQueryConfig.fields.some((field) =>
field.startsWith("variants.calculated_price")
)
// If the region is passed, we calculate the prices without requesting them.
// TODO: This seems a bit messy, reconsider if we want to keep this logic.
if (!withCalculatedPrice && req.filterableFields.region_id) {
req.remoteQueryConfig.fields.push("variants.calculated_price.*")
withCalculatedPrice = true
}
if (!withCalculatedPrice) {
return next()
}
// Region ID is required to calculate prices correctly.
// Country code, and optionally province, are needed to calculate taxes.
let regionId = req.filterableFields.region_id
let countryCode = req.filterableFields.country_code
let province = req.filterableFields.province
// If the cart is passed, get the information from it
if (req.filterableFields.cart_id) {
const cart = await refetchEntity(
"cart",
req.filterableFields.cart_id,
req.scope,
["region_id", "shipping_address.*"]
)
if (cart?.region_id) {
regionId = cart.region_id
}
if (cart?.shipping_address) {
countryCode = cart.shipping_address.country_code
province = cart.shipping_address.province
}
}
// Finally, try to get it from the store defaults if not available
if (!regionId) {
const stores = await refetchEntities("store", {}, req.scope, [
"default_region_id",
])
regionId = stores[0]?.default_region_id
}
if (!regionId) {
try {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`Missing required pricing context to calculate prices - region_id`
)
} catch (e) {
return next(e)
}
}
req.filterableFields.region_id = regionId
req.filterableFields.country_code = countryCode
req.filterableFields.province = province
return next()
}
}

View File

@@ -1,5 +1,5 @@
import { MedusaPricingContext } from "@medusajs/types"
import { isPresent, MedusaError } from "@medusajs/utils"
import { MedusaError } from "@medusajs/utils"
import { NextFunction } from "express"
import { AuthenticatedMedusaRequest } from "../../../../types/routing"
import { refetchEntities, refetchEntity } from "../../refetch-entity"
@@ -9,106 +9,47 @@ export function setPricingContext() {
const withCalculatedPrice = req.remoteQueryConfig.fields.some((field) =>
field.startsWith("variants.calculated_price")
)
// If the endpoint doesn't pass region_id and currency_code, we can exit early
if (
!withCalculatedPrice &&
!req.filterableFields.region_id &&
!req.filterableFields.currency_code
) {
if (!withCalculatedPrice) {
return next()
}
// If the endpoint requested the field variants.calculated_price, we should throw
// an error if region or currency is not passed
if (
withCalculatedPrice &&
!req.filterableFields.region_id &&
!req.filterableFields.currency_code
) {
// We validate the region ID in the previous middleware
const region = await refetchEntity(
"region",
req.filterableFields.region_id!,
req.scope,
["id", "currency_code"]
)
if (!region) {
try {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`Missing required pricing context to calculate prices - currency_code or region_id`
`Region with id ${req.filterableFields.region_id} not found when populating the pricing context`
)
} catch (e) {
return next(e)
}
}
const query = req.filterableFields || {}
const pricingContext: MedusaPricingContext = {}
const customerId = req.user?.customer_id
if (query.region_id) {
const region = await refetchEntity("region", query.region_id, req.scope, [
"id",
"currency_code",
])
if (region) {
pricingContext.region_id = region.id
}
if (region?.currency_code) {
pricingContext.currency_code = region.currency_code
}
delete req.filterableFields.region_id
}
// If a currency code is explicitly passed, we should be using that instead of the
// regions currency code
if (query.currency_code) {
const currency = await refetchEntity(
"currency",
{ code: query.currency_code },
req.scope,
["code"]
)
if (currency) {
pricingContext.currency_code = currency.code
}
delete req.filterableFields.currency_code
const pricingContext: MedusaPricingContext = {
region_id: region.id,
currency_code: region.currency_code,
}
// Find all the customer groups the customer is a part of and set
if (customerId) {
if (req.user?.customer_id) {
const customerGroups = await refetchEntities(
"customer_group",
{ customer_id: customerId },
{ customer_id: req.user?.customer_id },
req.scope,
["id"]
)
pricingContext.customer_group_id = customerGroups.map((cg) => cg.id)
delete req.filterableFields.customer_id
}
// If a currency_code is not present in the context, we will not be able to calculate prices
if (
!isPresent(pricingContext) ||
!isPresent(pricingContext.currency_code)
) {
try {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`Valid pricing parameters (currency_code or region_id) are required to calculate prices`
)
} catch (e) {
return next(e)
}
}
req.pricingContext = pricingContext
if (!withCalculatedPrice) {
req.remoteQueryConfig.fields.push("variants.calculated_price.*")
}
return next()
}
}

View File

@@ -0,0 +1,73 @@
import { TaxCalculationContext } from "@medusajs/types"
import { NextFunction } from "express"
import {
AuthenticatedMedusaRequest,
MedusaRequest,
} from "../../../../types/routing"
import { refetchEntity } from "../../refetch-entity"
import { MedusaError } from "@medusajs/utils"
import { RequestWithContext } from "../../../store/products/helpers"
export function setTaxContext() {
return async (req: AuthenticatedMedusaRequest, _, next: NextFunction) => {
const withCalculatedPrice = req.remoteQueryConfig.fields.some((field) =>
field.startsWith("variants.calculated_price")
)
if (!withCalculatedPrice) {
return next()
}
try {
const inclusivity = await getTaxInclusivityInfo(req)
if (!inclusivity || !inclusivity.automaticTaxes) {
return next()
}
const taxLinesContext = await getTaxLinesContext(req)
// TODO: Allow passing a context typings param to AuthenticatedMedusaRequest
;(req as unknown as RequestWithContext<any>).taxContext = {
taxLineContext: taxLinesContext,
taxInclusivityContext: inclusivity,
}
return next()
} catch (e) {
next(e)
}
}
}
const getTaxInclusivityInfo = async (req: MedusaRequest) => {
const region = await refetchEntity(
"region",
req.filterableFields.region_id as string,
req.scope,
["automatic_taxes"]
)
if (!region) {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`Region with id ${req.filterableFields.region_id} not found when populating the tax context`
)
}
return {
automaticTaxes: region.automatic_taxes,
}
}
const getTaxLinesContext = async (req: MedusaRequest) => {
if (!req.filterableFields.country_code) {
return
}
const taxContext = {
address: {
country_code: req.filterableFields.country_code as string,
province_code: req.filterableFields.province as string,
},
} as TaxCalculationContext
return taxContext
}

View File

@@ -1449,6 +1449,39 @@ moduleIntegrationTestRunner<IPricingModuleService>({
])
})
it("should return the region tax inclusivity for the selected price when there are multiple region preferences", async () => {
await (service as any).createPricePreferences([
{
attribute: "region_id",
value: "DE",
is_tax_inclusive: false,
},
{
attribute: "region_id",
value: "PL",
is_tax_inclusive: true,
},
])
const priceSetsResult = await service.calculatePrices(
{ id: ["price-set-PLN"] },
{
context: { currency_code: "PLN", region_id: "PL" },
}
)
expect(priceSetsResult).toEqual([
expect.objectContaining({
id: "price-set-PLN",
is_calculated_price_tax_inclusive: true,
calculated_amount: 300,
is_original_price_tax_inclusive: true,
original_amount: 300,
currency_code: "PLN",
}),
])
})
it("should return the appropriate tax inclusive setting for each calculated and original price", async () => {
await createPriceLists(service, {}, {})
await (service as any).createPricePreferences([

View File

@@ -334,7 +334,9 @@ export default class PricingModuleService
is_calculated_price_price_list: !!calculatedPrice?.price_list_id,
is_calculated_price_tax_inclusive: isTaxInclusive(
priceRulesPriceMap.get(calculatedPrice.id),
pricingPreferences
pricingPreferences,
calculatedPrice.currency_code!,
pricingContext.context?.region_id as string
),
calculated_amount: parseInt(calculatedPrice?.amount || "") || null,
@@ -342,7 +344,9 @@ export default class PricingModuleService
is_original_price_tax_inclusive: originalPrice?.id
? isTaxInclusive(
priceRulesPriceMap.get(originalPrice.id),
pricingPreferences
pricingPreferences,
originalPrice.currency_code || calculatedPrice.currency_code!,
pricingContext.context?.region_id as string
)
: false,
original_amount: parseInt(originalPrice?.amount || "") || null,
@@ -1451,19 +1455,23 @@ export default class PricingModuleService
const isTaxInclusive = (
priceRules: PriceRule[],
preferences: PricePreference[]
preferences: PricePreference[],
currencyCode: string,
regionId?: string
) => {
const regionPreference = preferences.find((p) => p.attribute === "region_id")
const currencyPreference = preferences.find(
(p) => p.attribute === "currency_code"
const regionRule = priceRules?.find(
(rule) => rule.attribute === "region_id" && rule.value === regionId
)
const regionRule = priceRules?.find((rule) => rule.attribute === "region_id")
if (
regionRule &&
regionPreference &&
regionRule.value === regionPreference.value
) {
const regionPreference = preferences.find(
(p) => p.attribute === "region_id" && p.value === regionId
)
const currencyPreference = preferences.find(
(p) => p.attribute === "currency_code" && p.value === currencyCode
)
if (regionRule && regionPreference) {
return regionPreference.is_tax_inclusive
}