fix: add shipping taxes (#1759)

**What**
Adds taxes to the shipping prices when listing in admin. Allows store operators to see correct prices when processing returns.
This commit is contained in:
Sebastian Rindom
2022-07-02 12:35:16 +02:00
committed by GitHub
parent f7e300e8ce
commit fee0f88a62
11 changed files with 248 additions and 27 deletions

View File

@@ -0,0 +1,109 @@
const path = require("path")
const setupServer = require("../../../helpers/setup-server")
const { useApi } = require("../../../helpers/use-api")
const { useDb, initDb } = require("../../../helpers/use-db")
const {
simpleRegionFactory,
simpleProductFactory,
simpleShippingTaxRateFactory,
simpleShippingOptionFactory,
} = require("../../factories")
const adminSeeder = require("../../helpers/admin-seeder")
jest.setTimeout(30000)
describe("Shipping Options Totals Calculations", () => {
let medusaProcess
let dbConnection
beforeAll(async () => {
const cwd = path.resolve(path.join(__dirname, "..", ".."))
dbConnection = await initDb({ cwd })
medusaProcess = await setupServer({ cwd })
})
afterAll(async () => {
const db = useDb()
await db.shutdown()
medusaProcess.kill()
})
beforeEach(async () => {
try {
await adminSeeder(dbConnection)
} catch (err) {
console.log(err)
throw err
}
})
afterEach(async () => {
const db = useDb()
await db.teardown()
})
it("admin gets correct shipping prices", async () => {
const api = useApi()
const region = await simpleRegionFactory(dbConnection, {
tax_rate: 25,
})
const so = await simpleShippingOptionFactory(dbConnection, {
region_id: region.id,
price: 100,
})
await simpleShippingTaxRateFactory(dbConnection, {
shipping_option_id: so.id,
rate: {
region_id: region.id,
rate: 10,
},
})
const res = await api.get(`/admin/shipping-options`, {
headers: {
Authorization: `Bearer test_token`,
},
})
expect(res.data.shipping_options).toEqual([
expect.objectContaining({
id: so.id,
amount: 100,
price_incl_tax: 110,
}),
])
})
it("gets correct shipping prices", async () => {
const api = useApi()
const region = await simpleRegionFactory(dbConnection, {
tax_rate: 25,
})
const so = await simpleShippingOptionFactory(dbConnection, {
region_id: region.id,
price: 100,
})
await simpleShippingTaxRateFactory(dbConnection, {
shipping_option_id: so.id,
rate: {
region_id: region.id,
rate: 10,
},
})
const res = await api.get(`/store/shipping-options?region_id=${region.id}`)
expect(res.data.shipping_options).toEqual([
expect.objectContaining({
id: so.id,
amount: 100,
price_incl_tax: 110,
}),
])
})
})

View File

@@ -1,6 +1,7 @@
import { Transform } from "class-transformer"
import { IsBoolean, IsOptional, IsString } from "class-validator"
import { defaultFields, defaultRelations } from "."
import { PricingService } from "../../../../services"
import { validator } from "../../../../utils/validator"
import { optionalBooleanMapper } from "../../../../utils/validators/is-boolean"
@@ -50,12 +51,15 @@ export default async (req, res) => {
)
const optionService = req.scope.resolve("shippingOptionService")
const pricingService: PricingService = req.scope.resolve("pricingService")
const [data, count] = await optionService.listAndCount(validatedParams, {
select: defaultFields,
relations: defaultRelations,
})
res.status(200).json({ shipping_options: data, count })
const options = await pricingService.setShippingOptionPrices(data)
res.status(200).json({ shipping_options: options, count })
}
export class AdminGetShippingOptionsParams {

View File

@@ -1,5 +1,5 @@
import { IsBooleanString, IsOptional, IsString } from "class-validator"
import ProductService from "../../../../services/product"
import { PricingService, ProductService } from "../../../../services"
import ShippingOptionService from "../../../../services/shipping-option"
import { validator } from "../../../../utils/validator"
@@ -33,6 +33,7 @@ export default async (req, res) => {
(validated.product_ids && validated.product_ids.split(",")) || []
const regionId = validated.region_id
const productService: ProductService = req.scope.resolve("productService")
const pricingService: PricingService = req.scope.resolve("pricingService")
const shippingOptionService: ShippingOptionService = req.scope.resolve(
"shippingOptionService"
)
@@ -59,7 +60,9 @@ export default async (req, res) => {
relations: ["requirements"],
})
res.status(200).json({ shipping_options: options })
const data = await pricingService.setShippingOptionPrices(options)
res.status(200).json({ shipping_options: data })
}
export class StoreGetShippingOptionsParams {

View File

@@ -1,4 +1,4 @@
import CartService from "../../../../services/cart"
import { CartService, PricingService } from "../../../../services"
import ShippingProfileService from "../../../../services/shipping-profile"
/**
@@ -26,6 +26,7 @@ export default async (req, res) => {
const { cart_id } = req.params
const cartService: CartService = req.scope.resolve("cartService")
const pricingService: PricingService = req.scope.resolve("pricingService")
const shippingProfileService: ShippingProfileService = req.scope.resolve(
"shippingProfileService"
)
@@ -36,6 +37,9 @@ export default async (req, res) => {
})
const options = await shippingProfileService.fetchCartOptions(cart)
const data = await pricingService.setShippingOptionPrices(options, {
cart_id,
})
res.status(200).json({ shipping_options: options })
res.status(200).json({ shipping_options: data })
}

View File

@@ -254,21 +254,11 @@ export class TaxRateRepository extends Repository<TaxRate> {
return unionBy(...results, (txr) => txr.id)
}
async listByShippingOption(optionId: string, config: TaxRateListByConfig) {
async listByShippingOption(optionId: string) {
let rates = this.createQueryBuilder("txr")
.leftJoin(ShippingTaxRate, "ptr", "ptr.rate_id = txr.id")
.leftJoin(
ShippingMethod,
"sm",
"sm.shipping_option_id = ptr.shipping_option_id"
)
.where("sm.shipping_option_id = :optionId", { optionId })
.where("ptr.shipping_option_id = :optionId", { optionId })
if (typeof config.region_id !== "undefined") {
rates.andWhere("txr.region_id = :regionId", {
regionId: config.region_id,
})
}
return await rates.getMany()
}
}

View File

@@ -8,6 +8,9 @@ export const PricingServiceMock = {
setVariantPrices: jest.fn().mockImplementation((variant) => {
return Promise.resolve(variant)
}),
setShippingOptionPrices: jest.fn().mockImplementation((opts) => {
return Promise.resolve(opts)
}),
}
const mock = jest.fn().mockImplementation(() => {

View File

@@ -1,12 +1,14 @@
import { EntityManager } from "typeorm"
import { MedusaError } from "medusa-core-utils"
import { ProductVariantService, RegionService, TaxProviderService } from "."
import { Product, ProductVariant } from "../models"
import { Product, ProductVariant, ShippingOption } from "../models"
import { TaxServiceRate } from "../types/tax-service"
import {
ProductVariantPricing,
TaxedPricing,
PricingContext,
PricedProduct,
PricedShippingOption,
PricedVariant,
} from "../types/pricing"
import { TransactionBaseService } from "../interfaces"
@@ -384,6 +386,108 @@ class PricingService extends TransactionBaseService<PricingService> {
})
)
}
/**
* Gets the prices for a shipping option.
* @param shippingOption - the shipping option to get prices for
* @param context - the price selection context to use
* @return The shipping option prices
*/
async getShippingOptionPricing(
shippingOption: ShippingOption,
context: PriceSelectionContext | PricingContext
): Promise<PricedShippingOption> {
let pricingContext: PricingContext
if ("automatic_taxes" in context) {
pricingContext = context
} else {
pricingContext = await this.collectPricingContext(context)
}
let shippingOptionRates: TaxServiceRate[] = []
if (
pricingContext.automatic_taxes &&
pricingContext.price_selection.region_id
) {
shippingOptionRates =
await this.taxProviderService.getRegionRatesForShipping(
shippingOption.id,
{
id: pricingContext.price_selection.region_id,
tax_rate: pricingContext.tax_rate,
}
)
}
const price = shippingOption.amount || 0
const rate = shippingOptionRates.reduce(
(accRate: number, nextTaxRate: TaxServiceRate) => {
return accRate + (nextTaxRate.rate || 0) / 100
},
0
)
const tax = Math.round(price * rate)
const total = price + tax
return {
...shippingOption,
price_incl_tax: total,
tax_rates: shippingOptionRates,
}
}
/**
* Set additional prices on a list of shipping options.
* @param shippingOptions - list of shipping options on which to set additional prices
* @param context - the price selection context to use
* @return A list of shipping options with prices
*/
async setShippingOptionPrices(
shippingOptions: ShippingOption[],
context: Omit<PriceSelectionContext, "region_id"> = {}
): Promise<PricedShippingOption[]> {
const regions = new Set<string>()
for (const shippingOption of shippingOptions) {
regions.add(shippingOption.region_id)
}
const contexts = await Promise.all(
[...regions].map(async (regionId) => {
return {
context: await this.collectPricingContext({
...context,
region_id: regionId,
}),
region_id: regionId,
}
})
)
return await Promise.all(
shippingOptions.map(async (shippingOption) => {
const pricingContext = contexts.find(
(c) => c.region_id === shippingOption.region_id
)
if (!pricingContext) {
throw new MedusaError(
MedusaError.Types.UNEXPECTED_STATE,
"Could not find pricing context for shipping option"
)
}
const shippingOptionPricing = await this.getShippingOptionPricing(
shippingOption,
pricingContext.context
)
return {
...shippingOption,
...shippingOptionPricing,
}
})
)
}
}
export default PricingService

