feat(tax): add tax provider support (#6492)
**What** - Adds Tax Provider model - Adds loader to get Tax Provider from module options - Adds System Tax provider which forwards tax rates as is
This commit is contained in:
43
packages/tax/src/loaders/providers.ts
Normal file
43
packages/tax/src/loaders/providers.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { moduleProviderLoader } from "@medusajs/modules-sdk"
|
||||
|
||||
import { LoaderOptions, ModuleProvider, ModulesSdkTypes } from "@medusajs/types"
|
||||
import { Lifetime, asFunction } from "awilix"
|
||||
|
||||
import * as providers from "../providers"
|
||||
|
||||
const registrationFn = async (klass, container, pluginOptions) => {
|
||||
container.register({
|
||||
[`tp_${klass.identifier}`]: asFunction(
|
||||
(cradle) => new klass(cradle, pluginOptions),
|
||||
{ lifetime: klass.LIFE_TIME || Lifetime.SINGLETON }
|
||||
),
|
||||
})
|
||||
|
||||
container.registerAdd(
|
||||
"tax_providers",
|
||||
asFunction((cradle) => new klass(cradle, pluginOptions), {
|
||||
lifetime: klass.LIFE_TIME || Lifetime.SINGLETON,
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
export default async ({
|
||||
container,
|
||||
options,
|
||||
}: LoaderOptions<
|
||||
(
|
||||
| ModulesSdkTypes.ModuleServiceInitializeOptions
|
||||
| ModulesSdkTypes.ModuleServiceInitializeCustomDataLayerOptions
|
||||
) & { providers: ModuleProvider[] }
|
||||
>): Promise<void> => {
|
||||
// Local providers
|
||||
for (const provider of Object.values(providers)) {
|
||||
await registrationFn(provider, container, {})
|
||||
}
|
||||
|
||||
await moduleProviderLoader({
|
||||
container,
|
||||
providers: options?.providers || [],
|
||||
registerServiceFn: registrationFn,
|
||||
})
|
||||
}
|
||||
16
packages/tax/src/models/tax-provider.ts
Normal file
16
packages/tax/src/models/tax-provider.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { Entity, OptionalProps, PrimaryKey, Property } from "@mikro-orm/core"
|
||||
|
||||
const TABLE_NAME = "tax_provider"
|
||||
@Entity({ tableName: TABLE_NAME })
|
||||
export default class TaxProvider {
|
||||
[OptionalProps]?: "is_enabled"
|
||||
|
||||
@PrimaryKey({ columnType: "text" })
|
||||
id: string
|
||||
|
||||
@Property({
|
||||
default: true,
|
||||
columnType: "boolean",
|
||||
})
|
||||
is_enabled: boolean = true
|
||||
}
|
||||
@@ -59,7 +59,7 @@ export default class TaxRateRule {
|
||||
type: "text",
|
||||
fieldName: "tax_rate_id",
|
||||
mapToPk: true,
|
||||
cascade: [Cascade.REMOVE],
|
||||
onDelete: "cascade",
|
||||
})
|
||||
@taxRateIdIndexStatement.MikroORMIndex()
|
||||
tax_rate_id: string
|
||||
|
||||
@@ -69,7 +69,7 @@ export default class TaxRate {
|
||||
type: "text",
|
||||
fieldName: "tax_region_id",
|
||||
mapToPk: true,
|
||||
cascade: [Cascade.REMOVE],
|
||||
onDelete: "cascade",
|
||||
})
|
||||
@taxRegionIdIndexStatement.MikroORMIndex()
|
||||
tax_region_id: string
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
Cascade,
|
||||
} from "@mikro-orm/core"
|
||||
import TaxRate from "./tax-rate"
|
||||
import TaxProvider from "./tax-provider"
|
||||
|
||||
type OptionalTaxRegionProps = DAL.SoftDeletableEntityDateColumns
|
||||
|
||||
@@ -32,7 +33,13 @@ const countryCodeProvinceIndexStatement = createPsqlIndexStatementHelper({
|
||||
unique: true,
|
||||
})
|
||||
|
||||
const taxRegionProviderTopLevelCheckName = "CK_tax_region_provider_top_level"
|
||||
const taxRegionCountryTopLevelCheckName = "CK_tax_region_country_top_level"
|
||||
|
||||
@Check({
|
||||
name: taxRegionProviderTopLevelCheckName,
|
||||
expression: `parent_id IS NULL OR provider_id IS NULL`,
|
||||
})
|
||||
@Check({
|
||||
name: taxRegionCountryTopLevelCheckName,
|
||||
expression: `parent_id IS NULL OR province_code IS NOT NULL`,
|
||||
@@ -46,6 +53,13 @@ export default class TaxRegion {
|
||||
@PrimaryKey({ columnType: "text" })
|
||||
id!: string
|
||||
|
||||
@ManyToOne(() => TaxProvider, {
|
||||
fieldName: "provider_id",
|
||||
mapToPk: true,
|
||||
nullable: true,
|
||||
})
|
||||
provider_id: string | null = null
|
||||
|
||||
@Property({ columnType: "text" })
|
||||
country_code: string
|
||||
|
||||
@@ -55,7 +69,7 @@ export default class TaxRegion {
|
||||
@ManyToOne(() => TaxRegion, {
|
||||
index: "IDX_tax_region_parent_id",
|
||||
fieldName: "parent_id",
|
||||
cascade: [Cascade.REMOVE],
|
||||
onDelete: "cascade",
|
||||
mapToPk: true,
|
||||
nullable: true,
|
||||
})
|
||||
|
||||
@@ -5,6 +5,7 @@ import * as Models from "@models"
|
||||
import * as ModuleModels from "@models"
|
||||
import * as ModuleServices from "@services"
|
||||
import { TaxModuleService } from "@services"
|
||||
import loadProviders from "./loaders/providers"
|
||||
|
||||
const migrationScriptOptions = {
|
||||
moduleName: Modules.TAX,
|
||||
@@ -33,7 +34,7 @@ const connectionLoader = ModulesSdkUtils.mikroOrmConnectionLoaderFactory({
|
||||
})
|
||||
|
||||
const service = TaxModuleService
|
||||
const loaders = [containerLoader, connectionLoader] as any
|
||||
const loaders = [containerLoader, connectionLoader, loadProviders] as any
|
||||
|
||||
export const moduleDefinition: ModuleExports = {
|
||||
service,
|
||||
|
||||
1
packages/tax/src/providers/index.ts
Normal file
1
packages/tax/src/providers/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { default as SystemTaxProvider } from "./system"
|
||||
40
packages/tax/src/providers/system.ts
Normal file
40
packages/tax/src/providers/system.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { ITaxProvider, TaxTypes } from "@medusajs/types"
|
||||
|
||||
export default class SystemTaxService implements ITaxProvider {
|
||||
static identifier = "system"
|
||||
|
||||
getIdentifier(): string {
|
||||
return SystemTaxService.identifier
|
||||
}
|
||||
|
||||
async getTaxLines(
|
||||
itemLines: TaxTypes.ItemTaxCalculationLine[],
|
||||
shippingLines: TaxTypes.ShippingTaxCalculationLine[],
|
||||
_: TaxTypes.TaxCalculationContext
|
||||
): Promise<(TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[]> {
|
||||
let taxLines: (TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[] =
|
||||
itemLines.flatMap((l) => {
|
||||
return l.rates.map((r) => ({
|
||||
rate_id: r.id,
|
||||
rate: r.rate || 0,
|
||||
name: r.name,
|
||||
code: r.code,
|
||||
line_item_id: l.line_item.id,
|
||||
}))
|
||||
})
|
||||
|
||||
taxLines = taxLines.concat(
|
||||
shippingLines.flatMap((l) => {
|
||||
return l.rates.map((r) => ({
|
||||
rate_id: r.id,
|
||||
rate: r.rate || 0,
|
||||
name: r.name,
|
||||
code: r.code,
|
||||
shipping_line_id: l.shipping_line.id,
|
||||
}))
|
||||
})
|
||||
)
|
||||
|
||||
return taxLines
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import {
|
||||
Context,
|
||||
DAL,
|
||||
ITaxModuleService,
|
||||
ITaxProvider,
|
||||
InternalModuleDeclaration,
|
||||
ModuleJoinerConfig,
|
||||
ModulesSdkTypes,
|
||||
@@ -25,10 +26,16 @@ type InjectedDependencies = {
|
||||
taxRateService: ModulesSdkTypes.InternalModuleService<any>
|
||||
taxRegionService: ModulesSdkTypes.InternalModuleService<any>
|
||||
taxRateRuleService: ModulesSdkTypes.InternalModuleService<any>
|
||||
[key: `tp_${string}`]: ITaxProvider
|
||||
}
|
||||
|
||||
const generateForModels = [TaxRegion, TaxRateRule]
|
||||
|
||||
type ItemWithRates = {
|
||||
rates: TaxRate[]
|
||||
item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO
|
||||
}
|
||||
|
||||
export default class TaxModuleService<
|
||||
TTaxRate extends TaxRate = TaxRate,
|
||||
TTaxRegion extends TaxRegion = TaxRegion,
|
||||
@@ -44,6 +51,7 @@ export default class TaxModuleService<
|
||||
>(TaxRate, generateForModels, entityNameToLinkableKeysMap)
|
||||
implements ITaxModuleService
|
||||
{
|
||||
protected readonly container_: InjectedDependencies
|
||||
protected baseRepository_: DAL.RepositoryService
|
||||
protected taxRateService_: ModulesSdkTypes.InternalModuleService<TTaxRate>
|
||||
protected taxRegionService_: ModulesSdkTypes.InternalModuleService<TTaxRegion>
|
||||
@@ -61,6 +69,7 @@ export default class TaxModuleService<
|
||||
// @ts-ignore
|
||||
super(...arguments)
|
||||
|
||||
this.container_ = arguments[0]
|
||||
this.baseRepository_ = baseRepository
|
||||
this.taxRateService_ = taxRateService
|
||||
this.taxRegionService_ = taxRegionService
|
||||
@@ -259,11 +268,19 @@ export default class TaxModuleService<
|
||||
sharedContext
|
||||
)
|
||||
|
||||
const parentRegion = regions.find((r) => r.province_code === null)
|
||||
if (!parentRegion) {
|
||||
throw new MedusaError(
|
||||
MedusaError.Types.INVALID_DATA,
|
||||
"No parent region found for country"
|
||||
)
|
||||
}
|
||||
|
||||
const toReturn = await promiseAll(
|
||||
items.map(async (item) => {
|
||||
const regionIds = regions.map((r) => r.id)
|
||||
const rateQuery = this.getTaxRateQueryForItem(item, regionIds)
|
||||
const rates = await this.taxRateService_.list(
|
||||
const candidateRates = await this.taxRateService_.list(
|
||||
rateQuery,
|
||||
{
|
||||
relations: ["tax_region", "rules"],
|
||||
@@ -271,11 +288,71 @@ export default class TaxModuleService<
|
||||
sharedContext
|
||||
)
|
||||
|
||||
return await this.getTaxRatesForItem(item, rates)
|
||||
const applicableRates = await this.getTaxRatesForItem(
|
||||
item,
|
||||
candidateRates
|
||||
)
|
||||
|
||||
return {
|
||||
rates: applicableRates,
|
||||
item,
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return toReturn.flat()
|
||||
const taxLines = await this.getTaxLinesFromProvider(
|
||||
parentRegion.provider_id,
|
||||
toReturn,
|
||||
calculationContext
|
||||
)
|
||||
|
||||
return taxLines
|
||||
}
|
||||
|
||||
private async getTaxLinesFromProvider(
|
||||
rawProviderId: string | null,
|
||||
items: ItemWithRates[],
|
||||
calculationContext: TaxTypes.TaxCalculationContext
|
||||
) {
|
||||
const providerId = rawProviderId || "system"
|
||||
let provider: ITaxProvider
|
||||
try {
|
||||
provider = this.container_[`tp_${providerId}`] as ITaxProvider
|
||||
} catch (err) {
|
||||
throw new MedusaError(
|
||||
MedusaError.Types.NOT_FOUND,
|
||||
`Failed to resolve Tax Provider with id: ${providerId}. Make sure it's installed and configured in the Tax Module's options.`
|
||||
)
|
||||
}
|
||||
|
||||
const [itemLines, shippingLines] = items.reduce(
|
||||
(acc, line) => {
|
||||
if ("shipping_option_id" in line.item) {
|
||||
acc[1].push({
|
||||
shipping_line: line.item,
|
||||
rates: line.rates,
|
||||
})
|
||||
} else {
|
||||
acc[0].push({
|
||||
line_item: line.item,
|
||||
rates: line.rates,
|
||||
})
|
||||
}
|
||||
return acc
|
||||
},
|
||||
[[], []] as [
|
||||
TaxTypes.ItemTaxCalculationLine[],
|
||||
TaxTypes.ShippingTaxCalculationLine[]
|
||||
]
|
||||
)
|
||||
|
||||
const itemTaxLines = await provider.getTaxLines(
|
||||
itemLines,
|
||||
shippingLines,
|
||||
calculationContext
|
||||
)
|
||||
|
||||
return itemTaxLines
|
||||
}
|
||||
|
||||
private async verifyProvinceToCountryMatch(
|
||||
@@ -316,7 +393,7 @@ export default class TaxModuleService<
|
||||
private async getTaxRatesForItem(
|
||||
item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO,
|
||||
rates: TTaxRate[]
|
||||
): Promise<(TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[]> {
|
||||
): Promise<TTaxRate[]> {
|
||||
if (!rates.length) {
|
||||
return []
|
||||
}
|
||||
@@ -324,7 +401,7 @@ export default class TaxModuleService<
|
||||
const prioritizedRates = this.prioritizeRates(rates, item)
|
||||
const rate = prioritizedRates[0]
|
||||
|
||||
const ratesToReturn = [this.buildRateForItem(rate, item)]
|
||||
const ratesToReturn = [rate]
|
||||
|
||||
// If the rate can be combined we need to find the rate's
|
||||
// parent region and add that rate too. If not we can return now.
|
||||
@@ -339,37 +416,12 @@ export default class TaxModuleService<
|
||||
)
|
||||
|
||||
if (parentRate) {
|
||||
ratesToReturn.push(this.buildRateForItem(parentRate, item))
|
||||
ratesToReturn.push(parentRate)
|
||||
}
|
||||
|
||||
return ratesToReturn
|
||||
}
|
||||
|
||||
private buildRateForItem(
|
||||
rate: TTaxRate,
|
||||
item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO
|
||||
): TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO {
|
||||
const isShipping = "shipping_option_id" in item
|
||||
const toReturn = {
|
||||
rate_id: rate.id,
|
||||
rate: rate.rate,
|
||||
code: rate.code,
|
||||
name: rate.name,
|
||||
}
|
||||
|
||||
if (isShipping) {
|
||||
return {
|
||||
...toReturn,
|
||||
shipping_line_id: item.id,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
...toReturn,
|
||||
line_item_id: item.id,
|
||||
}
|
||||
}
|
||||
|
||||
private getTaxRateQueryForItem(
|
||||
item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO,
|
||||
regionIds: string[]
|
||||
|
||||
@@ -33,7 +33,7 @@ export interface TaxRateDTO {
|
||||
/**
|
||||
* The ID of the user that created the Tax Rate.
|
||||
*/
|
||||
created_by: string
|
||||
created_by: string | null
|
||||
}
|
||||
|
||||
export interface FilterableTaxRateProps
|
||||
@@ -138,7 +138,7 @@ export interface TaxCalculationContext {
|
||||
}
|
||||
|
||||
interface TaxLineDTO {
|
||||
rate_id: string
|
||||
rate_id?: string
|
||||
rate: number | null
|
||||
code: string | null
|
||||
name: string
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
export * from "./common"
|
||||
export * from "./mutations"
|
||||
export * from "./service"
|
||||
export * from "./provider"
|
||||
|
||||
48
packages/types/src/tax/provider.ts
Normal file
48
packages/types/src/tax/provider.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import {
|
||||
ItemTaxLineDTO,
|
||||
ShippingTaxLineDTO,
|
||||
TaxCalculationContext,
|
||||
TaxRateDTO,
|
||||
TaxableItemDTO,
|
||||
TaxableShippingDTO,
|
||||
} from "./common"
|
||||
|
||||
/**
|
||||
* A shipping method and the tax rates configured to apply to the
|
||||
* shipping method.
|
||||
*/
|
||||
export type ShippingTaxCalculationLine = {
|
||||
/**
|
||||
* The shipping method to calculate taxes for.
|
||||
*/
|
||||
shipping_line: TaxableShippingDTO
|
||||
/**
|
||||
* The rates applicable on the shipping method.
|
||||
*/
|
||||
rates: TaxRateDTO[]
|
||||
}
|
||||
|
||||
/**
|
||||
* A line item and the tax rates configured to apply to the
|
||||
* product contained in the line item.
|
||||
*/
|
||||
export type ItemTaxCalculationLine = {
|
||||
/**
|
||||
* The line item to calculate taxes for.
|
||||
*/
|
||||
line_item: TaxableItemDTO
|
||||
/**
|
||||
* The rates applicable on the item.
|
||||
*/
|
||||
rates: TaxRateDTO[]
|
||||
}
|
||||
|
||||
export interface ITaxProvider {
|
||||
getIdentifier(): string
|
||||
|
||||
getTaxLines(
|
||||
itemLines: ItemTaxCalculationLine[],
|
||||
shippingLines: ShippingTaxCalculationLine[],
|
||||
context: TaxCalculationContext
|
||||
): Promise<(ItemTaxLineDTO | ShippingTaxLineDTO)[]>
|
||||
}
|
||||
@@ -95,7 +95,7 @@ export interface ITaxModuleService extends IModuleService {
|
||||
): Promise<TaxRateRuleDTO[]>
|
||||
|
||||
getTaxLines(
|
||||
item: (TaxableItemDTO | TaxableShippingDTO)[],
|
||||
items: (TaxableItemDTO | TaxableShippingDTO)[],
|
||||
calculationContext: TaxCalculationContext,
|
||||
sharedContext?: Context
|
||||
): Promise<(ItemTaxLineDTO | ShippingTaxLineDTO)[]>
|
||||
|
||||
Reference in New Issue
Block a user