feat(medusa): Cart and totals computational optimizations (#2475)

This commit is contained in:
Adrien de Peretti
2022-11-14 15:45:47 +01:00
committed by GitHub
parent 755ba90c05
commit d2b1848e52
29 changed files with 3005 additions and 666 deletions

View File

@@ -1866,25 +1866,28 @@ describe("/admin/products", () => {
expect(res.status).toEqual(200)
expect(insertedVariant.prices).toEqual([
expect.objectContaining({
currency_code: "usd",
amount: 100,
min_quantity: null,
max_quantity: null,
variant_id: insertedVariant.id,
region_id: null,
}),
expect.objectContaining({
currency_code: "usd",
amount: 200,
min_quantity: null,
max_quantity: null,
price_list_id: null,
variant_id: insertedVariant.id,
region_id: "test-region",
}),
])
expect(insertedVariant.prices).toHaveLength(2)
expect(insertedVariant.prices).toEqual(
expect.arrayContaining([
expect.objectContaining({
currency_code: "usd",
amount: 100,
min_quantity: null,
max_quantity: null,
variant_id: insertedVariant.id,
region_id: null,
}),
expect.objectContaining({
currency_code: "usd",
amount: 200,
min_quantity: null,
max_quantity: null,
price_list_id: null,
variant_id: insertedVariant.id,
region_id: "test-region",
}),
])
)
})
})

View File

