diff --git a/packages/tax/src/loaders/providers.ts b/packages/tax/src/loaders/providers.ts new file mode 100644 index 0000000000..fdeac2c944 --- /dev/null +++ b/packages/tax/src/loaders/providers.ts @@ -0,0 +1,43 @@ +import { moduleProviderLoader } from "@medusajs/modules-sdk" + +import { LoaderOptions, ModuleProvider, ModulesSdkTypes } from "@medusajs/types" +import { Lifetime, asFunction } from "awilix" + +import * as providers from "../providers" + +const registrationFn = async (klass, container, pluginOptions) => { + container.register({ + [`tp_${klass.identifier}`]: asFunction( + (cradle) => new klass(cradle, pluginOptions), + { lifetime: klass.LIFE_TIME || Lifetime.SINGLETON } + ), + }) + + container.registerAdd( + "tax_providers", + asFunction((cradle) => new klass(cradle, pluginOptions), { + lifetime: klass.LIFE_TIME || Lifetime.SINGLETON, + }) + ) +} + +export default async ({ + container, + options, +}: LoaderOptions< + ( + | ModulesSdkTypes.ModuleServiceInitializeOptions + | ModulesSdkTypes.ModuleServiceInitializeCustomDataLayerOptions + ) & { providers: ModuleProvider[] } +>): Promise => { + // Local providers + for (const provider of Object.values(providers)) { + await registrationFn(provider, container, {}) + } + + await moduleProviderLoader({ + container, + providers: options?.providers || [], + registerServiceFn: registrationFn, + }) +} diff --git a/packages/tax/src/models/tax-provider.ts b/packages/tax/src/models/tax-provider.ts new file mode 100644 index 0000000000..e08a3c952c --- /dev/null +++ b/packages/tax/src/models/tax-provider.ts @@ -0,0 +1,16 @@ +import { Entity, OptionalProps, PrimaryKey, Property } from "@mikro-orm/core" + +const TABLE_NAME = "tax_provider" +@Entity({ tableName: TABLE_NAME }) +export default class TaxProvider { + [OptionalProps]?: "is_enabled" + + @PrimaryKey({ columnType: "text" }) + id: string + + @Property({ + default: true, + columnType: "boolean", + }) + is_enabled: boolean = true +} diff --git a/packages/tax/src/models/tax-rate-rule.ts b/packages/tax/src/models/tax-rate-rule.ts index 3fbdbaa62e..998d7e00f0 100644 --- a/packages/tax/src/models/tax-rate-rule.ts +++ b/packages/tax/src/models/tax-rate-rule.ts @@ -59,7 +59,7 @@ export default class TaxRateRule { type: "text", fieldName: "tax_rate_id", mapToPk: true, - cascade: [Cascade.REMOVE], + onDelete: "cascade", }) @taxRateIdIndexStatement.MikroORMIndex() tax_rate_id: string diff --git a/packages/tax/src/models/tax-rate.ts b/packages/tax/src/models/tax-rate.ts index 14e66148c2..77d5203dc3 100644 --- a/packages/tax/src/models/tax-rate.ts +++ b/packages/tax/src/models/tax-rate.ts @@ -69,7 +69,7 @@ export default class TaxRate { type: "text", fieldName: "tax_region_id", mapToPk: true, - cascade: [Cascade.REMOVE], + onDelete: "cascade", }) @taxRegionIdIndexStatement.MikroORMIndex() tax_region_id: string diff --git a/packages/tax/src/models/tax-region.ts b/packages/tax/src/models/tax-region.ts index 5ed4ec00e4..6d29f2571e 100644 --- a/packages/tax/src/models/tax-region.ts +++ b/packages/tax/src/models/tax-region.ts @@ -19,6 +19,7 @@ import { Cascade, } from "@mikro-orm/core" import TaxRate from "./tax-rate" +import TaxProvider from "./tax-provider" type OptionalTaxRegionProps = DAL.SoftDeletableEntityDateColumns @@ -32,7 +33,13 @@ const countryCodeProvinceIndexStatement = createPsqlIndexStatementHelper({ unique: true, }) +const taxRegionProviderTopLevelCheckName = "CK_tax_region_provider_top_level" const taxRegionCountryTopLevelCheckName = "CK_tax_region_country_top_level" + +@Check({ + name: taxRegionProviderTopLevelCheckName, + expression: `parent_id IS NULL OR provider_id IS NULL`, +}) @Check({ name: taxRegionCountryTopLevelCheckName, expression: `parent_id IS NULL OR province_code IS NOT NULL`, @@ -46,6 +53,13 @@ export default class TaxRegion { @PrimaryKey({ columnType: "text" }) id!: string + @ManyToOne(() => TaxProvider, { + fieldName: "provider_id", + mapToPk: true, + nullable: true, + }) + provider_id: string | null = null + @Property({ columnType: "text" }) country_code: string @@ -55,7 +69,7 @@ export default class TaxRegion { @ManyToOne(() => TaxRegion, { index: "IDX_tax_region_parent_id", fieldName: "parent_id", - cascade: [Cascade.REMOVE], + onDelete: "cascade", mapToPk: true, nullable: true, }) diff --git a/packages/tax/src/module-definition.ts b/packages/tax/src/module-definition.ts index f574b99176..3f3375578a 100644 --- a/packages/tax/src/module-definition.ts +++ b/packages/tax/src/module-definition.ts @@ -5,6 +5,7 @@ import * as Models from "@models" import * as ModuleModels from "@models" import * as ModuleServices from "@services" import { TaxModuleService } from "@services" +import loadProviders from "./loaders/providers" const migrationScriptOptions = { moduleName: Modules.TAX, @@ -33,7 +34,7 @@ const connectionLoader = ModulesSdkUtils.mikroOrmConnectionLoaderFactory({ }) const service = TaxModuleService -const loaders = [containerLoader, connectionLoader] as any +const loaders = [containerLoader, connectionLoader, loadProviders] as any export const moduleDefinition: ModuleExports = { service, diff --git a/packages/tax/src/providers/index.ts b/packages/tax/src/providers/index.ts new file mode 100644 index 0000000000..c3790c4f05 --- /dev/null +++ b/packages/tax/src/providers/index.ts @@ -0,0 +1 @@ +export { default as SystemTaxProvider } from "./system" diff --git a/packages/tax/src/providers/system.ts b/packages/tax/src/providers/system.ts new file mode 100644 index 0000000000..0b67832ebb --- /dev/null +++ b/packages/tax/src/providers/system.ts @@ -0,0 +1,40 @@ +import { ITaxProvider, TaxTypes } from "@medusajs/types" + +export default class SystemTaxService implements ITaxProvider { + static identifier = "system" + + getIdentifier(): string { + return SystemTaxService.identifier + } + + async getTaxLines( + itemLines: TaxTypes.ItemTaxCalculationLine[], + shippingLines: TaxTypes.ShippingTaxCalculationLine[], + _: TaxTypes.TaxCalculationContext + ): Promise<(TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[]> { + let taxLines: (TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[] = + itemLines.flatMap((l) => { + return l.rates.map((r) => ({ + rate_id: r.id, + rate: r.rate || 0, + name: r.name, + code: r.code, + line_item_id: l.line_item.id, + })) + }) + + taxLines = taxLines.concat( + shippingLines.flatMap((l) => { + return l.rates.map((r) => ({ + rate_id: r.id, + rate: r.rate || 0, + name: r.name, + code: r.code, + shipping_line_id: l.shipping_line.id, + })) + }) + ) + + return taxLines + } +} diff --git a/packages/tax/src/services/tax-module-service.ts b/packages/tax/src/services/tax-module-service.ts index a1af013d14..d13ce53691 100644 --- a/packages/tax/src/services/tax-module-service.ts +++ b/packages/tax/src/services/tax-module-service.ts @@ -2,6 +2,7 @@ import { Context, DAL, ITaxModuleService, + ITaxProvider, InternalModuleDeclaration, ModuleJoinerConfig, ModulesSdkTypes, @@ -25,10 +26,16 @@ type InjectedDependencies = { taxRateService: ModulesSdkTypes.InternalModuleService taxRegionService: ModulesSdkTypes.InternalModuleService taxRateRuleService: ModulesSdkTypes.InternalModuleService + [key: `tp_${string}`]: ITaxProvider } const generateForModels = [TaxRegion, TaxRateRule] +type ItemWithRates = { + rates: TaxRate[] + item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO +} + export default class TaxModuleService< TTaxRate extends TaxRate = TaxRate, TTaxRegion extends TaxRegion = TaxRegion, @@ -44,6 +51,7 @@ export default class TaxModuleService< >(TaxRate, generateForModels, entityNameToLinkableKeysMap) implements ITaxModuleService { + protected readonly container_: InjectedDependencies protected baseRepository_: DAL.RepositoryService protected taxRateService_: ModulesSdkTypes.InternalModuleService protected taxRegionService_: ModulesSdkTypes.InternalModuleService @@ -61,6 +69,7 @@ export default class TaxModuleService< // @ts-ignore super(...arguments) + this.container_ = arguments[0] this.baseRepository_ = baseRepository this.taxRateService_ = taxRateService this.taxRegionService_ = taxRegionService @@ -259,11 +268,19 @@ export default class TaxModuleService< sharedContext ) + const parentRegion = regions.find((r) => r.province_code === null) + if (!parentRegion) { + throw new MedusaError( + MedusaError.Types.INVALID_DATA, + "No parent region found for country" + ) + } + const toReturn = await promiseAll( items.map(async (item) => { const regionIds = regions.map((r) => r.id) const rateQuery = this.getTaxRateQueryForItem(item, regionIds) - const rates = await this.taxRateService_.list( + const candidateRates = await this.taxRateService_.list( rateQuery, { relations: ["tax_region", "rules"], @@ -271,11 +288,71 @@ export default class TaxModuleService< sharedContext ) - return await this.getTaxRatesForItem(item, rates) + const applicableRates = await this.getTaxRatesForItem( + item, + candidateRates + ) + + return { + rates: applicableRates, + item, + } }) ) - return toReturn.flat() + const taxLines = await this.getTaxLinesFromProvider( + parentRegion.provider_id, + toReturn, + calculationContext + ) + + return taxLines + } + + private async getTaxLinesFromProvider( + rawProviderId: string | null, + items: ItemWithRates[], + calculationContext: TaxTypes.TaxCalculationContext + ) { + const providerId = rawProviderId || "system" + let provider: ITaxProvider + try { + provider = this.container_[`tp_${providerId}`] as ITaxProvider + } catch (err) { + throw new MedusaError( + MedusaError.Types.NOT_FOUND, + `Failed to resolve Tax Provider with id: ${providerId}. Make sure it's installed and configured in the Tax Module's options.` + ) + } + + const [itemLines, shippingLines] = items.reduce( + (acc, line) => { + if ("shipping_option_id" in line.item) { + acc[1].push({ + shipping_line: line.item, + rates: line.rates, + }) + } else { + acc[0].push({ + line_item: line.item, + rates: line.rates, + }) + } + return acc + }, + [[], []] as [ + TaxTypes.ItemTaxCalculationLine[], + TaxTypes.ShippingTaxCalculationLine[] + ] + ) + + const itemTaxLines = await provider.getTaxLines( + itemLines, + shippingLines, + calculationContext + ) + + return itemTaxLines } private async verifyProvinceToCountryMatch( @@ -316,7 +393,7 @@ export default class TaxModuleService< private async getTaxRatesForItem( item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO, rates: TTaxRate[] - ): Promise<(TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO)[]> { + ): Promise { if (!rates.length) { return [] } @@ -324,7 +401,7 @@ export default class TaxModuleService< const prioritizedRates = this.prioritizeRates(rates, item) const rate = prioritizedRates[0] - const ratesToReturn = [this.buildRateForItem(rate, item)] + const ratesToReturn = [rate] // If the rate can be combined we need to find the rate's // parent region and add that rate too. If not we can return now. @@ -339,37 +416,12 @@ export default class TaxModuleService< ) if (parentRate) { - ratesToReturn.push(this.buildRateForItem(parentRate, item)) + ratesToReturn.push(parentRate) } return ratesToReturn } - private buildRateForItem( - rate: TTaxRate, - item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO - ): TaxTypes.ItemTaxLineDTO | TaxTypes.ShippingTaxLineDTO { - const isShipping = "shipping_option_id" in item - const toReturn = { - rate_id: rate.id, - rate: rate.rate, - code: rate.code, - name: rate.name, - } - - if (isShipping) { - return { - ...toReturn, - shipping_line_id: item.id, - } - } - - return { - ...toReturn, - line_item_id: item.id, - } - } - private getTaxRateQueryForItem( item: TaxTypes.TaxableItemDTO | TaxTypes.TaxableShippingDTO, regionIds: string[] diff --git a/packages/types/src/tax/common.ts b/packages/types/src/tax/common.ts index cad17d34b5..7610386649 100644 --- a/packages/types/src/tax/common.ts +++ b/packages/types/src/tax/common.ts @@ -33,7 +33,7 @@ export interface TaxRateDTO { /** * The ID of the user that created the Tax Rate. */ - created_by: string + created_by: string | null } export interface FilterableTaxRateProps @@ -138,7 +138,7 @@ export interface TaxCalculationContext { } interface TaxLineDTO { - rate_id: string + rate_id?: string rate: number | null code: string | null name: string diff --git a/packages/types/src/tax/index.ts b/packages/types/src/tax/index.ts index 0c73656566..a8cf1df979 100644 --- a/packages/types/src/tax/index.ts +++ b/packages/types/src/tax/index.ts @@ -1,3 +1,4 @@ export * from "./common" export * from "./mutations" export * from "./service" +export * from "./provider" diff --git a/packages/types/src/tax/provider.ts b/packages/types/src/tax/provider.ts new file mode 100644 index 0000000000..6cc7904720 --- /dev/null +++ b/packages/types/src/tax/provider.ts @@ -0,0 +1,48 @@ +import { + ItemTaxLineDTO, + ShippingTaxLineDTO, + TaxCalculationContext, + TaxRateDTO, + TaxableItemDTO, + TaxableShippingDTO, +} from "./common" + +/** + * A shipping method and the tax rates configured to apply to the + * shipping method. + */ +export type ShippingTaxCalculationLine = { + /** + * The shipping method to calculate taxes for. + */ + shipping_line: TaxableShippingDTO + /** + * The rates applicable on the shipping method. + */ + rates: TaxRateDTO[] +} + +/** + * A line item and the tax rates configured to apply to the + * product contained in the line item. + */ +export type ItemTaxCalculationLine = { + /** + * The line item to calculate taxes for. + */ + line_item: TaxableItemDTO + /** + * The rates applicable on the item. + */ + rates: TaxRateDTO[] +} + +export interface ITaxProvider { + getIdentifier(): string + + getTaxLines( + itemLines: ItemTaxCalculationLine[], + shippingLines: ShippingTaxCalculationLine[], + context: TaxCalculationContext + ): Promise<(ItemTaxLineDTO | ShippingTaxLineDTO)[]> +} diff --git a/packages/types/src/tax/service.ts b/packages/types/src/tax/service.ts index 6567bef0a7..5f639a4b7c 100644 --- a/packages/types/src/tax/service.ts +++ b/packages/types/src/tax/service.ts @@ -95,7 +95,7 @@ export interface ITaxModuleService extends IModuleService { ): Promise getTaxLines( - item: (TaxableItemDTO | TaxableShippingDTO)[], + items: (TaxableItemDTO | TaxableShippingDTO)[], calculationContext: TaxCalculationContext, sharedContext?: Context ): Promise<(ItemTaxLineDTO | ShippingTaxLineDTO)[]>