feat(tax): add tax provider support (#6492)

**What**
- Adds Tax Provider model
- Adds loader to get Tax Provider from module options
- Adds System Tax provider which forwards tax rates as is
This commit is contained in:
Sebastian Rindom
2024-02-26 20:29:26 +01:00
committed by GitHub
parent ce39b9b66e
commit 63aea44e06
13 changed files with 254 additions and 38 deletions

View File

@@ -0,0 +1,43 @@
import { moduleProviderLoader } from "@medusajs/modules-sdk"
import { LoaderOptions, ModuleProvider, ModulesSdkTypes } from "@medusajs/types"
import { Lifetime, asFunction } from "awilix"
import * as providers from "../providers"
const registrationFn = async (klass, container, pluginOptions) => {
container.register({
[`tp_${klass.identifier}`]: asFunction(
(cradle) => new klass(cradle, pluginOptions),
{ lifetime: klass.LIFE_TIME || Lifetime.SINGLETON }
),
})
container.registerAdd(
"tax_providers",
asFunction((cradle) => new klass(cradle, pluginOptions), {
lifetime: klass.LIFE_TIME || Lifetime.SINGLETON,
})
)
}
export default async ({
container,
options,
}: LoaderOptions<
(
| ModulesSdkTypes.ModuleServiceInitializeOptions
| ModulesSdkTypes.ModuleServiceInitializeCustomDataLayerOptions
) & { providers: ModuleProvider[] }
>): Promise<void> => {
// Local providers
for (const provider of Object.values(providers)) {
await registrationFn(provider, container, {})
}
await moduleProviderLoader({
container,
providers: options?.providers || [],
registerServiceFn: registrationFn,
})
}

View File

@@ -0,0 +1,16 @@
import { Entity, OptionalProps, PrimaryKey, Property } from "@mikro-orm/core"
const TABLE_NAME = "tax_provider"
@Entity({ tableName: TABLE_NAME })
export default class TaxProvider {
[OptionalProps]?: "is_enabled"
@PrimaryKey({ columnType: "text" })
id: string
@Property({
default: true,
columnType: "boolean",
})
is_enabled: boolean = true
}

View File

@@ -59,7 +59,7 @@ export default class TaxRateRule {
type: "text",
fieldName: "tax_rate_id",
mapToPk: true,
cascade: [Cascade.REMOVE],
onDelete: "cascade",
})
@taxRateIdIndexStatement.MikroORMIndex()
tax_rate_id: string

View File

@@ -69,7 +69,7 @@ export default class TaxRate {
type: "text",
fieldName: "tax_region_id",
mapToPk: true,
cascade: [Cascade.REMOVE],
onDelete: "cascade",
})
@taxRegionIdIndexStatement.MikroORMIndex()
tax_region_id: string

View File