@@ -66,7 +66,9 @@ describe("/store/carts", () => {
tax_rate: 0,
})
await manager.query(
`UPDATE "country" SET region_id='region' WHERE iso_2 = 'us'`
`UPDATE "country"
SET region_id='region'
WHERE iso_2 = 'us'`
)
})
@@ -88,9 +90,12 @@ describe("/store/carts", () => {
const api = useApi()
await dbConnection.manager.query(
`UPDATE "country" SET region_id=null WHERE iso_2 = 'us'`
`UPDATE "country"
SET region_id=null
WHERE iso_2 = 'us'`
)
await dbConnection.manager.query(`DELETE from region`)
await dbConnection.manager.query(`DELETE
from region`)
try {
await api.post("/store/carts")
@@ -1679,7 +1684,9 @@ describe("/store/carts", () => {
const manager = dbConnection.manager
const api = useApi()
await manager.query(
`UPDATE "cart" SET completed_at=current_timestamp WHERE id = 'test-cart-2'`
`UPDATE "cart"
SET completed_at=current_timestamp
WHERE id = 'test-cart-2'`
)
try {
await api.post(`/store/carts/test-cart-2/complete-cart`)
@@ -1982,7 +1989,8 @@ describe("/store/carts", () => {
try {
await cartSeeder(dbConnection)
await dbConnection.manager.query(
`INSERT INTO "cart_discounts" (cart_id, discount_id) VALUES ('test-cart', 'free-shipping')`
`INSERT INTO "cart_discounts" (cart_id, discount_id)
VALUES ('test-cart', 'free-shipping')`
)
} catch (err) {
console.log(err)

View File

@@ -432,7 +432,6 @@ Object {
"claim_order_id": null,
"created_at": Any<Date>,
"description": "",
"discount_total": 0,
"fulfilled_quantity": 2,
"has_shipping": null,
"id": Any<String>,
@@ -440,13 +439,10 @@ Object {
"is_return": false,
"metadata": null,
"order_id": Any<String>,
"original_tax_total": 400,
"original_total": 2400,
"quantity": 2,
"returned_quantity": 1,
"shipped_quantity": 2,
"should_merge": true,
"subtotal": 2000,
"swap_id": null,
"tax_lines": Array [
Object {
@@ -460,10 +456,8 @@ Object {
"updated_at": Any<Date>,
},
],
"tax_total": 400,
"thumbnail": "",
"title": "Intelligent Plastic Chips",
"total": 2400,
"unit_price": 1000,
"updated_at": Any<Date>,
"variant": Object {
@@ -774,7 +768,6 @@ Object {
"claim_order_id": null,
"created_at": Any<Date>,
"description": "",
"discount_total": 0,
"fulfilled_quantity": null,
"has_shipping": null,
"id": "test-item",
@@ -782,14 +775,11 @@ Object {
"is_return": false,
"metadata": null,
"order_id": Any<String>,
"original_tax_total": 400,
"original_total": 2400,
"price": "10.00 USD",
"quantity": 2,
"returned_quantity": null,
"shipped_quantity": null,
"should_merge": true,
"subtotal": 2000,
"swap_id": null,
"tax_lines": Array [
Object {
@@ -803,10 +793,8 @@ Object {
"updated_at": Any<Date>,
},
],
"tax_total": 400,
"thumbnail": null,
"title": "Intelligent Plastic Chips",
"total": 2400,
"unit_price": 1000,
"updated_at": Any<Date>,
"variant": Object {
@@ -1003,7 +991,6 @@ Object {
"claim_order_id": null,
"created_at": Any<Date>,
"description": "",
"discount_total": 0,
"discounted_price": "12.00 USD",
"fulfilled_quantity": 2,
"has_shipping": null,
@@ -1012,14 +999,11 @@ Object {
"is_return": false,
"metadata": null,
"order_id": Any<String>,
"original_tax_total": 400,
"original_total": 2400,
"price": "12.00 USD",
"quantity": 2,
"returned_quantity": null,
"shipped_quantity": 2,
"should_merge": true,
"subtotal": 2000,
"swap_id": null,
"tax_lines": Array [
Object {
@@ -1033,10 +1017,8 @@ Object {
"updated_at": Any<Date>,
},
],
"tax_total": 400,
"thumbnail": null,
"title": "Intelligent Plastic Chips",
"total": 2400,
"totals": Object {
"discount_total": 0,
"original_tax_total": 400,
@@ -1276,7 +1258,6 @@ Object {
"claim_order_id": null,
"created_at": Any<Date>,
"description": "",
"discount_total": 0,
"fulfilled_quantity": 2,
"has_shipping": null,
"id": "test-item",
@@ -1284,13 +1265,10 @@ Object {
"is_return": false,
"metadata": null,
"order_id": Any<String>,
"original_tax_total": 400,
"original_total": 2400,
"quantity": 2,
"returned_quantity": null,
"shipped_quantity": 2,
"should_merge": true,
"subtotal": 2000,
"swap_id": null,
"tax_lines": Array [
Object {
@@ -1304,10 +1282,8 @@ Object {
"updated_at": Any<Date>,
},
],
"tax_total": 400,
"thumbnail": "",
"title": "Intelligent Plastic Chips",
"total": 2400,
"unit_price": 1000,
"updated_at": Any<Date>,
"variant": Object {
@@ -1603,7 +1579,6 @@ Object {
"claim_order_id": null,
"created_at": Any<Date>,
"description": "",
"discount_total": 0,
"fulfilled_quantity": 2,
"has_shipping": null,
"id": Any<String>,
@@ -1611,13 +1586,10 @@ Object {
"is_return": false,
"metadata": null,
"order_id": Any<String>,
"original_tax_total": 400,
"original_total": 2400,
"quantity": 2,
"returned_quantity": null,
"shipped_quantity": 2,
"should_merge": true,
"subtotal": 2000,
"swap_id": null,
"tax_lines": Array [
Object {
@@ -1631,10 +1603,8 @@ Object {
"updated_at": Any<Date>,
},
],
"tax_total": 400,
"thumbnail": "",
"title": "Intelligent Plastic Chips",
"total": 2400,
"unit_price": 1000,
"updated_at": Any<Date>,
"variant": Object {

View File

@@ -71,9 +71,15 @@ export default async (req, res) => {
relations: defaultAdminDraftOrdersRelations,
})
draftOrder.cart = await cartService.retrieveWithTotals(draftOrder.cart_id, {
relations: defaultAdminDraftOrdersCartRelations,
})
draftOrder.cart = await cartService.retrieveWithTotals(
draftOrder.cart_id,
{
relations: defaultAdminDraftOrdersCartRelations,
},
{
force_taxes: true,
}
)
res.json({ draft_order: draftOrder })
}

View File

@@ -83,16 +83,7 @@ export default async (req, res) => {
const cart = await cartService
.withTransaction(manager)
.retrieve(draftOrder.cart_id, {
select: ["total"],
relations: [
"discounts",
"discounts.rule",
"shipping_methods",
"region",
"items",
],
})
.retrieveWithTotals(draftOrder.cart_id)
await paymentProviderService
.withTransaction(manager)

View File

@@ -26,11 +26,23 @@ describe("GET /admin/orders", () => {
})
it("calls orderService retrieve", () => {
expect(OrderServiceMock.retrieve).toHaveBeenCalledTimes(1)
expect(OrderServiceMock.retrieve).toHaveBeenCalledWith(
expect(OrderServiceMock.retrieveWithTotals).toHaveBeenCalledTimes(1)
expect(OrderServiceMock.retrieveWithTotals).toHaveBeenCalledWith(
IdMap.getId("test-order"),
{
select: defaultAdminOrdersFields,
select: defaultAdminOrdersFields.filter((field) => {
return ![
"shipping_total",
"discount_total",
"tax_total",
"refunded_total",
"total",
"subtotal",
"refundable_amount",
"gift_card_total",
"gift_card_tax_total",
].includes(field)
}),
relations: defaultAdminOrdersRelations,
}
)

View File

@@ -56,7 +56,7 @@ export default async (req, res) => {
const orderService: OrderService = req.scope.resolve("orderService")
const order = await orderService.retrieve(id, req.retrieveConfig)
const order = await orderService.retrieveWithTotals(id, req.retrieveConfig)
res.json({ order })
}

View File

@@ -43,7 +43,19 @@ export default (app, featureFlagRouter: FlagRouter) => {
"/:id",
transformQuery(FindParams, {
defaultRelations: relations,
defaultFields: defaultAdminOrdersFields,
defaultFields: defaultAdminOrdersFields.filter((field) => {
return ![
"shipping_total",
"discount_total",
"tax_total",
"refunded_total",
"total",
"subtotal",
"refundable_amount",
"gift_card_total",
"gift_card_tax_total",
].includes(field)
}),
allowedFields: allowedAdminOrdersFields,
allowedRelations: allowedAdminOrdersRelations,
isList: false,

View File

@@ -0,0 +1,72 @@
import { asClass, asValue, createContainer } from "awilix"
import { IdMap, MockManager } from "medusa-test-utils"
import { taxProviderServiceMock } from "../__mocks__/tax-provider"
import { FlagRouter } from "../../utils/flag-router"
import {
GiftCard,
LineItem,
LineItemTaxLine,
ShippingMethod,
ShippingMethodTaxLine,
} from "../../models"
import TaxCalculationStrategy from "../../strategies/tax-calculation"
export const defaultContainerMock = createContainer()
defaultContainerMock.register("manager", asValue(MockManager))
defaultContainerMock.register(
"taxProviderService",
asValue(taxProviderServiceMock)
)
defaultContainerMock.register("featureFlagRouter", asValue(new FlagRouter({})))
defaultContainerMock.register(
"taxCalculationStrategy",
asClass(TaxCalculationStrategy)
)
export const lineItems = [
{
id: IdMap.getId("item_1_with_tax_lines"),
cart_id: "",
order_id: "",
swap_id: "",
claim_order_id: "",
title: "title",
description: "description",
unit_price: 1000,
quantity: 1,
tax_lines: [
{
id: IdMap.getId("item_1_with_tax_lines_tax_line_1"),
item_id: IdMap.getId("item_1_with_tax_lines"),
rate: 20,
name: "default",
code: "default",
},
] as LineItemTaxLine[],
},
] as LineItem[]
export const shippingMethods = [
{
id: IdMap.getId("sm_1_with_tax_lines"),
price: 1000,
tax_lines: [
{
id: IdMap.getId("sm_1_with_tax_lines_tax_line_1"),
shipping_method_id: IdMap.getId("sm_1_with_tax_lines"),
rate: 20,
name: "default",
code: "default",
},
] as ShippingMethodTaxLine[],
},
] as ShippingMethod[]
export const giftCards = [
{
id: IdMap.getId("gift_card_1"),
code: "CODE",
value: 10000,
balance: 10000,
},
] as GiftCard[]

View File

@@ -0,0 +1,25 @@
export const newTotalsServiceMock = {
withTransaction: function () {
return this
},
getLineItemTotals: jest.fn().mockImplementation(() => {
return Promise.resolve({})
}),
getGiftCardTotals: jest.fn().mockImplementation((order, lineItems) => {
return Promise.resolve({})
}),
getGiftCardTransactionsTotals: jest
.fn()
.mockImplementation((order, lineItems) => {
return Promise.resolve({})
}),
getShippingMethodTotals: jest.fn().mockImplementation((order, lineItems) => {
return Promise.resolve({})
}),
}
const mock = jest.fn().mockImplementation(() => {
return newTotalsServiceMock
})
export default mock

View File

@@ -170,6 +170,15 @@ export const OrderServiceMock = {
}
return Promise.resolve(undefined)
}),
retrieveWithTotals: jest.fn().mockImplementation((orderId) => {
if (orderId === IdMap.getId("test-order")) {
return Promise.resolve(orders.testOrder)
}
if (orderId === IdMap.getId("processed-order")) {
return Promise.resolve(orders.processedOrder)
}
return Promise.resolve(undefined)
}),
retrieveByCartId: jest.fn().mockImplementation((cartId) => {
return Promise.resolve({ id: IdMap.getId("test-order") })
}),

View File

@@ -8,6 +8,15 @@ export const taxProviderServiceMock = {
clearLineItemsTaxLines: jest.fn().mockImplementation((_) => {
return Promise.resolve()
}),
getTaxLines: jest.fn().mockImplementation((_) => {
return Promise.resolve([])
}),
getTaxLinesMap: jest.fn().mockImplementation((_) => {
return Promise.resolve({
lineItemsTaxLines: {},
shippingMethodsTaxLines: {},
})
}),
}
const mock = jest.fn().mockImplementation(() => {

View File

@@ -6,6 +6,7 @@ import { InventoryServiceMock } from "../__mocks__/inventory"
import { LineItemAdjustmentServiceMock } from "../__mocks__/line-item-adjustment"
import { FlagRouter } from "../../utils/flag-router"
import { taxProviderServiceMock } from "../__mocks__/tax-provider"
import { newTotalsServiceMock } from "../__mocks__/new-totals"
const eventBusService = {
emit: jest.fn(),
@@ -59,6 +60,7 @@ describe("CartService", () => {
totalsService,
cartRepository,
taxProviderService: taxProviderServiceMock,
newTotalsService: newTotalsServiceMock,
featureFlagRouter: new FlagRouter({}),
})
result = await cartService.retrieve(IdMap.getId("emptyCart"))
@@ -93,6 +95,7 @@ describe("CartService", () => {
cartRepository,
eventBusService,
taxProviderService: taxProviderServiceMock,
newTotalsService: newTotalsServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -180,6 +183,7 @@ describe("CartService", () => {
totalsService,
cartRepository,
customerService,
newTotalsService: newTotalsServiceMock,
regionService,
eventBusService,
taxProviderService: taxProviderServiceMock,
@@ -351,6 +355,7 @@ describe("CartService", () => {
cartRepository,
lineItemService,
lineItemRepository: MockRepository(),
newTotalsService: newTotalsServiceMock,
eventBusService,
shippingOptionService,
inventoryService,
@@ -585,6 +590,7 @@ describe("CartService", () => {
cartRepository,
lineItemService,
lineItemRepository: MockRepository(),
newTotalsService: newTotalsServiceMock,
eventBusService,
shippingOptionService,
inventoryService,
@@ -683,6 +689,7 @@ describe("CartService", () => {
cartRepository,
lineItemService,
lineItemRepository: MockRepository(),
newTotalsService: newTotalsServiceMock,
shippingOptionService,
eventBusService,
lineItemAdjustmentService: LineItemAdjustmentServiceMock,
@@ -794,6 +801,7 @@ describe("CartService", () => {
totalsService,
eventBusService,
taxProviderService: taxProviderServiceMock,
newTotalsService: newTotalsServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -880,6 +888,7 @@ describe("CartService", () => {
cartRepository,
lineItemService,
eventBusService,
newTotalsService: newTotalsServiceMock,
inventoryService,
lineItemAdjustmentService: LineItemAdjustmentServiceMock,
taxProviderService: taxProviderServiceMock,
@@ -965,6 +974,7 @@ describe("CartService", () => {
cartRepository,
eventBusService,
customerService,
newTotalsService: newTotalsServiceMock,
taxProviderService: taxProviderServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -1041,6 +1051,7 @@ describe("CartService", () => {
cartRepository,
addressRepository,
eventBusService,
newTotalsService: newTotalsServiceMock,
taxProviderService: taxProviderServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -1101,6 +1112,7 @@ describe("CartService", () => {
totalsService,
cartRepository,
eventBusService,
newTotalsService: newTotalsServiceMock,
taxProviderService: taxProviderServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -1247,6 +1259,7 @@ describe("CartService", () => {
addressRepository,
totalsService,
cartRepository,
newTotalsService: newTotalsServiceMock,
regionService,
lineItemService,
lineItemAdjustmentService: LineItemAdjustmentServiceMock,
@@ -1343,6 +1356,7 @@ describe("CartService", () => {
cartRepository,
eventBusService,
taxProviderService: taxProviderServiceMock,
newTotalsService: newTotalsServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -1467,6 +1481,7 @@ describe("CartService", () => {
paymentProviderService,
eventBusService,
taxProviderService: taxProviderServiceMock,
newTotalsService: newTotalsServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -1658,6 +1673,7 @@ describe("CartService", () => {
lineItemService,
eventBusService,
customShippingOptionService,
newTotalsService: newTotalsServiceMock,
taxProviderService: taxProviderServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -2015,6 +2031,7 @@ describe("CartService", () => {
eventBusService,
lineItemAdjustmentService: LineItemAdjustmentServiceMock,
taxProviderService: taxProviderServiceMock,
newTotalsService: newTotalsServiceMock,
featureFlagRouter: new FlagRouter({}),
})
@@ -2289,6 +2306,7 @@ describe("CartService", () => {
cartRepository,
eventBusService,
taxProviderService: taxProviderServiceMock,
newTotalsService: newTotalsServiceMock,
featureFlagRouter: new FlagRouter({}),
})

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,8 @@ import { IdMap, MockManager, MockRepository } from "medusa-test-utils"
import OrderService from "../order"
import { InventoryServiceMock } from "../__mocks__/inventory"
import { LineItemServiceMock } from "../__mocks__/line-item"
import { newTotalsServiceMock } from "../__mocks__/new-totals"
import { taxProviderServiceMock } from "../__mocks__/tax-provider"
describe("OrderService", () => {
const totalsService = {
@@ -141,6 +143,7 @@ describe("OrderService", () => {
paymentProviderService,
shippingOptionService,
totalsService,
newTotalsService: newTotalsServiceMock,
discountService,
regionService,
eventBusService,
@@ -184,6 +187,8 @@ describe("OrderService", () => {
{ id: "item_2", variant_id: "variant-2", quantity: 1 },
],
total: 100,
subtotal: 100,
discount_total: 0,
}
orderService.cartService_.retrieveWithTotals = jest.fn(() =>
@@ -209,17 +214,7 @@ describe("OrderService", () => {
expect(cartService.retrieveWithTotals).toHaveBeenCalledTimes(1)
expect(cartService.retrieveWithTotals).toHaveBeenCalledWith("cart_id", {
relations: [
"region",
"payment",
"items",
"discounts",
"discounts.rule",
"gift_cards",
"shipping_methods",
"items",
"items.adjustments",
],
relations: ["region", "payment"],
})
expect(paymentProviderService.updatePayment).toHaveBeenCalledTimes(1)
@@ -288,6 +283,7 @@ describe("OrderService", () => {
],
subtotal: 100,
total: 100,
discount_total: 0,
}
orderService.cartService_.retrieveWithTotals = () => {
@@ -380,9 +376,11 @@ describe("OrderService", () => {
{ id: "item_2", variant_id: "variant-2", quantity: 1 },
],
total: 0,
subtotal: 0,
discount_total: 0,
}
orderService.cartService_.retrieveWithTotals = () => Promise.resolve(cart)
await orderService.createFromCart(cart)
await orderService.createFromCart("cart_id")
const order = {
payment_status: "awaiting",
email: cart.email,
@@ -462,6 +460,7 @@ describe("OrderService", () => {
manager: MockManager,
orderRepository: orderRepo,
totalsService,
newTotalsService: newTotalsServiceMock,
})
beforeAll(async () => {
@@ -485,6 +484,7 @@ describe("OrderService", () => {
})
const orderService = new OrderService({
totalsService,
newTotalsService: newTotalsServiceMock,
manager: MockManager,
orderRepository: orderRepo,
})
@@ -527,6 +527,7 @@ describe("OrderService", () => {
})
const orderService = new OrderService({
totalsService,
newTotalsService: newTotalsServiceMock,
manager: MockManager,
orderRepository: orderRepo,
eventBusService,
@@ -638,6 +639,7 @@ describe("OrderService", () => {
const orderService = new OrderService({
totalsService,
newTotalsService: newTotalsServiceMock,
manager: MockManager,
orderRepository: orderRepo,
paymentProviderService,
@@ -738,6 +740,7 @@ describe("OrderService", () => {
orderRepository: orderRepo,
paymentProviderService,
totalsService,
newTotalsService: newTotalsServiceMock,
eventBusService,
})
@@ -857,6 +860,7 @@ describe("OrderService", () => {
fulfillmentService,
lineItemService,
totalsService,
newTotalsService: newTotalsServiceMock,
eventBusService,
})
@@ -1092,6 +1096,7 @@ describe("OrderService", () => {
orderRepository: orderRepo,
paymentProviderService,
totalsService,
newTotalsService: newTotalsServiceMock,
eventBusService,
})
@@ -1234,6 +1239,8 @@ describe("OrderService", () => {
eventBusService: eventBusService,
shippingOptionService: optionService,
totalsService,
taxProviderService: taxProviderServiceMock,
newTotalsService: newTotalsServiceMock,
})
beforeEach(async () => {
@@ -1254,8 +1261,14 @@ describe("OrderService", () => {
{ some: "data" },
{
order: {
discount_total: 0,
gift_card_tax_total: 0,
gift_card_total: 0,
id: IdMap.getId("order"),
items: [],
paid_total: 0,
refundable_amount: 0,
refunded_total: 0,
shipping_methods: [
{
shipping_option: {
@@ -1263,7 +1276,10 @@ describe("OrderService", () => {
},
},
],
shipping_total: 0,
subtotal: 0,
tax_total: 0,
total: 0,
},
}
)
@@ -1284,8 +1300,14 @@ describe("OrderService", () => {
{ some: "data" },
{
order: {
discount_total: 0,
gift_card_tax_total: 0,
gift_card_total: 0,
id: IdMap.getId("order"),
items: [],
paid_total: 0,
refundable_amount: 0,
refunded_total: 0,
shipping_methods: [
{
shipping_option: {
@@ -1293,7 +1315,10 @@ describe("OrderService", () => {
},
},
],
shipping_total: 0,
subtotal: 0,
tax_total: 0,
total: 0,
},
}
)
@@ -1391,6 +1416,7 @@ describe("OrderService", () => {
manager: MockManager,
orderRepository: orderRepo,
totalsService,
newTotalsService: newTotalsServiceMock,
fulfillmentService,
lineItemService,
eventBusService,
@@ -1511,6 +1537,7 @@ describe("OrderService", () => {
orderRepository: orderRepo,
paymentProviderService,
totalsService,
newTotalsService: newTotalsServiceMock,
eventBusService,
})

View File

@@ -1,4 +1,4 @@
import { IdMap, MockRepository, MockManager } from "medusa-test-utils"
import { IdMap, MockManager, MockRepository } from "medusa-test-utils"
import SwapService from "../swap"
import { InventoryServiceMock } from "../__mocks__/inventory"
@@ -16,7 +16,7 @@ import {
TotalsService,
} from "../index"
import CartService from "../cart"
import { Order, ReturnItem, Swap } from "../../models"
import { Order, Swap } from "../../models"
import { SwapRepository } from "../../repositories/swap"
import LineItemAdjustmentService from "../line-item-adjustment"
@@ -49,6 +49,9 @@ const cartService = {
withTransaction: function () {
return this
},
retrieveWithTotals: jest
.fn()
.mockReturnValue(Promise.resolve({ id: "cart" })),
} as unknown as CartService
const customShippingOptionService = {
@@ -826,6 +829,12 @@ describe("SwapService", () => {
withTransaction: function () {
return this
},
retrieveWithTotals: jest.fn().mockReturnValue(
Promise.resolve({
id: "cart",
items: [{ id: "test-item", variant_id: "variant" }],
})
),
} as unknown as CartService
const paymentProviderService = {
@@ -864,7 +873,8 @@ describe("SwapService", () => {
other: "data",
}
cartService.retrieve = (() => cart) as unknown as CartService["retrieve"]
cartService.retrieveWithTotals = (() =>
cart) as unknown as CartService["retrieveWithTotals"]
const swapRepo = MockRepository({
findOneWithRelations: () => Promise.resolve(existing),

View File

@@ -24,6 +24,7 @@ import {
CartCreateProps,
CartUpdateProps,
FilterableCartProps,
isCart,
LineItemUpdate,
} from "../types/cart"
import { AddressPayload, FindConfig, TotalField } from "../types/common"
@@ -35,7 +36,7 @@ import CustomerService from "./customer"
import DiscountService from "./discount"
import EventBusService from "./event-bus"
import GiftCardService from "./gift-card"
import { SalesChannelService } from "./index"
import { NewTotalsService, SalesChannelService } from "./index"
import InventoryService from "./inventory"
import LineItemService from "./line-item"
import LineItemAdjustmentService from "./line-item-adjustment"
@@ -70,6 +71,7 @@ type InjectedDependencies = {
discountService: DiscountService
giftCardService: GiftCardService
totalsService: TotalsService
newTotalsService: NewTotalsService
inventoryService: InventoryService
customShippingOptionService: CustomShippingOptionService
lineItemAdjustmentService: LineItemAdjustmentService
@@ -112,6 +114,7 @@ class CartService extends TransactionBaseService {
protected readonly giftCardService_: GiftCardService
protected readonly taxProviderService_: TaxProviderService
protected readonly totalsService_: TotalsService
protected readonly newTotalsService_: NewTotalsService
protected readonly inventoryService_: InventoryService
protected readonly customShippingOptionService_: CustomShippingOptionService
protected readonly priceSelectionStrategy_: IPriceSelectionStrategy
@@ -135,6 +138,7 @@ class CartService extends TransactionBaseService {
discountService,
giftCardService,
totalsService,
newTotalsService,
addressRepository,
paymentSessionRepository,
inventoryService,
@@ -163,6 +167,7 @@ class CartService extends TransactionBaseService {
this.discountService_ = discountService
this.giftCardService_ = giftCardService
this.totalsService_ = totalsService
this.newTotalsService_ = newTotalsService
this.addressRepository_ = addressRepository
this.paymentSessionRepository_ = paymentSessionRepository
this.inventoryService_ = inventoryService
@@ -175,126 +180,6 @@ class CartService extends TransactionBaseService {
this.storeService_ = storeService
}
private getTotalsRelations(config: FindConfig<Cart>): string[] {
const relationSet = new Set(config.relations)
relationSet.add("items")
relationSet.add("items.tax_lines")
relationSet.add("items.adjustments")
relationSet.add("gift_cards")
relationSet.add("discounts")
relationSet.add("discounts.rule")
relationSet.add("shipping_methods")
relationSet.add("shipping_methods.tax_lines")
relationSet.add("shipping_address")
relationSet.add("region")
relationSet.add("region.tax_rates")
return Array.from(relationSet.values())
}
protected transformQueryForTotals_(
config: FindConfig<Cart>
): FindConfig<Cart> & { totalsToSelect: TotalField[] } {
let { select, relations } = config
if (!select) {
return {
select,
relations,
totalsToSelect: [],
}
}
const totalFields = [
"subtotal",
"tax_total",
"shipping_total",
"discount_total",
"gift_card_total",
"total",
]
const totalsToSelect = select.filter((v) =>
totalFields.includes(v)
) as TotalField[]
if (totalsToSelect.length > 0) {
const relationSet = new Set(relations)
relationSet.add("items")
relationSet.add("items.tax_lines")
relationSet.add("gift_cards")
relationSet.add("discounts")
relationSet.add("discounts.rule")
// relationSet.add("discounts.parent_discount")
// relationSet.add("discounts.parent_discount.rule")
// relationSet.add("discounts.parent_discount.regions")
relationSet.add("shipping_methods")
relationSet.add("shipping_address")
relationSet.add("region")
relationSet.add("region.tax_rates")
relations = Array.from(relationSet.values())
select = select.filter((v) => !totalFields.includes(v))
}
return {
relations,
select,
totalsToSelect,
}
}
protected async decorateTotals_(
cart: Cart,
totalsToSelect: TotalField[],
options: TotalsConfig = { force_taxes: false }
): Promise<Cart> {
const totals: { [K in TotalField]?: number | null } = {}
for (const key of totalsToSelect) {
switch (key) {
case "total": {
totals.total = await this.totalsService_.getTotal(cart, {
force_taxes: options.force_taxes,
})
break
}
case "shipping_total": {
totals.shipping_total = await this.totalsService_.getShippingTotal(
cart
)
break
}
case "discount_total":
totals.discount_total = await this.totalsService_.getDiscountTotal(
cart
)
break
case "tax_total":
totals.tax_total = await this.totalsService_.getTaxTotal(
cart,
options.force_taxes
)
break
case "gift_card_total": {
const giftCardBreakdown = await this.totalsService_.getGiftCardTotal(
cart
)
totals.gift_card_total = giftCardBreakdown.total
totals.gift_card_tax_total = giftCardBreakdown.tax_total
break
}
case "subtotal":
totals.subtotal = await this.totalsService_.getSubtotal(cart)
break
default:
break
}
}
return Object.assign(cart, totals)
}
/**
* @param selector - the query object for find
* @param config - config object
@@ -315,12 +200,55 @@ class CartService extends TransactionBaseService {
* Gets a cart by id.
* @param cartId - the id of the cart to get.
* @param options - the options to get a cart
* @param totalsConfig
* @return the cart document.
*/
async retrieve(
cartId: string,
options: FindConfig<Cart> = {},
totalsConfig: TotalsConfig = {}
): Promise<Cart> {
const { totalsToSelect } = this.transformQueryForTotals_(options)
if (totalsToSelect.length) {
return await this.retrieveLegacy(cartId, options, totalsConfig)
}
const manager = this.manager_
const cartRepo = manager.getCustomRepository(this.cartRepository_)
const query = buildQuery({ id: cartId }, options)
if ((options.select || []).length === 0) {
query.select = undefined
}
const queryRelations = query.relations
query.relations = undefined
const raw = await cartRepo.findOneWithRelations(queryRelations, query)
if (!raw) {
throw new MedusaError(
MedusaError.Types.NOT_FOUND,
`Cart with ${cartId} was not found`
)
}
return raw
}
/**
* @deprecated
* @param cartId
* @param options
* @param totalsConfig
* @protected
*/
protected async retrieveLegacy(
cartId: string,
options: FindConfig<Cart> = {},
totalsConfig: TotalsConfig = {}
): Promise<Cart> {
const manager = this.manager_
const cartRepo = manager.getCustomRepository(this.cartRepository_)
@@ -349,34 +277,6 @@ class CartService extends TransactionBaseService {
return await this.decorateTotals_(raw, totalsToSelect, totalsConfig)
}
private async retrieveNew(
cartId: string,
options: FindConfig<Cart> = {}
): Promise<Cart> {
const manager = this.manager_
const cartRepo = manager.getCustomRepository(this.cartRepository_)
const query = buildQuery({ id: cartId }, options)
if ((options.select || []).length <= 0) {
query.select = undefined
}
const queryRelations = query.relations
query.relations = undefined
const raw = await cartRepo.findOneWithRelations(queryRelations, query)
if (!raw) {
throw new MedusaError(
MedusaError.Types.NOT_FOUND,
`Cart with ${cartId} was not found`
)
}
return raw
}
async retrieveWithTotals(
cartId: string,
options: FindConfig<Cart> = {},
@@ -384,7 +284,7 @@ class CartService extends TransactionBaseService {
): Promise<Cart> {
const relations = this.getTotalsRelations(options)
const cart = await this.retrieveNew(cartId, {
const cart = await this.retrieve(cartId, {
...options,
relations,
})
@@ -1413,7 +1313,9 @@ class CartService extends TransactionBaseService {
*/
async authorizePayment(
cartId: string,
context: Record<string, unknown> = {}
context: Record<string, unknown> & {
cart_id: string
} = { cart_id: "" }
): Promise<Cart> {
return await this.atomicPhase_(
async (transactionManager: EntityManager) => {
@@ -1421,27 +1323,23 @@ class CartService extends TransactionBaseService {
this.cartRepository_
)
const cart = await this.retrieve(cartId, {
select: ["total"],
relations: [
"items",
"items.adjustments",
"region",
"payment_sessions",
],
const cart = await this.retrieveWithTotals(cartId, {
relations: ["payment_sessions"],
})
if (typeof cart.total === "undefined") {
throw new MedusaError(
MedusaError.Types.UNEXPECTED_STATE,
"cart.total should be defined"
)
}
// If cart total is 0, we don't perform anything payment related
if (cart.total <= 0) {
if (cart.total! <= 0) {
cart.payment_authorized_at = new Date()
return await cartRepository.save(cart)
await cartRepository.save({
id: cart.id,
payment_authorized_at: cart.payment_authorized_at,
})
await this.eventBus_
.withTransaction(transactionManager)
.emit(CartService.Events.UPDATED, cart)
return cart
}
if (!cart.payment_session) {
@@ -1456,21 +1354,27 @@ class CartService extends TransactionBaseService {
.authorizePayment(cart.payment_session, context)) as PaymentSession
const freshCart = (await this.retrieve(cart.id, {
select: ["total"],
relations: ["payment_sessions", "items", "items.adjustments"],
relations: ["payment_sessions"],
})) as Cart & { payment_session: PaymentSession }
if (session.status === "authorized") {
freshCart.payment = await this.paymentProviderService_
.withTransaction(transactionManager)
.createPayment(freshCart)
.createPayment({
cart_id: cart.id,
currency_code: cart.region.currency_code,
amount: cart.total!,
payment_session: freshCart.payment_session,
})
freshCart.payment_authorized_at = new Date()
}
const updatedCart = await cartRepository.save(freshCart)
await this.eventBus_
.withTransaction(transactionManager)
.emit(CartService.Events.UPDATED, updatedCart)
return updatedCart
}
)
@@ -2147,23 +2051,25 @@ class CartService extends TransactionBaseService {
)
}
async createTaxLines(id: string): Promise<Cart> {
async createTaxLines(cartOrId: string | Cart): Promise<void> {
return await this.atomicPhase_(
async (transactionManager: EntityManager) => {
const cart = await this.retrieve(id, {
relations: [
"customer",
"discounts",
"discounts.rule",
"gift_cards",
"items",
"items.adjustments",
"region",
"region.tax_rates",
"shipping_address",
"shipping_methods",
],
})
const cart = isCart(cartOrId)
? cartOrId
: await this.retrieve(cartOrId, {
relations: [
"customer",
"discounts",
"discounts.rule",
"gift_cards",
"items",
"items.adjustments",
"region",
"region.tax_rates",
"shipping_address",
"shipping_methods",
],
})
const calculationContext = await this.totalsService_
.withTransaction(transactionManager)
@@ -2172,8 +2078,6 @@ class CartService extends TransactionBaseService {
await this.taxProviderService_
.withTransaction(transactionManager)
.createTaxLines(cart, calculationContext)
return cart
}
)
}
@@ -2197,63 +2101,86 @@ class CartService extends TransactionBaseService {
)
}
async decorateTotals(cart: Cart, totalsConfig?: TotalsConfig): Promise<Cart> {
const totalsService = this.totalsService_
async decorateTotals(
cart: Cart,
totalsConfig: TotalsConfig = {}
): Promise<Cart> {
const manager = this.transactionManager_ ?? this.manager_
const newTotalsServiceTx = this.newTotalsService_.withTransaction(manager)
const calculationContext = await totalsService.getCalculationContext(cart, {
exclude_shipping: true,
})
const calculationContext = await this.totalsService_.getCalculationContext(
cart
)
const includeTax = totalsConfig?.force_taxes || cart.region?.automatic_taxes
const cartItems = [...(cart.items ?? [])]
const cartShippingMethods = [...(cart.shipping_methods ?? [])]
cart.items = await Promise.all(
(cart.items || []).map(async (item) => {
const itemTotals = await totalsService.getLineItemTotals(item, cart, {
include_tax: totalsConfig?.force_taxes || cart.region.automatic_taxes,
calculation_context: calculationContext,
})
if (includeTax) {
const taxLinesMaps = await this.taxProviderService_
.withTransaction(manager)
.getTaxLinesMap(cartItems, calculationContext)
return Object.assign(item, itemTotals)
cartItems.forEach((item) => {
if (item.is_return) {
return
}
item.tax_lines = taxLinesMaps.lineItemsTaxLines[item.id] ?? []
})
cartShippingMethods.forEach((method) => {
method.tax_lines = taxLinesMaps.shippingMethodsTaxLines[method.id] ?? []
})
}
const itemsTotals = await newTotalsServiceTx.getLineItemTotals(cartItems, {
includeTax,
calculationContext,
})
const shippingTotals = await newTotalsServiceTx.getShippingMethodTotals(
cartShippingMethods,
{
discounts: cart.discounts,
includeTax,
calculationContext,
}
)
cart.shipping_methods = await Promise.all(
(cart.shipping_methods || []).map(async (shippingMethod) => {
const shippingTotals = await totalsService.getShippingMethodTotals(
cart.subtotal = 0
cart.discount_total = 0
cart.item_tax_total = 0
cart.shipping_total = 0
cart.shipping_tax_total = 0
cart.items = (cart.items || []).map((item) => {
const itemWithTotals = Object.assign(item, itemsTotals[item.id] ?? {})
cart.subtotal! += itemWithTotals.subtotal ?? 0
cart.discount_total! += itemWithTotals.discount_total ?? 0
cart.item_tax_total! += itemWithTotals.tax_total ?? 0
return itemWithTotals
})
cart.shipping_methods = (cart.shipping_methods || []).map(
(shippingMethod) => {
const methodWithTotals = Object.assign(
shippingMethod,
cart,
{
include_tax:
totalsConfig?.force_taxes || cart.region.automatic_taxes,
calculation_context: calculationContext,
}
shippingTotals[shippingMethod.id] ?? {}
)
return Object.assign(shippingMethod, shippingTotals)
})
cart.shipping_total! += methodWithTotals.subtotal ?? 0
cart.shipping_tax_total! += methodWithTotals.tax_total ?? 0
return methodWithTotals
}
)
cart.shipping_total = cart.shipping_methods.reduce((acc, method) => {
return acc + (method.subtotal ?? 0)
}, 0)
cart.subtotal = cart.items.reduce((acc, item) => {
return acc + (item.subtotal ?? 0)
}, 0)
cart.discount_total = cart.items.reduce((acc, item) => {
return acc + (item.discount_total ?? 0)
}, 0)
cart.item_tax_total = cart.items.reduce((acc, item) => {
return acc + (item.tax_total ?? 0)
}, 0)
cart.shipping_tax_total = cart.shipping_methods.reduce((acc, method) => {
return acc + (method.tax_total ?? 0)
}, 0)
const giftCardTotal = await totalsService.getGiftCardTotal(cart, {
gift_cardable: cart.subtotal - cart.discount_total,
})
const giftCardTotal = await this.newTotalsService_.getGiftCardTotals(
cart.subtotal - cart.discount_total,
{
region: cart.region,
giftCards: cart.gift_cards,
}
)
cart.gift_card_total = giftCardTotal.total || 0
cart.gift_card_tax_total = giftCardTotal.tax_total || 0
@@ -2287,6 +2214,133 @@ class CartService extends TransactionBaseService {
.withTransaction(transactionManager)
.createAdjustments(cart)
}
protected transformQueryForTotals_(
config: FindConfig<Cart>
): FindConfig<Cart> & { totalsToSelect: TotalField[] } {
let { select, relations } = config
if (!select) {
return {
select,
relations,
totalsToSelect: [],
}
}
const totalFields = [
"subtotal",
"tax_total",
"shipping_total",
"discount_total",
"gift_card_total",
"total",
]
const totalsToSelect = select.filter((v) =>
totalFields.includes(v)
) as TotalField[]
if (totalsToSelect.length > 0) {
const relationSet = new Set(relations)
relationSet.add("items")
relationSet.add("items.tax_lines")
relationSet.add("gift_cards")
relationSet.add("discounts")
relationSet.add("discounts.rule")
// relationSet.add("discounts.parent_discount")
// relationSet.add("discounts.parent_discount.rule")
// relationSet.add("discounts.parent_discount.regions")
relationSet.add("shipping_methods")
relationSet.add("shipping_address")
relationSet.add("region")
relationSet.add("region.tax_rates")
relations = Array.from(relationSet.values())
select = select.filter((v) => !totalFields.includes(v))
}
return {
relations,
select,
totalsToSelect,
}
}
/**
* @deprecated Use decorateTotals instead
* @param cart
* @param totalsToSelect
* @param options
* @protected
*/
protected async decorateTotals_(
cart: Cart,
totalsToSelect: TotalField[],
options: TotalsConfig = { force_taxes: false }
): Promise<Cart> {
const totals: { [K in TotalField]?: number | null } = {}
for (const key of totalsToSelect) {
switch (key) {
case "total": {
totals.total = await this.totalsService_.getTotal(cart, {
force_taxes: options.force_taxes,
})
break
}
case "shipping_total": {
totals.shipping_total = await this.totalsService_.getShippingTotal(
cart
)
break
}
case "discount_total":
totals.discount_total = await this.totalsService_.getDiscountTotal(
cart
)
break
case "tax_total":
totals.tax_total = await this.totalsService_.getTaxTotal(
cart,
options.force_taxes
)
break
case "gift_card_total": {
const giftCardBreakdown = await this.totalsService_.getGiftCardTotal(
cart
)
totals.gift_card_total = giftCardBreakdown.total
totals.gift_card_tax_total = giftCardBreakdown.tax_total
break
}
case "subtotal":
totals.subtotal = await this.totalsService_.getSubtotal(cart)
break
default:
break
}
}
return Object.assign(cart, totals)
}
private getTotalsRelations(config: FindConfig<Cart>): string[] {
const relationSet = new Set(config.relations)
relationSet.add("items")
relationSet.add("items.tax_lines")
relationSet.add("items.adjustments")
relationSet.add("gift_cards")
relationSet.add("discounts")
relationSet.add("discounts.rule")
relationSet.add("shipping_methods")
relationSet.add("shipping_methods.tax_lines")
relationSet.add("shipping_address")
relationSet.add("region")
relationSet.add("region.tax_rates")
return Array.from(relationSet.values())
}
}
export default CartService

View File

@@ -49,4 +49,5 @@ export { default as SystemPaymentProviderService } from "./system-payment-provid
export { default as TaxProviderService } from "./tax-provider"
export { default as TaxRateService } from "./tax-rate"
export { default as TotalsService } from "./totals"
export { default as NewTotalsService } from "./new-totals"
export { default as UserService } from "./user"

View File

@@ -71,24 +71,22 @@ class InventoryService extends TransactionBaseService {
return true
}
return await this.atomicPhase_(async (manager) => {
const variant = await this.productVariantService_
.withTransaction(manager)
.retrieve(variantId)
const { inventory_quantity, allow_backorder, manage_inventory } = variant
const isCovered =
!manage_inventory || allow_backorder || inventory_quantity >= quantity
const variant = await this.productVariantService_
.withTransaction(this.manager_)
.retrieve(variantId)
const { inventory_quantity, allow_backorder, manage_inventory } = variant
const isCovered =
!manage_inventory || allow_backorder || inventory_quantity >= quantity
if (!isCovered) {
throw new MedusaError(
MedusaError.Types.NOT_ALLOWED,
`Variant with id: ${variant.id} does not have the required inventory`,
MedusaError.Codes.INSUFFICIENT_INVENTORY
)
}
if (!isCovered) {
throw new MedusaError(
MedusaError.Types.NOT_ALLOWED,
`Variant with id: ${variant.id} does not have the required inventory`,
MedusaError.Codes.INSUFFICIENT_INVENTORY
)
}
return isCovered
})
return isCovered
}
}

View File

@@ -0,0 +1,737 @@
import {
ITaxCalculationStrategy,
TaxCalculationContext,
TransactionBaseService,
} from "../interfaces"
import { EntityManager } from "typeorm"
import {
Discount,
DiscountRuleType,
GiftCard,
LineItem,
LineItemTaxLine,
Region,
ShippingMethod,
ShippingMethodTaxLine,
} from "../models"
import { TaxProviderService } from "./index"
import { LineAllocationsMap } from "../types/totals"
import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing"
import { FlagRouter } from "../utils/flag-router"
import { calculatePriceTaxAmount, isDefined } from "../utils"
import { MedusaError } from "medusa-core-utils"
type LineItemTotals = {
unit_price: number
quantity: number
subtotal: number
tax_total: number
total: number
original_total: number
original_tax_total: number
tax_lines: LineItemTaxLine[]
discount_total: number
}
type ShippingMethodTotals = {
price: number
tax_total: number
total: number
subtotal: number
original_total: number
original_tax_total: number
tax_lines: ShippingMethodTaxLine[]
}
type InjectedDependencies = {
manager: EntityManager
taxProviderService: TaxProviderService
taxCalculationStrategy: ITaxCalculationStrategy
featureFlagRouter: FlagRouter
}
export default class NewTotalsService extends TransactionBaseService {
protected readonly manager_: EntityManager
protected readonly transactionManager_: EntityManager | undefined
protected readonly taxProviderService_: TaxProviderService
protected readonly featureFlagRouter_: FlagRouter
protected readonly taxCalculationStrategy_: ITaxCalculationStrategy
constructor({
manager,
taxProviderService,
featureFlagRouter,
taxCalculationStrategy,
}: InjectedDependencies) {
super(arguments[0])
this.manager_ = manager
this.taxProviderService_ = taxProviderService
this.featureFlagRouter_ = featureFlagRouter
this.taxCalculationStrategy_ = taxCalculationStrategy
}
/**
* Calculate and return the items totals for either the legacy calculation or the new calculation
* @param items
* @param includeTax
* @param calculationContext
* @param taxRate
*/
async getLineItemTotals(
items: LineItem | LineItem[],
{
includeTax,
calculationContext,
taxRate,
}: {
includeTax?: boolean
calculationContext: TaxCalculationContext
taxRate?: number | null
}
): Promise<{ [lineItemId: string]: LineItemTotals }> {
items = Array.isArray(items) ? items : [items]
const manager = this.transactionManager_ ?? this.manager_
let lineItemsTaxLinesMap: { [lineItemId: string]: LineItemTaxLine[] } = {}
if (!taxRate && includeTax) {
// Use existing tax lines if they are present
const itemContainsTaxLines = items.some((item) => item.tax_lines?.length)
if (itemContainsTaxLines) {
items.forEach((item) => {
lineItemsTaxLinesMap[item.id] = item.tax_lines ?? []
})
} else {
const { lineItemsTaxLines } = await this.taxProviderService_
.withTransaction(manager)
.getTaxLinesMap(items, calculationContext)
lineItemsTaxLinesMap = lineItemsTaxLines
}
}
const calculationMethod = taxRate
? this.getLineItemTotalsLegacy.bind(this)
: this.getLineItemTotals_.bind(this)
const itemsTotals: { [lineItemId: string]: LineItemTotals } = {}
for (const item of items) {
const lineItemAllocation =
calculationContext.allocation_map[item.id] || {}
itemsTotals[item.id] = await calculationMethod(item, {
taxRate,
includeTax,
lineItemAllocation,
taxLines: lineItemsTaxLinesMap[item.id],
calculationContext,
})
}
return itemsTotals
}
/**
* Calculate and return the totals for an item
* @param item
* @param includeTax
* @param lineItemAllocation
* @param taxLines Only needed to force the usage of the specified tax lines, often in the case where the item does not hold the tax lines
* @param calculationContext
*/
protected async getLineItemTotals_(
item: LineItem,
{
includeTax,
lineItemAllocation,
taxLines,
calculationContext,
}: {
includeTax?: boolean
lineItemAllocation: LineAllocationsMap[number]
taxLines?: LineItemTaxLine[]
calculationContext: TaxCalculationContext
}
): Promise<LineItemTotals> {
let subtotal = item.unit_price * item.quantity
if (
this.featureFlagRouter_.isFeatureEnabled(
TaxInclusivePricingFeatureFlag.key
) &&
item.includes_tax
) {
subtotal = 0 // in that case we need to know the tax rate to compute it later
}
const discount_total =
(lineItemAllocation.discount?.unit_amount || 0) * item.quantity
const totals: LineItemTotals = {
unit_price: item.unit_price,
quantity: item.quantity,
subtotal,
discount_total,
total: subtotal - discount_total,
original_total: subtotal,
original_tax_total: 0,
tax_total: 0,
tax_lines: item.tax_lines ?? [],
}
if (includeTax) {
totals.tax_lines = totals.tax_lines.length
? totals.tax_lines
: (taxLines as LineItemTaxLine[])
if (!totals.tax_lines) {
throw new MedusaError(
MedusaError.Types.UNEXPECTED_STATE,
"Tax Lines must be joined to calculate taxes"
)
}
}
if (item.is_return) {
if (!isDefined(item.tax_lines)) {
throw new MedusaError(
MedusaError.Types.UNEXPECTED_STATE,
"Return Line Items must join tax lines"
)
}
}
if (totals.tax_lines.length > 0) {
totals.tax_total = await this.taxCalculationStrategy_.calculate(
[item],
totals.tax_lines,
calculationContext
)
const noDiscountContext = {
...calculationContext,
allocation_map: {}, // Don't account for discounts
}
totals.original_tax_total = await this.taxCalculationStrategy_.calculate(
[item],
totals.tax_lines,
noDiscountContext
)
if (
this.featureFlagRouter_.isFeatureEnabled(
TaxInclusivePricingFeatureFlag.key
) &&
item.includes_tax
) {
totals.subtotal +=
totals.unit_price * totals.quantity - totals.original_tax_total
totals.total += totals.subtotal
totals.original_total += totals.subtotal
}
totals.total += totals.tax_total
totals.original_total += totals.original_tax_total
}
return totals
}
/**
* Calculate and return the legacy calculated totals using the tax rate
* @param item
* @param taxRate
* @param lineItemAllocation
* @param calculationContext
*/
protected async getLineItemTotalsLegacy(
item: LineItem,
{
taxRate,
lineItemAllocation,
calculationContext,
}: {
lineItemAllocation: LineAllocationsMap[number]
calculationContext: TaxCalculationContext
taxRate: number
}
): Promise<LineItemTotals> {
let subtotal = item.unit_price * item.quantity
if (
this.featureFlagRouter_.isFeatureEnabled(
TaxInclusivePricingFeatureFlag.key
) &&
item.includes_tax
) {
subtotal = 0 // in that case we need to know the tax rate to compute it later
}
const discount_total =
(lineItemAllocation.discount?.unit_amount || 0) * item.quantity
const totals: LineItemTotals = {
unit_price: item.unit_price,
quantity: item.quantity,
subtotal,
discount_total,
total: subtotal - discount_total,
original_total: subtotal,
original_tax_total: 0,
tax_total: 0,
tax_lines: [],
}
taxRate = taxRate / 100
const includesTax =
this.featureFlagRouter_.isFeatureEnabled(
TaxInclusivePricingFeatureFlag.key
) && item.includes_tax
const taxIncludedInPrice = !item.includes_tax
? 0
: Math.round(
calculatePriceTaxAmount({
price: item.unit_price,
taxRate: taxRate,
includesTax,
})
)
totals.subtotal = Math.round(
(item.unit_price - taxIncludedInPrice) * item.quantity
)
totals.total = totals.subtotal
totals.original_tax_total = Math.round(totals.subtotal * taxRate)
totals.tax_total = Math.round((totals.subtotal - discount_total) * taxRate)
totals.total += totals.tax_total
if (includesTax) {
totals.original_total += totals.subtotal
}
totals.original_total += totals.original_tax_total
return totals
}
/**
* Return the amount that can be refund on a line item
* @param lineItem
* @param calculationContext
* @param taxRate
*/
getLineItemRefund(
lineItem: {
id: string
unit_price: number
includes_tax: boolean
quantity: number
tax_lines: LineItemTaxLine[]
},
{
calculationContext,
taxRate,
}: { calculationContext: TaxCalculationContext; taxRate?: number | null }
): number {
/*
* Used for backcompat with old tax system
*/
if (taxRate != null) {
return this.getLineItemRefundLegacy(lineItem, {
calculationContext,
taxRate,
})
}
const includesTax =
this.featureFlagRouter_.isFeatureEnabled(
TaxInclusivePricingFeatureFlag.key
) && lineItem.includes_tax
const discountAmount =
(calculationContext.allocation_map[lineItem.id]?.discount?.unit_amount ||
0) * lineItem.quantity
if (!isDefined(lineItem.tax_lines)) {
throw new MedusaError(
MedusaError.Types.UNEXPECTED_STATE,
"Cannot compute line item refund amount, tax lines are missing from the line item"
)
}
const totalTaxRate = lineItem.tax_lines.reduce((acc, next) => {
return acc + next.rate / 100
}, 0)
const taxAmountIncludedInPrice = !includesTax
? 0
: Math.round(
calculatePriceTaxAmount({
price: lineItem.unit_price,
taxRate: totalTaxRate,
includesTax,
})
)
const lineSubtotal =
(lineItem.unit_price - taxAmountIncludedInPrice) * lineItem.quantity -
discountAmount
const taxTotal = lineItem.tax_lines.reduce((acc, next) => {
return acc + Math.round(lineSubtotal * (next.rate / 100))
}, 0)
return lineSubtotal + taxTotal
}
/**
* @param lineItem
* @param calculationContext
* @param taxRate
* @protected
*/
protected getLineItemRefundLegacy(
lineItem: {
id: string
unit_price: number
includes_tax: boolean
quantity: number
},
{
calculationContext,
taxRate,
}: { calculationContext: TaxCalculationContext; taxRate: number }
): number {
const includesTax =
this.featureFlagRouter_.isFeatureEnabled(
TaxInclusivePricingFeatureFlag.key
) && lineItem.includes_tax
const taxAmountIncludedInPrice = !includesTax
? 0
: Math.round(
calculatePriceTaxAmount({
price: lineItem.unit_price,
taxRate: taxRate / 100,
includesTax,
})
)
const discountAmount =
(calculationContext.allocation_map[lineItem.id]?.discount?.unit_amount ||
0) * lineItem.quantity
const lineSubtotal =
(lineItem.unit_price - taxAmountIncludedInPrice) * lineItem.quantity -
discountAmount
return Math.round(lineSubtotal * (1 + taxRate / 100))
}
/**
* Calculate and return the gift cards totals
* @param giftCardableAmount
* @param giftCardTransactions
* @param region
* @param giftCards
*/
async getGiftCardTotals(
giftCardableAmount: number,
{
giftCardTransactions,
region,
giftCards,
}: {
region: Region
giftCardTransactions?: {
tax_rate: number | null
is_taxable: boolean | null
amount: number
}[]
giftCards?: GiftCard[]
}
): Promise<{
total: number
tax_total: number
}> {
if (!giftCards && !giftCardTransactions) {
throw new MedusaError(
MedusaError.Types.UNEXPECTED_STATE,
"Cannot calculate the gift cart totals. Neither the gift cards or gift card transactions have been provided"
)
}
if (giftCardTransactions) {
return this.getGiftCardTransactionsTotals({
giftCardTransactions,
region,
})
}
const result = {
total: 0,
tax_total: 0,
}
if (!giftCards?.length) {
return result
}
const giftAmount = giftCards.reduce((acc, next) => acc + next.balance, 0)
result.total = Math.min(giftCardableAmount, giftAmount)
if (region?.gift_cards_taxable) {
result.tax_total = Math.round(result.total * (region.tax_rate / 100))
return result
}
return result
}
/**
* Calculate and return the gift cards totals based on their transactions
* @param gift_card_transactions
* @param region
*/
getGiftCardTransactionsTotals({
giftCardTransactions,
region,
}: {
giftCardTransactions: {
tax_rate: number | null
is_taxable: boolean | null
amount: number
}[]
region: { gift_cards_taxable: boolean; tax_rate: number }
}): { total: number; tax_total: number } {
return giftCardTransactions.reduce(
(acc, next) => {
let taxMultiplier = (next.tax_rate || 0) / 100
// Previously we did not record whether a gift card was taxable or not.
// All gift cards where is_taxable === null are from the old system,
// where we defaulted to taxable gift cards.
//
// This is a backwards compatability fix for orders that were created
// before we added the gift card tax rate.
if (next.is_taxable === null && region?.gift_cards_taxable) {
taxMultiplier = region.tax_rate / 100
}
return {
total: acc.total + next.amount,
tax_total: acc.tax_total + next.amount * taxMultiplier,
}
},
{
total: 0,
tax_total: 0,
}
)
}
/**
* Calculate and return the shipping methods totals for either the legacy calculation or the new calculation
* @param shippingMethods
* @param includeTax
* @param discounts
* @param taxRate
* @param calculationContext
*/
async getShippingMethodTotals(
shippingMethods: ShippingMethod | ShippingMethod[],
{
includeTax,
discounts,
taxRate,
calculationContext,
}: {
includeTax?: boolean
calculationContext: TaxCalculationContext
discounts?: Discount[]
taxRate?: number | null
}
): Promise<{ [shippingMethodId: string]: ShippingMethodTotals }> {
shippingMethods = Array.isArray(shippingMethods)
? shippingMethods
: [shippingMethods]
const manager = this.transactionManager_ ?? this.manager_
let shippingMethodsTaxLinesMap: {
[shippingMethodId: string]: ShippingMethodTaxLine[]
} = {}
if (!taxRate && includeTax) {
// Use existing tax lines if they are present
const shippingMethodContainsTaxLines = shippingMethods.some(
(method) => method.tax_lines?.length
)
if (shippingMethodContainsTaxLines) {
shippingMethods.forEach((sm) => {
shippingMethodsTaxLinesMap[sm.id] = sm.tax_lines ?? []
})
} else {
const calculationContextWithGivenMethod = {
...calculationContext,
shipping_methods: shippingMethods,
}
const { shippingMethodsTaxLines } = await this.taxProviderService_
.withTransaction(manager)
.getTaxLinesMap([], calculationContextWithGivenMethod)
shippingMethodsTaxLinesMap = shippingMethodsTaxLines
}
}
const calculationMethod = taxRate
? this.getShippingMethodTotalsLegacy.bind(this)
: this.getShippingMethodTotals_.bind(this)
const shippingMethodsTotals: {
[lineItemId: string]: ShippingMethodTotals
} = {}
for (const shippingMethod of shippingMethods) {
shippingMethodsTotals[shippingMethod.id] = await calculationMethod(
shippingMethod,
{
includeTax,
calculationContext,
taxLines: shippingMethodsTaxLinesMap[shippingMethod.id],
discounts,
taxRate,
}
)
}
return shippingMethodsTotals
}
/**
* Calculate and return the shipping method totals
* @param shippingMethod
* @param includeTax
* @param calculationContext
* @param taxLines
* @param discounts
*/
protected async getShippingMethodTotals_(
shippingMethod: ShippingMethod,
{
includeTax,
calculationContext,
taxLines,
discounts,
}: {
includeTax?: boolean
calculationContext: TaxCalculationContext
taxLines?: ShippingMethodTaxLine[]
discounts?: Discount[]
}
) {
const totals: ShippingMethodTotals = {
price: shippingMethod.price,
original_total: shippingMethod.price,
total: shippingMethod.price,
subtotal: shippingMethod.price,
original_tax_total: 0,
tax_total: 0,
tax_lines: shippingMethod.tax_lines ?? [],
}
if (includeTax) {
totals.tax_lines = totals.tax_lines.length
? totals.tax_lines
: (taxLines as ShippingMethodTaxLine[])
if (!totals.tax_lines) {
throw new MedusaError(
MedusaError.Types.UNEXPECTED_STATE,
"Tax Lines must be joined to calculate taxes"
)
}
}
const calculationContext_: TaxCalculationContext = {
...calculationContext,
shipping_methods: [shippingMethod],
}
if (totals.tax_lines.length) {
const includesTax =
this.featureFlagRouter_.isFeatureEnabled(
TaxInclusivePricingFeatureFlag.key
) && shippingMethod.includes_tax
totals.original_tax_total = await this.taxCalculationStrategy_.calculate(
[],
totals.tax_lines,
calculationContext_
)
totals.tax_total = totals.original_tax_total
if (includesTax) {
totals.subtotal -= totals.tax_total
} else {
totals.original_total += totals.original_tax_total
totals.total += totals.tax_total
}
}
const hasFreeShipping = discounts?.some(
(d) => d.rule.type === DiscountRuleType.FREE_SHIPPING
)
if (hasFreeShipping) {
totals.total = 0
totals.subtotal = 0
totals.tax_total = 0
}
return totals
}
/**
* Calculate and return the shipping method totals legacy using teh tax rate
* @param shippingMethod
* @param calculationContext
* @param taxRate
* @param discounts
*/
protected async getShippingMethodTotalsLegacy(
shippingMethod: ShippingMethod,
{
calculationContext,
discounts,
taxRate,
}: {
calculationContext: TaxCalculationContext
discounts?: Discount[]
taxRate: number
}
): Promise<ShippingMethodTotals> {
const totals: ShippingMethodTotals = {
price: shippingMethod.price,
original_total: shippingMethod.price,
total: shippingMethod.price,
subtotal: shippingMethod.price,
original_tax_total: 0,
tax_total: 0,
tax_lines: [],
}
totals.original_tax_total = Math.round(totals.price * (taxRate / 100))
totals.tax_total = Math.round(totals.price * (taxRate / 100))
const hasFreeShipping = discounts?.some(
(d) => d.rule.type === DiscountRuleType.FREE_SHIPPING
)
if (hasFreeShipping) {
totals.total = 0
totals.subtotal = 0
totals.tax_total = 0
}
return totals
}
}

View File

@@ -4,6 +4,7 @@ import { TransactionBaseService } from "../interfaces"
import SalesChannelFeatureFlag from "../loaders/feature-flags/sales-channels"
import {
Address,
Cart,
ClaimOrder,
Fulfillment,
FulfillmentItem,
@@ -26,7 +27,7 @@ import {
} from "../types/fulfillment"
import { UpdateOrderInput } from "../types/orders"
import { CreateShippingMethodDto } from "../types/shipping-options"
import { buildQuery, setMetadata } from "../utils"
import { buildQuery, isDefined, isString, setMetadata } from "../utils"
import { FlagRouter } from "../utils/flag-router"
import CartService from "./cart"
import CustomerService from "./customer"
@@ -43,6 +44,9 @@ import RegionService from "./region"
import ShippingOptionService from "./shipping-option"
import ShippingProfileService from "./shipping-profile"
import TotalsService from "./totals"
import { NewTotalsService, TaxProviderService } from "./index"
export const ORDER_CART_ALREADY_EXISTS_ERROR = "Order from cart already exists"
type InjectedDependencies = {
manager: EntityManager
@@ -56,6 +60,8 @@ type InjectedDependencies = {
fulfillmentService: FulfillmentService
lineItemService: LineItemService
totalsService: TotalsService
newTotalsService: NewTotalsService
taxProviderService: TaxProviderService
regionService: RegionService
cartService: CartService
addressRepository: typeof AddressRepository
@@ -66,6 +72,10 @@ type InjectedDependencies = {
featureFlagRouter: FlagRouter
}
type TotalsConfig = {
force_taxes?: boolean
}
class OrderService extends TransactionBaseService {
static readonly Events = {
GIFT_CARD_CREATED: "order.gift_card_created",
@@ -99,6 +109,8 @@ class OrderService extends TransactionBaseService {
protected readonly fulfillmentService_: FulfillmentService
protected readonly lineItemService_: LineItemService
protected readonly totalsService_: TotalsService
protected readonly newTotalsService_: NewTotalsService
protected readonly taxProviderService_: TaxProviderService
protected readonly regionService_: RegionService
protected readonly cartService_: CartService
protected readonly addressRepository_: typeof AddressRepository
@@ -120,6 +132,8 @@ class OrderService extends TransactionBaseService {
fulfillmentService,
lineItemService,
totalsService,
newTotalsService,
taxProviderService,
regionService,
cartService,
addressRepository,
@@ -140,6 +154,8 @@ class OrderService extends TransactionBaseService {
this.fulfillmentProviderService_ = fulfillmentProviderService
this.lineItemService_ = lineItemService
this.totalsService_ = totalsService
this.newTotalsService_ = newTotalsService
this.taxProviderService_ = taxProviderService
this.regionService_ = regionService
this.fulfillmentService_ = fulfillmentService
this.discountService_ = discountService
@@ -225,7 +241,7 @@ class OrderService extends TransactionBaseService {
this.transformQueryForTotals(config)
query.select = select
const rels = relations
const rels = this.getTotalsRelations({ relations })
delete query.relations
@@ -309,22 +325,57 @@ class OrderService extends TransactionBaseService {
/**
* Gets an order by id.
* @param orderId - id of order to retrieve
* @param orderId - id or selector of order to retrieve
* @param config - config of order to retrieve
* @return the order document
*/
async retrieve(
orderId: string,
config: FindConfig<Order> = {}
): Promise<Order> {
const { totalsToSelect } = this.transformQueryForTotals(config)
if (totalsToSelect?.length) {
return await this.retrieveLegacy(orderId, config)
}
const manager = this.manager_
const orderRepo = manager.getCustomRepository(this.orderRepository_)
const query = buildQuery({ id: orderId }, config)
if (!(config.select || []).length) {
query.select = undefined
}
const queryRelations = query.relations
query.relations = undefined
const raw = await orderRepo.findOneWithRelations(queryRelations, query)
if (!raw) {
throw new MedusaError(
MedusaError.Types.NOT_FOUND,
`Order with id ${orderId} was not found`
)
}
return raw
}
protected async retrieveLegacy(
orderIdOrSelector: string | Selector<Order>,
config: FindConfig<Order> = {}
): Promise<Order> {
const orderRepo = this.manager_.getCustomRepository(this.orderRepository_)
const { select, relations, totalsToSelect } =
this.transformQueryForTotals(config)
const query = {
where: { id: orderId },
} as FindConfig<Order>
const selector = isString(orderIdOrSelector)
? { id: orderIdOrSelector }
: orderIdOrSelector
const query = buildQuery(selector, config)
if (relations && relations.length > 0) {
query.relations = relations
@@ -334,17 +385,33 @@ class OrderService extends TransactionBaseService {
const rels = query.relations
delete query.relations
const raw = await orderRepo.findOneWithRelations(rels, query)
if (!raw) {
const selectorConstraints = Object.entries(selector)
.map((key, value) => `${key}: ${value}`)
.join(", ")
throw new MedusaError(
MedusaError.Types.NOT_FOUND,
`Order with ${orderId} was not found`
`Order with ${selectorConstraints} was not found`
)
}
return await this.decorateTotals(raw, totalsToSelect)
}
async retrieveWithTotals(
orderId: string,
options: FindConfig<Order> = {},
totalsConfig: TotalsConfig = {}
): Promise<Order> {
const relations = this.getTotalsRelations(options)
const order = await this.retrieve(orderId, { ...options, relations })
return await this.decorateTotals(order, totalsConfig)
}
/**
* Gets an order by cart id.
* @param cartId - cart id to find order
@@ -379,6 +446,10 @@ class OrderService extends TransactionBaseService {
)
}
if (!totalsToSelect?.length) {
return raw
}
return await this.decorateTotals(raw, totalsToSelect)
}
@@ -404,6 +475,7 @@ class OrderService extends TransactionBaseService {
if (relations && relations.length > 0) {
query.relations = relations
}
query.relations = this.getTotalsRelations({ relations: query.relations })
query.select = select?.length ? select : undefined
@@ -420,16 +492,6 @@ class OrderService extends TransactionBaseService {
return await this.decorateTotals(raw, totalsToSelect)
}
/**
* Checks the existence of an order by cart id.
* @param cartId - cart id to find order
* @return the order document
*/
async existsByCartId(cartId: string): Promise<boolean> {
const order = await this.retrieveByCartId(cartId).catch(() => undefined)
return !!order
}
/**
* @param orderId - id of the order to complete
* @return the result of the find operation
@@ -459,27 +521,33 @@ class OrderService extends TransactionBaseService {
/**
* Creates an order from a cart
* @param cartId - id of the cart to create an order from
* @return resolves to the creation result.
* @param cartOrId
*/
async createFromCart(cartId: string): Promise<Order | never> {
async createFromCart(cartOrId: string | Cart): Promise<Order | never> {
return await this.atomicPhase_(async (manager) => {
const cartServiceTx = this.cartService_.withTransaction(manager)
const inventoryServiceTx = this.inventoryService_.withTransaction(manager)
const cart = await cartServiceTx.retrieveWithTotals(cartId, {
relations: [
"region",
"payment",
"items",
"discounts",
"discounts.rule",
"gift_cards",
"shipping_methods",
"items",
"items.adjustments",
],
})
const exists = !!(await this.retrieveByCartId(
isString(cartOrId) ? cartOrId : cartOrId?.id,
{
select: ["id"],
}
).catch(() => void 0))
if (exists) {
throw new MedusaError(
MedusaError.Types.DUPLICATE_ERROR,
ORDER_CART_ALREADY_EXISTS_ERROR
)
}
const cart = isString(cartOrId)
? await cartServiceTx.retrieveWithTotals(cartOrId, {
relations: ["region", "payment"],
})
: cartOrId
if (cart.items.length === 0) {
throw new MedusaError(
@@ -490,30 +558,22 @@ class OrderService extends TransactionBaseService {
const { payment, region, total } = cart
for (const item of cart.items) {
try {
await inventoryServiceTx.confirmInventory(
await Promise.all(
cart.items.map(async (item) => {
return await inventoryServiceTx.confirmInventory(
item.variant_id,
item.quantity
)
} catch (err) {
if (payment) {
await this.paymentProviderService_
.withTransaction(manager)
.cancelPayment(payment)
}
await cartServiceTx.update(cart.id, { payment_authorized_at: null })
throw err
})
).catch(async (err) => {
if (payment) {
await this.paymentProviderService_
.withTransaction(manager)
.cancelPayment(payment)
}
}
const exists = await this.existsByCartId(cart.id)
if (exists) {
throw new MedusaError(
MedusaError.Types.DUPLICATE_ERROR,
"Order from cart already exists"
)
}
await cartServiceTx.update(cart.id, { payment_authorized_at: null })
throw err
})
// Would be the case if a discount code is applied that covers the item
// total
@@ -539,11 +599,19 @@ class OrderService extends TransactionBaseService {
const orderRepo = manager.getCustomRepository(this.orderRepository_)
// TODO: Due to cascade insert we have to remove the tax_lines that have been added by the cart decorate totals.
// Is the cascade insert really used? Also, is it really necessary to pass the entire entities when creating or updating?
// We normally should only pass what is needed?
const shippingMethods = cart.shipping_methods.map((method) => {
;(method.tax_lines as any) = undefined
return method
})
const toCreate = {
payment_status: "awaiting",
discounts: cart.discounts,
gift_cards: cart.gift_cards,
shipping_methods: cart.shipping_methods,
shipping_methods: shippingMethods,
shipping_address_id: cart.shipping_address_id,
billing_address_id: cart.billing_address_id,
region_id: cart.region_id,
@@ -570,18 +638,28 @@ class OrderService extends TransactionBaseService {
toCreate.no_notification = draft.no_notification_order
}
const o = orderRepo.create(toCreate)
const result = await orderRepo.save(o)
const rawOrder = orderRepo.create(toCreate)
const order = await orderRepo.save(rawOrder)
if (total !== 0 && payment) {
await this.paymentProviderService_
.withTransaction(manager)
.updatePayment(payment.id, {
order_id: result.id,
order_id: order.id,
})
}
let gcBalance = await this.totalsService_.getGiftCardableAmount(cart)
if (!isDefined(cart.subtotal) || !isDefined(cart.discount_total)) {
throw new MedusaError(
MedusaError.Types.UNEXPECTED_STATE,
"Unable to compute gift cardable amount during order creation from cart. The cart is missing the subtotal and/or discount_total"
)
}
let gcBalance =
(cart.region?.gift_cards_taxable
? cart.subtotal! - cart.discount_total!
: cart.total! + cart.gift_card_total!) || 0
const gcService = this.giftCardService_.withTransaction(manager)
for (const g of cart.gift_cards) {
@@ -594,7 +672,7 @@ class OrderService extends TransactionBaseService {
await gcService.createTransaction({
gift_card_id: g.id,
order_id: result.id,
order_id: order.id,
amount: usage,
is_taxable: cart.region.gift_cards_taxable,
tax_rate: cart.region.gift_cards_taxable
@@ -605,34 +683,43 @@ class OrderService extends TransactionBaseService {
gcBalance = gcBalance - usage
}
for (const method of cart.shipping_methods) {
await this.shippingOptionService_
.withTransaction(manager)
.updateShippingMethod(method.id, { order_id: result.id })
}
const shippingOptionServiceTx =
this.shippingOptionService_.withTransaction(manager)
const lineItemServiceTx = this.lineItemService_.withTransaction(manager)
for (const item of cart.items) {
await lineItemServiceTx.update(item.id, { order_id: result.id })
}
for (const item of cart.items) {
await inventoryServiceTx.adjustInventory(
item.variant_id,
-item.quantity
)
}
await Promise.all(
[
cart.items.map((item) => {
return [
lineItemServiceTx.update(item.id, { order_id: order.id }),
inventoryServiceTx.adjustInventory(
item.variant_id,
-item.quantity
),
]
}),
cart.shipping_methods.map((method) => {
// TODO: Due to cascade insert we have to remove the tax_lines that have been added by the cart decorate totals.
// Is the cascade insert really used? Also, is it really necessary to pass the entire entities when creating or updating?
// We normally should only pass what is needed?
;(method.tax_lines as any) = undefined
return shippingOptionServiceTx.updateShippingMethod(method.id, {
order_id: order.id,
})
}),
].flat(Infinity)
)
await this.eventBus_
.withTransaction(manager)
.emit(OrderService.Events.PLACED, {
id: result.id,
no_notification: result.no_notification,
id: order.id,
no_notification: order.no_notification,
})
await cartServiceTx.update(cart.id, { completed_at: new Date() })
return result
return order
})
}
@@ -814,8 +901,7 @@ class OrderService extends TransactionBaseService {
config: CreateShippingMethodDto = {}
): Promise<Order> {
return await this.atomicPhase_(async (manager) => {
const order = await this.retrieve(orderId, {
select: ["subtotal"],
const order = await this.retrieveWithTotals(orderId, {
relations: [
"shipping_methods",
"shipping_methods.shipping_option",
@@ -1408,31 +1494,10 @@ class OrderService extends TransactionBaseService {
})
}
protected async decorateTotals(
protected async decorateTotalsLegacy(
order: Order,
totalsFields: string[] = []
): Promise<Order> {
if (totalsFields.some((field) => ["subtotal", "total"].includes(field))) {
const calculationContext =
await this.totalsService_.getCalculationContext(order, {
exclude_shipping: true,
})
order.items = await Promise.all(
(order.items || []).map(async (item) => {
const itemTotals = await this.totalsService_.getLineItemTotals(
item,
order,
{
include_tax: true,
calculation_context: calculationContext,
}
)
return Object.assign(item, itemTotals)
})
)
}
for (const totalField of totalsFields) {
switch (totalField) {
case "shipping_total": {
@@ -1537,6 +1602,149 @@ class OrderService extends TransactionBaseService {
return order
}
/**
* @param order
* @param totalsFieldsOrConfig
* @protected
*/
async decorateTotals(
order: Order,
totalsFieldsOrConfig?: string[] | TotalsConfig
): Promise<Order> {
if (Array.isArray(totalsFieldsOrConfig)) {
return await this.decorateTotalsLegacy(order, totalsFieldsOrConfig)
}
const manager = this.transactionManager_ ?? this.manager_
const newTotalsServiceTx = this.newTotalsService_.withTransaction(manager)
const calculationContext = await this.totalsService_.getCalculationContext(
order
)
const orderItems = [...(order.items ?? [])]
const orderShippingMethods = [...(order.shipping_methods ?? [])]
const itemsTotals = await newTotalsServiceTx.getLineItemTotals(orderItems, {
taxRate: order.tax_rate,
includeTax: true,
calculationContext,
})
const shippingTotals = await newTotalsServiceTx.getShippingMethodTotals(
orderShippingMethods,
{
taxRate: order.tax_rate,
discounts: order.discounts,
includeTax: true,
calculationContext,
}
)
order.subtotal = 0
order.discount_total = 0
order.shipping_total = 0
order.refunded_total =
Math.round(order.refunds?.reduce((acc, next) => acc + next.amount, 0)) ||
0
order.paid_total =
order.payments?.reduce((acc, next) => (acc += next.amount), 0) || 0
order.refundable_amount = order.paid_total - order.refunded_total || 0
let item_tax_total = 0
let shipping_tax_total = 0
order.items = (order.items || []).map((item) => {
const refundable = newTotalsServiceTx.getLineItemRefund(
{
...item,
quantity: item.quantity - (item.returned_quantity || 0),
},
{
calculationContext,
taxRate: order.tax_rate,
}
)
const itemWithTotals = {
...item,
...(itemsTotals[item.id] ?? {}),
refundable,
}
order.subtotal += itemWithTotals.subtotal ?? 0
order.discount_total += itemWithTotals.discount_total ?? 0
item_tax_total += itemWithTotals.tax_total ?? 0
return itemWithTotals as LineItem
})
order.shipping_methods = (order.shipping_methods || []).map(
(shippingMethod) => {
const methodWithTotals = Object.assign(
shippingMethod,
shippingTotals[shippingMethod.id] ?? {}
)
order.shipping_total += methodWithTotals.subtotal ?? 0
shipping_tax_total += methodWithTotals.tax_total ?? 0
return methodWithTotals
}
)
const giftCardTotal = await this.newTotalsService_.getGiftCardTotals(
order.subtotal - order.discount_total,
{
region: order.region,
giftCards: order.gift_cards,
giftCardTransactions: order.gift_card_transactions ?? [],
}
)
order.gift_card_total = giftCardTotal.total || 0
order.gift_card_tax_total = giftCardTotal.tax_total || 0
order.tax_total =
item_tax_total + shipping_tax_total - order.gift_card_tax_total
for (const swap of order.swaps ?? []) {
swap.additional_items = swap.additional_items.map((item) => {
item.refundable = newTotalsServiceTx.getLineItemRefund(
{
...item,
quantity: item.quantity - (item.returned_quantity || 0),
},
{
calculationContext,
taxRate: order.tax_rate,
}
)
return item
})
}
for (const claim of order.claims ?? []) {
claim.additional_items = claim.additional_items.map((item) => {
item.refundable = newTotalsServiceTx.getLineItemRefund(
{
...item,
quantity: item.quantity - (item.returned_quantity || 0),
},
{
calculationContext,
taxRate: order.tax_rate,
}
)
return item
})
}
order.total =
order.subtotal +
order.shipping_total +
order.tax_total -
(order.gift_card_total + order.discount_total)
return order
}
/**
* Handles receiving a return. This will create a
* refund to the customer. If the returned items don't match the requested
@@ -1624,6 +1832,32 @@ class OrderService extends TransactionBaseService {
return result
})
}
private getTotalsRelations(config: FindConfig<Order>): string[] {
const relationSet = new Set(config.relations)
relationSet.add("items")
relationSet.add("items.tax_lines")
relationSet.add("items.adjustments")
relationSet.add("swaps")
relationSet.add("swaps.additional_items")
relationSet.add("swaps.additional_items.tax_lines")
relationSet.add("swaps.additional_items.adjustments")
relationSet.add("claims")
relationSet.add("claims.additional_items")
relationSet.add("claims.additional_items.tax_lines")
relationSet.add("claims.additional_items.adjustments")
relationSet.add("discounts")
relationSet.add("discounts.rule")
relationSet.add("gift_cards")
relationSet.add("gift_card_transactions")
relationSet.add("refunds")
relationSet.add("shipping_methods")
relationSet.add("shipping_methods.tax_lines")
relationSet.add("region")
return Array.from(relationSet.values())
}
}
export default OrderService

View File

@@ -374,11 +374,14 @@ export default class PaymentProviderService extends TransactionBaseService {
}
}
async createPayment(
cart: Cart & { payment_session: PaymentSession }
): Promise<Payment> {
async createPayment(data: {
cart_id: string
amount: number
currency_code: string
payment_session: PaymentSession
}): Promise<Payment> {
return await this.atomicPhase_(async (transactionManager) => {
const { payment_session: paymentSession, region, total } = cart
const { payment_session: paymentSession, currency_code, amount } = data
const provider = this.retrieveProvider(paymentSession.provider_id)
const paymentData = await provider
@@ -391,10 +394,10 @@ export default class PaymentProviderService extends TransactionBaseService {
const created = paymentRepo.create({
provider_id: paymentSession.provider_id,
amount: total,
currency_code: region.currency_code,
amount,
currency_code,
data: paymentData,
cart_id: cart.id,
cart_id: data.cart_id,
})
return await paymentRepo.save(created)

View File

@@ -1,5 +1,5 @@
import { MedusaError } from "medusa-core-utils"
import { EntityManager, In } from "typeorm"
import { EntityManager } from "typeorm"
import { buildQuery, isDefined, setMetadata, validateId } from "../utils"
import { TransactionBaseService } from "../interfaces"
@@ -719,14 +719,8 @@ class SwapService extends TransactionBaseService {
const cart = await this.cartService_
.withTransaction(manager)
.retrieve(swap.cart_id, {
select: ["total"],
relations: [
"payment",
"shipping_methods",
"items",
"items.adjustments",
],
.retrieveWithTotals(swap.cart_id, {
relations: ["payment"],
})
const { payment } = cart
@@ -802,7 +796,13 @@ class SwapService extends TransactionBaseService {
swap.difference_due = total
swap.shipping_address_id = cart.shipping_address_id
swap.shipping_methods = cart.shipping_methods
// TODO: Due to cascade insert we have to remove the tax_lines that have been added by the cart decorate totals.
// Is the cascade insert really used? Also, is it really necessary to pass the entire entities when creating or updating?
// We normally should only pass what is needed?
swap.shipping_methods = cart.shipping_methods.map((method) => {
;(method.tax_lines as any) = undefined
return method
})
swap.confirmed_at = new Date()
swap.payment_status =
total === 0 ? SwapPaymentStatus.CONFIRMED : SwapPaymentStatus.AWAITING

View File

@@ -23,7 +23,7 @@ import {
TransactionBaseService,
} from "../interfaces"
import { TaxServiceRate } from "../types/tax-service"
import { TaxLinesMaps, TaxServiceRate } from "../types/tax-service"
import TaxRateService from "./tax-rate"
import EventBusService from "./event-bus"
@@ -333,6 +333,42 @@ class TaxProviderService extends TransactionBaseService {
})
}
/**
* Return a map of tax lines for line items and shipping methods
* @param items
* @param calculationContext
* @protected
*/
async getTaxLinesMap(
items: LineItem[],
calculationContext: TaxCalculationContext
): Promise<TaxLinesMaps> {
const lineItemsTaxLinesMap = {}
const shippingMethodsTaxLinesMap = {}
const taxLines = await this.getTaxLines(items, calculationContext)
taxLines.forEach((taxLine) => {
if ("item_id" in taxLine) {
const itemTaxLines = lineItemsTaxLinesMap[taxLine.item_id] ?? []
itemTaxLines.push(taxLine)
lineItemsTaxLinesMap[taxLine.item_id] = itemTaxLines
}
if ("shipping_method_id" in taxLine) {
const shippingMethodTaxLines =
shippingMethodsTaxLinesMap[taxLine.shipping_method_id] ?? []
shippingMethodTaxLines.push(taxLine)
shippingMethodsTaxLinesMap[taxLine.shipping_method_id] =
shippingMethodTaxLines
}
})
return {
lineItemsTaxLines: lineItemsTaxLinesMap,
shippingMethodsTaxLines: shippingMethodsTaxLinesMap,
}
}
/**
* Gets the tax rates configured for a shipping option. The rates are cached
* between calls.

View File

@@ -1,5 +1,6 @@
import { MockManager } from "medusa-test-utils"
import CartCompletionStrategy from "../cart-completion"
import { newTotalsServiceMock } from "../../services/__mocks__/new-totals"
const IdempotencyKeyServiceMock = {
withTransaction: function () {
@@ -57,32 +58,34 @@ const toTest = [
})
expect(cartServiceMock.createTaxLines).toHaveBeenCalledTimes(1)
expect(cartServiceMock.createTaxLines).toHaveBeenCalledWith("test-cart")
expect(cartServiceMock.createTaxLines).toHaveBeenCalledWith(
expect.objectContaining({ id: "test-cart" })
)
expect(cartServiceMock.authorizePayment).toHaveBeenCalledTimes(1)
expect(cartServiceMock.authorizePayment).toHaveBeenCalledWith(
"test-cart",
{
idempotency_key: "ikey",
cart_id: "test-cart",
idempotency_key: {
idempotency_key: "ikey",
recovery_point: "tax_lines_created",
},
}
)
expect(orderServiceMock.createFromCart).toHaveBeenCalledTimes(1)
expect(orderServiceMock.createFromCart).toHaveBeenCalledWith(
"test-cart"
expect.objectContaining({ id: "test-cart" })
)
expect(orderServiceMock.retrieve).toHaveBeenCalledTimes(1)
expect(orderServiceMock.retrieve).toHaveBeenCalledWith("test-cart", {
select: [
"subtotal",
"tax_total",
"shipping_total",
"discount_total",
"total",
],
relations: ["shipping_address", "items", "payments"],
})
expect(orderServiceMock.retrieveWithTotals).toHaveBeenCalledTimes(1)
expect(orderServiceMock.retrieveWithTotals).toHaveBeenCalledWith(
"test-cart",
{
relations: ["shipping_address", "items", "payments"],
}
)
},
},
],
@@ -187,6 +190,7 @@ describe("CartCompletionStrategy", () => {
authorizePayment: jest.fn(() => Promise.resolve(cart)),
retrieve: jest.fn(() => Promise.resolve(cart)),
retrieveWithTotals: jest.fn(() => Promise.resolve(cart)),
newTotalsService: newTotalsServiceMock,
}
const orderServiceMock = {
withTransaction: function () {
@@ -194,6 +198,8 @@ describe("CartCompletionStrategy", () => {
},
createFromCart: jest.fn(() => Promise.resolve(cart)),
retrieve: jest.fn(() => Promise.resolve({})),
retrieveWithTotals: jest.fn(() => Promise.resolve({})),
newTotalsService: newTotalsServiceMock,
}
const swapServiceMock = {
withTransaction: function () {

View File

@@ -4,7 +4,9 @@ import { EntityManager } from "typeorm"
import { IdempotencyKey, Order } from "../models"
import CartService from "../services/cart"
import IdempotencyKeyService from "../services/idempotency-key"
import OrderService from "../services/order"
import OrderService, {
ORDER_CART_ALREADY_EXISTS_ERROR,
} from "../services/order"
import SwapService from "../services/swap"
import { RequestContext } from "../types/request"
@@ -52,11 +54,6 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy {
): Promise<CartCompletionResponse> {
let idempotencyKey: IdempotencyKey = ikey
const idempotencyKeyService = this.idempotencyKeyService_
const cartService = this.cartService_
const orderService = this.orderService_
const swapService = this.swapService_
let inProgress = true
let err: unknown = false
@@ -65,30 +62,12 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy {
case "started": {
await this.manager_
.transaction("SERIALIZABLE", async (transactionManager) => {
idempotencyKey = await idempotencyKeyService
idempotencyKey = await this.idempotencyKeyService_
.withTransaction(transactionManager)
.workStage(idempotencyKey.idempotency_key, async (manager) => {
const cart = await cartService
.withTransaction(manager)
.retrieve(id)
if (cart.completed_at) {
return {
response_code: 409,
response_body: {
code: MedusaError.Codes.CART_INCOMPATIBLE_STATE,
message: "Cart has already been completed",
type: MedusaError.Types.NOT_ALLOWED,
},
}
}
await cartService.withTransaction(manager).createTaxLines(id)
return {
recovery_point: "tax_lines_created",
}
})
.workStage(
idempotencyKey.idempotency_key,
async (manager) => await this.handleStarted(id, { manager })
)
})
.catch((e) => {
inProgress = false
@@ -99,40 +78,16 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy {
case "tax_lines_created": {
await this.manager_
.transaction("SERIALIZABLE", async (transactionManager) => {
idempotencyKey = await idempotencyKeyService
idempotencyKey = await this.idempotencyKeyService_
.withTransaction(transactionManager)
.workStage(idempotencyKey.idempotency_key, async (manager) => {
const cart = await cartService
.withTransaction(manager)
.authorizePayment(id, {
...context,
idempotency_key: idempotencyKey.idempotency_key,
.workStage(
idempotencyKey.idempotency_key,
async (manager) =>
await this.handleTaxLineCreated(id, idempotencyKey, {
context,
manager,
})
if (cart.payment_session) {
if (
cart.payment_session.status === "requires_more" ||
cart.payment_session.status === "pending"
) {
await cartService
.withTransaction(transactionManager)
.deleteTaxLines(id)
return {
response_code: 200,
response_body: {
data: cart,
payment_status: cart.payment_session.status,
type: "cart",
},
}
}
}
return {
recovery_point: "payment_authorized",
}
})
)
})
.catch((e) => {
inProgress = false
@@ -144,139 +99,13 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy {
case "payment_authorized": {
await this.manager_
.transaction("SERIALIZABLE", async (transactionManager) => {
idempotencyKey = await idempotencyKeyService
idempotencyKey = await this.idempotencyKeyService_
.withTransaction(transactionManager)
.workStage(idempotencyKey.idempotency_key, async (manager) => {
const cart = await cartService
.withTransaction(manager)
.retrieveWithTotals(id, {
relations: ["payment", "payment_sessions"],
})
// If cart is part of swap, we register swap as complete
switch (cart.type) {
case "swap": {
try {
const swapId = cart.metadata?.swap_id
let swap = await swapService
.withTransaction(manager)
.registerCartCompletion(swapId as string)
swap = await swapService
.withTransaction(manager)
.retrieve(swap.id, {
relations: ["shipping_address"],
})
return {
response_code: 200,
response_body: { data: swap, type: "swap" },
}
} catch (error) {
if (
error &&
error.code ===
MedusaError.Codes.INSUFFICIENT_INVENTORY
) {
return {
response_code: 409,
response_body: {
message: error.message,
type: error.type,
code: error.code,
},
}
} else {
throw error
}
}
}
default: {
if (typeof cart.total === "undefined") {
return {
response_code: 500,
response_body: {
message: "Unexpected state",
},
}
}
if (!cart.payment && cart.total > 0) {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`Cart payment not authorized`
)
}
let order: Order
try {
order = await orderService
.withTransaction(manager)
.createFromCart(cart.id)
} catch (error) {
if (
error &&
error.message === "Order from cart already exists"
) {
order = await orderService
.withTransaction(manager)
.retrieveByCartId(id, {
select: [
"subtotal",
"tax_total",
"shipping_total",
"discount_total",
"total",
],
relations: [
"shipping_address",
"items",
"payments",
],
})
return {
response_code: 200,
response_body: { data: order, type: "order" },
}
} else if (
error &&
error.code ===
MedusaError.Codes.INSUFFICIENT_INVENTORY
) {
return {
response_code: 409,
response_body: {
message: error.message,
type: error.type,
code: error.code,
},
}
} else {
throw error
}
}
order = await orderService
.withTransaction(manager)
.retrieve(order.id, {
select: [
"subtotal",
"tax_total",
"shipping_total",
"discount_total",
"total",
],
relations: ["shipping_address", "items", "payments"],
})
return {
response_code: 200,
response_body: { data: order, type: "order" },
}
}
}
})
.workStage(
idempotencyKey.idempotency_key,
async (manager) =>
await this.handlePaymentAuthorized(id, { manager })
)
})
.catch((e) => {
inProgress = false
@@ -292,7 +121,7 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy {
default:
await this.manager_.transaction(async (transactionManager) => {
idempotencyKey = await idempotencyKeyService
idempotencyKey = await this.idempotencyKeyService_
.withTransaction(transactionManager)
.update(idempotencyKey.idempotency_key, {
recovery_point: "finished",
@@ -308,11 +137,11 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy {
if (idempotencyKey.recovery_point !== "started") {
await this.manager_.transaction(async (transactionManager) => {
try {
await orderService
await this.orderService_
.withTransaction(transactionManager)
.retrieveByCartId(id)
} catch (error) {
await cartService
await this.cartService_
.withTransaction(transactionManager)
.deleteTaxLines(id)
}
@@ -326,6 +155,172 @@ class CartCompletionStrategy extends AbstractCartCompletionStrategy {
response_code: idempotencyKey.response_code,
}
}
protected async handleStarted(
id: string,
{ manager }: { manager: EntityManager }
) {
const cart = await this.cartService_.withTransaction(manager).retrieve(id, {
relations: [
"customer",
"discounts",
"discounts.rule",
"gift_cards",
"items",
"items.adjustments",
"region",
"region.tax_rates",
"shipping_address",
"shipping_methods",
],
})
if (cart.completed_at) {
return {
response_code: 409,
response_body: {
code: MedusaError.Codes.CART_INCOMPATIBLE_STATE,
message: "Cart has already been completed",
type: MedusaError.Types.NOT_ALLOWED,
},
}
}
await this.cartService_.withTransaction(manager).createTaxLines(cart)
return {
recovery_point: "tax_lines_created",
}
}
protected async handleTaxLineCreated(
id: string,
idempotencyKey: IdempotencyKey,
{ context, manager }: { context: any; manager: EntityManager }
) {
const cart = await this.cartService_
.withTransaction(manager)
.authorizePayment(id, {
...context,
cart_id: id,
idempotency_key: idempotencyKey,
})
if (cart.payment_session) {
if (
cart.payment_session.status === "requires_more" ||
cart.payment_session.status === "pending"
) {
await this.cartService_.withTransaction(manager).deleteTaxLines(id)
return {
response_code: 200,
response_body: {
data: cart,
payment_status: cart.payment_session.status,
type: "cart",
},
}
}
}
return {
recovery_point: "payment_authorized",
}
}
protected async handlePaymentAuthorized(
id: string,
{ manager }: { manager: EntityManager }
) {
const orderServiceTx = this.orderService_.withTransaction(manager)
const cart = await this.cartService_
.withTransaction(manager)
.retrieveWithTotals(id, {
relations: ["region", "payment", "payment_sessions"],
})
// If cart is part of swap, we register swap as complete
if (cart.type === "swap") {
try {
const swapId = cart.metadata?.swap_id
let swap = await this.swapService_
.withTransaction(manager)
.registerCartCompletion(swapId as string)
swap = await this.swapService_
.withTransaction(manager)
.retrieve(swap.id, {
relations: ["shipping_address"],
})
return {
response_code: 200,
response_body: { data: swap, type: "swap" },
}
} catch (error) {
if (error && error.code === MedusaError.Codes.INSUFFICIENT_INVENTORY) {
return {
response_code: 409,
response_body: {
message: error.message,
type: error.type,
code: error.code,
},
}
} else {
throw error
}
}
}
if (!cart.payment && cart.total! > 0) {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`Cart payment not authorized`
)
}
let order: Order
try {
order = await orderServiceTx.createFromCart(cart)
} catch (error) {
if (error && error.message === ORDER_CART_ALREADY_EXISTS_ERROR) {
order = await orderServiceTx.retrieveByCartId(id, {
relations: ["shipping_address", "payments"],
})
return {
response_code: 200,
response_body: { data: order, type: "order" },
}
} else if (
error &&
error.code === MedusaError.Codes.INSUFFICIENT_INVENTORY
) {
return {
response_code: 409,
response_body: {
message: error.message,
type: error.type,
code: error.code,
},
}
} else {
throw error
}
}
order = await orderServiceTx.retrieveWithTotals(order.id, {
relations: ["shipping_address", "items", "payments"],
})
return {
response_code: 200,
response_body: { data: order, type: "order" },
}
}
}
export default CartCompletionStrategy

View File

@@ -1,3 +1,12 @@
import { LineItemTaxLine, ShippingMethodTaxLine } from "../models"
export type TaxLinesMaps = {
lineItemsTaxLines: { [lineItemId: string]: LineItemTaxLine[] }
shippingMethodsTaxLines: {
[shippingMethodId: string]: ShippingMethodTaxLine[]
}
}
/**
* The tax rate object as configured in Medusa. These may have an unspecified
* numerical rate as they may be used for lookup purposes in the tax provider

View File

@@ -1,4 +1,4 @@
import { LineItem } from "../models/line-item"
import { LineItem } from "../models"
/** The amount of a gift card allocated to a line item */
export type GiftCardAllocation = {

View File

@@ -27,6 +27,7 @@
"./dist/**/*",
"./src/**/__tests__",
"./src/**/__mocks__",
"./src/**/__fixtures__",
"node_modules"
]
}