fix: ensures no duplicate tax lines when completing cart (#1262)

* fix: ensures that no duplicate tax lines are created when completing cart
This commit is contained in:
Sebastian Rindom
2022-04-05 21:25:48 +02:00
committed by GitHub
parent 497e9b3f7d
commit 607a382b4e
9 changed files with 269 additions and 16 deletions

View File

@@ -197,6 +197,111 @@ describe("Order Taxes", () => {
expect(response.data.order.total).toEqual(2300)
})
test("completing cart with failure doesn't duplicate", async () => {
const product1 = await simpleProductFactory(
dbConnection,
{
variants: [
{
id: "test-variant",
},
],
},
100
)
const product2 = await simpleProductFactory(
dbConnection,
{
variants: [
{
id: "test-variant-2",
},
],
},
100
)
const region = await simpleRegionFactory(dbConnection, {
name: "Test region",
tax_rate: 12,
})
await simpleProductTaxRateFactory(dbConnection, {
product_id: product1.id,
rate: {
region_id: region.id,
rate: 25,
},
})
await simpleProductTaxRateFactory(dbConnection, {
product_id: product2.id,
rate: {
region_id: region.id,
rate: 20,
},
})
const cart = await simpleCartFactory(
dbConnection,
{
region: region.id,
email: "test@testson.com",
line_items: [
{
variant_id: "test-variant",
unit_price: 100,
},
{
variant_id: "test-variant-2",
unit_price: 50,
},
],
},
100
)
const api = useApi()
await api.post(`/store/carts/${cart.id}`, {
email: "test@testson.com",
})
const failedComplete = await api
.post(`/store/carts/${cart.id}/complete`)
.catch((err) => err.response)
expect(failedComplete.status).toEqual(400)
expect(failedComplete.data.message).toEqual(
"You cannot complete a cart without a payment session."
)
await api.post(`/store/carts/${cart.id}/payment-sessions`)
const response = await api.post(`/store/carts/${cart.id}/complete`)
expect(response.status).toEqual(200)
expect(response.data.type).toEqual("order")
expect(response.data.data.tax_total).toEqual(35)
expect(response.data.data.total).toEqual(185)
expect(
response.data.data.items.flatMap((li) => li.tax_lines).length
).toEqual(2)
expect(response.data.data.items[0].tax_lines).toEqual([
expect.objectContaining({
rate: 25,
}),
])
expect(response.data.data.items[1].tax_lines).toEqual([
expect.objectContaining({
rate: 20,
}),
])
})
test("completing cart creates tax lines", async () => {
const product1 = await simpleProductFactory(
dbConnection,

View File

@@ -158,12 +158,18 @@ class BaseService {
* @param {string} isolation - the isolation level to be used for the work.
* @return {any} the result of the transactional work
*/
async atomicPhase_(work, isolationOrErrorHandler, maybeErrorHandler) {
let errorHandler = maybeErrorHandler
async atomicPhase_(
work,
isolationOrErrorHandler,
maybeErrorHandlerOrDontFail
) {
let errorHandler = maybeErrorHandlerOrDontFail
let isolation = isolationOrErrorHandler
let dontFail = false
if (typeof isolationOrErrorHandler === "function") {
isolation = null
errorHandler = isolationOrErrorHandler
dontFail = !!maybeErrorHandlerOrDontFail
}
if (this.transactionManager_) {
@@ -226,8 +232,12 @@ class BaseService {
return result
} catch (error) {
if (errorHandler) {
await errorHandler(error)
const result = await errorHandler(error)
if (dontFail) {
return result
}
}
throw error
}
}

View File

@@ -0,0 +1,23 @@
import { MigrationInterface, QueryRunner } from "typeorm"
export class taxLineConstraints1648641130007 implements MigrationInterface {
name = "taxLineConstraints1648641130007"
public async up(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(
`ALTER TABLE "line_item_tax_line" ADD CONSTRAINT "UQ_3c2af51043ed7243e7d9775a2ad" UNIQUE ("item_id", "code")`
)
await queryRunner.query(
`ALTER TABLE "shipping_method_tax_line" ADD CONSTRAINT "UQ_cd147fca71e50bc954139fa3104" UNIQUE ("shipping_method_id", "code")`
)
}
public async down(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(
`ALTER TABLE "shipping_method_tax_line" DROP CONSTRAINT "UQ_cd147fca71e50bc954139fa3104"`
)
await queryRunner.query(
`ALTER TABLE "line_item_tax_line" DROP CONSTRAINT "UQ_3c2af51043ed7243e7d9775a2ad"`
)
}
}

View File

@@ -1,10 +1,11 @@
import {
Entity,
BeforeInsert,
Index,
Column,
ManyToOne,
Entity,
Index,
JoinColumn,
ManyToOne,
Unique,
} from "typeorm"
import { ulid } from "ulid"
@@ -12,6 +13,7 @@ import { TaxLine } from "./tax-line"
import { LineItem } from "./line-item"
@Entity()
@Unique(["item_id", "code"])
export class LineItemTaxLine extends TaxLine {
@Index()
@Column()

View File

@@ -1,10 +1,11 @@
import {
Entity,
BeforeInsert,
Index,
Column,
ManyToOne,
Entity,
Index,
JoinColumn,
ManyToOne,
Unique,
} from "typeorm"
import { ulid } from "ulid"
@@ -12,6 +13,7 @@ import { TaxLine } from "./tax-line"
import { ShippingMethod } from "./shipping-method"
@Entity()
@Unique(["shipping_method_id", "code"])
export class ShippingMethodTaxLine extends TaxLine {
@Index()
@Column()

View File

@@ -2,4 +2,36 @@ import { EntityRepository, Repository } from "typeorm"
import { LineItemTaxLine } from "../models/line-item-tax-line"
@EntityRepository(LineItemTaxLine)
export class LineItemTaxLineRepository extends Repository<LineItemTaxLine> {}
export class LineItemTaxLineRepository extends Repository<LineItemTaxLine> {
async upsertLines(lines: LineItemTaxLine[]): Promise<LineItemTaxLine[]> {
const insertResult = await this.createQueryBuilder()
.insert()
.values(lines)
.orUpdate({
conflict_target: ["item_id", "code"],
overwrite: ["rate", "name", "updated_at"],
})
.execute()
return insertResult.identifiers as LineItemTaxLine[]
}
async deleteForCart(cartId: string): Promise<void> {
const qb = this.createQueryBuilder("line")
.select(["line.id"])
.innerJoin("line_item", "i", "i.id = line.item_id")
.innerJoin(
"cart",
"c",
"i.cart_id = :cartId AND c.completed_at is NULL",
{ cartId }
)
const toDelete = await qb.getMany()
await this.createQueryBuilder()
.delete()
.whereInIds(toDelete.map((d) => d.id))
.execute()
}
}

View File

@@ -2,4 +2,38 @@ import { EntityRepository, Repository } from "typeorm"
import { ShippingMethodTaxLine } from "../models/shipping-method-tax-line"
@EntityRepository(ShippingMethodTaxLine)
export class ShippingMethodTaxLineRepository extends Repository<ShippingMethodTaxLine> {}
export class ShippingMethodTaxLineRepository extends Repository<ShippingMethodTaxLine> {
async upsertLines(
lines: ShippingMethodTaxLine[]
): Promise<ShippingMethodTaxLine[]> {
const insertResult = await this.createQueryBuilder()
.insert()
.values(lines)
.orUpdate({
conflict_target: ["shipping_method_id", "code"],
overwrite: ["rate", "name", "updated_at"],
})
.execute()
return insertResult.identifiers as ShippingMethodTaxLine[]
}
async deleteForCart(cartId: string): Promise<void> {
const qb = this.createQueryBuilder("line")
.select(["line.id"])
.innerJoin("shipping_method", "sm", "sm.id = line.shipping_method_id")
.innerJoin(
"cart",
"c",
"sm.cart_id = :cartId AND c.completed_at is NULL",
{ cartId }
)
const toDelete = await qb.getMany()
await this.createQueryBuilder()
.delete()
.whereInIds(toDelete.map((d) => d.id))
.execute()
}
}

View File

@@ -1213,6 +1213,13 @@ class CartService extends BaseService {
return cartRepository.save(cart)
}
if (!cart.payment_session) {
throw new MedusaError(
MedusaError.Types.NOT_ALLOWED,
"You cannot complete a cart without a payment session."
)
}
const session = await this.paymentProviderService_
.withTransaction(manager)
.authorizePayment(cart.payment_session, context)
@@ -1874,9 +1881,8 @@ class CartService extends BaseService {
})
const calculationContext = this.totalsService_.getCalculationContext(cart)
await this.taxProviderService_
.withTransaction(manager)
.createTaxLines(cart, calculationContext)
const txTaxProvider = this.taxProviderService_.withTransaction(manager)
await txTaxProvider.createTaxLines(cart, calculationContext)
return cart
})

View File

@@ -1,7 +1,7 @@
import { MedusaError } from "medusa-core-utils"
import { AwilixContainer } from "awilix"
import { BaseService } from "medusa-interfaces"
import { EntityManager } from "typeorm"
import { EntityManager, UpdateResult } from "typeorm"
import Redis from "ioredis"
import { LineItemTaxLineRepository } from "../repositories/line-item-tax-line"
@@ -15,6 +15,7 @@ import { ShippingMethod } from "../models/shipping-method"
import { Region } from "../models/region"
import { Cart } from "../models/cart"
import { isCart } from "../types/cart"
import { PostgresError } from "../utils/exception-formatter"
import {
ITaxService,
ItemTaxCalculationLine,
@@ -94,6 +95,18 @@ class TaxProviderService extends BaseService {
return provider
}
async clearTaxLines(cartId: string): Promise<void> {
const taxLineRepo = this.manager_.getCustomRepository(this.taxLineRepo_)
const shippingTaxRepo = this.manager_.getCustomRepository(
this.smTaxLineRepo_
)
await Promise.all([
taxLineRepo.deleteForCart(cartId),
shippingTaxRepo.deleteForCart(cartId),
])
}
/**
* Persists the tax lines relevant for an order to the database.
* @param cartOrLineItems - the cart or line items to create tax lines for
@@ -114,7 +127,33 @@ class TaxProviderService extends BaseService {
taxLines = await this.getTaxLines(cartOrLineItems, calculationContext)
}
return this.manager_.save(taxLines)
const itemTaxLineRepo = this.manager_.getCustomRepository(this.taxLineRepo_)
const shippingTaxLineRepo = this.manager_.getCustomRepository(
this.smTaxLineRepo_
)
const { shipping, lineItems } = taxLines.reduce<{
shipping: ShippingMethodTaxLine[]
lineItems: LineItemTaxLine[]
}>(
(acc, tl) => {
if ("item_id" in tl) {
acc.lineItems.push(tl)
} else {
acc.shipping.push(tl)
}
return acc
},
{ shipping: [], lineItems: [] }
)
return (
await Promise.all([
itemTaxLineRepo.upsertLines(lineItems),
shippingTaxLineRepo.upsertLines(shipping),
])
).flat()
}
/**