@@ -19,6 +19,7 @@ import {
Cascade,
} from "@mikro-orm/core"
import TaxRate from "./tax-rate"
import TaxProvider from "./tax-provider"
type OptionalTaxRegionProps = DAL.SoftDeletableEntityDateColumns
@@ -32,7 +33,13 @@ const countryCodeProvinceIndexStatement = createPsqlIndexStatementHelper({
unique: true,
})
const taxRegionProviderTopLevelCheckName = "CK_tax_region_provider_top_level"
const taxRegionCountryTopLevelCheckName = "CK_tax_region_country_top_level"
@Check({
name: taxRegionProviderTopLevelCheckName,
expression: `parent_id IS NULL OR provider_id IS NULL`,
})
@Check({
name: taxRegionCountryTopLevelCheckName,
expression: `parent_id IS NULL OR province_code IS NOT NULL`,
@@ -46,6 +53,13 @@ export default class TaxRegion {
@PrimaryKey({ columnType: "text" })
id!: string
@ManyToOne(() => TaxProvider, {
fieldName: "provider_id",
mapToPk: true,
nullable: true,
})
provider_id: string | null = null
@Property({ columnType: "text" })
country_code: string
@@ -55,7 +69,7 @@ export default class TaxRegion {
@ManyToOne(() => TaxRegion, {
index: "IDX_tax_region_parent_id",
fieldName: "parent_id",
cascade: [Cascade.REMOVE],
onDelete: "cascade",
mapToPk: true,
nullable: true,
})

View File

@@ -5,6 +5,7 @@ import * as Models from "@models"
import * as ModuleModels from "@models"
import * as ModuleServices from "@services"
import { TaxModuleService } from "@services"
import loadProviders from "./loaders/providers"
const migrationScriptOptions = {
moduleName: Modules.TAX,
@@ -33,7 +34,7 @@ const connectionLoader = ModulesSdkUtils.mikroOrmConnectionLoaderFactory({
})
const service = TaxModuleService
const loaders = [containerLoader, connectionLoader] as any
const loaders = [containerLoader, connectionLoader, loadProviders] as any
export const moduleDefinition: ModuleExports = {
service,

View File

@@ -0,0 +1 @@
export { default as SystemTaxProvider } from "./system"

View File

@@ -0,0 +1,40 @@
import { ITaxProvider, TaxTypes } from "@medusajs/types"
export default class SystemTaxService implements ITaxProvider {
static identifier = "system"
getIdentifier(): string {
return SystemTaxService.identifier
}
async getTaxLines(
itemLines: TaxTypes.ItemTaxCalculationLine[],
shippingLines: TaxTypes.ShippingTaxCalculationLine[],
_: TaxTypes.TaxCalculationContext
): Promise<(TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[]> {
let taxLines: (TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[] =
itemLines.flatMap((l) => {
return l.rates.map((r) => ({
rate_id: r.id,
rate: r.rate || 0,
name: r.name,
code: r.code,
line_item_id: l.line_item.id,
}))
})
taxLines = taxLines.concat(
shippingLines.flatMap((l) => {
return l.rates.map((r) => ({
rate_id: r.id,
rate: r.rate || 0,
name: r.name,
code: r.code,
shipping_line_id: l.shipping_line.id,
}))
})
)
return taxLines
}
}

View File

@@ -2,6 +2,7 @@ import {
Context,
DAL,
ITaxModuleService,
ITaxProvider,
InternalModuleDeclaration,
ModuleJoinerConfig,
ModulesSdkTypes,
@@ -25,10 +26,16 @@ type InjectedDependencies = {
taxRateService: ModulesSdkTypes.InternalModuleService<any>
taxRegionService: ModulesSdkTypes.InternalModuleService<any>
taxRateRuleService: ModulesSdkTypes.InternalModuleService<any>
[key: `tp_${string}`]: ITaxProvider
}
const generateForModels = [TaxRegion, TaxRateRule]
type ItemWithRates = {
rates: TaxRate[]
item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO
}
export default class TaxModuleService<
TTaxRate extends TaxRate = TaxRate,
TTaxRegion extends TaxRegion = TaxRegion,
@@ -44,6 +51,7 @@ export default class TaxModuleService<
>(TaxRate, generateForModels, entityNameToLinkableKeysMap)
implements ITaxModuleService
{
protected readonly container_: InjectedDependencies
protected baseRepository_: DAL.RepositoryService
protected taxRateService_: ModulesSdkTypes.InternalModuleService<TTaxRate>
protected taxRegionService_: ModulesSdkTypes.InternalModuleService<TTaxRegion>
@@ -61,6 +69,7 @@ export default class TaxModuleService<
// @ts-ignore
super(...arguments)
this.container_ = arguments[0]
this.baseRepository_ = baseRepository
this.taxRateService_ = taxRateService
this.taxRegionService_ = taxRegionService
@@ -259,11 +268,19 @@ export default class TaxModuleService<
sharedContext
)
const parentRegion = regions.find((r) => r.province_code === null)
if (!parentRegion) {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
"No parent region found for country"
)
}
const toReturn = await promiseAll(
items.map(async (item) => {
const regionIds = regions.map((r) => r.id)
const rateQuery = this.getTaxRateQueryForItem(item, regionIds)
const rates = await this.taxRateService_.list(
const candidateRates = await this.taxRateService_.list(
rateQuery,
{
relations: ["tax_region", "rules"],
@@ -271,11 +288,71 @@ export default class TaxModuleService<
sharedContext
)
return await this.getTaxRatesForItem(item, rates)
const applicableRates = await this.getTaxRatesForItem(
item,
candidateRates
)
return {
rates: applicableRates,
item,
}
})
)
return toReturn.flat()
const taxLines = await this.getTaxLinesFromProvider(
parentRegion.provider_id,
toReturn,
calculationContext
)
return taxLines
}
private async getTaxLinesFromProvider(
rawProviderId: string | null,
items: ItemWithRates[],
calculationContext: TaxTypes.TaxCalculationContext
) {
const providerId = rawProviderId || "system"
let provider: ITaxProvider
try {
provider = this.container_[`tp_${providerId}`] as ITaxProvider
} catch (err) {
throw new MedusaError(
MedusaError.Types.NOT_FOUND,
`Failed to resolve Tax Provider with id: ${providerId}. Make sure it's installed and configured in the Tax Module's options.`
)
}
const [itemLines, shippingLines] = items.reduce(
(acc, line) => {
if ("shipping_option_id" in line.item) {
acc[1].push({
shipping_line: line.item,
rates: line.rates,
})
} else {
acc[0].push({
line_item: line.item,
rates: line.rates,
})
}
return acc
},
[[], []] as [
TaxTypes.ItemTaxCalculationLine[],
TaxTypes.ShippingTaxCalculationLine[]
]
)
const itemTaxLines = await provider.getTaxLines(
itemLines,
shippingLines,
calculationContext
)
return itemTaxLines
}
private async verifyProvinceToCountryMatch(
@@ -316,7 +393,7 @@ export default class TaxModuleService<
private async getTaxRatesForItem(
item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO,
rates: TTaxRate[]
): Promise<(TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[]> {
): Promise<TTaxRate[]> {
if (!rates.length) {
return []
}
@@ -324,7 +401,7 @@ export default class TaxModuleService<
const prioritizedRates = this.prioritizeRates(rates, item)
const rate = prioritizedRates[0]
const ratesToReturn = [this.buildRateForItem(rate, item)]
const ratesToReturn = [rate]
// If the rate can be combined we need to find the rate's
// parent region and add that rate too. If not we can return now.
@@ -339,37 +416,12 @@ export default class TaxModuleService<
)
if (parentRate) {
ratesToReturn.push(this.buildRateForItem(parentRate, item))
ratesToReturn.push(parentRate)
}
return ratesToReturn
}
private buildRateForItem(
rate: TTaxRate,
item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO
): TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO {
const isShipping = "shipping_option_id" in item
const toReturn = {
rate_id: rate.id,
rate: rate.rate,
code: rate.code,
name: rate.name,
}
if (isShipping) {
return {
...toReturn,
shipping_line_id: item.id,
}
}
return {
...toReturn,
line_item_id: item.id,
}
}
private getTaxRateQueryForItem(
item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO,
regionIds: string[]

View File

@@ -33,7 +33,7 @@ export interface TaxRateDTO {
/**
* The ID of the user that created the Tax Rate.
*/
created_by: string
created_by: string | null
}
export interface FilterableTaxRateProps
@@ -138,7 +138,7 @@ export interface TaxCalculationContext {
}
interface TaxLineDTO {
rate_id: string
rate_id?: string
rate: number | null
code: string | null
name: string

View File

@@ -1,3 +1,4 @@
export * from "./common"
export * from "./mutations"
export * from "./service"
export * from "./provider"

View File

@@ -0,0 +1,48 @@
import {
ItemTaxLineDTO,
ShippingTaxLineDTO,
TaxCalculationContext,
TaxRateDTO,
TaxableItemDTO,
TaxableShippingDTO,
} from "./common"
/**
* A shipping method and the tax rates configured to apply to the
* shipping method.
*/
export type ShippingTaxCalculationLine = {
/**
* The shipping method to calculate taxes for.
*/
shipping_line: TaxableShippingDTO
/**
* The rates applicable on the shipping method.
*/
rates: TaxRateDTO[]
}
/**
* A line item and the tax rates configured to apply to the
* product contained in the line item.
*/
export type ItemTaxCalculationLine = {
/**
* The line item to calculate taxes for.
*/
line_item: TaxableItemDTO
/**
* The rates applicable on the item.
*/
rates: TaxRateDTO[]
}
export interface ITaxProvider {
getIdentifier(): string
getTaxLines(
itemLines: ItemTaxCalculationLine[],
shippingLines: ShippingTaxCalculationLine[],
context: TaxCalculationContext
): Promise<(ItemTaxLineDTO | ShippingTaxLineDTO)[]>
}

View File

@@ -95,7 +95,7 @@ export interface ITaxModuleService extends IModuleService {
): Promise<TaxRateRuleDTO[]>
getTaxLines(
item: (TaxableItemDTO | TaxableShippingDTO)[],
items: (TaxableItemDTO | TaxableShippingDTO)[],
calculationContext: TaxCalculationContext,
sharedContext?: Context
): Promise<(ItemTaxLineDTO | ShippingTaxLineDTO)[]>