View File

@@ -420,7 +420,7 @@ class ShippingProfileService extends BaseService {
* Finds all the shipping profiles that cover the products in a cart, and
* validates all options that are available for the cart.
* @param {Cart} cart - the cart object to find shipping options for
* @return {[ShippingOption]} a list of the available shipping options
* @return {Promise<[ShippingOption]>} a list of the available shipping options
*/
async fetchCartOptions(cart) {
const profileIds = this.getProfilesInCart_(cart)

View File

@@ -340,8 +340,7 @@ class TaxProviderService extends BaseService {
let toReturn: TaxServiceRate[] = []
const optionRates = await this.taxRateService_.listByShippingOption(
optionId,
{ region_id: regionDetails.id }
optionId
)
if (optionRates.length > 0) {

View File

@@ -329,13 +329,10 @@ class TaxRateService extends BaseService {
})
}
async listByShippingOption(
shippingOptionId: string,
config: TaxRateListByConfig
): Promise<TaxRate[]> {
async listByShippingOption(shippingOptionId: string): Promise<TaxRate[]> {
return await this.atomicPhase_(async (manager: EntityManager) => {
const taxRateRepo = manager.getCustomRepository(this.taxRateRepository_)
return await taxRateRepo.listByShippingOption(shippingOptionId, config)
return await taxRateRepo.listByShippingOption(shippingOptionId)
})
}
}

View File

@@ -1,4 +1,4 @@
import { MoneyAmount, ProductVariant, Product } from "../models"
import { MoneyAmount, ProductVariant, Product, ShippingOption } from "../models"
import { TaxServiceRate } from "./tax-service"
import { PriceSelectionContext } from "../interfaces/price-selection-strategy"
@@ -23,6 +23,14 @@ export type PricingContext = {
tax_rate: number | null
}
export type ShippingOptionPricing = {
price_incl_tax: number | null
tax_rates: TaxServiceRate[] | null
}
export type PricedShippingOption = Partial<ShippingOption> &
ShippingOptionPricing
export type PricedVariant = Partial<ProductVariant> & ProductVariantPricing
export type PricedProduct = Omit<Partial<Product>, "variants"> & {