From cbb7dd978775676b094cef6b1dbb8c0c526a4905 Mon Sep 17 00:00:00 2001 From: Oliver Windall Juhl <59018053+olivermrbl@users.noreply.github.com> Date: Wed, 7 Sep 2022 10:21:29 +0200 Subject: [PATCH] feat(medusa): Tax-inclusive pricing (#2131) * add feature flag for tax inclusive pricing * update db model for TIP * add migration * set featureflag column decorators * remove unused prop * update tests to reflect feature_flags as any array * fix types * reference key from featureFlag file * use feature flag key in models * fix copy paste mistake * unify spelling * Create gorgeous-experts-guess.md * feat(medusa): create/update endpoints of currency/region/price-lists/shipping-options should allow to pass includes_tax * test(integration): continue to add some integration test * test(integration): continue to add some integration test * test(unit): Fix region service tests * fix(medusa): API unit tests flags management * feat(medusa): Minor cleanup * style(medusa): Fix typo * fix(medusa): rebase * feat(medusa): Replace old tag with the new one * feat(medusa): revert flag * feat(medusa): Cleanup * feat(medusa): feedback * feat(medusa): Rename currency retrieve method * test(medudsa): fix unit tests * chore(medusa): fix oas * feat(medusa): ShippingMethod should include tax setting from parent option (#2021) * feat(medusa): Shipping method should includes tax from parent options * feat(medusa): Condition the includes tax flag to the availability of the feature and add some other tests * test(integration): Move cart/order ff test in separate files * fix: snapshots folder * fix(integration): snapshots * Create calm-baboons-sit.md * test(integration): file naming Co-authored-by: Carlos R. L. Rodrigues * Feat/tax inclusive pricing extend price selection strategy (#2087) * initial changes to price selection strategy including unit tests * typing for tax calculation * update types and remove region and currency from prices results * fix casing * include tax calculation in priceselectionstrategy * integration tests for tax inclusive pricing price calculations * fix build * include tax inclusive considerations when calculating tax fields for variants * include only "includes_tax" fields from currency and region joins * test to see errors in pipelines * conditionally join featureflagged fields * add "includes_tax" to price list factory * add tests for tax inclusive price list prices and currency prices * fix unit tests * refactor pricing array checks to expect arraycontaining * undo error handler * Feat/tax inclusive pricing flag on generated lineitems (#2108) * include tax inclusive pricing flag on generated lineitems * initial addition of tax inclusivity for lineitem service * add generate test to ensure that includes_tax is set when returned from price selection strategy * add integration test for generating lineitem including tax * add test for negative tax inclusion * add tests for mixed pricing * add negative test for setting tax exclusivity * restructure the setting of includes_tax on lineitems * fix: update cwd to be correct in cart test * feat(medusa): Line item totals calculations (#2123) * feat(medusa): Update totals and tax calculation way to calculate the totals * feat(medusa): remove region feetching from decorate total * feat(medusa): cleanup * test(medusa): fix tax calculation tests * comment * test(integration): cleanup * test(integration): cleanup * fix(medusa): return service missing await * feat(medusa): cleanup * feat(medusa): cleanup * test(integration): fix data * feat(medusa): improve tax calculation readability * test(medusa): improve tax calculation structure case Co-authored-by: Sebastian Rindom * Feat(medusa): tax inclusive pricing in shipping method tax (#2125) * initial implementation and test * include tax inclusive calculations for getting shipping options * remove inaccurate comment * remove console log * refactor how prices and taxes are set for shipping methods * fix integration tests * remove verbose flag * fix integration tests * remove console log * format util * use util in price service and tax strategy * fix faulty integration test * undo tax calculation strategy changes in favor or Carlos' pr * undo changes to tax calculation strategy tests * round tax amount * feat(medusa): cleanup calculate tax amount utils and its usage (#2136) * feat(medusa): Refund line totals calculation (#2139) Rely on the update of the following pr https://github.com/medusajs/medusa/pull/2136 **WIP Missing integration tests** **What** Update the totals calculation on the refund line to include the notion of tax inclusive **Test** - Update and add new tests around the refund Fixes CORE-482 * feat(medusa): Tax inclusive discounts calculations (#2137) **What** - Calculate line adjustments correctly taking into account the tax inclusivity - fix totals getLineItemTotals by adjusting the sub total with the original tax amount instead of the tax amount when the unit price includes the taxes **Tests** - The tests create a cart with a percentage discount of 15%, the cart includes 2 items mixing the tax inclusive and validate the items on the result cart as well as the totals on each item. I ve based my calculation validation based on what we have done + some articles around discount apply on price without taxes to validate the output., FIXES CORE-477 * Chore: shipping methods tax inclusive total (#2130) * chore: calculate tax inclusive shipping methods * chore: additional tests and check undefined tax_rate (#2157) * chore: additional tests and check undefined tax_rate * fix: naming + correct price type check * fix: remove price_includes_tax from type * fix: remove price_includes_tax from type Co-authored-by: Philip Korsholm Co-authored-by: adrien2p Co-authored-by: Carlos R. L. Rodrigues Co-authored-by: Philip Korsholm <88927411+pKorsholm@users.noreply.github.com> Co-authored-by: Sebastian Rindom Co-authored-by: Carlos R. L. Rodrigues <37986729+carlos-r-l-rodrigues@users.noreply.github.com> --- .changeset/calm-baboons-sit.md | 5 + .changeset/gorgeous-experts-guess.md | 5 + .../admin/__snapshots__/currency.js.snap | 163 ++ .../admin/__snapshots__/store.js.snap | 7 +- .../api/__tests__/admin/currency.js | 89 ++ .../{ => order}/__snapshots__/order.js.snap | 0 .../admin/order/ff-tax-inclusive-pricing.js | 101 ++ .../api/__tests__/admin/{ => order}/order.js | 18 +- .../api/__tests__/admin/price-list.js | 214 ++- .../api/__tests__/admin/region.js | 104 ++ .../api/__tests__/admin/shipping-options.js | 128 +- .../api/__tests__/admin/store.js | 27 +- .../api/__tests__/admin/swaps.js | 4 +- .../__tests__/line-item-adjustments/index.js | 17 +- .../price-selection/tax-inclusive-prices.js | 1383 +++++++++++++++++ .../returns/ff-tax-inclusive-pricing.js | 276 ++++ .../{ => cart}/__snapshots__/cart.js.snap | 0 .../api/__tests__/store/{ => cart}/cart.js | 24 +- .../store/cart/ff-tax-inclusive-pricing.js | 538 +++++++ .../taxes/orders/ff-tax-inclusive-pricing.js | 106 ++ .../__tests__/taxes/{ => orders}/orders.js | 10 +- integration-tests/api/factories/index.ts | 1 + .../simple-custom-shipping-option-factory.ts | 36 + .../simple-discount-condition-factory.ts | 4 +- .../api/factories/simple-discount-factory.ts | 13 +- .../api/factories/simple-line-item-factory.ts | 2 + .../factories/simple-price-list-factory.ts | 2 + .../api/factories/simple-region-factory.ts | 4 + .../simple-shipping-method-factory.ts | 2 + .../simple-shipping-option-factory.ts | 20 +- .../currencies/__tests__/list-currencies.ts | 52 + .../currencies/__tests__/update-currency.ts | 48 + .../src/api/routes/admin/currencies/index.ts | 37 + .../admin/currencies/list-currencies.ts | 76 + .../admin/currencies/update-currency.ts | 52 + packages/medusa/src/api/routes/admin/index.js | 10 +- .../admin/price-lists/create-price-list.ts | 32 +- .../src/api/routes/admin/price-lists/index.ts | 8 +- .../admin/price-lists/update-price-list.ts | 28 +- .../api/routes/admin/regions/create-region.ts | 17 +- .../src/api/routes/admin/regions/index.ts | 8 +- .../api/routes/admin/regions/update-region.ts | 15 +- .../__tests__/add-product-batch.ts | 3 +- .../__tests__/create-sales-channel.ts | 3 +- .../__tests__/delete-products-batch.ts | 3 +- .../__tests__/delete-sales-channel.ts | 3 +- .../__tests__/get-sales-channel.ts | 3 +- .../__tests__/list-sales-channels.js | 3 +- .../__tests__/update-sales-channel.ts | 3 +- .../create-shipping-option.ts | 14 +- .../routes/admin/shipping-options/index.ts | 8 +- .../update-shipping-option.ts | 13 +- packages/medusa/src/helpers/test-request.js | 9 +- .../interfaces/price-selection-strategy.ts | 4 + .../medusa/src/loaders/feature-flags/index.ts | 8 +- .../feature-flags/tax-inclusive-pricing.ts | 10 + .../1659501357661-tax_inclusive_pricing.ts | 46 + packages/medusa/src/models/currency.ts | 8 + .../src/models/custom-shipping-option.ts | 3 + packages/medusa/src/models/line-item.ts | 8 + packages/medusa/src/models/money-amount.ts | 4 +- packages/medusa/src/models/price-list.ts | 9 + packages/medusa/src/models/region.ts | 8 + packages/medusa/src/models/shipping-method.ts | 8 + packages/medusa/src/models/shipping-option.ts | 8 + packages/medusa/src/repositories/currency.ts | 2 +- .../medusa/src/repositories/money-amount.ts | 42 +- .../medusa/src/services/__mocks__/cart.js | 19 + .../medusa/src/services/__mocks__/currency.js | 40 + .../medusa/src/services/__tests__/currency.ts | 60 + .../medusa/src/services/__tests__/discount.js | 31 +- .../src/services/__tests__/line-item.js | 778 ++++++---- .../src/services/__tests__/price-list.js | 19 +- .../medusa/src/services/__tests__/region.ts | 16 +- .../src/services/__tests__/shipping-option.js | 73 + .../medusa/src/services/__tests__/totals.js | 421 ++++- packages/medusa/src/services/cart.ts | 38 +- packages/medusa/src/services/claim.ts | 4 +- packages/medusa/src/services/currency.ts | 132 ++ packages/medusa/src/services/discount.ts | 28 +- packages/medusa/src/services/index.ts | 1 + .../src/services/line-item-adjustment.ts | 33 +- packages/medusa/src/services/line-item.ts | 39 +- packages/medusa/src/services/order.ts | 114 +- packages/medusa/src/services/price-list.ts | 50 +- packages/medusa/src/services/pricing.ts | 95 +- packages/medusa/src/services/region.ts | 46 +- packages/medusa/src/services/return.ts | 4 +- .../medusa/src/services/shipping-option.ts | 51 +- packages/medusa/src/services/store.ts | 4 +- packages/medusa/src/services/totals.ts | 353 +++-- .../strategies/__tests__/price-selection.js | 758 +++++++-- .../strategies/__tests__/tax-calculation.js | 365 +++-- .../medusa/src/strategies/price-selection.ts | 143 +- .../medusa/src/strategies/tax-calculation.ts | 76 +- packages/medusa/src/types/currency.ts | 3 + packages/medusa/src/types/price-list.ts | 9 +- packages/medusa/src/types/pricing.ts | 7 +- packages/medusa/src/types/region.ts | 5 +- packages/medusa/src/types/shipping-options.ts | 2 + .../__tests__/calculate-price-tax-amount.ts | 45 + .../src/utils/calculate-price-tax-amount.ts | 23 + .../src/utils/feature-flag-decorators.ts | 21 +- packages/medusa/src/utils/flag-router.ts | 2 +- packages/medusa/src/utils/index.ts | 1 + 105 files changed, 6788 insertions(+), 1040 deletions(-) create mode 100644 .changeset/calm-baboons-sit.md create mode 100644 .changeset/gorgeous-experts-guess.md create mode 100644 integration-tests/api/__tests__/admin/__snapshots__/currency.js.snap create mode 100644 integration-tests/api/__tests__/admin/currency.js rename integration-tests/api/__tests__/admin/{ => order}/__snapshots__/order.js.snap (100%) create mode 100644 integration-tests/api/__tests__/admin/order/ff-tax-inclusive-pricing.js rename integration-tests/api/__tests__/admin/{ => order}/order.js (99%) create mode 100644 integration-tests/api/__tests__/price-selection/tax-inclusive-prices.js create mode 100644 integration-tests/api/__tests__/returns/ff-tax-inclusive-pricing.js rename integration-tests/api/__tests__/store/{ => cart}/__snapshots__/cart.js.snap (100%) rename integration-tests/api/__tests__/store/{ => cart}/cart.js (98%) create mode 100644 integration-tests/api/__tests__/store/cart/ff-tax-inclusive-pricing.js create mode 100644 integration-tests/api/__tests__/taxes/orders/ff-tax-inclusive-pricing.js rename integration-tests/api/__tests__/taxes/{ => orders}/orders.js (96%) create mode 100644 integration-tests/api/factories/simple-custom-shipping-option-factory.ts create mode 100644 packages/medusa/src/api/routes/admin/currencies/__tests__/list-currencies.ts create mode 100644 packages/medusa/src/api/routes/admin/currencies/__tests__/update-currency.ts create mode 100644 packages/medusa/src/api/routes/admin/currencies/index.ts create mode 100644 packages/medusa/src/api/routes/admin/currencies/list-currencies.ts create mode 100644 packages/medusa/src/api/routes/admin/currencies/update-currency.ts create mode 100644 packages/medusa/src/loaders/feature-flags/tax-inclusive-pricing.ts create mode 100644 packages/medusa/src/migrations/1659501357661-tax_inclusive_pricing.ts create mode 100644 packages/medusa/src/services/__mocks__/currency.js create mode 100644 packages/medusa/src/services/__tests__/currency.ts create mode 100644 packages/medusa/src/services/currency.ts create mode 100644 packages/medusa/src/types/currency.ts create mode 100644 packages/medusa/src/utils/__tests__/calculate-price-tax-amount.ts create mode 100644 packages/medusa/src/utils/calculate-price-tax-amount.ts diff --git a/.changeset/calm-baboons-sit.md b/.changeset/calm-baboons-sit.md new file mode 100644 index 0000000000..cb5c3977fd --- /dev/null +++ b/.changeset/calm-baboons-sit.md @@ -0,0 +1,5 @@ +--- +"@medusajs/medusa": patch +--- + +Pass down the includes_tax to the shipping method from the shipping option diff --git a/.changeset/gorgeous-experts-guess.md b/.changeset/gorgeous-experts-guess.md new file mode 100644 index 0000000000..0cbcd53027 --- /dev/null +++ b/.changeset/gorgeous-experts-guess.md @@ -0,0 +1,5 @@ +--- +"@medusajs/medusa": patch +--- + +Extend models Currency, Region, PriceList, ShippingOption, LineItem, ShippingMethod with tax inclusive flag diff --git a/integration-tests/api/__tests__/admin/__snapshots__/currency.js.snap b/integration-tests/api/__tests__/admin/__snapshots__/currency.js.snap new file mode 100644 index 0000000000..a559d18dd2 --- /dev/null +++ b/integration-tests/api/__tests__/admin/__snapshots__/currency.js.snap @@ -0,0 +1,163 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`/admin/currencies GET /admin/currencies should retrieve the currencies 1`] = ` +Object { + "count": 120, + "currencies": Array [ + Object { + "code": "aed", + "includes_tax": false, + "name": "United Arab Emirates Dirham", + "symbol": "AED", + "symbol_native": "د.إ.‏", + }, + Object { + "code": "afn", + "includes_tax": false, + "name": "Afghan Afghani", + "symbol": "Af", + "symbol_native": "؋", + }, + Object { + "code": "all", + "includes_tax": false, + "name": "Albanian Lek", + "symbol": "ALL", + "symbol_native": "Lek", + }, + Object { + "code": "amd", + "includes_tax": false, + "name": "Armenian Dram", + "symbol": "AMD", + "symbol_native": "դր.", + }, + Object { + "code": "ars", + "includes_tax": false, + "name": "Argentine Peso", + "symbol": "AR$", + "symbol_native": "$", + }, + Object { + "code": "aud", + "includes_tax": false, + "name": "Australian Dollar", + "symbol": "AU$", + "symbol_native": "$", + }, + Object { + "code": "azn", + "includes_tax": false, + "name": "Azerbaijani Manat", + "symbol": "man.", + "symbol_native": "ман.", + }, + Object { + "code": "bam", + "includes_tax": false, + "name": "Bosnia-Herzegovina Convertible Mark", + "symbol": "KM", + "symbol_native": "KM", + }, + Object { + "code": "bdt", + "includes_tax": false, + "name": "Bangladeshi Taka", + "symbol": "Tk", + "symbol_native": "৳", + }, + Object { + "code": "bgn", + "includes_tax": false, + "name": "Bulgarian Lev", + "symbol": "BGN", + "symbol_native": "лв.", + }, + Object { + "code": "bhd", + "includes_tax": false, + "name": "Bahraini Dinar", + "symbol": "BD", + "symbol_native": "د.ب.‏", + }, + Object { + "code": "bif", + "includes_tax": false, + "name": "Burundian Franc", + "symbol": "FBu", + "symbol_native": "FBu", + }, + Object { + "code": "bnd", + "includes_tax": false, + "name": "Brunei Dollar", + "symbol": "BN$", + "symbol_native": "$", + }, + Object { + "code": "bob", + "includes_tax": false, + "name": "Bolivian Boliviano", + "symbol": "Bs", + "symbol_native": "Bs", + }, + Object { + "code": "brl", + "includes_tax": false, + "name": "Brazilian Real", + "symbol": "R$", + "symbol_native": "R$", + }, + Object { + "code": "bwp", + "includes_tax": false, + "name": "Botswanan Pula", + "symbol": "BWP", + "symbol_native": "P", + }, + Object { + "code": "byn", + "includes_tax": false, + "name": "Belarusian Ruble", + "symbol": "Br", + "symbol_native": "руб.", + }, + Object { + "code": "bzd", + "includes_tax": false, + "name": "Belize Dollar", + "symbol": "BZ$", + "symbol_native": "$", + }, + Object { + "code": "cad", + "includes_tax": false, + "name": "Canadian Dollar", + "symbol": "CA$", + "symbol_native": "$", + }, + Object { + "code": "cdf", + "includes_tax": false, + "name": "Congolese Franc", + "symbol": "CDF", + "symbol_native": "FrCD", + }, + ], + "limit": 20, + "offset": 0, +} +`; + +exports[`/admin/currencies POST /admin/currencies/:code should update currency includes_tax 1`] = ` +Object { + "currency": Object { + "code": "aed", + "includes_tax": true, + "name": "United Arab Emirates Dirham", + "symbol": "AED", + "symbol_native": "د.إ.‏", + }, +} +`; diff --git a/integration-tests/api/__tests__/admin/__snapshots__/store.js.snap b/integration-tests/api/__tests__/admin/__snapshots__/store.js.snap index 5155809a0c..9ff29ca7db 100644 --- a/integration-tests/api/__tests__/admin/__snapshots__/store.js.snap +++ b/integration-tests/api/__tests__/admin/__snapshots__/store.js.snap @@ -121,12 +121,7 @@ Object { "symbol_native": "$", }, "default_currency_code": "usd", - "feature_flags": Array [ - Object { - "key": "sales_channels", - "value": false, - }, - ], + "feature_flags": Any, "fulfillment_providers": Array [ Object { "id": "test-ful", diff --git a/integration-tests/api/__tests__/admin/currency.js b/integration-tests/api/__tests__/admin/currency.js new file mode 100644 index 0000000000..5587b9d6a9 --- /dev/null +++ b/integration-tests/api/__tests__/admin/currency.js @@ -0,0 +1,89 @@ +const path = require("path") +const startServerWithEnvironment = + require("../../../helpers/start-server-with-environment").default +const { useApi } = require("../../../helpers/use-api") +const { useDb } = require("../../../helpers/use-db") +const adminSeeder = require("../../helpers/admin-seeder"); + +const adminReqConfig = { + headers: { + Authorization: "Bearer test_token", + }, +} + +jest.setTimeout(30000) +describe("/admin/currencies", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("GET /admin/currencies", function () { + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + } catch (e) { + console.error(e) + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("should retrieve the currencies", async () => { + const api = useApi() + const response = await api.get( + `/admin/currencies?order=code`, + adminReqConfig + ) + + expect(response.data).toMatchSnapshot() + }) + }); + + describe("POST /admin/currencies/:code", function () { + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + } catch (e) { + console.error(e) + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("should update currency includes_tax", async () => { + const api = useApi() + const response = await api.post( + `/admin/currencies/aed`, + { + includes_tax: true + }, + adminReqConfig + ) + + expect(response.data).toMatchSnapshot() + }) + }); +}) diff --git a/integration-tests/api/__tests__/admin/__snapshots__/order.js.snap b/integration-tests/api/__tests__/admin/order/__snapshots__/order.js.snap similarity index 100% rename from integration-tests/api/__tests__/admin/__snapshots__/order.js.snap rename to integration-tests/api/__tests__/admin/order/__snapshots__/order.js.snap diff --git a/integration-tests/api/__tests__/admin/order/ff-tax-inclusive-pricing.js b/integration-tests/api/__tests__/admin/order/ff-tax-inclusive-pricing.js new file mode 100644 index 0000000000..94612d5d2f --- /dev/null +++ b/integration-tests/api/__tests__/admin/order/ff-tax-inclusive-pricing.js @@ -0,0 +1,101 @@ +const path = require("path") + +const startServerWithEnvironment = + require("../../../../helpers/start-server-with-environment").default +const { useApi } = require("../../../../helpers/use-api") +const { useDb } = require("../../../../helpers/use-db") + +const adminSeeder = require("../../../helpers/admin-seeder") + +const { + simpleRegionFactory, + simpleShippingOptionFactory, + simpleOrderFactory +} = require("../../../factories"); + +jest.setTimeout(30000) + +describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] /admin/orders", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("POST /admin/orders/:id/shipping-methods", () => { + let includesTaxShippingOption + let order + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + const shippingAddress = { + id: "test-shipping-address", + first_name: "lebron", + country_code: "us", + } + const region = await simpleRegionFactory(dbConnection, { + id: "test-region" + }) + order = await simpleOrderFactory(dbConnection, { + id: "test-order", + region: region.id, + shipping_address: shippingAddress, + currency_code: "usd", + }) + includesTaxShippingOption = await simpleShippingOptionFactory(dbConnection, { + includes_tax: true, + region_id: region.id + }) + } catch (err) { + console.log(err) + } + }) + + afterEach(async() => { + const db = useDb() + return await db.teardown() + }) + + it("should add a normal shipping method to the order", async () => { + const api = useApi() + + const orderWithShippingMethodRes = await api.post( + `/admin/orders/${order.id}/shipping-methods`, + { + option_id: includesTaxShippingOption.id, + price: 10, + }, + { + headers: { + Authorization: "Bearer test_token", + }, + } + ) + + expect(orderWithShippingMethodRes.status).toEqual(200) + expect(orderWithShippingMethodRes.data.order.shipping_methods) + .toEqual(expect.arrayContaining([ + expect.objectContaining({ + shipping_option_id: includesTaxShippingOption.id, + includes_tax: true, + }) + ])) + }) + }) +}) diff --git a/integration-tests/api/__tests__/admin/order.js b/integration-tests/api/__tests__/admin/order/order.js similarity index 99% rename from integration-tests/api/__tests__/admin/order.js rename to integration-tests/api/__tests__/admin/order/order.js index 52ae78a2f5..cd1b48dd5d 100644 --- a/integration-tests/api/__tests__/admin/order.js +++ b/integration-tests/api/__tests__/admin/order/order.js @@ -7,21 +7,21 @@ const { ShippingMethod, } = require("@medusajs/medusa") -const setupServer = require("../../../helpers/setup-server") -const { useApi } = require("../../../helpers/use-api") -const { initDb, useDb } = require("../../../helpers/use-db") +const setupServer = require("../../../../helpers/setup-server") +const { useApi } = require("../../../../helpers/use-api") +const { initDb, useDb } = require("../../../../helpers/use-db") -const orderSeeder = require("../../helpers/order-seeder") -const swapSeeder = require("../../helpers/swap-seeder") -const adminSeeder = require("../../helpers/admin-seeder") -const claimSeeder = require("../../helpers/claim-seeder") +const orderSeeder = require("../../../helpers/order-seeder") +const swapSeeder = require("../../../helpers/swap-seeder") +const adminSeeder = require("../../../helpers/admin-seeder") +const claimSeeder = require("../../../helpers/claim-seeder") const { expectPostCallToReturn, expectAllPostCallsToReturn, callGet, partial, -} = require("../../helpers/call-helpers") +} = require("../../../helpers/call-helpers") jest.setTimeout(30000) @@ -30,7 +30,7 @@ describe("/admin/orders", () => { let dbConnection beforeAll(async () => { - const cwd = path.resolve(path.join(__dirname, "..", "..")) + const cwd = path.resolve(path.join(__dirname, "..", "..", "..")) dbConnection = await initDb({ cwd }) medusaProcess = await setupServer({ cwd }) }) diff --git a/integration-tests/api/__tests__/admin/price-list.js b/integration-tests/api/__tests__/admin/price-list.js index e065e6222e..e9b953c917 100644 --- a/integration-tests/api/__tests__/admin/price-list.js +++ b/integration-tests/api/__tests__/admin/price-list.js @@ -1,7 +1,7 @@ -const { PriceList, CustomerGroup } = require("@medusajs/medusa") const path = require("path") const setupServer = require("../../../helpers/setup-server") +const startServerWithEnvironment = require("../../../helpers/start-server-with-environment").default const { useApi } = require("../../../helpers/use-api") const { useDb, initDb } = require("../../../helpers/use-db") @@ -9,14 +9,17 @@ const { simpleProductFactory, simplePriceListFactory, } = require("../../factories") -const { - simpleCustomerGroupFactory, -} = require("../../factories/simple-customer-group-factory") const adminSeeder = require("../../helpers/admin-seeder") const customerSeeder = require("../../helpers/customer-seeder") const priceListSeeder = require("../../helpers/price-list-seeder") const productSeeder = require("../../helpers/product-seeder") +const adminReqConfig = { + headers: { + Authorization: "Bearer test_token", + }, +} + jest.setTimeout(30000) describe("/admin/price-lists", () => { @@ -1141,54 +1144,52 @@ describe("/admin/price-lists", () => { expect(response.status).toEqual(200) expect(response.data.count).toEqual(2) expect(response.data.products).toHaveLength(2) - expect(response.data.products).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - id: "test-prod-1", - variants: [ - expect.objectContaining({ - id: "test-variant-1", - prices: [ - expect.objectContaining({ currency_code: "usd", amount: 100 }), - expect.objectContaining({ - currency_code: "usd", - amount: 150, - price_list_id: "test-list", - }), - ], - }), - expect.objectContaining({ - id: "test-variant-2", - prices: [ - expect.objectContaining({ currency_code: "usd", amount: 100 }), - ], - }), - ], - }), - expect.objectContaining({ - id: "test-prod-2", - variants: [ - expect.objectContaining({ - id: "test-variant-3", - prices: [ - expect.objectContaining({ currency_code: "usd", amount: 100 }), - ], - }), - expect.objectContaining({ - id: "test-variant-4", - prices: [ - expect.objectContaining({ currency_code: "usd", amount: 100 }), - expect.objectContaining({ - currency_code: "usd", - amount: 150, - price_list_id: "test-list", - }), - ], - }), - ], - }), - ]) - ) + expect(response.data.products).toEqual([ + expect.objectContaining({ + id: "test-prod-1", + variants: expect.arrayContaining([ + expect.objectContaining({ + id: "test-variant-1", + prices: expect.arrayContaining([ + expect.objectContaining({ currency_code: "usd", amount: 100 }), + expect.objectContaining({ + currency_code: "usd", + amount: 150, + price_list_id: "test-list", + }), + ],) + }), + expect.objectContaining({ + id: "test-variant-2", + prices: expect.arrayContaining([ + expect.objectContaining({ currency_code: "usd", amount: 100 }), + ]), + }), + ]), + }), + expect.objectContaining({ + id: "test-prod-2", + variants: expect.arrayContaining([ + expect.objectContaining({ + id: "test-variant-3", + prices: expect.arrayContaining([ + expect.objectContaining({ currency_code: "usd", amount: 100 }), + ]), + }), + expect.objectContaining({ + id: "test-variant-4", + prices: expect.arrayContaining([ + expect.objectContaining({ currency_code: "usd", amount: 100 }), + expect.objectContaining({ + currency_code: "usd", + amount: 150, + price_list_id: "test-list", + }), + ]), + }), + ]), + }), + ]) }) it("lists only product 2", async () => { @@ -1387,3 +1388,112 @@ describe("/admin/price-lists", () => { }) }) }) + +describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] /admin/price-lists", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("POST /admin/price-list", () => { + const priceListIncludesTaxId = "price-list-1-includes-tax" + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + await customerSeeder(dbConnection) + await productSeeder(dbConnection) + await simplePriceListFactory(dbConnection, { + id: priceListIncludesTaxId, + }) + } catch (err) { + console.log(err) + throw err + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("should creates a price list that includes tax", async () => { + const api = useApi() + + const payload = { + name: "VIP Summer sale", + description: "Summer sale for VIP customers. 25% off selected items.", + type: "sale", + status: "active", + starts_at: "2022-07-01T00:00:00.000Z", + ends_at: "2022-07-31T00:00:00.000Z", + customer_groups: [ + { + id: "customer-group-1", + }, + ], + prices: [ + { + amount: 85, + currency_code: "usd", + variant_id: "test-variant", + }, + ], + includes_tax: true, + } + + const response = await api + .post("/admin/price-lists", payload, adminReqConfig) + .catch((err) => { + console.warn(err.response.data) + }) + + expect(response.status).toEqual(200) + expect(response.data.price_list).toEqual( + expect.objectContaining({ + id: expect.any(String), + includes_tax: true, + }) + ) + }) + + it("should update a price list that include_tax", async () => { + const api = useApi() + + let response = await api + .get(`/admin/price-lists/${priceListIncludesTaxId}`, adminReqConfig) + .catch((err) => { + console.log(err) + }) + + expect(response.data.price_list.includes_tax).toBe(false) + + response = await api + .post( + `/admin/price-lists/${priceListIncludesTaxId}`, + { includes_tax: true, }, + adminReqConfig + ).catch((err) => { + console.log(err) + }) + + expect(response.data.price_list.includes_tax).toBe(true) + }) + }) +}) diff --git a/integration-tests/api/__tests__/admin/region.js b/integration-tests/api/__tests__/admin/region.js index f6fd86e5bc..a700381fb0 100644 --- a/integration-tests/api/__tests__/admin/region.js +++ b/integration-tests/api/__tests__/admin/region.js @@ -2,9 +2,17 @@ const path = require("path") const { Region } = require("@medusajs/medusa") const setupServer = require("../../../helpers/setup-server") +const startServerWithEnvironment = require("../../../helpers/start-server-with-environment").default const { useApi } = require("../../../helpers/use-api") const { initDb, useDb } = require("../../../helpers/use-db") const adminSeeder = require("../../helpers/admin-seeder") +const { simpleRegionFactory } = require("../../factories"); + +const adminReqConfig = { + headers: { + Authorization: "Bearer test_token", + }, +} jest.setTimeout(30000) @@ -286,3 +294,99 @@ describe("/admin/regions", () => { }) }) }) + +describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] /admin/regions", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("POST /admin/regions/:id", () => { + const region1TaxInclusiveId = "region-1-tax-inclusive" + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + await simpleRegionFactory(dbConnection, { + id: region1TaxInclusiveId, + countries: ["fr"], + }) + } catch (err) { + console.log(err) + throw err + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("should allow to create a region that includes tax", async function () { + const api = useApi() + + const payload = { + name: "region-including-taxes", + currency_code: "usd", + tax_rate: 0, + payment_providers: ["test-pay"], + fulfillment_providers: ["test-ful"], + countries: ["us"], + includes_tax: true, + } + + let response = await api + .post(`/admin/regions`, payload, adminReqConfig) + .catch((err) => { + console.log(err) + }) + + expect(response.data.region).toEqual( + expect.objectContaining({ + id: expect.any(String), + includes_tax: true, + name: "region-including-taxes", + }) + ) + }); + + it("should allow to update a region that includes tax", async function () { + const api = useApi() + let response = await api + .get(`/admin/regions/${region1TaxInclusiveId}`, adminReqConfig) + .catch((err) => { + console.log(err) + }) + + expect(response.data.region.includes_tax).toBe(false) + + response = await api.post( + `/admin/regions/${region1TaxInclusiveId}`, + { + includes_tax: true, + }, + adminReqConfig, + ).catch((err) => { + console.log(err) + }) + + expect(response.data.region.includes_tax).toBe(true) + }); + }) +}) \ No newline at end of file diff --git a/integration-tests/api/__tests__/admin/shipping-options.js b/integration-tests/api/__tests__/admin/shipping-options.js index 9872dc8582..92eaf630f6 100644 --- a/integration-tests/api/__tests__/admin/shipping-options.js +++ b/integration-tests/api/__tests__/admin/shipping-options.js @@ -1,16 +1,21 @@ const path = require("path") const { - Region, ShippingProfile, - ShippingOption, - ShippingOptionRequirement, } = require("@medusajs/medusa") const setupServer = require("../../../helpers/setup-server") +const startServerWithEnvironment = require("../../../helpers/start-server-with-environment").default const { useApi } = require("../../../helpers/use-api") const { initDb, useDb } = require("../../../helpers/use-db") const adminSeeder = require("../../helpers/admin-seeder") const shippingOptionSeeder = require("../../helpers/shipping-option-seeder") +const { simpleShippingOptionFactory, simpleRegionFactory } = require("../../factories") + +const adminReqConfig = { + headers: { + Authorization: "Bearer test_token", + }, +} jest.setTimeout(30000) @@ -460,3 +465,120 @@ describe("/admin/shipping-options", () => { }) }) }) + +describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] /admin/shipping-options", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("POST /admin/shipping-options", () => { + const shippingOptionIncludesTaxId = "shipping-option-1-includes-tax" + let region + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + region = await simpleRegionFactory(dbConnection, { + id: "region", + countries: ["fr"], + }) + await simpleShippingOptionFactory(dbConnection, { + id: shippingOptionIncludesTaxId, + region_id: region.id, + }) + } catch (err) { + console.log(err) + throw err + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("should creates a shipping option that includes tax", async () => { + const api = useApi() + + const defaultProfile = await dbConnection.manager.findOne(ShippingProfile, { + type: "default", + }) + + const payload = { + name: "Test option", + amount: 100, + price_type: "flat_rate", + region_id: region.id, + provider_id: "test-ful", + data: {}, + profile_id: defaultProfile.id, + includes_tax: true, + } + + const response = await api + .post("/admin/shipping-options", payload, adminReqConfig) + .catch((err) => { + console.log(err) + }) + + expect(response.status).toEqual(200) + expect(response.data.shipping_option).toEqual( + expect.objectContaining({ + id: expect.any(String), + includes_tax: true, + }) + ) + }) + + it("should update a shipping option that include_tax", async () => { + const api = useApi() + + let response = await api + .get(`/admin/shipping-options/${shippingOptionIncludesTaxId}`, adminReqConfig) + .catch((err) => { + console.log(err) + }) + + expect(response.data.shipping_option.includes_tax).toBe(false) + + const payload = { + requirements: [ + { + type: "min_subtotal", + amount: 1, + }, + { + type: "max_subtotal", + amount: 2, + }, + ], + includes_tax: true, + } + + response = await api + .post(`/admin/shipping-options/${shippingOptionIncludesTaxId}`, payload, adminReqConfig) + .catch((err) => { + console.log(err) + }) + + expect(response.data.shipping_option.includes_tax).toBe(true) + }) + }) +}) diff --git a/integration-tests/api/__tests__/admin/store.js b/integration-tests/api/__tests__/admin/store.js index 65097c68d6..2b062d33d6 100644 --- a/integration-tests/api/__tests__/admin/store.js +++ b/integration-tests/api/__tests__/admin/store.js @@ -52,12 +52,7 @@ describe("/admin/store", () => { code: "usd", }, ], - feature_flags: [ - { - key: "sales_channels", - value: false, - }, - ], + feature_flags: expect.any(Array), default_currency_code: "usd", created_at: expect.any(String), updated_at: expect.any(String), @@ -133,15 +128,17 @@ describe("/admin/store", () => { it("successfully updates default currency code", async () => { const api = useApi() - const response = await api.post( - "/admin/store", - { - default_currency_code: "dkk", - }, - { - headers: { Authorization: "Bearer test_token " }, - } - ) + const response = await api + .post( + "/admin/store", + { + default_currency_code: "dkk", + }, + { + headers: { Authorization: "Bearer test_token " }, + } + ) + .catch((err) => console.log(err)) expect(response.status).toEqual(200) expect(response.data.store).toMatchSnapshot({ diff --git a/integration-tests/api/__tests__/admin/swaps.js b/integration-tests/api/__tests__/admin/swaps.js index 2fa385b782..be1a8ec594 100644 --- a/integration-tests/api/__tests__/admin/swaps.js +++ b/integration-tests/api/__tests__/admin/swaps.js @@ -250,12 +250,10 @@ describe("/admin/swaps", () => { data: {}, }) await api.post("/store/carts/cart-test/payment-sessions") - const TEST = await api.post("/store/carts/cart-test/payment-session", { + await api.post("/store/carts/cart-test/payment-session", { provider_id: "test-pay", }) - console.log("Testing, ", TEST.data.cart.items[0]) - // ********* COMPLETE CART ********* const completedOrder = await api.post("/store/carts/cart-test/complete") diff --git a/integration-tests/api/__tests__/line-item-adjustments/index.js b/integration-tests/api/__tests__/line-item-adjustments/index.js index c48cffbef5..1a65d84316 100644 --- a/integration-tests/api/__tests__/line-item-adjustments/index.js +++ b/integration-tests/api/__tests__/line-item-adjustments/index.js @@ -33,9 +33,9 @@ describe("Line Item Adjustments", () => { }) describe("Tests database constraints", () => { - let cart, - discount, - lineItemId = "line-test" + let cart + let discount + const lineItemId = "line-test" beforeEach(async () => { await cartSeeder(dbConnection) discount = await simpleDiscountFactory(dbConnection, { @@ -113,7 +113,7 @@ describe("Line Item Adjustments", () => { }) } - expect(createLineItemWithAdjustment()).resolves.toEqual( + await expect(createLineItemWithAdjustment()).resolves.toEqual( expect.anything() ) }) @@ -131,7 +131,7 @@ describe("Line Item Adjustments", () => { }) } - expect(createAdjustmentNullDiscount()).resolves.toEqual( + await expect(createAdjustmentNullDiscount()).resolves.toEqual( expect.anything() ) }) @@ -155,7 +155,8 @@ describe("Line Item Adjustments", () => { discount_id: null, }) } - expect(createAdjustmentsNullDiscount()).resolves.toEqual( + + await expect(createAdjustmentsNullDiscount()).resolves.toEqual( expect.anything() ) }) @@ -184,7 +185,7 @@ describe("Line Item Adjustments", () => { }) } - expect(createAdjustment()).resolves.toEqual(expect.anything()) + await expect(createAdjustment()).resolves.toEqual(expect.anything()) }) }) @@ -199,7 +200,7 @@ describe("Line Item Adjustments", () => { discount_id: discount.id, }) - expect(createDuplicateAdjustment()).rejects.toEqual( + await expect(createDuplicateAdjustment()).rejects.toEqual( expect.objectContaining({ code: "23505" }) ) }) diff --git a/integration-tests/api/__tests__/price-selection/tax-inclusive-prices.js b/integration-tests/api/__tests__/price-selection/tax-inclusive-prices.js new file mode 100644 index 0000000000..7c83bf7f9c --- /dev/null +++ b/integration-tests/api/__tests__/price-selection/tax-inclusive-prices.js @@ -0,0 +1,1383 @@ +const { Currency, Region } = require("@medusajs/medusa") +const path = require("path") + +const startServerWithEnvironment = + require("../../../helpers/start-server-with-environment").default +const { useApi } = require("../../../helpers/use-api") +const { useDb } = require("../../../helpers/use-db") +const { + simpleProductFactory, + simpleRegionFactory, + simplePriceListFactory, + simpleProductTaxRateFactory, + simpleShippingOptionFactory, + simpleShippingTaxRateFactory, +} = require("../../factories") + +const adminSeeder = require("../../helpers/admin-seeder") +const promotionsSeeder = require("../../helpers/price-selection-seeder") + +jest.setTimeout(30000) + +describe("tax inclusive prices", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..")) + const [process, conn] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = conn // await initDb({ cwd }) + medusaProcess = process // await setupServer({ cwd }) + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("Create line item with tax inclusive pricing", () => { + let region + let productId + + beforeEach(async () => { + region = await simpleRegionFactory(dbConnection, { + includes_tax: true, + }) + + const product = await simpleProductFactory(dbConnection, { + variants: [{ id: "var_1", prices: [{ currency: "usd", amount: 100 }] }], + }) + productId = product.id + + await simpleProductTaxRateFactory(dbConnection, { + product_id: product.id, + rate: { + region_id: region.id, + rate: 25, + }, + }) + + await simplePriceListFactory(dbConnection, { + status: "active", + type: "sale", + prices: [ + { + variant_id: "var_1", + amount: 110, + currency_code: "usd", + region_id: region.id, + }, + ], + }) + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("creates a line item with tax inclusive pricing when variant is tax inclusive", async () => { + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${region.id}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const response = await api.post("/store/carts", { + region_id: region.id, + }) + + const lineItemResp = await api.post( + `/store/carts/${response.data.cart.id}/line-items`, + { + variant_id: variant.id, + quantity: 2, + } + ) + + expect(lineItemResp.data.cart.items).toEqual([ + expect.objectContaining({ includes_tax: true }), + ]) + }) + + it("creates a line item without tax inclusive pricing when variant is not tax inclusive", async () => { + await dbConnection.manager.save(Region, { + ...region, + includes_tax: false, + }) + + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${region.id}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const response = await api.post("/store/carts", { + region_id: region.id, + }) + + const lineItemResp = await api.post( + `/store/carts/${response.data.cart.id}/line-items`, + { + variant_id: variant.id, + quantity: 2, + } + ) + + expect(lineItemResp.data.cart.items).toEqual([ + expect.objectContaining({ includes_tax: false }), + ]) + }) + + it("creates a line item without tax inclusive pricing with mixed variant pricing", async () => { + await dbConnection.manager.save(Region, { + ...region, + includes_tax: false, + }) + + await simplePriceListFactory(dbConnection, { + status: "active", + type: "sale", + includes_tax: true, + prices: [ + { + variant_id: "var_1", + amount: 130, + currency_code: "usd", + region_id: region.id, + }, + ], + }) + + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${region.id}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const response = await api.post("/store/carts", { + region_id: region.id, + }) + + const lineItemResp = await api.post( + `/store/carts/${response.data.cart.id}/line-items`, + { + variant_id: variant.id, + quantity: 2, + } + ) + + expect(lineItemResp.data.cart.items).toEqual([ + expect.objectContaining({ includes_tax: false }), + ]) + }) + + it("creates a line item with tax inclusive pricing with mixed variant pricing", async () => { + await dbConnection.manager.save(Region, { + ...region, + includes_tax: false, + }) + + await simplePriceListFactory(dbConnection, { + status: "active", + type: "sale", + includes_tax: true, + prices: [ + { + variant_id: "var_1", + amount: 110, + currency_code: "usd", + region_id: region.id, + }, + ], + }) + + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${region.id}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const response = await api.post("/store/carts", { + region_id: region.id, + }) + + const lineItemResp = await api.post( + `/store/carts/${response.data.cart.id}/line-items`, + { + variant_id: variant.id, + quantity: 2, + } + ) + + expect(lineItemResp.data.cart.items).toEqual([ + expect.objectContaining({ includes_tax: true }), + ]) + }) + }) + + describe("region tax inclusive", () => { + describe("getting product with mixed prices preferring tax inclusive prices", () => { + let regionId + let productId + + beforeEach(async () => { + const region = await simpleRegionFactory(dbConnection, { + includes_tax: true, + }) + + const product = await simpleProductFactory(dbConnection, { + variants: [ + { id: "var_1", prices: [{ currency: "usd", amount: 100 }] }, + ], + }) + + regionId = region.id + productId = product.id + + await simpleProductTaxRateFactory(dbConnection, { + product_id: product.id, + rate: { + region_id: region.id, + rate: 25, + }, + }) + + await simplePriceListFactory(dbConnection, { + status: "active", + type: "sale", + prices: [ + { + variant_id: "var_1", + amount: 110, + currency_code: "usd", + region_id: region.id, + }, + ], + }) + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("test", async () => { + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${regionId}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant).toEqual( + expect.objectContaining({ + original_price: 100, + calculated_price: 110, + calculated_price_type: "sale", + original_price_includes_tax: false, + calculated_price_includes_tax: true, + calculated_price_incl_tax: 110, + calculated_tax: 22, + original_price_incl_tax: 125, + original_tax: 25, + }) + ) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + amount: 100, + currency_code: "usd", + price_list_id: null, + }), + expect.objectContaining({ + amount: 110, + currency_code: "usd", + price_list_id: expect.any(String), + }), + ]) + ) + }) + }) + + describe("getting product with mixed prices preferring tax exclusive prices", () => { + let regionId + let productId + + beforeEach(async () => { + const region = await simpleRegionFactory(dbConnection, { + includes_tax: true, + }) + + const product = await simpleProductFactory(dbConnection, { + variants: [ + { id: "var_1", prices: [{ currency: "usd", amount: 100 }] }, + ], + }) + + regionId = region.id + productId = product.id + + await simpleProductTaxRateFactory(dbConnection, { + product_id: product.id, + rate: { + region_id: region.id, + rate: 25, + }, + }) + + await simplePriceListFactory(dbConnection, { + status: "active", + prices: [ + { + variant_id: "var_1", + amount: 130, + currency_code: "usd", + region_id: region.id, + }, + ], + }) + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("test", async () => { + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${regionId}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant).toEqual( + expect.objectContaining({ + original_price: 100, + calculated_price: 100, + calculated_price_type: "default", + original_price_includes_tax: false, + calculated_price_includes_tax: false, + original_tax: 25, + calculated_tax: 25, + original_price_incl_tax: 125, + calculated_price_incl_tax: 125, + }) + ) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + amount: 100, + currency_code: "usd", + price_list_id: null, + }), + expect.objectContaining({ + amount: 130, + currency_code: "usd", + price_list_id: expect.any(String), + }), + ]) + ) + }) + }) + }) + + describe("currency tax inclusive", () => { + describe("getting product with mixed prices preferring tax inclusive prices", () => { + let regionId + let productId + + beforeEach(async () => { + const manager = dbConnection.manager + + const currency = await manager.findOne(Currency, { + where: { code: "usd" }, + }) + + currency.includes_tax = true + + await manager.save(currency) + + const region = await simpleRegionFactory(dbConnection, {}) + + const product = await simpleProductFactory(dbConnection, { + variants: [ + { id: "var_1", prices: [{ currency: "usd", amount: 110 }] }, + ], + }) + + regionId = region.id + productId = product.id + + await simpleProductTaxRateFactory(dbConnection, { + product_id: product.id, + rate: { + region_id: region.id, + rate: 25, + }, + }) + + await simplePriceListFactory(dbConnection, { + status: "active", + type: "sale", + prices: [ + { + variant_id: "var_1", + amount: 100, + currency_code: "usd", + region_id: region.id, + }, + ], + }) + }) + + afterEach(async () => { + const db = useDb() + + const currency = await dbConnection.manager.findOne(Currency, { + where: { code: "usd" }, + }) + + currency.includes_tax = false + + await dbConnection.manager.save(currency) + + await db.teardown() + }) + + it("test", async () => { + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${regionId}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant).toEqual( + expect.objectContaining({ + original_price: 110, + calculated_price: 100, + calculated_price_type: "sale", + original_price_includes_tax: true, + calculated_price_includes_tax: true, + calculated_price_incl_tax: 100, + calculated_tax: 20, + original_price_incl_tax: 110, + original_tax: 22, + }) + ) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + amount: 110, + currency_code: "usd", + price_list_id: null, + }), + expect.objectContaining({ + amount: 100, + currency_code: "usd", + price_list_id: expect.any(String), + }), + ]) + ) + }) + }) + }) + + describe("pricelist tax inclusive", () => { + describe("getting product with mixed prices preferring tax inclusive prices", () => { + let regionId + let productId + + beforeEach(async () => { + const region = await simpleRegionFactory(dbConnection, {}) + + const product = await simpleProductFactory(dbConnection, { + variants: [ + { id: "var_1", prices: [{ currency: "usd", amount: 100 }] }, + ], + }) + + regionId = region.id + productId = product.id + + await simpleProductTaxRateFactory(dbConnection, { + product_id: product.id, + rate: { + region_id: region.id, + rate: 25, + }, + }) + + await simplePriceListFactory(dbConnection, { + status: "active", + type: "sale", + includes_tax: true, + prices: [ + { + variant_id: "var_1", + amount: 110, + currency_code: "usd", + region_id: region.id, + }, + ], + }) + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("test", async () => { + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${regionId}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant).toEqual( + expect.objectContaining({ + original_price: 100, + calculated_price: 110, + calculated_price_type: "sale", + original_price_includes_tax: false, + calculated_price_includes_tax: true, + calculated_price_incl_tax: 110, + calculated_tax: 22, + original_price_incl_tax: 125, + original_tax: 25, + }) + ) + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + amount: 100, + currency_code: "usd", + price_list_id: null, + }), + expect.objectContaining({ + amount: 110, + currency_code: "usd", + price_list_id: expect.any(String), + }), + ]) + ) + }) + }) + + describe("getting product with mixed prices preferring tax exclusive prices", () => { + let regionId + let productId + + beforeEach(async () => { + const region = await simpleRegionFactory(dbConnection, {}) + + const product = await simpleProductFactory(dbConnection, { + variants: [ + { id: "var_1", prices: [{ currency: "usd", amount: 100 }] }, + ], + }) + + regionId = region.id + productId = product.id + + await simpleProductTaxRateFactory(dbConnection, { + product_id: product.id, + rate: { + region_id: region.id, + rate: 25, + }, + }) + + await simplePriceListFactory(dbConnection, { + status: "active", + includes_tax: true, + prices: [ + { + variant_id: "var_1", + amount: 130, + currency_code: "usd", + region_id: region.id, + }, + ], + }) + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("test", async () => { + const api = useApi() + const res = await api + .get(`/store/products/${productId}?region_id=${regionId}`) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant).toEqual( + expect.objectContaining({ + original_price: 100, + calculated_price: 100, + calculated_price_type: "default", + original_price_includes_tax: false, + calculated_price_includes_tax: false, + original_tax: 25, + calculated_tax: 25, + original_price_incl_tax: 125, + calculated_price_incl_tax: 125, + }) + ) + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + amount: 100, + currency_code: "usd", + price_list_id: null, + }), + expect.objectContaining({ + amount: 130, + currency_code: "usd", + price_list_id: expect.any(String), + }), + ]) + ) + }) + }) + }) + + describe("tax inclusive shipping options", () => { + beforeAll(async () => { + await adminSeeder(dbConnection) + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("admin gets correct shipping prices", async () => { + const api = useApi() + + const region = await simpleRegionFactory(dbConnection, { + tax_rate: 25, + }) + const so = await simpleShippingOptionFactory(dbConnection, { + region_id: region.id, + price: 100, + includes_tax: true, + }) + await simpleShippingTaxRateFactory(dbConnection, { + shipping_option_id: so.id, + rate: { + region_id: region.id, + rate: 10, + }, + }) + + const res = await api.get(`/admin/shipping-options`, { + headers: { + Authorization: `Bearer test_token`, + }, + }) + + expect(res.data.shipping_options).toHaveLength(1) + expect(res.data.shipping_options).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: so.id, + amount: 100, + price_incl_tax: 100, + includes_tax: true, + tax_amount: 9, + }), + ]) + ) + }) + + it("gets correct shipping prices", async () => { + const api = useApi() + + const region = await simpleRegionFactory(dbConnection, { + tax_rate: 25, + }) + const so = await simpleShippingOptionFactory(dbConnection, { + region_id: region.id, + includes_tax: true, + price: 100, + }) + await simpleShippingTaxRateFactory(dbConnection, { + shipping_option_id: so.id, + rate: { + region_id: region.id, + rate: 10, + }, + }) + + const res = await api.get( + `/store/shipping-options?region_id=${region.id}` + ) + + expect(res.data.shipping_options).toHaveLength(1) + expect(res.data.shipping_options).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: so.id, + amount: 100, + price_incl_tax: 100, + includes_tax: true, + tax_amount: 9, + }), + ]) + ) + }) + }) + + describe("Money amount", () => { + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + await promotionsSeeder(dbConnection) + } catch (err) { + console.log(err) + throw err + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("calculated_price contains lowest price", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product?region_id=test-region") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const lowestPrice = variant.prices.reduce( + (prev, curr) => (curr.amount < prev ? curr.amount : prev), + Infinity + ) + + expect(variant.calculated_price).toEqual(lowestPrice) + + expect(variant).toEqual( + expect.objectContaining({ original_price: 120, calculated_price: 110 }) + ) + }) + + it("returns no money amounts belonging to customer groups without login", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product?cart_id=test-cart") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant.prices.length).toEqual(2) + variant.prices.forEach((price) => { + if (price.price_list) { + expect(price.price_list.customer_groups).toEqual(undefined) + } else { + expect(price.price_list).toEqual(null) + } + }) + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price1", + region_id: "test-region", + currency_code: "usd", + amount: 120, + }), + expect.objectContaining({ + id: "test-price3", + region_id: "test-region", + currency_code: "usd", + price_list_id: "pl", + amount: 110, + }), + ]) + ) + }) + + it("sets default price as original price", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product?cart_id=test-cart") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant.original_price).toEqual( + variant.prices.find((p) => p.price_list_id === null).amount + ) + }) + + it("gets prices for currency if no region prices exist", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product?cart_id=test-cart-2") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant.original_price).toEqual( + variant.prices.find((p) => p.price_list_id === null).amount + ) + expect(variant.prices.length).toEqual(2) + variant.prices.forEach((price) => { + if (price.price_list) { + expect(price.price_list.customer_groups).toEqual(undefined) + } else { + expect(price.price_list).toEqual(null) + } + }) + variant.prices.forEach((price) => { + expect(price.region_id).toEqual("test-region") + expect(price.currency_code).toEqual("usd") + }) + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price1", + region_id: "test-region", + currency_code: "usd", + amount: 120, + }), + expect.objectContaining({ + id: "test-price3", + region_id: "test-region", + currency_code: "usd", + price_list_id: "pl", + amount: 110, + }), + ]) + ) + }) + + it("gets prices for cart region for multi region product", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product-multi-region?cart_id=test-cart-1") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant).toEqual( + expect.objectContaining({ original_price: 130, calculated_price: 110 }) + ) + expect(variant.original_price).toEqual( + variant.prices.find((p) => p.price_list_id === null).amount + ) + + expect(variant.prices.length).toEqual(2) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price1-region-2", + region_id: "test-region-2", + currency_code: "dkk", + amount: 130, + }), + expect.objectContaining({ + id: "test-price3-region-2", + region_id: "test-region-2", + currency_code: "dkk", + price_list_id: "pl", + amount: 110, + }), + ]) + ) + }) + + it("gets prices for multi region product", async () => { + const api = useApi() + const res = await api + .get( + "/store/products/test-product-multi-region?region_id=test-region-2" + ) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant).toEqual( + expect.objectContaining({ original_price: 130, calculated_price: 110 }) + ) + expect(variant.original_price).toEqual( + variant.prices.find((p) => p.price_list_id === null).amount + ) + + expect(variant.prices.length).toEqual(2) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price1-region-2", + region_id: "test-region-2", + currency_code: "dkk", + amount: 130, + }), + expect.objectContaining({ + id: "test-price3-region-2", + region_id: "test-region-2", + currency_code: "dkk", + price_list_id: "pl", + amount: 110, + }), + ]) + ) + }) + + it("gets prices for multi currency product", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product-multi-region?currency_code=dkk") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant).toEqual( + expect.objectContaining({ original_price: 130, calculated_price: 110 }) + ) + expect(variant.original_price).toEqual( + variant.prices.find((p) => p.price_list_id === null).amount + ) + + expect(variant.prices.length).toEqual(2) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price1-region-2", + region_id: "test-region-2", + currency_code: "dkk", + amount: 130, + }), + expect.objectContaining({ + id: "test-price3-region-2", + region_id: "test-region-2", + currency_code: "dkk", + price_list_id: "pl", + amount: 110, + }), + ]) + ) + }) + + it("gets moneyamounts only with valid date interval", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product-sale?cart_id=test-cart") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const date = new Date() + + expect(variant.prices.length).toEqual(2) + variant.prices.forEach((price) => { + if (price.starts_at) { + expect(new Date(price.starts_at).getTime()).toBeLessThan( + date.getTime() + ) + } + if (price.ends_at) { + expect(new Date(price.ends_at).getTime()).toBeGreaterThan( + date.getTime() + ) + } + }) + }) + + it("gets moneyamounts with valid date intervals and finds lowest price with overlapping intervals", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product-sale-overlap?cart_id=test-cart") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const date = new Date() + + expect(variant).toEqual( + expect.objectContaining({ + original_price: 150, + calculated_price: 120, + }) + ) + expect(variant.prices.length).toEqual(3) + variant.prices.forEach((price) => { + if (price.starts_at) { + expect(new Date(price.starts_at).getTime()).toBeLessThan( + date.getTime() + ) + } + if (price.ends_at) { + expect(new Date(price.ends_at).getTime()).toBeGreaterThan( + date.getTime() + ) + } + }) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price-sale-overlap-1", + region_id: "test-region", + currency_code: "usd", + amount: 140, + price_list_id: "pl_current_1", + }), + expect.objectContaining({ + id: "test-price1-sale-overlap", + region_id: "test-region", + currency_code: "usd", + amount: 120, + price_list_id: "pl_current", + }), + expect.objectContaining({ + id: "test-price2-sale-overlap-default", + region_id: "test-region", + currency_code: "usd", + amount: 150, + }), + ]) + ) + }) + + it("gets all prices with varying quantity limits with no quantity", async () => { + const api = useApi() + const res = await api + .get("/store/products/test-product-quantity?cart_id=test-cart") + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant.prices.length).toEqual(5) + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price-quantity", + region_id: "test-region", + currency_code: "usd", + amount: 100, + price_list_id: "pl", + min_quantity: 10, + max_quantity: 100, + }), + expect.objectContaining({ + id: "test-price1-quantity", + region_id: "test-region", + currency_code: "usd", + amount: 120, + price_list_id: "pl", + min_quantity: 101, + max_quantity: 1000, + }), + expect.objectContaining({ + id: "test-price2-quantity", + region_id: "test-region", + currency_code: "usd", + amount: 130, + price_list_id: "pl", + max_quantity: 9, + }), + expect.objectContaining({ + id: "test-price3-quantity-now", + region_id: "test-region", + currency_code: "usd", + amount: 140, + price_list_id: "pl_current", + min_quantity: 101, + max_quantity: 1000, + }), + expect.objectContaining({ + id: "test-price3-quantity-default", + region_id: "test-region", + currency_code: "usd", + amount: 150, + }), + ]) + ) + + expect(variant.calculated_price).toEqual(130) + expect(variant.original_price).toEqual(150) + expect(variant.original_price).toEqual( + variant.prices.find((p) => p.price_list_id === null).amount + ) + }) + + it("fetches product with groups in money amounts with login", async () => { + const api = useApi() + + // customer with customer-group 5 + const authResponse = await api.post("/store/auth", { + email: "test5@email.com", + password: "test", + }) + + const [authCookie] = authResponse.headers["set-cookie"][0].split(";") + + const res = await api + .get("/store/products/test-product?cart_id=test-cart", { + headers: { + Cookie: authCookie, + }, + }) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant.prices.length).toEqual(3) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price1", + region_id: "test-region", + currency_code: "usd", + amount: 120, + }), + expect.objectContaining({ + id: "test-price3", + region_id: "test-region", + currency_code: "usd", + price_list_id: "pl", + amount: 110, + }), + expect.objectContaining({ + id: "test-price", + region_id: "test-region", + currency_code: "usd", + amount: 100, + price_list: expect.objectContaining({}), + }), + ]) + ) + }) + + it("fetches product with groups and quantities in money amounts with login", async () => { + const api = useApi() + + // customer with customer-group 5 + const authResponse = await api.post("/store/auth", { + email: "test5@email.com", + password: "test", + }) + + const [authCookie] = authResponse.headers["set-cookie"][0].split(";") + + const res = await api + .get( + "/store/products/test-product-quantity-customer?cart_id=test-cart", + { + headers: { + Cookie: authCookie, + }, + } + ) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + expect(variant.prices.length).toEqual(6) + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price-quantity-customer", + region_id: "test-region", + currency_code: "usd", + amount: 100, + min_quantity: 10, + max_quantity: 100, + }), + expect.objectContaining({ + id: "test-price1-quantity-customer", + region_id: "test-region", + currency_code: "usd", + amount: 120, + min_quantity: 101, + max_quantity: 1000, + }), + expect.objectContaining({ + id: "test-price2-quantity-customer", + region_id: "test-region", + currency_code: "usd", + amount: 130, + max_quantity: 9, + }), + expect.objectContaining({ + id: "test-price2-quantity-customer-group", + region_id: "test-region", + currency_code: "usd", + amount: 100, + max_quantity: 9, + price_list: expect.objectContaining({}), + }), + expect.objectContaining({ + id: "test-price3-quantity-customer-now", + region_id: "test-region", + currency_code: "usd", + amount: 140, + min_quantity: 101, + max_quantity: 1000, + }), + expect.objectContaining({ + id: "test-price3-quantity-customer-default", + region_id: "test-region", + currency_code: "usd", + amount: 150, + price_list_id: null, + }), + ]) + ) + + expect(variant.calculated_price).toEqual(100) + expect(variant.original_price).toEqual(150) + expect(variant.original_price).toEqual( + variant.prices.find((p) => p.price_list_id === null).amount + ) + }) + + it("gets moneyamounts only with valid date interval for customer", async () => { + const api = useApi() + + // customer with customer-group 5 + const authResponse = await api.post("/store/auth", { + email: "test5@email.com", + password: "test", + }) + + const [authCookie] = authResponse.headers["set-cookie"][0].split(";") + + const res = await api + .get("/store/products/test-product-sale-customer?cart_id=test-cart", { + headers: { + Cookie: authCookie, + }, + }) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const date = new Date() + + expect(variant).toEqual( + expect.objectContaining({ + original_price: 150, + calculated_price: 100, + }) + ) + + expect(variant.prices.length).toEqual(2) + variant.prices.forEach((price) => { + if (price.starts_at) { + expect(new Date(price.starts_at).getTime()).toBeLessThan( + date.getTime() + ) + } + if (price.ends_at) { + expect(new Date(price.ends_at).getTime()).toBeGreaterThan( + date.getTime() + ) + } + }) + }) + + it("gets moneyamounts only with valid date interval for customer regardless of quantity limits", async () => { + const api = useApi() + + // customer with customer-group 5 + const authResponse = await api.post("/store/auth", { + email: "test5@email.com", + password: "test", + }) + + const [authCookie] = authResponse.headers["set-cookie"][0].split(";") + + const res = await api + .get( + "/store/products/test-product-sale-customer-quantity?cart_id=test-cart", + { + headers: { + Cookie: authCookie, + }, + } + ) + .catch((error) => console.log(error)) + + const variant = res.data.product.variants[0] + + const date = new Date() + + expect(variant).toEqual( + expect.objectContaining({ + original_price: 150, + calculated_price: 100, + }) + ) + + expect(variant.prices.length).toEqual(3) + variant.prices.forEach((price) => { + if (price.starts_at) { + expect(new Date(price.starts_at).getTime()).toBeLessThan( + date.getTime() + ) + } + if (price.ends_at) { + expect(new Date(price.ends_at).getTime()).toBeGreaterThan( + date.getTime() + ) + } + }) + + expect(variant.prices).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: "test-price1-sale-customer-quantity-groups", + region_id: "test-region", + currency_code: "usd", + amount: 100, + max_quantity: 99, + price_list: expect.objectContaining({}), + }), + expect.objectContaining({ + id: "test-price2-sale-customer-quantity-default", + region_id: "test-region", + currency_code: "usd", + amount: 150, + }), + expect.objectContaining({ + id: "test-price1-sale-customer-quantity", + region_id: "test-region", + currency_code: "usd", + amount: 110, + max_quantity: 99, + }), + ]) + ) + }) + }) +}) diff --git a/integration-tests/api/__tests__/returns/ff-tax-inclusive-pricing.js b/integration-tests/api/__tests__/returns/ff-tax-inclusive-pricing.js new file mode 100644 index 0000000000..cdc9a1fb23 --- /dev/null +++ b/integration-tests/api/__tests__/returns/ff-tax-inclusive-pricing.js @@ -0,0 +1,276 @@ +const path = require("path") + +const startServerWithEnvironment = + require("../../../helpers/start-server-with-environment").default +const { useApi } = require("../../../helpers/use-api") +const { useDb } = require("../../../helpers/use-db") + +const { simpleProductFactory, simpleOrderFactory } = require("../../factories") +const adminSeeder = require("../../helpers/admin-seeder") + +const createReturnableOrder = async (dbConnection, options) => { + await simpleProductFactory( + dbConnection, + { + id: "test-product", + variants: [{ id: "test-variant" }], + }, + 100 + ) + + let discounts = [] + + if (options?.discount) { + discounts = [ + { + code: "TESTCODE", + }, + ] + } + + let unitPrice = options.includes_tax ? 1200 : 1000 + if (options.oldTaxes) { + unitPrice = options.includes_tax ? 1125 : 1000 + } + + return await simpleOrderFactory(dbConnection, { + email: "test@testson.com", + tax_rate: options?.oldTaxes ? undefined : null, + region: { + id: "test-region", + name: "Test region", + tax_rate: 12.5, + }, + discounts, + line_items: [ + { + id: "test-item", + variant_id: "test-variant", + quantity: 2, + fulfilled_quantity: options?.shipped ? 2 : undefined, + shipped_quantity: options?.shipped ? 2 : undefined, + unit_price: unitPrice, + includes_tax: options?.includes_tax, + tax_lines: [ + { + name: "default", + code: "default", + rate: 20, + }, + ], + }, + ], + }) +} + +jest.setTimeout(30000) + +describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] /store/carts", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + afterEach(async () => { + const db = useDb() + return await db.teardown() + }) + + it("creates a return with the old tax system and tax inclusive line", async () => { + await adminSeeder(dbConnection) + + const order = await createReturnableOrder(dbConnection, { + oldTaxes: true, + includes_tax: true, + }) + const api = useApi() + + const response = await api.post( + `/admin/orders/${order.id}/return`, + { + items: [ + { + item_id: "test-item", + quantity: 1, + note: "TOO SMALL", + }, + ], + }, + { + headers: { + authorization: "Bearer test_token", + }, + } + ) + + expect(response.status).toEqual(200) + + /* + * Region has default tax rate 12.5 therefore refund amount should be + * 1000 * 1.125 = 1125 + */ + expect(response.data.order.returns[0].refund_amount).toEqual(1125) + expect(response.data.order.returns[0].items).toHaveLength(1) + expect(response.data.order.returns[0].items).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + item_id: "test-item", + quantity: 1, + note: "TOO SMALL", + }), + ]) + ) + }) + + it("creates a return with the old tax system and tax exclusive line", async () => { + await adminSeeder(dbConnection) + + const order = await createReturnableOrder(dbConnection, { + oldTaxes: true, + includes_tax: false, + }) + const api = useApi() + + const response = await api.post( + `/admin/orders/${order.id}/return`, + { + items: [ + { + item_id: "test-item", + quantity: 1, + note: "TOO SMALL", + }, + ], + }, + { + headers: { + authorization: "Bearer test_token", + }, + } + ) + + expect(response.status).toEqual(200) + + /* + * Region has default tax rate 12.5 therefore refund amount should be + * 1000 * 1.125 = 1125 + */ + expect(response.data.order.returns[0].refund_amount).toEqual(1125) + expect(response.data.order.returns[0].items).toHaveLength(1) + expect(response.data.order.returns[0].items).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + item_id: "test-item", + quantity: 1, + note: "TOO SMALL", + }), + ]) + ) + }) + + it("creates a return with tax inclusive line", async () => { + await adminSeeder(dbConnection) + + const order = await createReturnableOrder(dbConnection, { + includes_tax: true, + }) + const api = useApi() + + const response = await api.post( + `/admin/orders/${order.id}/return`, + { + items: [ + { + item_id: "test-item", + quantity: 1, + note: "TOO SMALL", + }, + ], + }, + { + headers: { + authorization: "Bearer test_token", + }, + } + ) + + expect(response.status).toEqual(200) + + /* + * Region has a tax rate of 20% therefore refund amount should be + * 1000 * 1.2 = 1200 + */ + expect(response.data.order.returns[0].refund_amount).toEqual(1200) + expect(response.data.order.returns[0].items).toHaveLength(1) + expect(response.data.order.returns[0].items).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + item_id: "test-item", + quantity: 1, + note: "TOO SMALL", + }), + ]) + ) + }) + + it("creates a return with tax exclusive line", async () => { + await adminSeeder(dbConnection) + + const order = await createReturnableOrder(dbConnection, { + includes_tax: false, + }) + const api = useApi() + + const response = await api.post( + `/admin/orders/${order.id}/return`, + { + items: [ + { + item_id: "test-item", + quantity: 1, + note: "TOO SMALL", + }, + ], + }, + { + headers: { + authorization: "Bearer test_token", + }, + } + ) + + expect(response.status).toEqual(200) + + /* + * Region has a tax rate of 20% therefore refund amount should be + * 1000 * 1.2 = 1200 + */ + expect(response.data.order.returns[0].refund_amount).toEqual(1200) + expect(response.data.order.returns[0].items).toHaveLength(1) + expect(response.data.order.returns[0].items).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + item_id: "test-item", + quantity: 1, + note: "TOO SMALL", + }), + ]) + ) + }) +}) diff --git a/integration-tests/api/__tests__/store/__snapshots__/cart.js.snap b/integration-tests/api/__tests__/store/cart/__snapshots__/cart.js.snap similarity index 100% rename from integration-tests/api/__tests__/store/__snapshots__/cart.js.snap rename to integration-tests/api/__tests__/store/cart/__snapshots__/cart.js.snap diff --git a/integration-tests/api/__tests__/store/cart.js b/integration-tests/api/__tests__/store/cart/cart.js similarity index 98% rename from integration-tests/api/__tests__/store/cart.js rename to integration-tests/api/__tests__/store/cart/cart.js index 4d33520978..09696de85f 100644 --- a/integration-tests/api/__tests__/store/cart.js +++ b/integration-tests/api/__tests__/store/cart/cart.js @@ -9,29 +9,29 @@ const { MoneyAmount, } = require("@medusajs/medusa") -const setupServer = require("../../../helpers/setup-server") -const { useApi } = require("../../../helpers/use-api") -const { initDb, useDb } = require("../../../helpers/use-db") +const setupServer = require("../../../../helpers/setup-server") +const { useApi } = require("../../../../helpers/use-api") +const { initDb, useDb } = require("../../../../helpers/use-db") -const cartSeeder = require("../../helpers/cart-seeder") -const productSeeder = require("../../helpers/product-seeder") -const swapSeeder = require("../../helpers/swap-seeder") +const cartSeeder = require("../../../helpers/cart-seeder") +const productSeeder = require("../../../helpers/product-seeder") +const swapSeeder = require("../../../helpers/swap-seeder") const { simpleCartFactory, simpleRegionFactory, simpleProductFactory, simpleShippingOptionFactory, simpleLineItemFactory, -} = require("../../factories") +} = require("../../../factories") const { simpleDiscountFactory, -} = require("../../factories/simple-discount-factory") +} = require("../../../factories/simple-discount-factory") const { simpleCustomerFactory, -} = require("../../factories/simple-customer-factory") +} = require("../../../factories/simple-customer-factory") const { simpleCustomerGroupFactory, -} = require("../../factories/simple-customer-group-factory") +} = require("../../../factories/simple-customer-group-factory") jest.setTimeout(30000) @@ -45,7 +45,7 @@ describe("/store/carts", () => { } beforeAll(async () => { - const cwd = path.resolve(path.join(__dirname, "..", "..")) + const cwd = path.resolve(path.join(__dirname, "..", "..", "..")) dbConnection = await initDb({ cwd }) medusaProcess = await setupServer({ cwd, verbose: false }) }) @@ -1813,7 +1813,7 @@ describe("/store/carts", () => { type: "swap", }) - const cartWithCustomSo = await manager.save(_cart) + await manager.save(_cart) await manager.insert(CustomShippingOption, { id: "another-cso-test", diff --git a/integration-tests/api/__tests__/store/cart/ff-tax-inclusive-pricing.js b/integration-tests/api/__tests__/store/cart/ff-tax-inclusive-pricing.js new file mode 100644 index 0000000000..ee47f29a6f --- /dev/null +++ b/integration-tests/api/__tests__/store/cart/ff-tax-inclusive-pricing.js @@ -0,0 +1,538 @@ +const path = require("path") + +const startServerWithEnvironment = + require("../../../../helpers/start-server-with-environment").default +const { useApi } = require("../../../../helpers/use-api") +const { useDb } = require("../../../../helpers/use-db") + +const { + simpleCartFactory, + simpleRegionFactory, + simpleShippingOptionFactory, + simpleCustomShippingOptionFactory, + simpleProductFactory, + simplePriceListFactory, + simpleDiscountFactory, +} = require("../../../factories") +const { IdMap } = require("medusa-test-utils") + +jest.setTimeout(30000) + +describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] /store/carts", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("POST /store/carts/:id/shipping-methods", () => { + let includesTaxShippingOption + let cart + let customSoCart + + beforeEach(async () => { + try { + const shippingAddress = { + id: "test-shipping-address", + first_name: "lebron", + country_code: "us", + } + const region = await simpleRegionFactory(dbConnection, { + id: "test-region", + }) + cart = await simpleCartFactory(dbConnection, { + id: "test-cart", + email: "some-customer1@email.com", + region: region.id, + shipping_address: shippingAddress, + currency_code: "usd", + }) + customSoCart = await simpleCartFactory(dbConnection, { + id: "test-cart-with-cso", + email: "some-customer2@email.com", + region: region.id, + shipping_address: shippingAddress, + currency_code: "usd", + }) + includesTaxShippingOption = await simpleShippingOptionFactory( + dbConnection, + { + includes_tax: true, + region_id: region.id, + } + ) + await simpleCustomShippingOptionFactory(dbConnection, { + id: "another-cso-test", + cart_id: customSoCart.id, + shipping_option_id: includesTaxShippingOption.id, + price: 5, + }) + } catch (err) { + console.log(err) + } + }) + + afterEach(async () => { + const db = useDb() + return await db.teardown() + }) + + it("should add a normal shipping method to the cart", async () => { + const api = useApi() + + const cartWithShippingMethodRes = await api.post( + `/store/carts/${cart.id}/shipping-methods`, + { + option_id: includesTaxShippingOption.id, + }, + { withCredentials: true } + ) + + expect(cartWithShippingMethodRes.status).toEqual(200) + expect(cartWithShippingMethodRes.data.cart.shipping_methods).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + shipping_option_id: includesTaxShippingOption.id, + includes_tax: true, + }), + ]) + ) + }) + + it("should add a custom shipping method to the cart", async () => { + const api = useApi() + + const cartWithCustomShippingMethodRes = await api + .post( + `/store/carts/${customSoCart.id}/shipping-methods`, + { + option_id: includesTaxShippingOption.id, + }, + { withCredentials: true } + ) + .catch((err) => err.response) + + expect(cartWithCustomShippingMethodRes.status).toEqual(200) + expect( + cartWithCustomShippingMethodRes.data.cart.shipping_methods + ).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + shipping_option_id: includesTaxShippingOption.id, + includes_tax: true, + price: 5, + }), + ]) + ) + }) + }) + + describe("POST /store/carts/:id", () => { + const variantId1 = IdMap.getId("test-variant-1") + const variantId2 = IdMap.getId("test-variant-2") + const productId1 = IdMap.getId("test-product-1") + const productId2 = IdMap.getId("test-product-2") + const regionId = IdMap.getId("test-region") + const regionData = { + id: regionId, + includes_tax: false, + currency_code: "usd", + countries: ["us"], + tax_rate: 20, + name: "region test", + } + const buildProductData = (productId, variantId) => { + return { + id: productId, + variants: [ + { + id: variantId, + prices: [], + }, + ], + } + } + const buildPriceListData = (variantId, price, includesTax) => { + return { + status: "active", + type: "sale", + prices: [ + { + variant_id: variantId, + amount: price, + currency_code: "usd", + region_id: regionId, + }, + ], + includes_tax: includesTax, + } + } + const customnerPayload = { + email: "adrien@test.dk", + password: "adrientest", + first_name: "adrien", + last_name: "adrien", + } + const createCartPayload = { + region_id: regionId, + items: [ + { + variant_id: variantId1, + quantity: 1, + }, + { + variant_id: variantId2, + quantity: 1, + }, + ], + } + + describe("with a cart with full tax exclusive variant pricing", () => { + beforeEach(async () => { + await simpleRegionFactory(dbConnection, regionData) + await simpleProductFactory( + dbConnection, + buildProductData(productId1, variantId1) + ) + await simplePriceListFactory( + dbConnection, + buildPriceListData(variantId1, 100, false) + ) + await simpleProductFactory( + dbConnection, + buildProductData(productId2, variantId2) + ) + await simplePriceListFactory( + dbConnection, + buildPriceListData(variantId2, 100, false) + ) + }) + + afterEach(async () => { + const db = useDb() + return await db.teardown() + }) + + it("should calculates correct payment totals on cart completion", async () => { + const api = useApi() + + const customerRes = await api.post( + "/store/customers", + customnerPayload, + { withCredentials: true } + ) + + const createCartRes = await api.post("/store/carts", createCartPayload) + + const cart = createCartRes.data.cart + + await api.post(`/store/carts/${cart.id}`, { + customer_id: customerRes.data.customer.id, + }) + + await api.post(`/store/carts/${cart.id}/payment-sessions`) + + const createdOrder = await api.post( + `/store/carts/${cart.id}/complete-cart` + ) + + expect(createdOrder.data.type).toEqual("order") + expect(createdOrder.data.data.discount_total).toEqual(0) + expect(createdOrder.data.data.subtotal).toEqual(200) + expect(createdOrder.data.data.total).toEqual(240) + + expect(createdOrder.status).toEqual(200) + }) + }) + + describe("with a cart with full tax inclusive variant pricing", () => { + beforeEach(async () => { + await simpleRegionFactory(dbConnection, regionData) + await simpleProductFactory( + dbConnection, + buildProductData(productId1, variantId1) + ) + await simplePriceListFactory( + dbConnection, + buildPriceListData(variantId1, 120, true) + ) + await simpleProductFactory( + dbConnection, + buildProductData(productId2, variantId2) + ) + await simplePriceListFactory( + dbConnection, + buildPriceListData(variantId2, 120, true) + ) + }) + + afterEach(async () => { + const db = useDb() + return await db.teardown() + }) + + it("should calculates correct payment totals on cart completion", async () => { + const api = useApi() + + const customerRes = await api.post( + "/store/customers", + customnerPayload, + { withCredentials: true } + ) + + const createCartRes = await api.post("/store/carts", createCartPayload) + + const cart = createCartRes.data.cart + + await api.post(`/store/carts/${cart.id}`, { + customer_id: customerRes.data.customer.id, + }) + + await api.post(`/store/carts/${cart.id}/payment-sessions`) + + const createdOrder = await api.post( + `/store/carts/${cart.id}/complete-cart` + ) + + expect(createdOrder.data.type).toEqual("order") + expect(createdOrder.data.data.discount_total).toEqual(0) + expect(createdOrder.data.data.subtotal).toEqual(200) + expect(createdOrder.data.data.total).toEqual(240) + + expect(createdOrder.status).toEqual(200) + }) + }) + + describe("with a cart mixing tax inclusive and exclusive variant pricing", () => { + beforeEach(async () => { + await simpleRegionFactory(dbConnection, regionData) + await simpleProductFactory( + dbConnection, + buildProductData(productId1, variantId1) + ) + await simplePriceListFactory( + dbConnection, + buildPriceListData(variantId1, 120, true) + ) + await simpleProductFactory( + dbConnection, + buildProductData(productId2, variantId2) + ) + await simplePriceListFactory( + dbConnection, + buildPriceListData(variantId2, 100, false) + ) + }) + + afterEach(async () => { + const db = useDb() + return await db.teardown() + }) + + it("should calculates correct payment totals on cart completion", async () => { + const api = useApi() + + const customerRes = await api.post( + "/store/customers", + customnerPayload, + { withCredentials: true } + ) + + const createCartRes = await api.post("/store/carts", createCartPayload) + + const cart = createCartRes.data.cart + + await api.post(`/store/carts/${cart.id}`, { + customer_id: customerRes.data.customer.id, + }) + + await api.post(`/store/carts/${cart.id}/payment-sessions`) + + const createdOrder = await api.post( + `/store/carts/${cart.id}/complete-cart` + ) + + expect(createdOrder.data.type).toEqual("order") + expect(createdOrder.data.data.discount_total).toEqual(0) + expect(createdOrder.data.data.subtotal).toEqual(200) + expect(createdOrder.data.data.total).toEqual(240) + + expect(createdOrder.status).toEqual(200) + }) + }) + }) + + describe("POST /store/carts/:id/line-items", () => { + const cartIdWithItemPercentageDiscount = + "test-cart-w-item-percentage-discount" + const percentage15discountId = IdMap.getId("percentage15discountId") + const variantId1 = IdMap.getId("test-variant-1") + const variantId2 = IdMap.getId("test-variant-2") + const productId1 = IdMap.getId("test-product-1") + const productId2 = IdMap.getId("test-product-2") + const regionId = IdMap.getId("test-region") + const regionData = { + id: regionId, + includes_tax: false, + currency_code: "usd", + countries: ["us"], + tax_rate: 20, + name: "region test", + } + const buildProductData = (productId, variantId) => { + return { + id: productId, + variants: [ + { + id: variantId, + prices: [], + }, + ], + } + } + const buildPriceListData = (variantId, price, includesTax) => { + return { + status: "active", + type: "sale", + prices: [ + { + variant_id: variantId, + amount: price, + currency_code: "usd", + region_id: regionId, + }, + ], + includes_tax: includesTax, + } + } + + describe("with a cart mixing tax inclusive and exclusive variant pricing", () => { + beforeEach(async () => { + const region = await simpleRegionFactory(dbConnection, regionData) + await simpleCartFactory(dbConnection, { + id: cartIdWithItemPercentageDiscount, + region, + }) + await simpleProductFactory( + dbConnection, + buildProductData(productId1, variantId1) + ) + await simplePriceListFactory( + dbConnection, + buildPriceListData(variantId1, 120, true) + ) + await simpleProductFactory( + dbConnection, + buildProductData(productId2, variantId2) + ) + await simplePriceListFactory( + dbConnection, + buildPriceListData(variantId2, 100, false) + ) + + const tenDaysAgo = ((today) => + new Date(today.setDate(today.getDate() - 10)))(new Date()) + const tenDaysFromToday = ((today) => + new Date(today.setDate(today.getDate() + 10)))(new Date()) + await simpleDiscountFactory(dbConnection, { + id: percentage15discountId, + code: percentage15discountId, + regions: [regionId], + rule: { + type: "percentage", + value: "15", + allocation: "item", + }, + starts_at: tenDaysAgo, + ends_at: tenDaysFromToday, + }) + }) + + afterEach(async () => { + const db = useDb() + return await db.teardown() + }) + + it("calculates correct item totals for percentage discount with mix of tax inclusive/exclusive items", async () => { + const api = useApi() + + await api.post(`/store/carts/${cartIdWithItemPercentageDiscount}`, { + region_id: regionId, + discounts: [{ code: percentage15discountId }], + }) + + await api.post( + `/store/carts/${cartIdWithItemPercentageDiscount}/line-items`, + { + variant_id: variantId1, + quantity: 2, + }, + { withCredentials: true } + ) + const response = await api.post( + `/store/carts/${cartIdWithItemPercentageDiscount}/line-items`, + { + variant_id: variantId2, + quantity: 2, + }, + { withCredentials: true } + ) + + const expectedItemTotals = { + subtotal: 200, + gift_card_total: 0, + discount_total: 30, + total: 204, + original_total: 240, + original_tax_total: 40, + tax_total: 34, + } + + const expectedAdjustment = { + amount: 30, + discount_id: percentage15discountId, + description: "discount", + } + + expect(response.data.cart.items).toHaveLength(2) + expect(response.data.cart.items).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + includes_tax: true, + cart_id: cartIdWithItemPercentageDiscount, + unit_price: 120, + variant_id: variantId1, + quantity: 2, + adjustments: [expect.objectContaining(expectedAdjustment)], + ...expectedItemTotals, + }), + expect.objectContaining({ + includes_tax: false, + cart_id: cartIdWithItemPercentageDiscount, + unit_price: 100, + variant_id: variantId2, + quantity: 2, + adjustments: [expect.objectContaining(expectedAdjustment)], + ...expectedItemTotals, + }), + ]) + ) + }) + }) + }) +}) diff --git a/integration-tests/api/__tests__/taxes/orders/ff-tax-inclusive-pricing.js b/integration-tests/api/__tests__/taxes/orders/ff-tax-inclusive-pricing.js new file mode 100644 index 0000000000..addddc5177 --- /dev/null +++ b/integration-tests/api/__tests__/taxes/orders/ff-tax-inclusive-pricing.js @@ -0,0 +1,106 @@ +const path = require("path") + +const { useApi } = require("../../../../helpers/use-api") +const { useDb } = require("../../../../helpers/use-db") + +const startServerWithEnvironment = + require("../../../../helpers/start-server-with-environment").default + +const { + simpleOrderFactory, + simpleProductFactory, +} = require("../../../factories") + +jest.setTimeout(30000) + +describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING]: Order Taxes", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_TAX_INCLUSIVE_PRICING: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + medusaProcess.kill() + }) + + afterEach(async () => { + const db = useDb() + return await db.teardown() + }) + + test("calculates taxes w. tax inclusive shipping method price", async () => { + await simpleProductFactory( + dbConnection, + { + id: "test-product", + variants: [ + { + id: "test-variant", + }, + ], + }, + 100 + ) + + const order = await simpleOrderFactory( + dbConnection, + { + email: "test@testson.com", + tax_rate: null, + region: { + id: "test-region", + name: "Test region", + tax_rate: null, + }, + shipping_methods: [ + { + price: 110, + includes_tax: true, + shipping_option: { + region_id: "test-region", + }, + tax_lines: [ + { + rate: 10, + name: "default", + code: "default", + }, + ], + }, + ], + line_items: [ + { + variant_id: "test-variant", + unit_price: 1000, + tax_lines: [ + { + rate: 20, + name: "default", + code: "default", + }, + ], + }, + ], + }, + 100 + ) + + const api = useApi() + + const response = await api.get(`/store/orders/${order.id}`) + expect(response.status).toEqual(200) + expect(response.data.order.tax_total).toEqual(210) + expect(response.data.order.total).toEqual(1310) + }) +}) diff --git a/integration-tests/api/__tests__/taxes/orders.js b/integration-tests/api/__tests__/taxes/orders/orders.js similarity index 96% rename from integration-tests/api/__tests__/taxes/orders.js rename to integration-tests/api/__tests__/taxes/orders/orders.js index 580667dd00..27964365c7 100644 --- a/integration-tests/api/__tests__/taxes/orders.js +++ b/integration-tests/api/__tests__/taxes/orders/orders.js @@ -1,8 +1,8 @@ const path = require("path") -const setupServer = require("../../../helpers/setup-server") -const { useApi } = require("../../../helpers/use-api") -const { initDb, useDb } = require("../../../helpers/use-db") +const setupServer = require("../../../../helpers/setup-server") +const { useApi } = require("../../../../helpers/use-api") +const { initDb, useDb } = require("../../../../helpers/use-db") const { simpleOrderFactory, @@ -10,7 +10,7 @@ const { simpleCartFactory, simpleProductFactory, simpleProductTaxRateFactory, -} = require("../../factories") +} = require("../../../factories") jest.setTimeout(30000) @@ -24,7 +24,7 @@ describe("Order Taxes", () => { } beforeAll(async () => { - const cwd = path.resolve(path.join(__dirname, "..", "..")) + const cwd = path.resolve(path.join(__dirname, "..", "..", "..")) dbConnection = await initDb({ cwd }) medusaProcess = await setupServer({ cwd }) }) diff --git a/integration-tests/api/factories/index.ts b/integration-tests/api/factories/index.ts index 192646a0a0..16a6830152 100644 --- a/integration-tests/api/factories/index.ts +++ b/integration-tests/api/factories/index.ts @@ -17,3 +17,4 @@ export * from "./simple-product-type-tax-rate-factory" export * from "./simple-price-list-factory" export * from "./simple-batch-job-factory" export * from "./simple-sales-channel-factory" +export * from "./simple-custom-shipping-option-factory" diff --git a/integration-tests/api/factories/simple-custom-shipping-option-factory.ts b/integration-tests/api/factories/simple-custom-shipping-option-factory.ts new file mode 100644 index 0000000000..f3404bbfe0 --- /dev/null +++ b/integration-tests/api/factories/simple-custom-shipping-option-factory.ts @@ -0,0 +1,36 @@ +import { Connection } from "typeorm" +import faker from "faker" +import { + CustomShippingOption, +} from "@medusajs/medusa" + +export type CustomShippingOptionFactoryData = { + id?: string + cart_id: string + shipping_option_id: string + price?: number + metadata?: Record +} + +export const simpleCustomShippingOptionFactory = async ( + connection: Connection, + data: CustomShippingOptionFactoryData, + seed?: number +): Promise => { + if (typeof seed !== "undefined") { + faker.seed(seed) + } + + const manager = connection.manager + + const customShippingOptionData = { + id: data.id ?? `custon-simple-so-${Math.random() * 1000}`, + price: typeof data.price !== "undefined" ? data.price : 500, + cart_id: data.cart_id, + shipping_option_id: data.shipping_option_id, + metadata: data.metadata ?? {} + } + + const created = manager.create(CustomShippingOption, customShippingOptionData) + return await manager.save(created) +} diff --git a/integration-tests/api/factories/simple-discount-condition-factory.ts b/integration-tests/api/factories/simple-discount-condition-factory.ts index 9735f8a2b3..eba3af4a27 100644 --- a/integration-tests/api/factories/simple-discount-condition-factory.ts +++ b/integration-tests/api/factories/simple-discount-condition-factory.ts @@ -12,7 +12,7 @@ import { DiscountConditionJoinTableForeignKey } from "@medusajs/medusa/dist/repo import faker from "faker" import { Connection } from "typeorm" -export type DiscuntConditionFactoryData = { +export type DiscountConditionFactoryData = { id?: string rule_id: string type: DiscountConditionType @@ -67,7 +67,7 @@ const getJoinTableResourceIdentifiers = (type: string) => { export const simpleDiscountConditionFactory = async ( connection: Connection, - data: DiscuntConditionFactoryData, + data: DiscountConditionFactoryData, seed?: number ): Promise => { if (typeof seed !== "undefined") { diff --git a/integration-tests/api/factories/simple-discount-factory.ts b/integration-tests/api/factories/simple-discount-factory.ts index d8066c6b45..24cf20ff33 100644 --- a/integration-tests/api/factories/simple-discount-factory.ts +++ b/integration-tests/api/factories/simple-discount-factory.ts @@ -7,7 +7,7 @@ import { import faker from "faker" import { Connection } from "typeorm" import { - DiscuntConditionFactoryData, + DiscountConditionFactoryData, simpleDiscountConditionFactory, } from "./simple-discount-condition-factory" @@ -15,7 +15,7 @@ export type DiscountRuleFactoryData = { type?: DiscountRuleType value?: number allocation?: AllocationType - conditions: DiscuntConditionFactoryData[] + conditions: DiscountConditionFactoryData[] } export type DiscountFactoryData = { @@ -24,6 +24,8 @@ export type DiscountFactoryData = { is_dynamic?: boolean rule?: DiscountRuleFactoryData regions?: string[] + starts_at?: Date + ends_at?: Date } export const simpleDiscountFactory = async ( @@ -37,7 +39,7 @@ export const simpleDiscountFactory = async ( const manager = connection.manager - const ruleData = data.rule ?? {} + const ruleData = data.rule ?? ({} as DiscountRuleFactoryData) const ruleToSave = manager.create(DiscountRule, { type: ruleData.type ?? DiscountRuleType.PERCENTAGE, value: ruleData.value ?? 10, @@ -63,8 +65,9 @@ export const simpleDiscountFactory = async ( rule_id: dRule.id, code: data.code ?? "TESTCODE", regions: data.regions?.map((r) => ({ id: r })) || [], + starts_at: data.starts_at, + ends_at: data.ends_at, }) - const discount = await manager.save(toSave) - return discount + return await manager.save(toSave) } diff --git a/integration-tests/api/factories/simple-line-item-factory.ts b/integration-tests/api/factories/simple-line-item-factory.ts index d270b40043..98d8ccedf2 100644 --- a/integration-tests/api/factories/simple-line-item-factory.ts +++ b/integration-tests/api/factories/simple-line-item-factory.ts @@ -30,6 +30,7 @@ export type LineItemFactoryData = { returned_quantity?: boolean tax_lines?: TaxLineFactoryData[] adjustments: LineItemAdjustmentFactoryData[] + includes_tax?: boolean } export const simpleLineItemFactory = async ( @@ -70,6 +71,7 @@ export const simpleLineItemFactory = async ( shipped_quantity: data.shipped_quantity || null, returned_quantity: data.returned_quantity || null, adjustments: data.adjustments, + includes_tax: data.includes_tax, }) const line = await manager.save(toSave) diff --git a/integration-tests/api/factories/simple-price-list-factory.ts b/integration-tests/api/factories/simple-price-list-factory.ts index f8d5c9d9f0..b122919aa6 100644 --- a/integration-tests/api/factories/simple-price-list-factory.ts +++ b/integration-tests/api/factories/simple-price-list-factory.ts @@ -26,6 +26,7 @@ export type PriceListFactoryData = { ends_at?: Date customer_groups?: string[] prices?: ProductListPrice[] + includes_tax?: boolean } export const simplePriceListFactory = async ( @@ -59,6 +60,7 @@ export const simplePriceListFactory = async ( starts_at: data.starts_at || null, ends_at: data.ends_at || null, customer_groups: customerGroups, + includes_tax: data.includes_tax, } const toSave = manager.create(PriceList, toCreate) diff --git a/integration-tests/api/factories/simple-region-factory.ts b/integration-tests/api/factories/simple-region-factory.ts index 185d02cfb8..6a473b681e 100644 --- a/integration-tests/api/factories/simple-region-factory.ts +++ b/integration-tests/api/factories/simple-region-factory.ts @@ -10,6 +10,8 @@ export type RegionFactoryData = { countries?: string[] automatic_taxes?: boolean gift_cards_taxable?: boolean + fulfillment_providers?: { id: string }[] + includes_tax?: boolean } export const simpleRegionFactory = async ( @@ -30,7 +32,9 @@ export const simpleRegionFactory = async ( currency_code: data.currency_code || "usd", tax_rate: data.tax_rate || 0, payment_providers: [{ id: "test-pay" }], + fulfillment_providers: data.fulfillment_providers ?? [{ id: "test-ful" }], gift_cards_taxable: data.gift_cards_taxable ?? true, + includes_tax: data.includes_tax, automatic_taxes: typeof data.automatic_taxes !== "undefined" ? data.automatic_taxes : true, }) diff --git a/integration-tests/api/factories/simple-shipping-method-factory.ts b/integration-tests/api/factories/simple-shipping-method-factory.ts index 3fde22e63d..ef219f45ab 100644 --- a/integration-tests/api/factories/simple-shipping-method-factory.ts +++ b/integration-tests/api/factories/simple-shipping-method-factory.ts @@ -15,6 +15,7 @@ export type ShippingMethodFactoryData = { price?: number shipping_option: string | ShippingOptionFactoryData tax_lines?: ShippingMethodTaxLine[] + includes_tax?: boolean } export const simpleShippingMethodFactory = async ( @@ -47,6 +48,7 @@ export const simpleShippingMethodFactory = async ( shipping_option_id: shippingOptionId, data: data.data || {}, price: typeof data.price !== "undefined" ? data.price : 500, + includes_tax: data.includes_tax, }) const shippingMethod = await manager.save(toSave) diff --git a/integration-tests/api/factories/simple-shipping-option-factory.ts b/integration-tests/api/factories/simple-shipping-option-factory.ts index a85d683e4b..e53bc1a0d5 100644 --- a/integration-tests/api/factories/simple-shipping-option-factory.ts +++ b/integration-tests/api/factories/simple-shipping-option-factory.ts @@ -8,12 +8,14 @@ import faker from "faker" import { Connection } from "typeorm" export type ShippingOptionFactoryData = { + id?: string name?: string region_id: string is_return?: boolean is_giftcard?: boolean price?: number price_type?: ShippingOptionPriceType + includes_tax?: boolean data?: object } @@ -35,8 +37,8 @@ export const simpleShippingOptionFactory = async ( type: ShippingProfileType.GIFT_CARD, }) - const created = manager.create(ShippingOption, { - id: `simple-so-${Math.random() * 1000}`, + const shippingOptionData = { + id: data.id ?? `simple-so-${Math.random() * 1000}`, name: data.name || "Test Method", is_return: data.is_return ?? false, region_id: data.region_id, @@ -45,7 +47,15 @@ export const simpleShippingOptionFactory = async ( price_type: data.price_type ?? ShippingOptionPriceType.FLAT_RATE, data: data.data ?? {}, amount: typeof data.price !== "undefined" ? data.price : 500, - }) - const option = await manager.save(created) - return option + } + + // This is purposefully managed out of the original object for the purpose of separating the data linked to a feature flag + // MEDUSA_FF_TAX_INCLUSIVE_PRICING + const { includes_tax } = data + if (typeof includes_tax !== "undefined") { + shippingOptionData["includes_tax"] = includes_tax + } + + const created = manager.create(ShippingOption, shippingOptionData) + return await manager.save(created) } diff --git a/packages/medusa/src/api/routes/admin/currencies/__tests__/list-currencies.ts b/packages/medusa/src/api/routes/admin/currencies/__tests__/list-currencies.ts new file mode 100644 index 0000000000..637594c38a --- /dev/null +++ b/packages/medusa/src/api/routes/admin/currencies/__tests__/list-currencies.ts @@ -0,0 +1,52 @@ +import { IdMap } from "medusa-test-utils" +import { request } from "../../../../../helpers/test-request" +import { currency, CurrencyServiceMock } from "../../../../../services/__mocks__/currency"; +import TaxInclusivePricingFeatureFlag from "../../../../../loaders/feature-flags/tax-inclusive-pricing"; + +describe("GET /admin/currencies/", () => { + describe("successfully list the currency", () => { + let subject + + beforeAll(async () => { + subject = await request( + "GET", + `/admin/currencies`, + { + adminSession: { + jwt: { + userId: IdMap.getId("admin_user"), + }, + }, + flags: [TaxInclusivePricingFeatureFlag], + } + ) + }) + + afterAll(() => { + jest.clearAllMocks() + }) + + it("calls the listAndCount method from the currency service", () => { + expect(CurrencyServiceMock.listAndCount).toHaveBeenCalledTimes(1) + expect(CurrencyServiceMock.listAndCount).toHaveBeenCalledWith( + {}, + { + order: {}, + select: undefined, + relations: [], + skip: 0, + take: 20 + } + ) + }) + + it("returns the expected currencies", () => { + expect(subject.body).toEqual({ + currencies: [currency], + offset: 0, + limit: 20, + count: 1, + }) + }) + }) +}) diff --git a/packages/medusa/src/api/routes/admin/currencies/__tests__/update-currency.ts b/packages/medusa/src/api/routes/admin/currencies/__tests__/update-currency.ts new file mode 100644 index 0000000000..cd8217ca50 --- /dev/null +++ b/packages/medusa/src/api/routes/admin/currencies/__tests__/update-currency.ts @@ -0,0 +1,48 @@ +import { IdMap } from "medusa-test-utils" +import { request } from "../../../../../helpers/test-request" +import { currency, CurrencyServiceMock } from "../../../../../services/__mocks__/currency"; +import TaxInclusivePricingFeatureFlag from "../../../../../loaders/feature-flags/tax-inclusive-pricing"; + +describe("POST /admin/currencies/:code", () => { + let subject + const code = IdMap.getId("currency-1") + + beforeAll(async () => { + subject = await request( + "POST", + `/admin/currencies/${code}`, + { + payload: { + includes_tax: true, + }, + adminSession: { + jwt: { + userId: IdMap.getId("admin_user"), + }, + }, + flags: [TaxInclusivePricingFeatureFlag], + } + ) + }) + + it("returns 200", () => { + expect(subject.status).toEqual(200) + }) + + it("returns updated currency", () => { + expect(subject.body.currency).toEqual({ + ...currency, + includes_tax: true, + }) + }) + + it("calls service update", () => { + expect(CurrencyServiceMock.update).toHaveBeenCalledTimes(1) + expect(CurrencyServiceMock.update).toHaveBeenCalledWith( + code, + { + includes_tax: true, + } + ) + }) +}) diff --git a/packages/medusa/src/api/routes/admin/currencies/index.ts b/packages/medusa/src/api/routes/admin/currencies/index.ts new file mode 100644 index 0000000000..305cbe508b --- /dev/null +++ b/packages/medusa/src/api/routes/admin/currencies/index.ts @@ -0,0 +1,37 @@ +import { Router } from "express" +import middlewares, { + transformBody, + transformQuery, +} from "../../../middlewares" +import { AdminGetCurrenciesParams } from "./list-currencies" +import { AdminPostCurrenciesCurrencyReq } from "./update-currency" +import { isFeatureFlagEnabled } from "../../../middlewares/feature-flag-enabled" +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" + +export default (app) => { + const route = Router() + app.use( + "/currencies", + isFeatureFlagEnabled(TaxInclusivePricingFeatureFlag.key), + route + ) + + route.get( + "/", + transformQuery(AdminGetCurrenciesParams, { + isList: true, + }), + middlewares.wrap(require("./list-currencies").default) + ) + + route.post( + "/:code", + transformBody(AdminPostCurrenciesCurrencyReq), + middlewares.wrap(require("./update-currency").default) + ) + + return app +} + +export * from "./list-currencies" +export * from "./update-currency" diff --git a/packages/medusa/src/api/routes/admin/currencies/list-currencies.ts b/packages/medusa/src/api/routes/admin/currencies/list-currencies.ts new file mode 100644 index 0000000000..4e92ed311f --- /dev/null +++ b/packages/medusa/src/api/routes/admin/currencies/list-currencies.ts @@ -0,0 +1,76 @@ +import { IsBoolean, IsOptional, IsString } from "class-validator" +import { Currency } from "../../../../models" +import { CurrencyService } from "../../../../services" +import { ExtendedRequest } from "../../../../types/global" +import { FindConfig, FindPaginationParams } from "../../../../types/common" + +/** + * @oas [get] /currencies + * operationId: "GetCurrencies" + * summary: "List Currency" + * description: "Retrieves a list of Currency" + * x-authenticated: true + * parameters: + * - (query) code {string} Code of the currency to search for. + * - (query) includes_tax {boolean} Search for tax inclusive currencies. + * - (query) order {string} to retrieve products in. + * - (query) offset {string} How many products to skip in the result. + * - (query) limit {string} Limit the number of products returned. + * tags: + * - Currency + * responses: + * 200: + * description: OK + * content: + * application/json: + * schema: + * properties: + * count: + * description: The number of Currency. + * type: integer + * offset: + * description: The offset of the Currency query. + * type: integer + * limit: + * description: The limit of the currency query. + * type: integer + * currencies: + * type: array + * items: + * $ref: "#/components/schemas/currency" + */ +export default async (req: ExtendedRequest, res) => { + const currencyService: CurrencyService = req.scope.resolve("currencyService") + + const { skip, take } = req.listConfig + + req.listConfig.select = undefined + if (req.listConfig.order && req.listConfig.order["created_at"]) { + delete req.listConfig.order["created_at"] + } + const [currencies, count] = await currencyService.listAndCount( + req.filterableFields, + req.listConfig + ) + + res.json({ + currencies, + count, + offset: skip, + limit: take, + }) +} + +export class AdminGetCurrenciesParams extends FindPaginationParams { + @IsString() + @IsOptional() + code?: string + + @IsBoolean() + @IsOptional() + includes_tax?: boolean + + @IsString() + @IsOptional() + order?: string +} diff --git a/packages/medusa/src/api/routes/admin/currencies/update-currency.ts b/packages/medusa/src/api/routes/admin/currencies/update-currency.ts new file mode 100644 index 0000000000..c91640977f --- /dev/null +++ b/packages/medusa/src/api/routes/admin/currencies/update-currency.ts @@ -0,0 +1,52 @@ +import { IsBoolean, IsOptional } from "class-validator" +import { Currency } from "../../../../models" +import { ExtendedRequest } from "../../../../types/global" +import { CurrencyService } from "../../../../services" +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators" +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" + +/** + * @oas [post] /currencies/:code + * operationId: "PostCurrenciesCurrency" + * summary: "Update a Currency" + * description: "Update a Currency" + * x-authenticated: true + * parameters: + * - (path) code=* {string} The code of the Currency. + * requestBody: + * content: + * application/json: + * schema: + * properties: + * includes_tax: + * type: boolean + * description: [EXPERIMENTAL] Tax included in prices of currency. + * tags: + * - Currency + * responses: + * 200: + * description: OK + * content: + * application/json: + * schema: + * properties: + * currency: + * $ref: "#/components/schemas/currency" + */ +export default async (req: ExtendedRequest, res) => { + const code = req.params.code as string + const data = req.validatedBody as AdminPostCurrenciesCurrencyReq + const currencyService: CurrencyService = req.scope.resolve("currencyService") + + const currency = await currencyService.update(code, data) + + res.json({ currency }) +} + +export class AdminPostCurrenciesCurrencyReq { + @FeatureFlagDecorators(TaxInclusivePricingFeatureFlag.key, [ + IsOptional(), + IsBoolean(), + ]) + includes_tax?: boolean +} diff --git a/packages/medusa/src/api/routes/admin/index.js b/packages/medusa/src/api/routes/admin/index.js index 171679d144..98a4b49bee 100644 --- a/packages/medusa/src/api/routes/admin/index.js +++ b/packages/medusa/src/api/routes/admin/index.js @@ -3,7 +3,9 @@ import { Router } from "express" import middlewares from "../../middlewares" import appRoutes from "./apps" import authRoutes from "./auth" +import batchRoutes from "./batch" import collectionRoutes from "./collections" +import currencyRoutes from "./currencies" import customerGroupRoutes from "./customer-groups" import customerRoutes from "./customers" import discountRoutes from "./discounts" @@ -14,7 +16,6 @@ import noteRoutes from "./notes" import notificationRoutes from "./notifications" import orderRoutes from "./orders" import priceListRoutes from "./price-lists" -import batchRoutes from "./batch" import productTagRoutes from "./product-tags" import productTypesRoutes from "./product-types" import productRoutes from "./products" @@ -70,6 +71,7 @@ export default (app, container, config) => { collectionRoutes(route) customerGroupRoutes(route) customerRoutes(route) + currencyRoutes(route) discountRoutes(route) draftOrderRoutes(route) giftCardRoutes(route) @@ -77,15 +79,15 @@ export default (app, container, config) => { noteRoutes(route) notificationRoutes(route) orderRoutes(route, featureFlagRouter) - priceListRoutes(route) + priceListRoutes(route, featureFlagRouter) productRoutes(route, featureFlagRouter) productTagRoutes(route) productTypesRoutes(route) - regionRoutes(route) + regionRoutes(route, featureFlagRouter) returnReasonRoutes(route) returnRoutes(route) salesChannelRoutes(route) - shippingOptionRoutes(route) + shippingOptionRoutes(route, featureFlagRouter) shippingProfileRoutes(route) storeRoutes(route) swapRoutes(route) diff --git a/packages/medusa/src/api/routes/admin/price-lists/create-price-list.ts b/packages/medusa/src/api/routes/admin/price-lists/create-price-list.ts index 94525eda00..ae19a9bffd 100644 --- a/packages/medusa/src/api/routes/admin/price-lists/create-price-list.ts +++ b/packages/medusa/src/api/routes/admin/price-lists/create-price-list.ts @@ -1,21 +1,24 @@ +import { + IsArray, + IsBoolean, + IsEnum, + IsOptional, + IsString, + ValidateNested, +} from "class-validator" import { AdminPriceListPricesCreateReq, CreatePriceListInput, PriceListStatus, PriceListType, } from "../../../../types/price-list" -import { - IsArray, - IsEnum, - IsOptional, - IsString, - ValidateNested, -} from "class-validator" -import { EntityManager } from "typeorm" -import PriceListService from "../../../../services/price-list" -import { Request } from "express" import { Type } from "class-transformer" +import { Request } from "express" +import { EntityManager } from "typeorm" +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" +import PriceListService from "../../../../services/price-list" +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators" /** * @oas [post] /price-lists @@ -98,6 +101,9 @@ import { Type } from "class-transformer" * id: * description: The ID of a customer group * type: string + * includes_tax: + * description: "[EXPERIMENTAL] Tax included in prices of price list" + * type: boolean * x-codeSamples: * - lang: JavaScript * label: JS Client @@ -215,4 +221,10 @@ export class AdminPostPriceListsPriceListReq { @Type(() => CustomerGroup) @ValidateNested({ each: true }) customer_groups?: CustomerGroup[] + + @FeatureFlagDecorators(TaxInclusivePricingFeatureFlag.key, [ + IsOptional(), + IsBoolean(), + ]) + includes_tax?: boolean } diff --git a/packages/medusa/src/api/routes/admin/price-lists/index.ts b/packages/medusa/src/api/routes/admin/price-lists/index.ts index a77bdc0647..7d4437f939 100644 --- a/packages/medusa/src/api/routes/admin/price-lists/index.ts +++ b/packages/medusa/src/api/routes/admin/price-lists/index.ts @@ -14,12 +14,18 @@ import { defaultAdminProductRelations, } from "../products" import { AdminPostPriceListsPriceListReq } from "./create-price-list" +import { FlagRouter } from "../../../../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" const route = Router() -export default (app) => { +export default (app, featureFlagRouter: FlagRouter) => { app.use("/price-lists", route) + if (featureFlagRouter.isFeatureEnabled(TaxInclusivePricingFeatureFlag.key)) { + defaultAdminPriceListFields.push("includes_tax") + } + route.get("/:id", middlewares.wrap(require("./get-price-list").default)) route.get( diff --git a/packages/medusa/src/api/routes/admin/price-lists/update-price-list.ts b/packages/medusa/src/api/routes/admin/price-lists/update-price-list.ts index 9088be5edd..512143b121 100644 --- a/packages/medusa/src/api/routes/admin/price-lists/update-price-list.ts +++ b/packages/medusa/src/api/routes/admin/price-lists/update-price-list.ts @@ -1,22 +1,25 @@ -import { - AdminPriceListPricesUpdateReq, - PriceListStatus, - PriceListType, -} from "../../../../types/price-list" import { IsArray, + IsBoolean, IsEnum, IsOptional, IsString, ValidateNested, } from "class-validator" import { defaultAdminPriceListFields, defaultAdminPriceListRelations } from "." +import { + AdminPriceListPricesUpdateReq, + PriceListStatus, + PriceListType, +} from "../../../../types/price-list" -import { PriceList } from "../../../.." -import PriceListService from "../../../../services/price-list" import { Type } from "class-transformer" -import { validator } from "../../../../utils/validator" import { EntityManager } from "typeorm" +import { PriceList } from "../../../.." +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" +import PriceListService from "../../../../services/price-list" +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators" +import { validator } from "../../../../utils/validator" /** * @oas [post] /price-lists/{id} @@ -99,6 +102,9 @@ import { EntityManager } from "typeorm" * id: * description: The ID of a customer group * type: string + * includes_tax: + * description: "[EXPERIMENTAL] Tax included in prices of price list" + * type: boolean * x-codeSamples: * - lang: JavaScript * label: JS Client @@ -213,4 +219,10 @@ export class AdminPostPriceListsPriceListPriceListReq { @Type(() => CustomerGroup) @ValidateNested({ each: true }) customer_groups?: CustomerGroup[] + + @FeatureFlagDecorators(TaxInclusivePricingFeatureFlag.key, [ + IsOptional(), + IsBoolean(), + ]) + includes_tax?: boolean } diff --git a/packages/medusa/src/api/routes/admin/regions/create-region.ts b/packages/medusa/src/api/routes/admin/regions/create-region.ts index e62737c172..40e15108c5 100644 --- a/packages/medusa/src/api/routes/admin/regions/create-region.ts +++ b/packages/medusa/src/api/routes/admin/regions/create-region.ts @@ -1,16 +1,18 @@ import { IsArray, + IsBoolean, IsNumber, IsObject, IsOptional, IsString, } from "class-validator" import { EntityManager } from "typeorm" - -import { validator } from "../../../../utils/validator" +import { defaultAdminRegionFields, defaultAdminRegionRelations } from "." import { Region } from "../../../.." +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" import RegionService from "../../../../services/region" -import { defaultAdminRegionRelations, defaultAdminRegionFields } from "." +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators" +import { validator } from "../../../../utils/validator" /** * @oas [post] /regions @@ -61,6 +63,9 @@ import { defaultAdminRegionRelations, defaultAdminRegionFields } from "." * type: array * items: * type: string + * includes_tax: + * description: "[EXPERIMENTAL] Tax included in prices of region" + * type: boolean * x-codeSamples: * - lang: JavaScript * label: JS Client @@ -179,6 +184,12 @@ export class AdminPostRegionsReq { @IsString({ each: true }) countries: string[] + @FeatureFlagDecorators(TaxInclusivePricingFeatureFlag.key, [ + IsOptional(), + IsBoolean(), + ]) + includes_tax?: boolean + @IsObject() @IsOptional() metadata?: Record diff --git a/packages/medusa/src/api/routes/admin/regions/index.ts b/packages/medusa/src/api/routes/admin/regions/index.ts index d8abb85286..ff6dd80ca0 100644 --- a/packages/medusa/src/api/routes/admin/regions/index.ts +++ b/packages/medusa/src/api/routes/admin/regions/index.ts @@ -3,12 +3,18 @@ import { Region } from "../../../.." import { DeleteResponse, PaginatedResponse } from "../../../../types/common" import middlewares from "../../../middlewares" import "reflect-metadata" +import { FlagRouter } from "../../../../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" const route = Router() -export default (app) => { +export default (app, featureFlagRouter: FlagRouter) => { app.use("/regions", route) + if (featureFlagRouter.isFeatureEnabled(TaxInclusivePricingFeatureFlag.key)) { + defaultAdminRegionFields.push("includes_tax") + } + route.get("/", middlewares.wrap(require("./list-regions").default)) route.get("/:region_id", middlewares.wrap(require("./get-region").default)) diff --git a/packages/medusa/src/api/routes/admin/regions/update-region.ts b/packages/medusa/src/api/routes/admin/regions/update-region.ts index 5a2189a678..a92c757580 100644 --- a/packages/medusa/src/api/routes/admin/regions/update-region.ts +++ b/packages/medusa/src/api/routes/admin/regions/update-region.ts @@ -8,9 +8,11 @@ import { } from "class-validator" import { EntityManager } from "typeorm" -import { validator } from "../../../../utils/validator" +import { defaultAdminRegionFields, defaultAdminRegionRelations } from "." +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" import RegionService from "../../../../services/region" -import { defaultAdminRegionRelations, defaultAdminRegionFields } from "." +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators" +import { validator } from "../../../../utils/validator" /** * @oas [post] /regions/{id} @@ -49,6 +51,9 @@ import { defaultAdminRegionRelations, defaultAdminRegionFields } from "." * tax_rate: * description: "The tax rate to use on Orders in the Region." * type: number + * includes_tax: + * description: "[EXPERIMENTAL] Tax included in prices of region" + * type: boolean * payment_providers: * description: "A list of Payment Provider IDs that should be enabled for the Region" * type: array @@ -178,6 +183,12 @@ export class AdminPostRegionsRegionReq { @IsOptional() countries?: string[] + @FeatureFlagDecorators(TaxInclusivePricingFeatureFlag.key, [ + IsOptional(), + IsBoolean(), + ]) + includes_tax?: boolean + @IsObject() @IsOptional() metadata?: Record diff --git a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/add-product-batch.ts b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/add-product-batch.ts index 457f54efc5..5fe904e0cc 100644 --- a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/add-product-batch.ts +++ b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/add-product-batch.ts @@ -1,6 +1,7 @@ import { IdMap } from "medusa-test-utils" import { request } from "../../../../../helpers/test-request" import { SalesChannelServiceMock } from "../../../../../services/__mocks__/sales-channel" +import SalesChannelFeatureFlag from "../../../../../loaders/feature-flags/sales-channels"; describe("POST /admin/sales-channels/:id/products/batch", () => { describe("add product to a sales channel", () => { @@ -21,7 +22,7 @@ describe("POST /admin/sales-channels/:id/products/batch", () => { payload: { product_ids: [{ id: "sales_channel_1_product_1" }], }, - flags: ["sales_channels"], + flags: [SalesChannelFeatureFlag], } ) }) diff --git a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/create-sales-channel.ts b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/create-sales-channel.ts index 12c033d814..071d9d978f 100644 --- a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/create-sales-channel.ts +++ b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/create-sales-channel.ts @@ -2,6 +2,7 @@ import { IdMap } from "medusa-test-utils" import { request } from "../../../../../helpers/test-request" import { SalesChannelServiceMock } from "../../../../../services/__mocks__/sales-channel" +import SalesChannelFeatureFlag from "../../../../../loaders/feature-flags/sales-channels"; describe("POST /admin/sales-channels", () => { describe("successfully get a sales channel", () => { @@ -18,7 +19,7 @@ describe("POST /admin/sales-channels", () => { name: "sales channel 1 name", description: "sales channel 1 description", }, - flags: ["sales_channels"], + flags: [SalesChannelFeatureFlag], }) }) diff --git a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/delete-products-batch.ts b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/delete-products-batch.ts index 62501e4aea..6da1d5760d 100644 --- a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/delete-products-batch.ts +++ b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/delete-products-batch.ts @@ -1,6 +1,7 @@ import { IdMap } from "medusa-test-utils" import { request } from "../../../../../helpers/test-request" import { SalesChannelServiceMock } from "../../../../../services/__mocks__/sales-channel" +import SalesChannelFeatureFlag from "../../../../../loaders/feature-flags/sales-channels"; describe("DELETE /admin/sales-channels/:id/products/batch", () => { describe("remove product from a sales channel", () => { @@ -19,7 +20,7 @@ describe("DELETE /admin/sales-channels/:id/products/batch", () => { payload: { product_ids: [{ id: IdMap.getId("sales_channel_1_product_1") }] }, - flags: ["sales_channels"], + flags: [SalesChannelFeatureFlag], } ) }) diff --git a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/delete-sales-channel.ts b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/delete-sales-channel.ts index a1a6ffa86f..d941d01204 100644 --- a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/delete-sales-channel.ts +++ b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/delete-sales-channel.ts @@ -1,6 +1,7 @@ import { IdMap } from "medusa-test-utils" import { request } from "../../../../../helpers/test-request" import { SalesChannelServiceMock } from "../../../../../services/__mocks__/sales-channel" +import SalesChannelFeatureFlag from "../../../../../loaders/feature-flags/sales-channels"; describe("DELETE /admin/sales-channels/:id", () => { describe("successfully delete a sales channel", () => { @@ -16,7 +17,7 @@ describe("DELETE /admin/sales-channels/:id", () => { userId: IdMap.getId("admin_user"), }, }, - flags: ["sales_channels"], + flags: [SalesChannelFeatureFlag], } ) }) diff --git a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/get-sales-channel.ts b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/get-sales-channel.ts index 04b4f1142e..0df7ebe301 100644 --- a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/get-sales-channel.ts +++ b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/get-sales-channel.ts @@ -1,6 +1,7 @@ import { IdMap } from "medusa-test-utils" import { request } from "../../../../../helpers/test-request" import { SalesChannelServiceMock } from "../../../../../services/__mocks__/sales-channel" +import SalesChannelFeatureFlag from "../../../../../loaders/feature-flags/sales-channels"; describe("GET /admin/sales-channels/:id", () => { describe("successfully get a sales channel", () => { @@ -16,7 +17,7 @@ describe("GET /admin/sales-channels/:id", () => { userId: IdMap.getId("admin_user"), }, }, - flags: ["sales_channels"], + flags: [SalesChannelFeatureFlag], } ) }) diff --git a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/list-sales-channels.js b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/list-sales-channels.js index 4b5fdcb589..835e1309a0 100644 --- a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/list-sales-channels.js +++ b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/list-sales-channels.js @@ -1,6 +1,7 @@ import { IdMap } from "medusa-test-utils" import { request } from "../../../../../helpers/test-request" import { SalesChannelServiceMock } from "../../../../../services/__mocks__/sales-channel" +import SalesChannelFeatureFlag from "../../../../../loaders/feature-flags/sales-channels"; describe("GET /admin/sales-channels/", () => { describe("successfully list the sales channel", () => { @@ -16,7 +17,7 @@ describe("GET /admin/sales-channels/", () => { userId: IdMap.getId("admin_user"), }, }, - flags: ["sales_channels"], + flags: [SalesChannelFeatureFlag], } ) }) diff --git a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/update-sales-channel.ts b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/update-sales-channel.ts index 2125f8e5ba..d723beb692 100644 --- a/packages/medusa/src/api/routes/admin/sales-channels/__tests__/update-sales-channel.ts +++ b/packages/medusa/src/api/routes/admin/sales-channels/__tests__/update-sales-channel.ts @@ -1,6 +1,7 @@ import { IdMap } from "medusa-test-utils" import { request } from "../../../../../helpers/test-request" import { SalesChannelServiceMock } from "../../../../../services/__mocks__/sales-channel" +import SalesChannelFeatureFlag from "../../../../../loaders/feature-flags/sales-channels"; describe("POST /admin/regions/:region_id/countries", () => { describe("successful creation", () => { @@ -18,7 +19,7 @@ describe("POST /admin/regions/:region_id/countries", () => { userId: IdMap.getId("admin_user"), }, }, - flags: ["sales_channels"], + flags: [SalesChannelFeatureFlag], }) }) diff --git a/packages/medusa/src/api/routes/admin/shipping-options/create-shipping-option.ts b/packages/medusa/src/api/routes/admin/shipping-options/create-shipping-option.ts index ee617d8800..0f92290114 100644 --- a/packages/medusa/src/api/routes/admin/shipping-options/create-shipping-option.ts +++ b/packages/medusa/src/api/routes/admin/shipping-options/create-shipping-option.ts @@ -1,5 +1,6 @@ import { IsArray, + IsBoolean, IsNumber, IsObject, IsOptional, @@ -9,8 +10,10 @@ import { import { defaultFields, defaultRelations } from "." import { Type } from "class-transformer" -import { validator } from "../../../../utils/validator" import { EntityManager } from "typeorm" +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators" +import { validator } from "../../../../utils/validator" /** * @oas [post] /shipping-options @@ -81,6 +84,9 @@ import { EntityManager } from "typeorm" * metadata: * description: An optional set of key-value pairs with additional information. * type: object + * includes_tax: + * description: "[EXPERIMENTAL] Tax included in prices of shipping option" + * type: boolean * x-codeSamples: * - lang: JavaScript * label: JS Client @@ -214,4 +220,10 @@ export class AdminPostShippingOptionsReq { @IsObject() @IsOptional() metadata?: object + + @FeatureFlagDecorators(TaxInclusivePricingFeatureFlag.key, [ + IsOptional(), + IsBoolean(), + ]) + includes_tax?: boolean } diff --git a/packages/medusa/src/api/routes/admin/shipping-options/index.ts b/packages/medusa/src/api/routes/admin/shipping-options/index.ts index 58055c2e6d..4bb16ae444 100644 --- a/packages/medusa/src/api/routes/admin/shipping-options/index.ts +++ b/packages/medusa/src/api/routes/admin/shipping-options/index.ts @@ -2,12 +2,18 @@ import { Router } from "express" import { ShippingOption } from "../../../.." import { PaginatedResponse, DeleteResponse } from "../../../../types/common" import middlewares from "../../../middlewares" +import { FlagRouter } from "../../../../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" const route = Router() -export default (app) => { +export default (app, featureFlagRouter: FlagRouter) => { app.use("/shipping-options", route) + if (featureFlagRouter.isFeatureEnabled(TaxInclusivePricingFeatureFlag.key)) { + defaultFields.push("includes_tax") + } + route.get("/", middlewares.wrap(require("./list-shipping-options").default)) route.post("/", middlewares.wrap(require("./create-shipping-option").default)) diff --git a/packages/medusa/src/api/routes/admin/shipping-options/update-shipping-option.ts b/packages/medusa/src/api/routes/admin/shipping-options/update-shipping-option.ts index 2c7f235d11..763393c569 100644 --- a/packages/medusa/src/api/routes/admin/shipping-options/update-shipping-option.ts +++ b/packages/medusa/src/api/routes/admin/shipping-options/update-shipping-option.ts @@ -10,8 +10,10 @@ import { import { defaultFields, defaultRelations } from "." import { Type } from "class-transformer" -import { validator } from "../../../../utils/validator" import { EntityManager } from "typeorm" +import TaxInclusivePricingFeatureFlag from "../../../../loaders/feature-flags/tax-inclusive-pricing" +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators" +import { validator } from "../../../../utils/validator" /** * @oas [post] /shipping-options/{id} @@ -60,6 +62,9 @@ import { EntityManager } from "typeorm" * amount: * description: The amount to compare with. * type: integer + * includes_tax: + * description: "[EXPERIMENTAL] Tax included in prices of shipping option" + * type: boolean * x-codeSamples: * - lang: JavaScript * label: JS Client @@ -174,4 +179,10 @@ export class AdminPostShippingOptionsOptionReq { @IsObject() @IsOptional() metadata?: object + + @FeatureFlagDecorators(TaxInclusivePricingFeatureFlag.key, [ + IsOptional(), + IsBoolean(), + ]) + includes_tax?: boolean } diff --git a/packages/medusa/src/helpers/test-request.js b/packages/medusa/src/helpers/test-request.js index f2ede4eb30..53ed229c6d 100644 --- a/packages/medusa/src/helpers/test-request.js +++ b/packages/medusa/src/helpers/test-request.js @@ -7,9 +7,10 @@ import supertest from "supertest" import querystring from "querystring" import apiLoader from "../loaders/api" import passportLoader from "../loaders/passport" -import featureFlagLoader from "../loaders/feature-flags" +import featureFlagLoader, { featureFlagRouter } from "../loaders/feature-flags" import servicesLoader from "../loaders/services" import strategiesLoader from "../loaders/strategies" +import logger from "../loaders/logger"; const adminSessionOpts = { cookieName: "session", @@ -36,8 +37,6 @@ const testApp = express() const container = createContainer() -const featureFlagRouter = featureFlagLoader(config) - container.register("featureFlagRouter", asValue(featureFlagRouter)) container.register("configModule", asValue(config)) container.register({ @@ -60,6 +59,7 @@ testApp.use((req, res, next) => { next() }) +featureFlagLoader(config) servicesLoader({ container, configModule: config }) strategiesLoader({ container, configModule: config }) passportLoader({ app: testApp, container, configModule: config }) @@ -77,7 +77,7 @@ export async function request(method, url, opts = {}) { const { payload, query, headers = {}, flags = [] } = opts flags.forEach((flag) => { - featureFlagRouter.setFlag(flag, true) + featureFlagRouter.setFlag(flag.key, true) }) const queryParams = query && querystring.stringify(query) @@ -148,6 +148,5 @@ export async function request(method, url, opts = {}) { // c[clientSessionOpts.cookieName] && // sessions.util.decode(clientSessionOpts, c[clientSessionOpts.cookieName]) // .content - return res } diff --git a/packages/medusa/src/interfaces/price-selection-strategy.ts b/packages/medusa/src/interfaces/price-selection-strategy.ts index 9abf2ce4ff..deec8b350b 100644 --- a/packages/medusa/src/interfaces/price-selection-strategy.ts +++ b/packages/medusa/src/interfaces/price-selection-strategy.ts @@ -1,6 +1,7 @@ import { EntityManager } from "typeorm" import { MoneyAmount } from ".." import { PriceListType } from "../types/price-list" +import { TaxServiceRate } from "../types/tax-service" export interface IPriceSelectionStrategy { /** @@ -55,6 +56,7 @@ export type PriceSelectionContext = { region_id?: string currency_code?: string include_discount_prices?: boolean + tax_rates?: TaxServiceRate[] } enum DefaultPriceType { @@ -67,7 +69,9 @@ export const PriceType = { ...DefaultPriceType, ...PriceListType } export type PriceSelectionResult = { originalPrice: number | null + originalPriceIncludesTax?: boolean | null calculatedPrice: number | null + calculatedPriceIncludesTax?: boolean | null calculatedPriceType?: PriceType prices: MoneyAmount[] // prices is an array of all possible price for the input customer and region prices } diff --git a/packages/medusa/src/loaders/feature-flags/index.ts b/packages/medusa/src/loaders/feature-flags/index.ts index 60874b43c4..1f56d26ab2 100644 --- a/packages/medusa/src/loaders/feature-flags/index.ts +++ b/packages/medusa/src/loaders/feature-flags/index.ts @@ -14,6 +14,8 @@ const isTruthy = (val: string | boolean | undefined): boolean => { return !!val } +export const featureFlagRouter = new FlagRouter({}) + export default ( configModule: { featureFlags?: Record } = {}, logger?: Logger, @@ -60,5 +62,9 @@ export default ( } } - return new FlagRouter(flagConfig) + for (const flag of Object.keys(flagConfig)) { + featureFlagRouter.setFlag(flag, flagConfig[flag]) + } + + return featureFlagRouter } diff --git a/packages/medusa/src/loaders/feature-flags/tax-inclusive-pricing.ts b/packages/medusa/src/loaders/feature-flags/tax-inclusive-pricing.ts new file mode 100644 index 0000000000..512f3f5cf1 --- /dev/null +++ b/packages/medusa/src/loaders/feature-flags/tax-inclusive-pricing.ts @@ -0,0 +1,10 @@ +import { FlagSettings } from "../../types/feature-flags" + +const TaxInclusivePricingFeatureFlag: FlagSettings = { + key: "tax_inclusive_pricing", + default_val: false, + env_key: "MEDUSA_FF_TAX_INCLUSIVE_PRICING", + description: "[WIP] Enable tax inclusive pricing", +} + +export default TaxInclusivePricingFeatureFlag diff --git a/packages/medusa/src/migrations/1659501357661-tax_inclusive_pricing.ts b/packages/medusa/src/migrations/1659501357661-tax_inclusive_pricing.ts new file mode 100644 index 0000000000..78867e719e --- /dev/null +++ b/packages/medusa/src/migrations/1659501357661-tax_inclusive_pricing.ts @@ -0,0 +1,46 @@ +import { MigrationInterface, QueryRunner } from "typeorm" +import TaxInclusivePricingFlag from "../loaders/feature-flags/tax-inclusive-pricing" + +export const featureFlag = TaxInclusivePricingFlag.key + +export class test1659501357661 implements MigrationInterface { + name = "test1659501357661" + + public async up(queryRunner: QueryRunner): Promise { + await queryRunner.query( + `ALTER TABLE "currency" ADD "includes_tax" boolean NOT NULL DEFAULT false` + ) + await queryRunner.query( + `ALTER TABLE "region" ADD "includes_tax" boolean NOT NULL DEFAULT false` + ) + await queryRunner.query( + `ALTER TABLE "shipping_option" ADD "includes_tax" boolean NOT NULL DEFAULT false` + ) + await queryRunner.query( + `ALTER TABLE "price_list" ADD "includes_tax" boolean NOT NULL DEFAULT false` + ) + await queryRunner.query( + `ALTER TABLE "shipping_method" ADD "includes_tax" boolean NOT NULL DEFAULT false` + ) + await queryRunner.query( + `ALTER TABLE "line_item" ADD "includes_tax" boolean NOT NULL DEFAULT false` + ) + } + + public async down(queryRunner: QueryRunner): Promise { + await queryRunner.query( + `ALTER TABLE "line_item" DROP COLUMN "includes_tax"` + ) + await queryRunner.query( + `ALTER TABLE "shipping_method" DROP COLUMN "includes_tax"` + ) + await queryRunner.query( + `ALTER TABLE "price_list" DROP COLUMN "includes_tax"` + ) + await queryRunner.query( + `ALTER TABLE "shipping_option" DROP COLUMN "includes_tax"` + ) + await queryRunner.query(`ALTER TABLE "region" DROP COLUMN "includes_tax"`) + await queryRunner.query(`ALTER TABLE "currency" DROP COLUMN "includes_tax"`) + } +} diff --git a/packages/medusa/src/models/currency.ts b/packages/medusa/src/models/currency.ts index 66add3bd5d..2fdc51b7a3 100644 --- a/packages/medusa/src/models/currency.ts +++ b/packages/medusa/src/models/currency.ts @@ -1,4 +1,6 @@ import { Column, Entity, PrimaryColumn } from "typeorm" +import { FeatureFlagColumn } from "../utils/feature-flag-decorators" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" @Entity() export class Currency { @@ -13,6 +15,9 @@ export class Currency { @Column() name: string + + @FeatureFlagColumn(TaxInclusivePricingFeatureFlag.key, { default: false }) + includes_tax?: boolean } /** @@ -45,4 +50,7 @@ export class Currency { * description: "The written name of the currency" * type: string * example: US Dollar + * includes_tax: + * description: "[EXPERIMENTAL] Does the currency prices include tax" + * type: boolean */ diff --git a/packages/medusa/src/models/custom-shipping-option.ts b/packages/medusa/src/models/custom-shipping-option.ts index eba1bf6432..dd779c2338 100644 --- a/packages/medusa/src/models/custom-shipping-option.ts +++ b/packages/medusa/src/models/custom-shipping-option.ts @@ -92,4 +92,7 @@ export class CustomShippingOption extends SoftDeletableEntity { * type: object * description: An optional key-value map with additional details * example: {car: "white"} + * includes_tax: + * description: "[EXPERIMENTAL] Indicates if the custom shipping option price include tax" + * type: boolean */ diff --git a/packages/medusa/src/models/line-item.ts b/packages/medusa/src/models/line-item.ts index 4d94ee7471..f1ab06a229 100644 --- a/packages/medusa/src/models/line-item.ts +++ b/packages/medusa/src/models/line-item.ts @@ -19,6 +19,8 @@ import { Order } from "./order" import { ProductVariant } from "./product-variant" import { Swap } from "./swap" import { generateEntityId } from "../utils/generate-entity-id" +import { FeatureFlagColumn } from "../utils/feature-flag-decorators" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" @Check(`"fulfilled_quantity" <= "quantity"`) @Check(`"shipped_quantity" <= "fulfilled_quantity"`) @@ -116,6 +118,9 @@ export class LineItem extends BaseEntity { @DbAwareColumn({ type: "jsonb", nullable: true }) metadata: Record + @FeatureFlagColumn(TaxInclusivePricingFeatureFlag.key, { default: false }) + includes_tax: boolean + refundable?: number | null subtotal?: number | null tax_total?: number | null @@ -275,6 +280,9 @@ export class LineItem extends BaseEntity { * type: integer * description: The total of the gift card of the line item * example: 0 + * includes_tax: + * description: "[EXPERIMENTAL] Indicates if the line item unit_price include tax" + * type: boolean * created_at: * type: string * description: "The date with timezone at which the resource was created." diff --git a/packages/medusa/src/models/money-amount.ts b/packages/medusa/src/models/money-amount.ts index 646ec92328..b0884a9669 100644 --- a/packages/medusa/src/models/money-amount.ts +++ b/packages/medusa/src/models/money-amount.ts @@ -21,7 +21,7 @@ export class MoneyAmount extends SoftDeletableEntity { @ManyToOne(() => Currency) @JoinColumn({ name: "currency_code", referencedColumnName: "code" }) - currency: Currency + currency?: Currency @Column({ type: "int" }) amount: number @@ -58,7 +58,7 @@ export class MoneyAmount extends SoftDeletableEntity { @ManyToOne(() => Region) @JoinColumn({ name: "region_id" }) - region: Region + region?: Region @BeforeInsert() private beforeInsert(): undefined | void { diff --git a/packages/medusa/src/models/price-list.ts b/packages/medusa/src/models/price-list.ts index f3dd85f402..85754723c6 100644 --- a/packages/medusa/src/models/price-list.ts +++ b/packages/medusa/src/models/price-list.ts @@ -13,6 +13,8 @@ import { CustomerGroup } from "./customer-group" import { MoneyAmount } from "./money-amount" import { SoftDeletableEntity } from "../interfaces/models/soft-deletable-entity" import { generateEntityId } from "../utils/generate-entity-id" +import { FeatureFlagColumn } from "../utils/feature-flag-decorators" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" @Entity() export class PriceList extends SoftDeletableEntity { @@ -58,6 +60,9 @@ export class PriceList extends SoftDeletableEntity { }) prices: MoneyAmount[] + @FeatureFlagColumn(TaxInclusivePricingFeatureFlag.key, { default: false }) + includes_tax: boolean + @BeforeInsert() private beforeInsert(): undefined | void { this.id = generateEntityId(this.id, "pl") @@ -118,6 +123,10 @@ export class PriceList extends SoftDeletableEntity { * type: array * items: * $ref: "#/components/schemas/money_amount" + * $ref: "#/components/schemas/customer_group" + * includes_tax: + * description: "[EXPERIMENTAL] Does the price list prices include tax" + * type: boolean * created_at: * type: string * description: "The date with timezone at which the resource was created." diff --git a/packages/medusa/src/models/region.ts b/packages/medusa/src/models/region.ts index af65466600..12aa581714 100644 --- a/packages/medusa/src/models/region.ts +++ b/packages/medusa/src/models/region.ts @@ -18,6 +18,8 @@ import { SoftDeletableEntity } from "../interfaces/models/soft-deletable-entity" import { TaxProvider } from "./tax-provider" import { TaxRate } from "./tax-rate" import { generateEntityId } from "../utils/generate-entity-id" +import { FeatureFlagColumn } from "../utils/feature-flag-decorators" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" @Entity() export class Region extends SoftDeletableEntity { @@ -93,6 +95,9 @@ export class Region extends SoftDeletableEntity { @DbAwareColumn({ type: "jsonb", nullable: true }) metadata: Record + @FeatureFlagColumn(TaxInclusivePricingFeatureFlag.key, { default: false }) + includes_tax: boolean + @BeforeInsert() private beforeInsert(): void { this.id = generateEntityId(this.id, "reg") @@ -170,6 +175,9 @@ export class Region extends SoftDeletableEntity { * type: array * items: * $ref: "#/components/schemas/fulfillment_provider" + * includes_tax: + * description: "[EXPERIMENTAL] Does the prices for the region include tax" + * type: boolean * created_at: * type: string * description: "The date with timezone at which the resource was created." diff --git a/packages/medusa/src/models/shipping-method.ts b/packages/medusa/src/models/shipping-method.ts index f395d03a81..173a340cbc 100644 --- a/packages/medusa/src/models/shipping-method.ts +++ b/packages/medusa/src/models/shipping-method.ts @@ -20,6 +20,8 @@ import { ShippingMethodTaxLine } from "./shipping-method-tax-line" import { ShippingOption } from "./shipping-option" import { Swap } from "./swap" import { generateEntityId } from "../utils/generate-entity-id" +import { FeatureFlagColumn } from "../utils/feature-flag-decorators" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" @Check( `"claim_order_id" IS NOT NULL OR "order_id" IS NOT NULL OR "cart_id" IS NOT NULL OR "swap_id" IS NOT NULL OR "return_id" IS NOT NULL` @@ -89,6 +91,9 @@ export class ShippingMethod { @DbAwareColumn({ type: "jsonb" }) data: Record + @FeatureFlagColumn(TaxInclusivePricingFeatureFlag.key, { default: false }) + includes_tax: boolean + @BeforeInsert() private beforeInsert(): void { this.id = generateEntityId(this.id, "sm") @@ -163,4 +168,7 @@ export class ShippingMethod { * description: "Additional data that the Fulfillment Provider needs to fulfill the shipment. This is used in combination with the Shipping Options data, and may contain information such as a drop point id." * type: object * example: {} + * includes_tax: + * description: "[EXPERIMENTAL] Indicates if the shipping method price include tax" + * type: boolean */ diff --git a/packages/medusa/src/models/shipping-option.ts b/packages/medusa/src/models/shipping-option.ts index dd5f8c563e..4003b7c07b 100644 --- a/packages/medusa/src/models/shipping-option.ts +++ b/packages/medusa/src/models/shipping-option.ts @@ -16,6 +16,8 @@ import { ShippingOptionRequirement } from "./shipping-option-requirement" import { ShippingProfile } from "./shipping-profile" import { SoftDeletableEntity } from "../interfaces/models/soft-deletable-entity" import { generateEntityId } from "../utils/generate-entity-id" +import { FeatureFlagColumn } from "../utils/feature-flag-decorators" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" export enum ShippingOptionPriceType { FLAT_RATE = "flat_rate", @@ -75,6 +77,9 @@ export class ShippingOption extends SoftDeletableEntity { @DbAwareColumn({ type: "jsonb", nullable: true }) metadata: Record + @FeatureFlagColumn(TaxInclusivePricingFeatureFlag.key, { default: false }) + includes_tax: boolean + @BeforeInsert() private beforeInsert(): void { this.id = generateEntityId(this.id, "so") @@ -146,6 +151,9 @@ export class ShippingOption extends SoftDeletableEntity { * description: "The data needed for the Fulfillment Provider to identify the Shipping Option." * type: object * example: {} + * includes_tax: + * description: "[EXPERIMENTAL] Does the shipping option price include tax" + * type: boolean * created_at: * type: string * description: "The date with timezone at which the resource was created." diff --git a/packages/medusa/src/repositories/currency.ts b/packages/medusa/src/repositories/currency.ts index bf30c3e00b..7b7f4ab71f 100644 --- a/packages/medusa/src/repositories/currency.ts +++ b/packages/medusa/src/repositories/currency.ts @@ -1,5 +1,5 @@ import { EntityRepository, Repository } from "typeorm" -import { Currency } from "../models/currency" +import { Currency } from "../models" @EntityRepository(Currency) export class CurrencyRepository extends Repository { } diff --git a/packages/medusa/src/repositories/money-amount.ts b/packages/medusa/src/repositories/money-amount.ts index 0a841e10e5..87a67af446 100644 --- a/packages/medusa/src/repositories/money-amount.ts +++ b/packages/medusa/src/repositories/money-amount.ts @@ -1,5 +1,4 @@ import partition from "lodash/partition" -import { MedusaError } from "medusa-core-utils" import { Brackets, EntityRepository, @@ -7,6 +6,7 @@ import { IsNull, Not, Repository, + WhereExpressionBuilder, } from "typeorm" import { MoneyAmount } from "../models/money-amount" import { @@ -126,7 +126,7 @@ export class MoneyAmountRepository extends Repository { .leftJoinAndSelect("ma.price_list", "price_list") .where("ma.variant_id = :variant_id", { variant_id }) - const getAndWhere = (subQb) => { + const getAndWhere = (subQb): WhereExpressionBuilder => { const andWhere = subQb.where("ma.price_list_id = :price_list_id", { price_list_id, }) @@ -146,7 +146,8 @@ export class MoneyAmountRepository extends Repository { region_id?: string, currency_code?: string, customer_id?: string, - include_discount_prices?: boolean + include_discount_prices?: boolean, + include_tax_inclusive_pricing = false ): Promise<[MoneyAmount[], number]> { const date = new Date() @@ -154,12 +155,9 @@ export class MoneyAmountRepository extends Repository { .leftJoinAndSelect("ma.price_list", "price_list") .where({ variant_id: variant_id }) .andWhere("(ma.price_list_id is null or price_list.status = 'active')") - .andWhere( - "(price_list.ends_at is null OR price_list.ends_at > :date)", - { - date: date.toUTCString(), - } - ) + .andWhere("(price_list.ends_at is null OR price_list.ends_at > :date)", { + date: date.toUTCString(), + }) .andWhere( "(price_list.starts_at is null OR price_list.starts_at < :date)", { @@ -167,6 +165,11 @@ export class MoneyAmountRepository extends Repository { } ) + if (include_tax_inclusive_pricing) { + qb.leftJoin("ma.currency", "currency") + .leftJoin("ma.region", "region") + .addSelect(["currency.includes_tax", "region.includes_tax"]) + } if (region_id || currency_code) { qb.andWhere( new Brackets((qb) => @@ -181,14 +184,21 @@ export class MoneyAmountRepository extends Repository { if (customer_id) { qb.leftJoin("price_list.customer_groups", "cgroup") - .leftJoin("customer_group_customers", "cgc", "cgc.customer_group_id = cgroup.id") - .andWhere("(cgc.customer_group_id is null OR cgc.customer_id = :customer_id)", { - customer_id, - }) + .leftJoin( + "customer_group_customers", + "cgc", + "cgc.customer_group_id = cgroup.id" + ) + .andWhere( + "(cgc.customer_group_id is null OR cgc.customer_id = :customer_id)", + { + customer_id, + } + ) } else { - qb - .leftJoin("price_list.customer_groups", "cgroup") - .andWhere("cgroup.id is null") + qb.leftJoin("price_list.customer_groups", "cgroup").andWhere( + "cgroup.id is null" + ) } return await qb.getManyAndCount() } diff --git a/packages/medusa/src/services/__mocks__/cart.js b/packages/medusa/src/services/__mocks__/cart.js index 987aff7abf..afd2f7537a 100644 --- a/packages/medusa/src/services/__mocks__/cart.js +++ b/packages/medusa/src/services/__mocks__/cart.js @@ -28,6 +28,22 @@ export const carts = { total: 1000, region_id: IdMap.getId("testRegion"), }, + testCartTaxInclusive: { + id: IdMap.getId("test-cart"), + items: [], + payment: { + data: "some-data", + }, + payment_session: { + status: "authorized", + }, + total: 1000, + region_id: IdMap.getId("testRegion"), + shipping_options: [{ + id: IdMap.getId("tax-inclusive-option"), + includes_tax: true + }], + }, testSwapCart: { id: IdMap.getId("test-swap"), items: [], @@ -263,6 +279,9 @@ export const CartServiceMock = { if (cartId === IdMap.getId("test-cart2")) { return Promise.resolve(carts.testCart) } + if (cartId === IdMap.getId("tax-inclusive-option")) { + return Promise.resolve(carts.testCartTaxInclusive) + } throw new MedusaError(MedusaError.Types.NOT_FOUND, "cart not found") }), addLineItem: jest.fn().mockImplementation((cartId, lineItem) => { diff --git a/packages/medusa/src/services/__mocks__/currency.js b/packages/medusa/src/services/__mocks__/currency.js new file mode 100644 index 0000000000..a4b6b94f07 --- /dev/null +++ b/packages/medusa/src/services/__mocks__/currency.js @@ -0,0 +1,40 @@ +import { IdMap } from "medusa-test-utils" + +export const currency = { + code: IdMap.getId("currency-1"), + symbol: "SYM", + symbol_native: "SYM", + name: "Symbol", +} + +export const CurrencyServiceMock = { + withTransaction: function() { + return this + }, + retrieve: jest.fn().mockImplementation((code) => { + return Promise.resolve({ + ...currency, + code: code, + }) + }), + + update: jest.fn().mockImplementation((code, data) => { + return Promise.resolve({ + ...currency, + ...data, + }) + }), + + listAndCount: jest.fn().mockImplementation(() => { + return Promise.resolve([ + [currency], + 1, + ]) + }) +} + +const mock = jest.fn().mockImplementation(() => { + return CurrencyServiceMock +}) + +export default mock diff --git a/packages/medusa/src/services/__tests__/currency.ts b/packages/medusa/src/services/__tests__/currency.ts new file mode 100644 index 0000000000..270df58308 --- /dev/null +++ b/packages/medusa/src/services/__tests__/currency.ts @@ -0,0 +1,60 @@ +import { IdMap, MockManager, MockRepository } from "medusa-test-utils" +import { EventBusService } from "../index" +import { Currency } from "../../models" +import CurrencyService from "../currency" +import { FlagRouter } from "../../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../../loaders/feature-flags/tax-inclusive-pricing" + +const currencyCode = IdMap.getId("currency-1") +const eventBusServiceMock = { + emit: jest.fn(), + withTransaction: function() { + return this + }, +} as unknown as EventBusService +const currencyRepositoryMock = MockRepository({ + findOne: jest.fn().mockImplementation(() => { + return { + code: currencyCode + } + }), + save: jest.fn().mockImplementation((data) => { + return Object.assign(new Currency(), data) + }) +}) + + +describe('CurrencyService', () => { + const currencyService = new CurrencyService({ + manager: MockManager, + currencyRepository: currencyRepositoryMock, + eventBusService: eventBusServiceMock, + featureFlagRouter: new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: true + }), + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + it("should retrieve the currency by calling the repository findOne method", async () => { + await currencyService.retrieveByCode(currencyCode) + expect(currencyRepositoryMock.findOne).toHaveBeenCalledWith({ + where: { code: currencyCode.toLowerCase() }, + }) + }) + + it("should update the currency by calling the save method", async () => { + await currencyService.update(currencyCode, { + includes_tax: true, + }) + expect(currencyRepositoryMock.findOne).toHaveBeenCalledWith({ + where: { code: currencyCode.toLowerCase() }, + }) + expect(currencyRepositoryMock.save).toHaveBeenCalledWith({ + code: currencyCode, + includes_tax: true, + }) + }) +}) diff --git a/packages/medusa/src/services/__tests__/discount.js b/packages/medusa/src/services/__tests__/discount.js index c3ca2499b2..c35328a81f 100644 --- a/packages/medusa/src/services/__tests__/discount.js +++ b/packages/medusa/src/services/__tests__/discount.js @@ -1,5 +1,8 @@ import { IdMap, MockManager, MockRepository } from "medusa-test-utils" import DiscountService from "../discount" +import { FlagRouter } from "../../utils/flag-router" + +const featureFlagRouter = new FlagRouter({}) describe("DiscountService", () => { describe("create", () => { @@ -23,6 +26,7 @@ describe("DiscountService", () => { discountRepository, discountRuleRepository, regionService, + featureFlagRouter, }) beforeEach(() => { @@ -160,6 +164,7 @@ describe("DiscountService", () => { const discountService = new DiscountService({ manager: MockManager, discountRepository, + featureFlagRouter, }) beforeEach(() => { @@ -203,6 +208,7 @@ describe("DiscountService", () => { const discountService = new DiscountService({ manager: MockManager, discountRepository, + featureFlagRouter, }) beforeEach(() => { @@ -248,6 +254,7 @@ describe("DiscountService", () => { discountRepository, discountRuleRepository, regionService, + featureFlagRouter, }) beforeEach(() => { @@ -345,6 +352,7 @@ describe("DiscountService", () => { discountRepository, discountRuleRepository, regionService, + featureFlagRouter, }) beforeEach(() => { @@ -418,6 +426,7 @@ describe("DiscountService", () => { discountRepository, discountRuleRepository, regionService, + featureFlagRouter, }) beforeEach(() => { @@ -466,6 +475,7 @@ describe("DiscountService", () => { discountRepository, discountRuleRepository, regionService, + featureFlagRouter, }) beforeEach(() => { @@ -509,6 +519,7 @@ describe("DiscountService", () => { const discountService = new DiscountService({ manager: MockManager, discountRepository, + featureFlagRouter, }) beforeEach(() => { @@ -588,6 +599,7 @@ describe("DiscountService", () => { manager: MockManager, discountRepository, totalsService, + featureFlagRouter, }) beforeEach(() => { @@ -763,6 +775,7 @@ describe("DiscountService", () => { beforeEach(async () => { discountService = new DiscountService({ manager: MockManager, + featureFlagRouter, }) const hasReachedLimitMock = jest.fn().mockImplementation(() => false) const isDisabledMock = jest.fn().mockImplementation(() => false) @@ -892,7 +905,9 @@ describe("DiscountService", () => { }) describe("hasReachedLimit", () => { - const discountService = new DiscountService({}) + const discountService = new DiscountService({ + featureFlagRouter, + }) it("returns true if discount limit is reached", () => { const discount = { @@ -936,7 +951,9 @@ describe("DiscountService", () => { }) describe("isDisabled", () => { - const discountService = new DiscountService({}) + const discountService = new DiscountService({ + featureFlagRouter, + }) it("returns false if discount not disabled", async () => { const discount = { @@ -972,7 +989,9 @@ describe("DiscountService", () => { }) describe("hasNotStarted", () => { - const discountService = new DiscountService({}) + const discountService = new DiscountService({ + featureFlagRouter, + }) it("returns true if discount has a future starts_at date", async () => { const discount = { @@ -1008,7 +1027,9 @@ describe("DiscountService", () => { }) describe("hasExpired", () => { - const discountService = new DiscountService({}) + const discountService = new DiscountService({ + featureFlagRouter, + }) it("returns false if discount has a future ends_at date", async () => { const discount = { @@ -1068,6 +1089,7 @@ describe("DiscountService", () => { const discountService = new DiscountService({ manager: MockManager, + featureFlagRouter, }) discountService.retrieve = retrieveMock @@ -1181,6 +1203,7 @@ describe("DiscountService", () => { manager: MockManager, discountConditionRepository, customerService, + featureFlagRouter, }) it("returns false on undefined customer id", async () => { diff --git a/packages/medusa/src/services/__tests__/line-item.js b/packages/medusa/src/services/__tests__/line-item.js index e9cd5d6b9a..8d2f731e73 100644 --- a/packages/medusa/src/services/__tests__/line-item.js +++ b/packages/medusa/src/services/__tests__/line-item.js @@ -1,284 +1,520 @@ import { IdMap, MockManager, MockRepository } from "medusa-test-utils" +import { FlagRouter } from "../../utils/flag-router" import LineItemService from "../line-item" - -describe("LineItemService", () => { - describe("create", () => { - const lineItemRepository = MockRepository({ - create: (data) => data, - }) - - const cartRepository = MockRepository({ - findOne: () => - Promise.resolve({ - region_id: IdMap.getId("test-region"), - }), - }) - - const regionService = { - withTransaction: function () { - return this - }, - retrieve: () => { - return { - id: IdMap.getId("test-region"), - } - }, - } - - const productVariantService = { - withTransaction: function () { - return this - }, - retrieve: (query) => { - if (query === IdMap.getId("test-giftcard")) { - return { - id: IdMap.getId("test-giftcard"), - title: "Test variant", - product: { - title: "Test product", - thumbnail: "", - is_giftcard: true, - discountable: false, - }, - } - } - return { - id: IdMap.getId("test-variant"), - title: "Test variant", - product: { - title: "Test product", - thumbnail: "", - }, - } - }, - getRegionPrice: () => 100, - } - - const pricingService = { - withTransaction: function () { - return this - }, - getProductVariantPricingById: () => { - return { - calculated_price: 100, - } - }, - getProductVariantPricing: () => { - return { - calculated_price: 100, - } - }, - } - - const lineItemService = new LineItemService({ - manager: MockManager, - pricingService, - lineItemRepository, - productVariantService, - regionService, - cartRepository, - }) - - beforeEach(async () => { - jest.clearAllMocks() - }) - - it("successfully create a line item", async () => { - await lineItemService.create({ - variant_id: IdMap.getId("test-variant"), - cart_id: IdMap.getId("test-cart"), - title: "Test product", - description: "Test variant", - thumbnail: "", - unit_price: 100, - quantity: 1, - }) - - expect(lineItemRepository.create).toHaveBeenCalledTimes(1) - expect(lineItemRepository.create).toHaveBeenCalledWith({ - variant_id: IdMap.getId("test-variant"), - cart_id: IdMap.getId("test-cart"), - title: "Test product", - description: "Test variant", - thumbnail: "", - unit_price: 100, - quantity: 1, - }) - }) - - it("successfully create a line item with price and quantity", async () => { - await lineItemService.create({ - variant_id: IdMap.getId("test-variant"), - cart_id: IdMap.getId("test-cart"), - unit_price: 50, - quantity: 2, - }) - - expect(lineItemRepository.create).toHaveBeenCalledTimes(1) - expect(lineItemRepository.create).toHaveBeenCalledWith({ - variant_id: IdMap.getId("test-variant"), - cart_id: IdMap.getId("test-cart"), - unit_price: 50, - quantity: 2, - }) - }) - - it("successfully create a line item giftcard", async () => { - const line = await lineItemService.generate( - IdMap.getId("test-giftcard"), - IdMap.getId("test-region"), - 1 - ) - - await lineItemService.create({ - ...line, - cart_id: IdMap.getId("test-cart"), - }) - - expect(lineItemRepository.create).toHaveBeenCalledTimes(2) - expect(lineItemRepository.create).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ - allow_discounts: false, - variant_id: IdMap.getId("test-giftcard"), - cart_id: IdMap.getId("test-cart"), - title: "Test product", - description: "Test variant", - thumbnail: "", - unit_price: 100, - quantity: 1, - is_giftcard: true, - should_merge: true, - metadata: {}, +;[true, false].forEach((isTaxInclusiveEnabled) => { + describe(`tax inclusive flag set to: ${isTaxInclusiveEnabled}`, () => { + describe("LineItemService", () => { + describe("create", () => { + const lineItemRepository = MockRepository({ + create: (data) => data, }) - ) - }) - }) - describe("update", () => { - const lineItemRepository = MockRepository({ - findOne: () => - Promise.resolve({ - id: IdMap.getId("test-line-item"), - variant_id: IdMap.getId("test-variant"), - variant: { - id: IdMap.getId("test-variant"), - title: "Test variant", + const cartRepository = MockRepository({ + findOne: () => + Promise.resolve({ + region_id: IdMap.getId("test-region"), + }), + }) + + const regionService = { + withTransaction: function () { + return this }, - cart_id: IdMap.getId("test-cart"), - title: "Test product", - description: "Test variant", - thumbnail: "", - unit_price: 50, - quantity: 1, - }), - }) - - const lineItemService = new LineItemService({ - manager: MockManager, - lineItemRepository, - }) - - beforeEach(async () => { - jest.clearAllMocks() - }) - - it("successfully updates a line item with quantity", async () => { - await lineItemService.update(IdMap.getId("test-line-item"), { - quantity: 2, - has_shipping: true, - }) - - expect(lineItemRepository.save).toHaveBeenCalledTimes(1) - expect(lineItemRepository.save).toHaveBeenCalledWith({ - id: IdMap.getId("test-line-item"), - variant_id: IdMap.getId("test-variant"), - variant: { - id: IdMap.getId("test-variant"), - title: "Test variant", - }, - cart_id: IdMap.getId("test-cart"), - title: "Test product", - description: "Test variant", - thumbnail: "", - unit_price: 50, - quantity: 2, - has_shipping: true, - }) - }) - - it("successfully updates a line item with metadata", async () => { - await lineItemService.update(IdMap.getId("test-line-item"), { - metadata: { - testKey: "testValue", - }, - }) - - expect(lineItemRepository.save).toHaveBeenCalledTimes(1) - expect(lineItemRepository.save).toHaveBeenCalledWith({ - id: IdMap.getId("test-line-item"), - variant_id: IdMap.getId("test-variant"), - variant: { - id: IdMap.getId("test-variant"), - title: "Test variant", - }, - cart_id: IdMap.getId("test-cart"), - title: "Test product", - description: "Test variant", - thumbnail: "", - unit_price: 50, - quantity: 1, - metadata: { - testKey: "testValue", - }, - }) - }) - }) - describe("delete", () => { - const lineItemRepository = MockRepository({ - findOne: () => - Promise.resolve({ - id: IdMap.getId("test-line-item"), - variant_id: IdMap.getId("test-variant"), - variant: { - id: IdMap.getId("test-variant"), - title: "Test variant", + retrieve: () => { + return { + id: IdMap.getId("test-region"), + } }, - cart_id: IdMap.getId("test-cart"), - title: "Test product", - description: "Test variant", - thumbnail: "", - unit_price: 50, - quantity: 1, - }), - }) + } - const lineItemService = new LineItemService({ - manager: MockManager, - lineItemRepository, - }) + const productVariantService = { + withTransaction: function () { + return this + }, + retrieve: (query) => { + if (query === IdMap.getId("test-giftcard")) { + return { + id: IdMap.getId("test-giftcard"), + title: "Test variant", + product: { + title: "Test product", + thumbnail: "", + is_giftcard: true, + discountable: false, + }, + } + } + return { + id: IdMap.getId("test-variant"), + title: "Test variant", + product: { + title: "Test product", + thumbnail: "", + }, + } + }, + getRegionPrice: () => 100, + } - beforeEach(async () => { - jest.clearAllMocks() - }) + const pricingService = { + withTransaction: function () { + return this + }, + getProductVariantPricingById: () => { + return { + calculated_price: 100, + } + }, + getProductVariantPricing: () => { + return { + calculated_price: 100, + } + }, + } - it("successfully deletes", async () => { - await lineItemService.delete(IdMap.getId("test-line-item")) + const featureFlagRouter = new FlagRouter({ + tax_inclusive_pricing: isTaxInclusiveEnabled, + }) - expect(lineItemRepository.remove).toHaveBeenCalledTimes(1) - expect(lineItemRepository.remove).toHaveBeenCalledWith({ - id: IdMap.getId("test-line-item"), - variant_id: IdMap.getId("test-variant"), - variant: { - id: IdMap.getId("test-variant"), - title: "Test variant", - }, - cart_id: IdMap.getId("test-cart"), - title: "Test product", - description: "Test variant", - thumbnail: "", - unit_price: 50, - quantity: 1, + const lineItemService = new LineItemService({ + manager: MockManager, + pricingService, + lineItemRepository, + productVariantService, + regionService, + cartRepository, + featureFlagRouter, + }) + + beforeEach(async () => { + jest.clearAllMocks() + }) + + it("successfully create a line item", async () => { + await lineItemService.create({ + variant_id: IdMap.getId("test-variant"), + cart_id: IdMap.getId("test-cart"), + title: "Test product", + description: "Test variant", + thumbnail: "", + unit_price: 100, + quantity: 1, + }) + + expect(lineItemRepository.create).toHaveBeenCalledTimes(1) + expect(lineItemRepository.create).toHaveBeenCalledWith({ + variant_id: IdMap.getId("test-variant"), + cart_id: IdMap.getId("test-cart"), + title: "Test product", + description: "Test variant", + thumbnail: "", + unit_price: 100, + quantity: 1, + }) + }) + + it("successfully create a line item with price and quantity", async () => { + await lineItemService.create({ + variant_id: IdMap.getId("test-variant"), + cart_id: IdMap.getId("test-cart"), + unit_price: 50, + quantity: 2, + }) + + expect(lineItemRepository.create).toHaveBeenCalledTimes(1) + expect(lineItemRepository.create).toHaveBeenCalledWith({ + variant_id: IdMap.getId("test-variant"), + cart_id: IdMap.getId("test-cart"), + unit_price: 50, + quantity: 2, + }) + }) + + it("successfully create a line item giftcard", async () => { + const line = await lineItemService.generate( + IdMap.getId("test-giftcard"), + IdMap.getId("test-region"), + 1 + ) + + await lineItemService.create({ + ...line, + cart_id: IdMap.getId("test-cart"), + }) + + expect(lineItemRepository.create).toHaveBeenCalledTimes(2) + expect(lineItemRepository.create).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + allow_discounts: false, + variant_id: IdMap.getId("test-giftcard"), + cart_id: IdMap.getId("test-cart"), + title: "Test product", + description: "Test variant", + thumbnail: "", + unit_price: 100, + quantity: 1, + is_giftcard: true, + should_merge: true, + metadata: {}, + }) + ) + }) + }) + + describe("update", () => { + const lineItemRepository = MockRepository({ + findOne: () => + Promise.resolve({ + id: IdMap.getId("test-line-item"), + variant_id: IdMap.getId("test-variant"), + variant: { + id: IdMap.getId("test-variant"), + title: "Test variant", + }, + cart_id: IdMap.getId("test-cart"), + title: "Test product", + description: "Test variant", + thumbnail: "", + unit_price: 50, + quantity: 1, + }), + }) + + const lineItemService = new LineItemService({ + manager: MockManager, + lineItemRepository, + }) + + beforeEach(async () => { + jest.clearAllMocks() + }) + + it("successfully updates a line item with quantity", async () => { + await lineItemService.update(IdMap.getId("test-line-item"), { + quantity: 2, + has_shipping: true, + }) + + expect(lineItemRepository.save).toHaveBeenCalledTimes(1) + expect(lineItemRepository.save).toHaveBeenCalledWith({ + id: IdMap.getId("test-line-item"), + variant_id: IdMap.getId("test-variant"), + variant: { + id: IdMap.getId("test-variant"), + title: "Test variant", + }, + cart_id: IdMap.getId("test-cart"), + title: "Test product", + description: "Test variant", + thumbnail: "", + unit_price: 50, + quantity: 2, + has_shipping: true, + }) + }) + + it("successfully updates a line item with metadata", async () => { + await lineItemService.update(IdMap.getId("test-line-item"), { + metadata: { + testKey: "testValue", + }, + }) + + expect(lineItemRepository.save).toHaveBeenCalledTimes(1) + expect(lineItemRepository.save).toHaveBeenCalledWith({ + id: IdMap.getId("test-line-item"), + variant_id: IdMap.getId("test-variant"), + variant: { + id: IdMap.getId("test-variant"), + title: "Test variant", + }, + cart_id: IdMap.getId("test-cart"), + title: "Test product", + description: "Test variant", + thumbnail: "", + unit_price: 50, + quantity: 1, + metadata: { + testKey: "testValue", + }, + }) + }) + }) + describe("delete", () => { + const lineItemRepository = MockRepository({ + findOne: () => + Promise.resolve({ + id: IdMap.getId("test-line-item"), + variant_id: IdMap.getId("test-variant"), + variant: { + id: IdMap.getId("test-variant"), + title: "Test variant", + }, + cart_id: IdMap.getId("test-cart"), + title: "Test product", + description: "Test variant", + thumbnail: "", + unit_price: 50, + quantity: 1, + }), + }) + + const lineItemService = new LineItemService({ + manager: MockManager, + lineItemRepository, + }) + + beforeEach(async () => { + jest.clearAllMocks() + }) + + it("successfully deletes", async () => { + await lineItemService.delete(IdMap.getId("test-line-item")) + + expect(lineItemRepository.remove).toHaveBeenCalledTimes(1) + expect(lineItemRepository.remove).toHaveBeenCalledWith({ + id: IdMap.getId("test-line-item"), + variant_id: IdMap.getId("test-variant"), + variant: { + id: IdMap.getId("test-variant"), + title: "Test variant", + }, + cart_id: IdMap.getId("test-cart"), + title: "Test product", + description: "Test variant", + thumbnail: "", + unit_price: 50, + quantity: 1, + }) + }) + }) + }) + }) +}) + +describe("LineItemService", () => { + describe(`tax inclusive pricing tests `, () => { + describe("generate", () => { + const lineItemRepository = MockRepository({ + create: (data) => data, + }) + + const cartRepository = MockRepository({ + findOne: () => + Promise.resolve({ + region_id: IdMap.getId("test-region"), + }), + }) + + const regionService = { + withTransaction: function () { + return this + }, + retrieve: () => { + return { + id: IdMap.getId("test-region"), + } + }, + } + + const productVariantService = { + withTransaction: function () { + return this + }, + retrieve: (query) => { + if (query === IdMap.getId("test-giftcard")) { + return { + id: IdMap.getId("test-giftcard"), + title: "Test variant", + product: { + title: "Test product", + thumbnail: "", + is_giftcard: true, + discountable: false, + }, + } + } + return { + id: IdMap.getId("test-variant"), + title: "Test variant", + product: { + title: "Test product", + thumbnail: "", + }, + } + }, + getRegionPrice: () => 100, + } + + const pricingService = { + withTransaction: function () { + return this + }, + getProductVariantPricingById: () => { + return { + calculated_price: 100, + calculated_price_includes_tax: true, + } + }, + getProductVariantPricing: () => { + return { + calculated_price: 100, + calculated_price_includes_tax: true, + } + }, + } + + const featureFlagRouter = new FlagRouter({ + tax_inclusive_pricing: true, + }) + + const lineItemService = new LineItemService({ + manager: MockManager, + pricingService, + lineItemRepository, + productVariantService, + regionService, + cartRepository, + featureFlagRouter, + }) + + beforeEach(async () => { + jest.clearAllMocks() + }) + + it("successfully create a line item with tax inclusive set to true", async () => { + await lineItemService.generate( + IdMap.getId("test-variant"), + IdMap.getId("test-region"), + 1 + ) + + expect(lineItemRepository.create).toHaveBeenCalledTimes(1) + expect(lineItemRepository.create).toHaveBeenCalledWith({ + unit_price: 100, + title: "Test product", + description: "Test variant", + thumbnail: "", + variant_id: IdMap.getId("test-variant"), + quantity: 1, + allow_discounts: undefined, + is_giftcard: undefined, + metadata: {}, + should_merge: true, + includes_tax: true, + }) + }) + }) + describe("generate", () => { + const lineItemRepository = MockRepository({ + create: (data) => data, + }) + + const cartRepository = MockRepository({ + findOne: () => + Promise.resolve({ + region_id: IdMap.getId("test-region"), + }), + }) + + const regionService = { + withTransaction: function () { + return this + }, + retrieve: () => { + return { + id: IdMap.getId("test-region"), + } + }, + } + + const productVariantService = { + withTransaction: function () { + return this + }, + retrieve: (query) => { + if (query === IdMap.getId("test-giftcard")) { + return { + id: IdMap.getId("test-giftcard"), + title: "Test variant", + product: { + title: "Test product", + thumbnail: "", + is_giftcard: true, + discountable: false, + }, + } + } + return { + id: IdMap.getId("test-variant"), + title: "Test variant", + product: { + title: "Test product", + thumbnail: "", + }, + } + }, + getRegionPrice: () => 100, + } + + const pricingService = { + withTransaction: function () { + return this + }, + getProductVariantPricingById: () => { + return { + calculated_price: 100, + calculated_price_includes_tax: false, + } + }, + getProductVariantPricing: () => { + return { + calculated_price: 100, + calculated_price_includes_tax: false, + } + }, + } + + const featureFlagRouter = new FlagRouter({ + tax_inclusive_pricing: true, + }) + + const lineItemService = new LineItemService({ + manager: MockManager, + pricingService, + lineItemRepository, + productVariantService, + regionService, + cartRepository, + featureFlagRouter, + }) + + beforeEach(async () => { + jest.clearAllMocks() + }) + + it("successfully create a line item with tax inclusive set to false", async () => { + await lineItemService.generate( + IdMap.getId("test-variant"), + IdMap.getId("test-region"), + 1 + ) + + expect(lineItemRepository.create).toHaveBeenCalledTimes(1) + expect(lineItemRepository.create).toHaveBeenCalledWith({ + unit_price: 100, + title: "Test product", + description: "Test variant", + thumbnail: "", + variant_id: IdMap.getId("test-variant"), + quantity: 1, + allow_discounts: undefined, + is_giftcard: undefined, + metadata: {}, + should_merge: true, + includes_tax: false, + }) }) }) }) diff --git a/packages/medusa/src/services/__tests__/price-list.js b/packages/medusa/src/services/__tests__/price-list.js index 588085e003..5a6a2ffb8c 100644 --- a/packages/medusa/src/services/__tests__/price-list.js +++ b/packages/medusa/src/services/__tests__/price-list.js @@ -1,8 +1,9 @@ import { MedusaError } from "medusa-core-utils" import { IdMap, MockManager, MockRepository } from "medusa-test-utils" -import PriceListService from "../price-list" import { MoneyAmountRepository } from "../../repositories/money-amount" -import { RegionServiceMock } from "../__mocks__/region"; +import { FlagRouter } from "../../utils/flag-router" +import PriceListService from "../price-list" +import { RegionServiceMock } from "../__mocks__/region" const priceListRepository = MockRepository({ findOne: (q) => { @@ -42,6 +43,7 @@ describe("PriceListService", () => { customerGroupService, priceListRepository, moneyAmountRepository, + featureFlagRouter: new FlagRouter({}), }) beforeEach(async () => { @@ -121,16 +123,21 @@ describe("PriceListService", () => { describe("update", () => { const updateRelatedMoneyAmountRepository = MockRepository() - updateRelatedMoneyAmountRepository.create = jest.fn().mockImplementation((rawEntity) => Promise.resolve(rawEntity)) - updateRelatedMoneyAmountRepository.save = jest.fn().mockImplementation(() => Promise.resolve()) - updateRelatedMoneyAmountRepository.updatePriceListPrices = (new MoneyAmountRepository()).updatePriceListPrices + updateRelatedMoneyAmountRepository.create = jest + .fn() + .mockImplementation((rawEntity) => Promise.resolve(rawEntity)) + updateRelatedMoneyAmountRepository.save = jest + .fn() + .mockImplementation(() => Promise.resolve()) + updateRelatedMoneyAmountRepository.updatePriceListPrices = new MoneyAmountRepository().updatePriceListPrices const updateRelatedPriceListService = new PriceListService({ manager: MockManager, customerGroupService, priceListRepository, moneyAmountRepository: updateRelatedMoneyAmountRepository, - regionService: RegionServiceMock + featureFlagRouter: new FlagRouter({}), + regionService: RegionServiceMock, }) it("update only existing price lists and related money amount", async () => { diff --git a/packages/medusa/src/services/__tests__/region.ts b/packages/medusa/src/services/__tests__/region.ts index 921fd8999b..1f6dd984a0 100644 --- a/packages/medusa/src/services/__tests__/region.ts +++ b/packages/medusa/src/services/__tests__/region.ts @@ -1,12 +1,13 @@ import { IdMap, MockManager, MockRepository } from "medusa-test-utils" -import RegionService from "../region" +import { CreateRegionInput } from "../../types/region" +import { FlagRouter } from "../../utils/flag-router" import { EventBusService, FulfillmentProviderService, PaymentProviderService, StoreService, } from "../index" -import { CreateRegionInput } from "../../types/region" +import RegionService from "../region" const eventBusService = { emit: jest.fn(), @@ -87,6 +88,7 @@ describe("RegionService", () => { regionRepository, countryRepository, storeService, + featureFlagRouter: new FlagRouter({}), fulfillmentProviderService, taxProviderRepository, paymentProviderService, @@ -196,6 +198,7 @@ describe("RegionService", () => { manager: MockManager, eventBusService, regionRepository, + featureFlagRouter: new FlagRouter({}), fulfillmentProviderService, taxProviderRepository, paymentProviderService, @@ -250,6 +253,7 @@ describe("RegionService", () => { currencyRepository, countryRepository, storeService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(async () => { @@ -312,6 +316,7 @@ describe("RegionService", () => { currencyRepository, countryRepository, storeService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(async () => { @@ -367,6 +372,7 @@ describe("RegionService", () => { paymentProviderRepository: ppRepository, currencyRepository, countryRepository, + featureFlagRouter: new FlagRouter({}), storeService, }) @@ -427,6 +433,7 @@ describe("RegionService", () => { paymentProviderRepository: ppRepository, currencyRepository, countryRepository, + featureFlagRouter: new FlagRouter({}), storeService, }) @@ -472,6 +479,7 @@ describe("RegionService", () => { manager: MockManager, eventBusService, regionRepository, + featureFlagRouter: new FlagRouter({}), } as any) beforeEach(async () => { @@ -507,6 +515,7 @@ describe("RegionService", () => { paymentProviderService, fulfillmentProviderRepository: fpRepository, paymentProviderRepository: ppRepository, + featureFlagRouter: new FlagRouter({}), currencyRepository, countryRepository, storeService, @@ -569,6 +578,7 @@ describe("RegionService", () => { manager: MockManager, eventBusService, regionRepository, + featureFlagRouter: new FlagRouter({}), fulfillmentProviderService, taxProviderRepository, paymentProviderService, @@ -626,6 +636,7 @@ describe("RegionService", () => { manager: MockManager, eventBusService, regionRepository, + featureFlagRouter: new FlagRouter({}), } as any) beforeEach(async () => { @@ -661,6 +672,7 @@ describe("RegionService", () => { manager: MockManager, eventBusService, regionRepository, + featureFlagRouter: new FlagRouter({}), } as any) beforeEach(async () => { diff --git a/packages/medusa/src/services/__tests__/shipping-option.js b/packages/medusa/src/services/__tests__/shipping-option.js index 2a30f9aa2b..f483e8c59b 100644 --- a/packages/medusa/src/services/__tests__/shipping-option.js +++ b/packages/medusa/src/services/__tests__/shipping-option.js @@ -1,6 +1,8 @@ import _ from "lodash" import { IdMap, MockRepository, MockManager } from "medusa-test-utils" import ShippingOptionService from "../shipping-option" +import { FlagRouter } from "../../utils/flag-router"; +import TaxInclusivePricingFeatureFlag from "../../loaders/feature-flags/tax-inclusive-pricing"; describe("ShippingOptionService", () => { describe("retrieve", () => { @@ -13,6 +15,7 @@ describe("ShippingOptionService", () => { const optionService = new ShippingOptionService({ manager: MockManager, shippingOptionRepository, + featureFlagRouter: new FlagRouter({}), }) it("successfully gets shipping option", async () => { @@ -60,6 +63,7 @@ describe("ShippingOptionService", () => { shippingOptionRepository, shippingOptionRequirementRepository, fulfillmentProviderService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -214,6 +218,7 @@ describe("ShippingOptionService", () => { const optionService = new ShippingOptionService({ manager: MockManager, shippingOptionRepository, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -262,6 +267,7 @@ describe("ShippingOptionService", () => { manager: MockManager, shippingOptionRepository, shippingOptionRequirementRepository, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -311,6 +317,7 @@ describe("ShippingOptionService", () => { const optionService = new ShippingOptionService({ manager: MockManager, shippingOptionRequirementRepository, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -363,6 +370,7 @@ describe("ShippingOptionService", () => { shippingOptionRequirementRepository, fulfillmentProviderService, regionService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -526,6 +534,7 @@ describe("ShippingOptionService", () => { shippingOptionRepository, totalsService, fulfillmentProviderService: providerService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -589,4 +598,68 @@ describe("ShippingOptionService", () => { ) }) }) + + describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] createShippingMethod", () => { + const option = (id) => ({ + id, + region_id: IdMap.getId("region"), + price_type: "flat_rate", + amount: 10, + includes_tax: true, + data: { + something: "yes", + }, + requirements: [ + { + type: "min_subtotal", + amount: 100, + }, + ], + }) + const shippingOptionRepository = MockRepository({ + findOne: (q) => { + switch (q.where.id) { + default: + return Promise.resolve(option(q.where.id)) + } + }, + }) + const shippingMethodRepository = MockRepository({ create: (r) => r }) + const totalsService = { + getSubtotal: (c) => { + return c.subtotal + }, + } + + const providerService = { + validateFulfillmentData: jest + .fn() + .mockImplementation((r) => Promise.resolve(r.data)), + getPrice: (d) => d.price, + } + + const optionService = new ShippingOptionService({ + manager: MockManager, + shippingMethodRepository, + shippingOptionRepository, + totalsService, + fulfillmentProviderService: providerService, + featureFlagRouter: new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: true + }), + }) + + beforeEach(() => { + jest.clearAllMocks() + }) + + it("should create a shipping method that also includes the taxes", async () => { + await optionService.createShippingMethod("random_id", {}, { price: 10 }) + expect(shippingMethodRepository.save).toHaveBeenCalledWith( + expect.objectContaining({ + includes_tax: true + }) + ) + }) + }) }) diff --git a/packages/medusa/src/services/__tests__/totals.js b/packages/medusa/src/services/__tests__/totals.js index 9f633e0be7..0b524efc54 100644 --- a/packages/medusa/src/services/__tests__/totals.js +++ b/packages/medusa/src/services/__tests__/totals.js @@ -1,5 +1,9 @@ import { IdMap } from "medusa-test-utils" import TotalsService from "../totals" +import { FlagRouter } from "../../utils/flag-router" + +import TaxInclusivePricingFeatureFlag from "../../loaders/feature-flags/tax-inclusive-pricing" +import { calculatePriceTaxAmount } from "../../utils" const discounts = { total10Percent: { @@ -83,11 +87,21 @@ const applyDiscount = (cart, discount) => { const calculateAdjustment = (cart, lineItem, discount) => { let amount = discount.rule.value * lineItem.quantity - let lineItemPrice = lineItem.unit_price * lineItem.quantity + const taxAmountIncludedInPrice = !lineItem.includes_tax + ? 0 + : Math.round( + calculatePriceTaxAmount({ + price: lineItem.unit_price, + taxRate: cart.tax_rate / 100, + includesTax: lineItem.includes_tax, + }) + ) + let price = lineItem.unit_price - taxAmountIncludedInPrice + const lineItemPrice = price * lineItem.quantity if (discount.rule.type === "fixed" && discount.rule.allocation === "total") { let subtotal = cart.items.reduce( - (total, item) => total + item.unit_price * item.quantity, + (total, item) => total + price * item.quantity, 0 ) const nominator = Math.min(discount.rule.value, subtotal) @@ -99,13 +113,20 @@ const calculateAdjustment = (cart, lineItem, discount) => { } describe("TotalsService", () => { + const getTaxLinesMock = jest.fn(() => Promise.resolve([{ id: "line1" }])) + const featureFlagRouter = new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: false, + }) + const container = { taxProviderService: { - withTransaction: function () { + withTransaction: function() { return this }, + getTaxLines: getTaxLinesMock, }, taxCalculationStrategy: {}, + featureFlagRouter, } describe("getAllocationItemDiscounts", () => { @@ -278,7 +299,7 @@ describe("TotalsService", () => { it("calculate total percentage discount", async () => { discountCart.discounts.push(discounts.total10Percent) let cart = applyDiscount(discountCart, discounts.total10Percent) - res = totalsService.getDiscountTotal(cart) + res = await totalsService.getDiscountTotal(cart) expect(res).toEqual(28) }) @@ -288,7 +309,7 @@ describe("TotalsService", () => { it("calculate item fixed discount", async () => { discountCart.discounts.push(discounts.item2Fixed) let cart = applyDiscount(discountCart, discounts.item2Fixed) - res = totalsService.getDiscountTotal(cart) + res = await totalsService.getDiscountTotal(cart) expect(res).toEqual(40) }) @@ -296,7 +317,7 @@ describe("TotalsService", () => { it("calculate item percentage discount", async () => { discountCart.discounts.push(discounts.item10Percent) let cart = applyDiscount(discountCart, discounts.item10Percent) - res = totalsService.getDiscountTotal(cart) + res = await totalsService.getDiscountTotal(cart) expect(res).toEqual(28) }) @@ -304,26 +325,26 @@ describe("TotalsService", () => { it("calculate total fixed discount", async () => { discountCart.discounts.push(discounts.total10Fixed) let cart = applyDiscount(discountCart, discounts.total10Fixed) - res = totalsService.getDiscountTotal(cart) + res = await totalsService.getDiscountTotal(cart) expect(res).toEqual(10) }) it("ignores discount if expired", async () => { discountCart.discounts.push(discounts.expiredDiscount) - res = totalsService.getDiscountTotal(discountCart) + res = await totalsService.getDiscountTotal(discountCart) expect(res).toEqual(0) }) it("returns 0 if no discounts are applied", async () => { - res = totalsService.getDiscountTotal(discountCart) + res = await totalsService.getDiscountTotal(discountCart) expect(res).toEqual(0) }) it("returns 0 if no items are in cart", async () => { - res = totalsService.getDiscountTotal({ + res = await totalsService.getDiscountTotal({ items: [], discounts: [discounts.total10Fixed], }) @@ -385,7 +406,7 @@ describe("TotalsService", () => { }) it("calculates refund", async () => { - res = totalsService.getRefundTotal(orderToRefund, [ + res = await totalsService.getRefundTotal(orderToRefund, [ { id: "line2", unit_price: 100, @@ -447,7 +468,7 @@ describe("TotalsService", () => { it("calculates refund with item fixed discount", async () => { orderToRefund.discounts.push(discounts.item2Fixed) let order = applyDiscount(orderToRefund, discounts.item2Fixed) - res = totalsService.getRefundTotal(order, [ + res = await totalsService.getRefundTotal(order, [ { id: "line2", unit_price: 100, @@ -467,7 +488,7 @@ describe("TotalsService", () => { it("calculates refund with item percentage discount", async () => { orderToRefund.discounts.push(discounts.item10Percent) let order = applyDiscount(orderToRefund, discounts.item10Percent) - res = totalsService.getRefundTotal(order, [ + res = await totalsService.getRefundTotal(order, [ { id: "line2", unit_price: 100, @@ -485,8 +506,9 @@ describe("TotalsService", () => { }) it("throws if line items to return is not in order", async () => { - const work = () => - totalsService.getRefundTotal(orderToRefund, [ + let errMsg + await totalsService + .getRefundTotal(orderToRefund, [ { id: "notInOrder", unit_price: 123, @@ -498,14 +520,213 @@ describe("TotalsService", () => { quantity: 1, }, ]) + .catch((e) => (errMsg = e.message)) - expect(work).toThrow("Line item does not exist on order") + expect(errMsg).toBe("Line item does not exist on order") + }) + }) + + describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] getRefundTotal", () => { + let res + const totalsService = new TotalsService({ + ...container, + featureFlagRouter: new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: true, + }), + }) + + const orderToRefund = { + id: "refund-order", + tax_rate: 25, + items: [ + { + id: "line", + unit_price: 125, + includes_tax: true, + allow_discounts: true, + variant: { + id: "variant", + product_id: "testp1", + }, + quantity: 10, + returned_quantity: 0, + }, + { + id: "line2", + unit_price: 100, + allow_discounts: true, + variant: { + id: "variant", + product_id: "testp2", + }, + quantity: 10, + returned_quantity: 0, + metadata: {}, + }, + { + id: "non-discount", + unit_price: 100, + allow_discounts: false, + variant: { + id: "variant", + product_id: "testp2", + }, + quantity: 1, + returned_quantity: 0, + metadata: {}, + }, + ], + region_id: "fr", + discounts: [], + } + + beforeEach(() => { + jest.clearAllMocks() + orderToRefund.discounts = [] + }) + + it("calculates refund", async () => { + res = await totalsService.getRefundTotal(orderToRefund, [ + { + id: "line2", + unit_price: 100, + allow_discounts: true, + variant: { + id: "variant", + product_id: "product2", + }, + quantity: 10, + returned_quantity: 0, + metadata: {}, + }, + ]) + + expect(res).toEqual(1250) + }) + + it("calculates refund with line that includes tax", async () => { + res = await totalsService.getRefundTotal(orderToRefund, [ + { + id: "line", + unit_price: 125, + includes_tax: true, + allow_discounts: true, + variant: { + id: "variant", + product_id: "product2", + }, + quantity: 10, + returned_quantity: 0, + metadata: {}, + }, + ]) + + expect(res).toEqual(1250) + }) + + it("calculates refund with item fixed discount", async () => { + orderToRefund.discounts.push(discounts.item2Fixed) + let order = applyDiscount(orderToRefund, discounts.item2Fixed) + res = await totalsService.getRefundTotal(order, [ + { + id: "line2", + unit_price: 100, + allow_discounts: true, + variant: { + id: "variant", + product_id: "testp2", + }, + quantity: 10, + returned_quantity: 0, + }, + ]) + + expect(res).toEqual(1225) + }) + + it("calculates refund with item fixed discount and a line that includes tax", async () => { + orderToRefund.discounts.push(discounts.item2Fixed) + let order = applyDiscount(orderToRefund, discounts.item2Fixed) + res = await totalsService.getRefundTotal(order, [ + { + id: "line", + unit_price: 125, + includes_tax: true, + allow_discounts: true, + variant: { + id: "variant", + product_id: "testp2", + }, + quantity: 10, + returned_quantity: 0, + }, + ]) + + expect(res).toEqual(1225) + }) + + it("calculates refund with item percentage discount", async () => { + orderToRefund.discounts.push(discounts.item10Percent) + let order = applyDiscount(orderToRefund, discounts.item10Percent) + res = await totalsService.getRefundTotal(order, [ + { + id: "line2", + unit_price: 100, + allow_discounts: true, + variant: { + id: "variant", + product_id: "testp2", + }, + quantity: 10, + returned_quantity: 0, + }, + ]) + + expect(res).toEqual(1125) + }) + + it("calculates refund with item percentage discount and a line that includes tax", async () => { + orderToRefund.discounts.push(discounts.item10Percent) + let order = applyDiscount(orderToRefund, discounts.item10Percent) + res = await totalsService.getRefundTotal(order, [ + { + id: "line", + unit_price: 125, + includes_tax: true, + allow_discounts: true, + variant: { + id: "variant", + product_id: "testp2", + }, + quantity: 10, + returned_quantity: 0, + }, + ]) + + expect(res).toEqual(1125) }) }) describe("getShippingTotal", () => { - let res - const totalsService = new TotalsService(container) + const getTaxLinesMock = jest.fn(() => + Promise.resolve([ + { shipping_method_id: IdMap.getId("expensiveShipping") }, + ]) + ) + const calculateMock = jest.fn(() => Promise.resolve(20)) + + const totalsService = new TotalsService({ + ...container, + taxProviderService: { + withTransaction: function() { + return this + }, + getTaxLines: getTaxLinesMock, + }, + taxCalculationStrategy: { + calculate: calculateMock, + }, + }) beforeEach(() => { jest.clearAllMocks() @@ -515,7 +736,7 @@ describe("TotalsService", () => { const order = { shipping_methods: [ { - _id: IdMap.getId("expensiveShipping"), + id: IdMap.getId("expensiveShipping"), name: "Expensive Shipping", price: 100, provider_id: "default_provider", @@ -526,11 +747,12 @@ describe("TotalsService", () => { }, ], } - res = totalsService.getShippingTotal(order) + const total = await totalsService.getShippingTotal(order) - expect(res).toEqual(100) + expect(total).toEqual(100) }) }) + describe("getTaxTotal", () => { let res let totalsService @@ -541,7 +763,7 @@ describe("TotalsService", () => { const cradle = { taxProviderService: { - withTransaction: function () { + withTransaction: function() { return this }, getTaxLines: getTaxLinesMock, @@ -549,6 +771,7 @@ describe("TotalsService", () => { taxCalculationStrategy: { calculate: calculateMock, }, + featureFlagRouter, } beforeEach(() => { @@ -597,13 +820,14 @@ describe("TotalsService", () => { expect(res).toEqual(20) - expect(getAllocationMapMock).toHaveBeenCalledTimes(1) - expect(getAllocationMapMock).toHaveBeenCalledWith(order, {}) + expect(getAllocationMapMock).toHaveBeenCalledTimes(2) + expect(getAllocationMapMock).toHaveBeenNthCalledWith(1, order, {}) expect(getTaxLinesMock).toHaveBeenCalledTimes(0) - expect(calculateMock).toHaveBeenCalledTimes(1) - expect(calculateMock).toHaveBeenCalledWith( + expect(calculateMock).toHaveBeenCalledTimes(3) + expect(calculateMock).toHaveBeenNthCalledWith( + 3, order.items, [{ id: "orderline1" }], { @@ -652,11 +876,15 @@ describe("TotalsService", () => { expect(res).toEqual(20) - expect(getAllocationMapMock).toHaveBeenCalledTimes(1) - expect(getAllocationMapMock).toHaveBeenCalledWith(order, {}) + expect(getAllocationMapMock).toHaveBeenCalledTimes(2) + expect(getAllocationMapMock).toHaveBeenNthCalledWith(2, order, { + exclude_discounts: undefined, + exclude_gift_cards: true, + }) - expect(getTaxLinesMock).toHaveBeenCalledTimes(1) - expect(getTaxLinesMock).toHaveBeenCalledWith( + expect(getTaxLinesMock).toHaveBeenCalledTimes(2) + expect(getTaxLinesMock).toHaveBeenNthCalledWith( + 2, [{ quantity: 2, unit_price: 20 }], { shipping_address: order.shipping_address, @@ -726,4 +954,137 @@ describe("TotalsService", () => { expect(res).toEqual(175) }) }) + + describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] getTotal", () => { + let res + const totalsService = new TotalsService({ + ...container, + featureFlagRouter: new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: true, + }), + }) + + beforeEach(() => { + jest.clearAllMocks() + }) + + it("calculates total", async () => { + const order = { + region: { + tax_rate: 25, + }, + items: [ + { + unit_price: 20, + quantity: 2, + }, + { + unit_price: 25, + quantity: 2, + includes_tax: true, + }, + ], + shipping_methods: [ + { + _id: IdMap.getId("expensiveShipping"), + name: "Expensive Shipping", + price: 100, + provider_id: "default_provider", + profile_id: IdMap.getId("default"), + data: { + extra: "hi", + }, + }, + ], + } + const getTaxTotalMock = jest.fn(() => Promise.resolve(45)) + totalsService.getTaxTotal = getTaxTotalMock + res = await totalsService.getTotal(order) + + expect(getTaxTotalMock).toHaveBeenCalledTimes(1) + expect(getTaxTotalMock).toHaveBeenCalledWith(order, undefined) + + expect(res).toEqual(185) + }) + }) + + describe("[MEDUSA_FF_TAX_INCLUSIVE_PRICING] getShippingTotal ", () => { + const shippingMethodData = { + id: IdMap.getId("expensiveShipping"), + name: "Expensive Shipping", + price: 120, + tax_lines: [{ shipping_method_id: IdMap.getId("expensiveShipping") }], + provider_id: "default_provider", + profile_id: IdMap.getId("default"), + data: { + extra: "hi", + }, + } + const calculateMock = jest.fn(() => Promise.resolve(20)) + const totalsService = new TotalsService({ + ...container, + taxCalculationStrategy: { + calculate: calculateMock, + }, + featureFlagRouter: new FlagRouter({ + [TaxInclusivePricingFeatureFlag.key]: true, + }), + }) + + beforeEach(() => { + jest.clearAllMocks() + }) + + it("calculates total with tax lines and being tax inclusive", async () => { + const order = { + object: "order", + shipping_methods: [ + { + ...shippingMethodData, + includes_tax: true, + }, + ], + } + + const total = await totalsService.getShippingTotal(order) + + expect(total).toEqual(100) + }) + + it("calculates total with tax lines and not being tax inclusive", async () => { + const order = { + object: "order", + shipping_methods: [ + { + ...shippingMethodData, + price: 100, + includes_tax: false, + }, + ], + } + + const total = await totalsService.getShippingTotal(order) + + expect(total).toEqual(100) + }) + + it("calculates total with the old system and not being tax inclusive", async () => { + const order = { + object: "order", + tax_rate: 20, + shipping_methods: [ + { + ...shippingMethodData, + price: 100, + includes_tax: false, + tax_lines: [], + }, + ], + } + + const total = await totalsService.getShippingTotal(order) + + expect(total).toEqual(100) + }) + }) }) diff --git a/packages/medusa/src/services/cart.ts b/packages/medusa/src/services/cart.ts index 03177525cc..a9d0840086 100644 --- a/packages/medusa/src/services/cart.ts +++ b/packages/medusa/src/services/cart.ts @@ -243,11 +243,15 @@ class CartService extends TransactionBaseService { break } case "shipping_total": { - totals.shipping_total = this.totalsService_.getShippingTotal(cart) + totals.shipping_total = await this.totalsService_.getShippingTotal( + cart + ) break } case "discount_total": - totals.discount_total = this.totalsService_.getDiscountTotal(cart) + totals.discount_total = await this.totalsService_.getDiscountTotal( + cart + ) break case "tax_total": totals.tax_total = await this.totalsService_.getTaxTotal( @@ -256,13 +260,15 @@ class CartService extends TransactionBaseService { ) break case "gift_card_total": { - const giftCardBreakdown = this.totalsService_.getGiftCardTotal(cart) + const giftCardBreakdown = await this.totalsService_.getGiftCardTotal( + cart + ) totals.gift_card_total = giftCardBreakdown.total totals.gift_card_tax_total = giftCardBreakdown.tax_total break } case "subtotal": - totals.subtotal = this.totalsService_.getSubtotal(cart) + totals.subtotal = await this.totalsService_.getSubtotal(cart) break default: break @@ -518,7 +524,7 @@ class CartService extends TransactionBaseService { .delete(lineItem.id) const result = await this.retrieve(cartId, { - relations: ["items", "discounts", "discounts.rule"], + relations: ["items", "discounts", "discounts.rule", "region"], }) await this.refreshAdjustments_(result) @@ -686,7 +692,7 @@ class CartService extends TransactionBaseService { ) const result = await this.retrieve(cartId, { - relations: ["items", "discounts", "discounts.rule"], + relations: ["items", "discounts", "discounts.rule", "region"], }) await this.refreshAdjustments_(result) @@ -748,7 +754,7 @@ class CartService extends TransactionBaseService { .update(lineItemId, lineItemUpdate) const updatedCart = await this.retrieve(cartId, { - relations: ["items", "discounts", "discounts.rule"], + relations: ["items", "discounts", "discounts.rule", "region"], }) await this.refreshAdjustments_(updatedCart) @@ -919,14 +925,14 @@ class CartService extends TransactionBaseService { ) const hasFreeShipping = cart.discounts.some( - ({ rule }) => rule?.type === "free_shipping" + ({ rule }) => rule?.type === DiscountRuleType.FREE_SHIPPING ) // if we previously had a free shipping discount and then removed it, // we need to update shipping methods to original price if ( previousDiscounts.some( - ({ rule }) => rule.type === "free_shipping" + ({ rule }) => rule.type === DiscountRuleType.FREE_SHIPPING ) && !hasFreeShipping ) { @@ -1229,7 +1235,7 @@ class CartService extends TransactionBaseService { default: if (!sawNotShipping) { sawNotShipping = true - if (rule?.type !== "free_shipping") { + if (rule?.type !== DiscountRuleType.FREE_SHIPPING) { return discount } return discountToParse @@ -1245,7 +1251,7 @@ class CartService extends TransactionBaseService { ) // ignore if free shipping - if (rule?.type !== "free_shipping" && cart?.items) { + if (rule?.type !== DiscountRuleType.FREE_SHIPPING && cart?.items) { await this.refreshAdjustments_(cart) } } @@ -1270,7 +1276,11 @@ class CartService extends TransactionBaseService { ], }) - if (cart.discounts.some(({ rule }) => rule.type === "free_shipping")) { + if ( + cart.discounts.some( + ({ rule }) => rule.type === DiscountRuleType.FREE_SHIPPING + ) + ) { await this.adjustFreeShipping_(cart, false) } @@ -1770,7 +1780,7 @@ class CartService extends TransactionBaseService { // if cart has freeshipping, adjust price if ( updatedCart.discounts.some( - ({ rule }) => rule.type === "free_shipping" + ({ rule }) => rule.type === DiscountRuleType.FREE_SHIPPING ) ) { await this.adjustFreeShipping_(updatedCart, true) @@ -2107,7 +2117,7 @@ class CartService extends TransactionBaseService { ], }) - const calculationContext = this.totalsService_ + const calculationContext = await this.totalsService_ .withTransaction(transactionManager) .getCalculationContext(cart) diff --git a/packages/medusa/src/services/claim.ts b/packages/medusa/src/services/claim.ts index 12089add51..9d2c24701d 100644 --- a/packages/medusa/src/services/claim.ts +++ b/packages/medusa/src/services/claim.ts @@ -381,7 +381,9 @@ export default class ClaimService extends TransactionBaseService { const result: ClaimOrder = await claimRepo.save(created) if (result.additional_items && result.additional_items.length) { - const calcContext = this.totalsService_.getCalculationContext(order) + const calcContext = await this.totalsService_.getCalculationContext( + order + ) const lineItems = await lineItemServiceTx.list({ id: result.additional_items.map((i) => i.id), }) diff --git a/packages/medusa/src/services/currency.ts b/packages/medusa/src/services/currency.ts new file mode 100644 index 0000000000..83b67fa9b0 --- /dev/null +++ b/packages/medusa/src/services/currency.ts @@ -0,0 +1,132 @@ +import { MedusaError } from "medusa-core-utils" +import { EntityManager } from "typeorm" +import { TransactionBaseService } from "../interfaces" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" +import { Currency } from "../models" +import { CurrencyRepository } from "../repositories/currency" +import { FindConfig, Selector } from "../types/common" +import { UpdateCurrencyInput } from "../types/currency" +import { buildQuery } from "../utils" +import { FlagRouter } from "../utils/flag-router" +import EventBusService from "./event-bus" + +type InjectedDependencies = { + manager: EntityManager + currencyRepository: typeof CurrencyRepository + eventBusService: EventBusService + featureFlagRouter: FlagRouter +} + +export default class CurrencyService extends TransactionBaseService { + static readonly Events = { + UPDATED: "currency.updated", + } + + protected manager_: EntityManager + protected transactionManager_: EntityManager | undefined + + protected readonly currencyRepository_: typeof CurrencyRepository + protected readonly eventBusService_: EventBusService + protected readonly featureFlagRouter_: FlagRouter + + constructor({ + manager, + currencyRepository, + eventBusService, + featureFlagRouter, + }: InjectedDependencies) { + super({ manager }) + this.manager_ = manager + this.currencyRepository_ = currencyRepository + this.eventBusService_ = eventBusService + this.featureFlagRouter_ = featureFlagRouter + } + + /** + * Return the currency + * @param code - The code of the currency that must be retrieve + * @return The currency + */ + async retrieveByCode(code: string): Promise { + const currencyRepo = this.manager_.getCustomRepository( + this.currencyRepository_ + ) + + code = code.toLowerCase() + const currency = await currencyRepo.findOne({ + where: { code }, + }) + + if (!currency) { + throw new MedusaError( + MedusaError.Types.NOT_FOUND, + `Currency with code: ${code} was not found` + ) + } + + return currency + } + + /** + * Lists currencies based on the provided parameters and includes the count of + * currencies that match the query. + * @param selector - an object that defines rules to filter currencies + * by + * @param config - object that defines the scope for what should be + * returned + * @return an array containing the currencies as + * the first element and the total count of products that matches the query + * as the second element. + */ + async listAndCount( + selector: Selector, + config: FindConfig = { + skip: 0, + take: 20, + } + ): Promise<[Currency[], number]> { + const productRepo = this.manager_.getCustomRepository( + this.currencyRepository_ + ) + + const query = buildQuery(selector, config) + + return await productRepo.findAndCount(query) + } + + /** + * Update a currency + * @param code - The code of the currency to update + * @param data - The data that must be updated on the currency + * @return The updated currency + */ + async update( + code: string, + data: UpdateCurrencyInput + ): Promise { + return await this.atomicPhase_(async (transactionManager) => { + const currency = await this.retrieveByCode(code) + + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + if (typeof data.includes_tax !== "undefined") { + currency.includes_tax = data.includes_tax + } + } + + const currencyRepo = transactionManager.getCustomRepository( + this.currencyRepository_ + ) + await currencyRepo.save(currency) + + await this.eventBusService_.emit(CurrencyService.Events.UPDATED, { + code, + }) + + return currency + }) + } +} diff --git a/packages/medusa/src/services/discount.ts b/packages/medusa/src/services/discount.ts index 7a16487c10..3253572bbc 100644 --- a/packages/medusa/src/services/discount.ts +++ b/packages/medusa/src/services/discount.ts @@ -39,6 +39,8 @@ import DiscountConditionService from "./discount-condition" import CustomerService from "./customer" import { TransactionBaseService } from "../interfaces" import { buildQuery, setMetadata } from "../utils" +import { FlagRouter } from "../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" /** * Provides layer to manipulate discounts. @@ -58,6 +60,7 @@ class DiscountService extends TransactionBaseService { protected readonly productService_: ProductService protected readonly regionService_: RegionService protected readonly eventBus_: EventBusService + protected readonly featureFlagRouter_: FlagRouter constructor({ manager, @@ -71,6 +74,7 @@ class DiscountService extends TransactionBaseService { regionService, customerService, eventBusService, + featureFlagRouter, }) { // eslint-disable-next-line prefer-rest-params super(arguments[0]) @@ -86,6 +90,7 @@ class DiscountService extends TransactionBaseService { this.regionService_ = regionService this.customerService_ = customerService this.eventBus_ = eventBusService + this.featureFlagRouter_ = featureFlagRouter } /** @@ -579,7 +584,23 @@ class DiscountService extends TransactionBaseService { const { type, value, allocation } = discount.rule - const fullItemPrice = lineItem.unit_price * lineItem.quantity + let fullItemPrice = lineItem.unit_price * lineItem.quantity + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && + lineItem.includes_tax + ) { + const lineItemTotals = await this.totalsService_.getLineItemTotals( + lineItem, + cart, + { + include_tax: true, + exclude_gift_cards: true, + } + ) + fullItemPrice = lineItemTotals.subtotal + } if (type === DiscountRuleType.PERCENTAGE) { adjustment = Math.round((fullItemPrice / 100) * value) @@ -590,12 +611,11 @@ class DiscountService extends TransactionBaseService { // when a fixed discount should be applied to the total, // we create line adjustments for each item with an amount // relative to the subtotal - const subtotal = this.totalsService_.getSubtotal(cart, { + const subtotal = await this.totalsService_.getSubtotal(cart, { excludeNonDiscounts: true, }) const nominator = Math.min(value, subtotal) - const itemRelativeToSubtotal = lineItem.unit_price / subtotal - const totalItemPercentage = itemRelativeToSubtotal * lineItem.quantity + const totalItemPercentage = fullItemPrice / subtotal adjustment = Math.round(nominator * totalItemPercentage) } else { adjustment = value * lineItem.quantity diff --git a/packages/medusa/src/services/index.ts b/packages/medusa/src/services/index.ts index c6feace3ec..466ba6ca94 100644 --- a/packages/medusa/src/services/index.ts +++ b/packages/medusa/src/services/index.ts @@ -3,6 +3,7 @@ export { default as BatchJobService } from "./batch-job" export { default as CartService } from "./cart" export { default as ClaimItemService } from "./claim-item" export { default as ClaimService } from "./claim" +export { default as CurrencyService } from "./currency" export { default as CustomShippingOptionService } from "./custom-shipping-option" export { default as CustomerGroupService } from "./customer-group" export { default as CustomerService } from "./customer" diff --git a/packages/medusa/src/services/line-item-adjustment.ts b/packages/medusa/src/services/line-item-adjustment.ts index 4331a1835b..3514ab7125 100644 --- a/packages/medusa/src/services/line-item-adjustment.ts +++ b/packages/medusa/src/services/line-item-adjustment.ts @@ -1,10 +1,13 @@ import { MedusaError } from "medusa-core-utils" import { BaseService } from "medusa-interfaces" import { EntityManager } from "typeorm" -import { Cart } from "../models/cart" -import { LineItem } from "../models/line-item" -import { LineItemAdjustment } from "../models/line-item-adjustment" -import { ProductVariant } from "../models/product-variant" +import { + Cart, + DiscountRuleType, + LineItem, + LineItemAdjustment, + ProductVariant, +} from "../models" import { LineItemAdjustmentRepository } from "../repositories/line-item-adjustment" import { FindConfig } from "../types/common" import { FilterableLineItemAdjustmentProps } from "../types/line-item-adjustment" @@ -70,8 +73,9 @@ class LineItemAdjustmentService extends BaseService { id: string, config: FindConfig = {} ): Promise { - const lineItemAdjustmentRepo: LineItemAdjustmentRepository = - this.manager_.getCustomRepository(this.lineItemAdjustmentRepo_) + const lineItemAdjustmentRepo: LineItemAdjustmentRepository = this.manager_.getCustomRepository( + this.lineItemAdjustmentRepo_ + ) const query = this.buildQuery_({ id }, config) const lineItemAdjustment = await lineItemAdjustmentRepo.findOne(query) @@ -93,8 +97,9 @@ class LineItemAdjustmentService extends BaseService { */ async create(data: Partial): Promise { return await this.atomicPhase_(async (manager: EntityManager) => { - const lineItemAdjustmentRepo: LineItemAdjustmentRepository = - manager.getCustomRepository(this.lineItemAdjustmentRepo_) + const lineItemAdjustmentRepo: LineItemAdjustmentRepository = manager.getCustomRepository( + this.lineItemAdjustmentRepo_ + ) const lineItemAdjustment = lineItemAdjustmentRepo.create(data) @@ -113,8 +118,9 @@ class LineItemAdjustmentService extends BaseService { data: Partial ): Promise { return await this.atomicPhase_(async (manager: EntityManager) => { - const lineItemAdjustmentRepo: LineItemAdjustmentRepository = - manager.getCustomRepository(this.lineItemAdjustmentRepo_) + const lineItemAdjustmentRepo: LineItemAdjustmentRepository = manager.getCustomRepository( + this.lineItemAdjustmentRepo_ + ) const lineItemAdjustment = await this.retrieve(id) @@ -163,8 +169,9 @@ class LineItemAdjustmentService extends BaseService { selectorOrId: string | FilterableLineItemAdjustmentProps ): Promise { return this.atomicPhase_(async (manager) => { - const lineItemAdjustmentRepo: LineItemAdjustmentRepository = - manager.getCustomRepository(this.lineItemAdjustmentRepo_) + const lineItemAdjustmentRepo: LineItemAdjustmentRepository = manager.getCustomRepository( + this.lineItemAdjustmentRepo_ + ) if (typeof selectorOrId === "string") { return await this.delete({ id: selectorOrId }) @@ -206,7 +213,7 @@ class LineItemAdjustmentService extends BaseService { } const [discount] = cart.discounts.filter( - (d) => d.rule.type !== "free_shipping" + (d) => d.rule.type !== DiscountRuleType.FREE_SHIPPING ) // if no discount is applied to the cart then return diff --git a/packages/medusa/src/services/line-item.ts b/packages/medusa/src/services/line-item.ts index 4a0f19ee5c..87ed1dac83 100644 --- a/packages/medusa/src/services/line-item.ts +++ b/packages/medusa/src/services/line-item.ts @@ -2,21 +2,23 @@ import { MedusaError } from "medusa-core-utils" import { BaseService } from "medusa-interfaces" import { EntityManager } from "typeorm" import { DeepPartial } from "typeorm/common/DeepPartial" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" +import { LineItemTaxLine } from "../models" +import { Cart } from "../models/cart" +import { LineItem } from "../models/line-item" +import { LineItemAdjustment } from "../models/line-item-adjustment" +import { CartRepository } from "../repositories/cart" import { LineItemRepository } from "../repositories/line-item" import { LineItemTaxLineRepository } from "../repositories/line-item-tax-line" +import { FindConfig } from "../types/common" +import { FlagRouter } from "../utils/flag-router" import { PricingService, ProductService, - RegionService, ProductVariantService, + RegionService, } from "./index" -import { CartRepository } from "../repositories/cart" -import { LineItem } from "../models/line-item" import LineItemAdjustmentService from "./line-item-adjustment" -import { Cart } from "../models/cart" -import { LineItemAdjustment } from "../models/line-item-adjustment" -import { FindConfig } from "../types/common" -import { LineItemTaxLine } from "../models" type InjectedDependencies = { manager: EntityManager @@ -28,6 +30,7 @@ type InjectedDependencies = { pricingService: PricingService regionService: RegionService lineItemAdjustmentService: LineItemAdjustmentService + featureFlagRouter: FlagRouter } /** @@ -41,7 +44,9 @@ class LineItemService extends BaseService { protected readonly cartRepository_: typeof CartRepository protected readonly productVariantService_: ProductVariantService protected readonly productService_: ProductService + protected readonly pricingService_: PricingService protected readonly regionService_: RegionService + protected readonly featureFlagRouter_: FlagRouter protected readonly lineItemAdjustmentService_: LineItemAdjustmentService constructor({ @@ -54,6 +59,7 @@ class LineItemService extends BaseService { regionService, cartRepository, lineItemAdjustmentService, + featureFlagRouter, }: InjectedDependencies) { super() @@ -66,6 +72,7 @@ class LineItemService extends BaseService { this.regionService_ = regionService this.cartRepository_ = cartRepository this.lineItemAdjustmentService_ = lineItemAdjustmentService + this.featureFlagRouter_ = featureFlagRouter } withTransaction(transactionManager: EntityManager): LineItemService { @@ -83,6 +90,7 @@ class LineItemService extends BaseService { regionService: this.regionService_, cartRepository: this.cartRepository_, lineItemAdjustmentService: this.lineItemAdjustmentService_, + featureFlagRouter: this.featureFlagRouter_, }) cloned.transactionManager_ = transactionManager @@ -197,6 +205,7 @@ class LineItemService extends BaseService { quantity: number, context: { unit_price?: number + includes_tax?: boolean metadata?: Record customer_id?: string cart?: Cart @@ -216,6 +225,9 @@ class LineItemService extends BaseService { ]) let unit_price = Number(context.unit_price) < 0 ? 0 : context.unit_price + + let unitPriceIncludesTax = false + let shouldMerge = false if (context.unit_price === undefined || context.unit_price === null) { @@ -228,7 +240,10 @@ class LineItemService extends BaseService { customer_id: context?.customer_id, include_discount_prices: true, }) - unit_price = variantPricing.calculated_price + + unitPriceIncludesTax = !!variantPricing.calculated_price_includes_tax + + unit_price = variantPricing.calculated_price ?? undefined } const rawLineItem: Partial = { @@ -244,6 +259,14 @@ class LineItemService extends BaseService { should_merge: shouldMerge, } + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + rawLineItem.includes_tax = unitPriceIncludesTax + } + const lineItemRepo = transactionManager.getCustomRepository( this.lineItemRepository_ ) diff --git a/packages/medusa/src/services/order.ts b/packages/medusa/src/services/order.ts index 34d63507f1..ba43df641f 100644 --- a/packages/medusa/src/services/order.ts +++ b/packages/medusa/src/services/order.ts @@ -164,8 +164,9 @@ class OrderService extends TransactionBaseService { const orderRepo = this.manager_.getCustomRepository(this.orderRepository_) const query = buildQuery(selector, config) - const { select, relations, totalsToSelect } = - this.transformQueryForTotals(config) + const { select, relations, totalsToSelect } = this.transformQueryForTotals( + config + ) if (select && select.length) { query.select = select @@ -233,8 +234,9 @@ class OrderService extends TransactionBaseService { } } - const { select, relations, totalsToSelect } = - this.transformQueryForTotals(config) + const { select, relations, totalsToSelect } = this.transformQueryForTotals( + config + ) if (select && select.length) { query.select = select @@ -252,7 +254,9 @@ class OrderService extends TransactionBaseService { return [orders, count] } - protected transformQueryForTotals(config: FindConfig): { + protected transformQueryForTotals( + config: FindConfig + ): { relations: string[] | undefined select: FindConfig["select"] totalsToSelect: FindConfig["select"] @@ -333,8 +337,9 @@ class OrderService extends TransactionBaseService { ): Promise { const orderRepo = this.manager_.getCustomRepository(this.orderRepository_) - const { select, relations, totalsToSelect } = - this.transformQueryForTotals(config) + const { select, relations, totalsToSelect } = this.transformQueryForTotals( + config + ) const query = { where: { id: orderId }, @@ -373,8 +378,9 @@ class OrderService extends TransactionBaseService { ): Promise { const orderRepo = this.manager_.getCustomRepository(this.orderRepository_) - const { select, relations, totalsToSelect } = - this.transformQueryForTotals(config) + const { select, relations, totalsToSelect } = this.transformQueryForTotals( + config + ) const query = { where: { cart_id: cartId }, @@ -412,8 +418,9 @@ class OrderService extends TransactionBaseService { ): Promise { const orderRepo = this.manager_.getCustomRepository(this.orderRepository_) - const { select, relations, totalsToSelect } = - this.transformQueryForTotals(config) + const { select, relations, totalsToSelect } = this.transformQueryForTotals( + config + ) const query = { where: { external_id: externalId }, @@ -851,8 +858,9 @@ class OrderService extends TransactionBaseService { .withTransaction(manager) .createShippingMethod(optionId, data ?? {}, { order, ...config }) - const shippingOptionServiceTx = - this.shippingOptionService_.withTransaction(manager) + const shippingOptionServiceTx = this.shippingOptionService_.withTransaction( + manager + ) const methods = [newMethod] if (shipping_methods.length) { @@ -1023,8 +1031,9 @@ class OrderService extends TransactionBaseService { await inventoryServiceTx.adjustInventory(item.variant_id, item.quantity) } - const paymentProviderServiceTx = - this.paymentProviderService_.withTransaction(manager) + const paymentProviderServiceTx = this.paymentProviderService_.withTransaction( + manager + ) for (const p of order.payments) { await paymentProviderServiceTx.cancelPayment(p) } @@ -1064,8 +1073,9 @@ class OrderService extends TransactionBaseService { ) } - const paymentProviderServiceTx = - this.paymentProviderService_.withTransaction(manager) + const paymentProviderServiceTx = this.paymentProviderService_.withTransaction( + manager + ) const payments: Payment[] = [] for (const p of order.payments) { @@ -1218,7 +1228,7 @@ class OrderService extends TransactionBaseService { const fulfillments = await this.fulfillmentService_ .withTransaction(manager) .createFulfillment( - order as unknown as CreateFulfillmentOrder, + (order as unknown) as CreateFulfillmentOrder, itemsToFulfill, { metadata, @@ -1429,17 +1439,23 @@ class OrderService extends TransactionBaseService { for (const totalField of totalsFields) { switch (totalField) { case "shipping_total": { - order.shipping_total = this.totalsService_.getShippingTotal(order) + order.shipping_total = await this.totalsService_.getShippingTotal( + order + ) break } case "gift_card_total": { - const giftCardBreakdown = this.totalsService_.getGiftCardTotal(order) + const giftCardBreakdown = await this.totalsService_.getGiftCardTotal( + order + ) order.gift_card_total = giftCardBreakdown.total order.gift_card_tax_total = giftCardBreakdown.tax_total break } case "discount_total": { - order.discount_total = this.totalsService_.getDiscountTotal(order) + order.discount_total = await this.totalsService_.getDiscountTotal( + order + ) break } case "tax_total": { @@ -1447,7 +1463,7 @@ class OrderService extends TransactionBaseService { break } case "subtotal": { - order.subtotal = this.totalsService_.getSubtotal(order) + order.subtotal = await this.totalsService_.getSubtotal(order) break } case "total": { @@ -1471,36 +1487,48 @@ class OrderService extends TransactionBaseService { break } case "items.refundable": { - order.items = order.items.map((i) => ({ - ...i, - refundable: this.totalsService_.getLineItemRefund(order, { - ...i, - quantity: i.quantity - (i.returned_quantity || 0), - } as LineItem), - })) as LineItem[] + const items: LineItem[] = [] + for (const item of order.items) { + items.push({ + ...item, + refundable: await this.totalsService_.getLineItemRefund(order, { + ...item, + quantity: item.quantity - (item.returned_quantity || 0), + } as LineItem), + } as LineItem) + } + order.items = items break } case "swaps.additional_items.refundable": { for (const s of order.swaps) { - s.additional_items = s.additional_items.map((i) => ({ - ...i, - refundable: this.totalsService_.getLineItemRefund(order, { - ...i, - quantity: i.quantity - (i.returned_quantity || 0), - } as LineItem), - })) as LineItem[] + const items: LineItem[] = [] + for (const item of s.additional_items) { + items.push({ + ...item, + refundable: await this.totalsService_.getLineItemRefund(order, { + ...item, + quantity: item.quantity - (item.returned_quantity || 0), + } as LineItem), + } as LineItem) + } + s.additional_items = items } break } case "claims.additional_items.refundable": { for (const c of order.claims) { - c.additional_items = c.additional_items.map((i) => ({ - ...i, - refundable: this.totalsService_.getLineItemRefund(order, { - ...i, - quantity: i.quantity - (i.returned_quantity || 0), - } as LineItem), - })) as LineItem[] + const items: LineItem[] = [] + for (const item of c.additional_items) { + items.push({ + ...item, + refundable: await this.totalsService_.getLineItemRefund(order, { + ...item, + quantity: item.quantity - (item.returned_quantity || 0), + } as LineItem), + } as LineItem) + } + c.additional_items = items } break } diff --git a/packages/medusa/src/services/price-list.ts b/packages/medusa/src/services/price-list.ts index c15b45b30e..e7083b7791 100644 --- a/packages/medusa/src/services/price-list.ts +++ b/packages/medusa/src/services/price-list.ts @@ -1,5 +1,5 @@ import { MedusaError } from "medusa-core-utils" -import { EntityManager, FindOperator } from "typeorm" +import { DeepPartial, EntityManager, FindOperator } from "typeorm" import { CustomerGroupService } from "." import { CustomerGroup, PriceList, Product, ProductVariant } from "../models" import { MoneyAmountRepository } from "../repositories/money-amount" @@ -24,6 +24,8 @@ import { FilterableProductProps } from "../types/product" import ProductVariantService from "./product-variant" import { FilterableProductVariantProps } from "../types/product-variant" import { ProductVariantRepository } from "../repositories/product-variant" +import { FlagRouter } from "../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" type PriceListConstructorProps = { manager: EntityManager @@ -34,6 +36,7 @@ type PriceListConstructorProps = { priceListRepository: typeof PriceListRepository moneyAmountRepository: typeof MoneyAmountRepository productVariantRepository: typeof ProductVariantRepository + featureFlagRouter: FlagRouter } /** @@ -51,6 +54,7 @@ class PriceListService extends TransactionBaseService { protected readonly priceListRepo_: typeof PriceListRepository protected readonly moneyAmountRepo_: typeof MoneyAmountRepository protected readonly productVariantRepo_: typeof ProductVariantRepository + protected readonly featureFlagRouter_: FlagRouter constructor({ manager, @@ -61,6 +65,7 @@ class PriceListService extends TransactionBaseService { priceListRepository, moneyAmountRepository, productVariantRepository, + featureFlagRouter, }: PriceListConstructorProps) { // eslint-disable-next-line prefer-rest-params super(arguments[0]) @@ -73,6 +78,7 @@ class PriceListService extends TransactionBaseService { this.priceListRepo_ = priceListRepository this.moneyAmountRepo_ = moneyAmountRepository this.productVariantRepo_ = productVariantRepository + this.featureFlagRouter_ = featureFlagRouter } /** @@ -102,18 +108,34 @@ class PriceListService extends TransactionBaseService { /** * Creates a Price List - * @param {CreatePriceListInput} priceListObject - the Price List to create - * @return {Promise} created Price List + * @param priceListObject - the Price List to create + * @return created Price List */ - async create(priceListObject: CreatePriceListInput): Promise { + async create( + priceListObject: CreatePriceListInput + ): Promise { return await this.atomicPhase_(async (manager: EntityManager) => { const priceListRepo = manager.getCustomRepository(this.priceListRepo_) const moneyAmountRepo = manager.getCustomRepository(this.moneyAmountRepo_) - const { prices, customer_groups, ...rest } = priceListObject + const { prices, customer_groups, includes_tax, ...rest } = priceListObject try { - const entity = priceListRepo.create(rest) + const rawPriceList: DeepPartial = { + ...rest, + } + + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + if (typeof includes_tax !== "undefined") { + rawPriceList.includes_tax = includes_tax + } + } + + const entity = priceListRepo.create(rawPriceList) const priceList = await priceListRepo.save(entity) @@ -125,11 +147,9 @@ class PriceListService extends TransactionBaseService { await this.upsertCustomerGroups_(priceList.id, customer_groups) } - const result = await this.retrieve(priceList.id, { + return await this.retrieve(priceList.id, { relations: ["prices", "customer_groups"], }) - - return result } catch (error) { throw formatException(error) } @@ -149,7 +169,17 @@ class PriceListService extends TransactionBaseService { const priceList = await this.retrieve(id, { select: ["id"] }) - const { prices, customer_groups, ...rest } = update + const { prices, customer_groups, includes_tax, ...rest } = update + + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + if (typeof includes_tax !== "undefined") { + priceList.includes_tax = includes_tax + } + } if (prices) { const prices_ = await this.addCurrencyFromRegion(prices) diff --git a/packages/medusa/src/services/pricing.ts b/packages/medusa/src/services/pricing.ts index 6af05e28ef..25d8f5b60e 100644 --- a/packages/medusa/src/services/pricing.ts +++ b/packages/medusa/src/services/pricing.ts @@ -1,21 +1,24 @@ -import { EntityManager } from "typeorm" import { MedusaError } from "medusa-core-utils" +import { EntityManager } from "typeorm" import { ProductVariantService, RegionService, TaxProviderService } from "." -import { Product, ProductVariant, ShippingOption } from "../models" -import { TaxServiceRate } from "../types/tax-service" -import { - ProductVariantPricing, - TaxedPricing, - PricingContext, - PricedProduct, - PricedShippingOption, - PricedVariant, -} from "../types/pricing" import { TransactionBaseService } from "../interfaces" import { IPriceSelectionStrategy, PriceSelectionContext, } from "../interfaces/price-selection-strategy" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" +import { Product, ProductVariant, ShippingOption } from "../models" +import { + PricedProduct, + PricedShippingOption, + PricedVariant, + PricingContext, + ProductVariantPricing, + TaxedPricing, +} from "../types/pricing" +import { TaxServiceRate } from "../types/tax-service" +import { calculatePriceTaxAmount } from "../utils" +import { FlagRouter } from "../utils/flag-router" type InjectedDependencies = { manager: EntityManager @@ -23,6 +26,7 @@ type InjectedDependencies = { taxProviderService: TaxProviderService regionService: RegionService priceSelectionStrategy: IPriceSelectionStrategy + featureFlagRouter: FlagRouter } /** @@ -36,6 +40,7 @@ class PricingService extends TransactionBaseService { protected readonly taxProviderService: TaxProviderService protected readonly priceSelectionStrategy: IPriceSelectionStrategy protected readonly productVariantService: ProductVariantService + protected readonly featureFlagRouter: FlagRouter constructor({ manager, @@ -43,6 +48,7 @@ class PricingService extends TransactionBaseService { taxProviderService, regionService, priceSelectionStrategy, + featureFlagRouter, }: InjectedDependencies) { // eslint-disable-next-line prefer-rest-params super(arguments[0]) @@ -52,6 +58,7 @@ class PricingService extends TransactionBaseService { this.taxProviderService = taxProviderService this.priceSelectionStrategy = priceSelectionStrategy this.productVariantService = productVariantService + this.featureFlagRouter = featureFlagRouter } /** @@ -115,17 +122,43 @@ class PricingService extends TransactionBaseService { } if (variantPricing.calculated_price !== null) { - const taxAmount = Math.round(variantPricing.calculated_price * rate) - taxedPricing.calculated_tax = taxAmount + const includesTax = !!( + this.featureFlagRouter.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && variantPricing.calculated_price_includes_tax + ) + taxedPricing.calculated_tax = Math.round( + calculatePriceTaxAmount({ + price: variantPricing.calculated_price, + taxRate: rate, + includesTax, + }) + ) + taxedPricing.calculated_price_incl_tax = - variantPricing.calculated_price + taxAmount + variantPricing.calculated_price_includes_tax + ? variantPricing.calculated_price + : variantPricing.calculated_price + taxedPricing.calculated_tax } if (variantPricing.original_price !== null) { - const taxAmount = Math.round(variantPricing.original_price * rate) - taxedPricing.original_tax = taxAmount + const includesTax = !!( + this.featureFlagRouter.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && variantPricing.original_price_includes_tax + ) + taxedPricing.original_tax = Math.round( + calculatePriceTaxAmount({ + price: variantPricing.original_price, + taxRate: rate, + includesTax, + }) + ) + taxedPricing.original_price_incl_tax = - variantPricing.original_price + taxAmount + variantPricing.original_price_includes_tax + ? variantPricing.original_price + : variantPricing.original_price + taxedPricing.original_tax } return taxedPricing @@ -137,6 +170,9 @@ class PricingService extends TransactionBaseService { context: PricingContext ): Promise { const transactionManager = this.transactionManager_ ?? this.manager_ + + context.price_selection.tax_rates = taxRates + const pricing = await this.priceSelectionStrategy .withTransaction(transactionManager) .calculateVariantPrice(variantId, context.price_selection) @@ -146,6 +182,8 @@ class PricingService extends TransactionBaseService { original_price: pricing.originalPrice, calculated_price: pricing.calculatedPrice, calculated_price_type: pricing.calculatedPriceType, + original_price_includes_tax: pricing.originalPriceIncludesTax, + calculated_price_includes_tax: pricing.calculatedPriceIncludesTax, original_price_incl_tax: null, calculated_price_incl_tax: null, original_tax: null, @@ -418,14 +456,29 @@ class PricingService extends TransactionBaseService { }, 0 ) - const tax = Math.round(price * rate) - const total = price + tax - return { + const includesTax = + this.featureFlagRouter.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && shippingOption.includes_tax + + const taxAmount = Math.round( + calculatePriceTaxAmount({ + taxRate: rate, + price, + includesTax, + }) + ) + const totalInclTax = includesTax ? price : price + taxAmount + + const result: PricedShippingOption = { ...shippingOption, - price_incl_tax: total, + price_incl_tax: totalInclTax, tax_rates: shippingOptionRates, + tax_amount: taxAmount, } + + return result } /** diff --git a/packages/medusa/src/services/region.ts b/packages/medusa/src/services/region.ts index 8524687686..18b4f932e8 100644 --- a/packages/medusa/src/services/region.ts +++ b/packages/medusa/src/services/region.ts @@ -2,22 +2,24 @@ import { DeepPartial, EntityManager } from "typeorm" import { MedusaError } from "medusa-core-utils" -import StoreService from "./store" -import EventBusService from "./event-bus" -import { countries } from "../utils/countries" import { TransactionBaseService } from "../interfaces" -import { RegionRepository } from "../repositories/region" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" +import { Country, Currency, Region } from "../models" import { CountryRepository } from "../repositories/country" import { CurrencyRepository } from "../repositories/currency" -import { PaymentProviderRepository } from "../repositories/payment-provider" import { FulfillmentProviderRepository } from "../repositories/fulfillment-provider" +import { PaymentProviderRepository } from "../repositories/payment-provider" +import { RegionRepository } from "../repositories/region" import { TaxProviderRepository } from "../repositories/tax-provider" -import FulfillmentProviderService from "./fulfillment-provider" -import { Country, Currency, Region } from "../models" import { FindConfig, Selector } from "../types/common" import { CreateRegionInput, UpdateRegionInput } from "../types/region" import { buildQuery, setMetadata } from "../utils" +import { countries } from "../utils/countries" +import { FlagRouter } from "../utils/flag-router" +import EventBusService from "./event-bus" +import FulfillmentProviderService from "./fulfillment-provider" import { PaymentProviderService } from "./index" +import StoreService from "./store" type InjectedDependencies = { manager: EntityManager @@ -25,6 +27,7 @@ type InjectedDependencies = { eventBusService: EventBusService paymentProviderService: PaymentProviderService fulfillmentProviderService: FulfillmentProviderService + featureFlagRouter: FlagRouter regionRepository: typeof RegionRepository countryRepository: typeof CountryRepository @@ -47,6 +50,7 @@ class RegionService extends TransactionBaseService { protected manager_: EntityManager protected transactionManager_: EntityManager | undefined + protected featureFlagRouter_: FlagRouter protected readonly eventBus_: EventBusService protected readonly storeService_: StoreService @@ -71,6 +75,7 @@ class RegionService extends TransactionBaseService { taxProviderRepository, paymentProviderService, fulfillmentProviderService, + featureFlagRouter, }: InjectedDependencies) { super({ manager, @@ -84,6 +89,7 @@ class RegionService extends TransactionBaseService { taxProviderRepository, paymentProviderService, fulfillmentProviderService, + featureFlagRouter, }) this.manager_ = manager @@ -97,6 +103,8 @@ class RegionService extends TransactionBaseService { this.paymentProviderService_ = paymentProviderService this.taxProviderRepository_ = taxProviderRepository this.fulfillmentProviderService_ = fulfillmentProviderService + + this.featureFlagRouter_ = featureFlagRouter } /** @@ -115,10 +123,20 @@ class RegionService extends TransactionBaseService { ) const regionObject = { ...data } as DeepPartial - const { metadata, currency_code, ...toValidate } = data + const { metadata, currency_code, includes_tax, ...toValidate } = data const validated = await this.validateFields(toValidate) + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + if (typeof includes_tax !== "undefined") { + regionObject.includes_tax = includes_tax + } + } + if (currency_code) { // will throw if currency is not added to store currencies await this.validateCurrency(currency_code) @@ -179,10 +197,20 @@ class RegionService extends TransactionBaseService { const region = await this.retrieve(regionId) - const { metadata, currency_code, ...toValidate } = update + const { metadata, currency_code, includes_tax, ...toValidate } = update const validated = await this.validateFields(toValidate, region.id) + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + if (typeof includes_tax !== "undefined") { + region.includes_tax = includes_tax + } + } + if (currency_code) { // will throw if currency is not added to store currencies await this.validateCurrency(currency_code) diff --git a/packages/medusa/src/services/return.ts b/packages/medusa/src/services/return.ts index 30308c9e8a..40f79460de 100644 --- a/packages/medusa/src/services/return.ts +++ b/packages/medusa/src/services/return.ts @@ -415,7 +415,7 @@ class ReturnService extends TransactionBaseService { } } else { // Merchant hasn't specified refund amount so we calculate it - toRefund = this.totalsService_.getRefundTotal(order, returnLines) + toRefund = await this.totalsService_.getRefundTotal(order, returnLines) } const method = data.shipping_method @@ -469,7 +469,7 @@ class ReturnService extends TransactionBaseService { ) const calculationContext = - this.totalsService_.getCalculationContext(order) + await this.totalsService_.getCalculationContext(order) const taxLines = await this.taxProviderService_ .withTransaction(manager) diff --git a/packages/medusa/src/services/shipping-option.ts b/packages/medusa/src/services/shipping-option.ts index 83dee51501..0cdcbd3280 100644 --- a/packages/medusa/src/services/shipping-option.ts +++ b/packages/medusa/src/services/shipping-option.ts @@ -22,6 +22,18 @@ import { import { buildQuery, isDefined, setMetadata } from "../utils" import FulfillmentProviderService from "./fulfillment-provider" import RegionService from "./region" +import { FlagRouter } from "../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" + +type InjectedDependencies = { + manager: EntityManager + fulfillmentProviderService: FulfillmentProviderService + regionService: RegionService + shippingOptionRequirementRepository: typeof ShippingOptionRequirementRepository + shippingOptionRepository: typeof ShippingOptionRepository + shippingMethodRepository: typeof ShippingMethodRepository + featureFlagRouter: FlagRouter +} /** * Provides layer to manipulate profiles. @@ -32,6 +44,7 @@ class ShippingOptionService extends TransactionBaseService { protected readonly requirementRepository_: typeof ShippingOptionRequirementRepository protected readonly optionRepository_: typeof ShippingOptionRepository protected readonly methodRepository_: typeof ShippingMethodRepository + protected readonly featureFlagRouter_: FlagRouter protected manager_: EntityManager protected transactionManager_: EntityManager | undefined @@ -43,7 +56,8 @@ class ShippingOptionService extends TransactionBaseService { shippingMethodRepository, fulfillmentProviderService, regionService, - }) { + featureFlagRouter, + }: InjectedDependencies) { // eslint-disable-next-line prefer-rest-params super(arguments[0]) @@ -53,6 +67,7 @@ class ShippingOptionService extends TransactionBaseService { this.requirementRepository_ = shippingOptionRequirementRepository this.providerService_ = fulfillmentProviderService this.regionService_ = regionService + this.featureFlagRouter_ = featureFlagRouter } /** @@ -285,6 +300,16 @@ class ShippingOptionService extends TransactionBaseService { price: methodPrice, } + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + if (typeof option.includes_tax !== "undefined") { + toCreate.includes_tax = option.includes_tax + } + } + if (config.order) { toCreate.order_id = config.order.id } @@ -313,10 +338,10 @@ class ShippingOptionService extends TransactionBaseService { const created = await methodRepo.save(method) - return methodRepo.findOne({ + return (await methodRepo.findOne({ where: { id: created.id }, relations: ["shipping_option"], - }) as unknown as ShippingMethod + })) as ShippingMethod }) } @@ -405,6 +430,16 @@ class ShippingOptionService extends TransactionBaseService { option.amount = data.price_type === "calculated" ? null : data.amount ?? null + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + if (typeof data.includes_tax !== "undefined") { + option.includes_tax = data.includes_tax + } + } + const isValid = await this.providerService_.validateOption(option) if (!isValid) { @@ -584,6 +619,16 @@ class ShippingOptionService extends TransactionBaseService { option.admin_only = update.admin_only } + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + if (typeof update.includes_tax !== "undefined") { + option.includes_tax = update.includes_tax + } + } + const optionRepo = manager.getCustomRepository(this.optionRepository_) return await optionRepo.save(option) }) diff --git a/packages/medusa/src/services/store.ts b/packages/medusa/src/services/store.ts index 348efb43ed..9a15b39dde 100644 --- a/packages/medusa/src/services/store.ts +++ b/packages/medusa/src/services/store.ts @@ -1,13 +1,13 @@ import { MedusaError } from "medusa-core-utils" import { EntityManager } from "typeorm" import { TransactionBaseService } from "../interfaces" -import { Store } from "../models" +import { Currency, Store } from "../models" import { CurrencyRepository } from "../repositories/currency" import { StoreRepository } from "../repositories/store" import { FindConfig } from "../types/common" import { UpdateStoreInput } from "../types/store" import { buildQuery, setMetadata } from "../utils" -import { currencies, Currency } from "../utils/currencies" +import { currencies } from "../utils/currencies" import EventBusService from "./event-bus" type InjectedDependencies = { diff --git a/packages/medusa/src/services/totals.ts b/packages/medusa/src/services/totals.ts index 5ff720d5e4..59dc48722d 100644 --- a/packages/medusa/src/services/totals.ts +++ b/packages/medusa/src/services/totals.ts @@ -1,18 +1,19 @@ import { MedusaError } from "medusa-core-utils" -import { BaseService } from "medusa-interfaces" import { ITaxCalculationStrategy, TaxCalculationContext, TransactionBaseService, } from "../interfaces" -import { Cart } from "../models/cart" -import { Discount } from "../models/discount" -import { DiscountRuleType } from "../models/discount-rule" -import { LineItem } from "../models/line-item" -import { LineItemTaxLine } from "../models/line-item-tax-line" -import { Order } from "../models/order" -import { ShippingMethod } from "../models/shipping-method" -import { ShippingMethodTaxLine } from "../models/shipping-method-tax-line" +import { + Cart, + Discount, + DiscountRuleType, + LineItem, + LineItemTaxLine, + Order, + ShippingMethod, + ShippingMethodTaxLine, +} from "../models" import { isCart } from "../types/cart" import { isOrder } from "../types/orders" import { @@ -23,12 +24,16 @@ import { } from "../types/totals" import TaxProviderService from "./tax-provider" import { EntityManager } from "typeorm" -import { isDefined } from "../utils" + +import { calculatePriceTaxAmount, isDefined } from "../utils" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" +import { FlagRouter } from "../utils/flag-router" type ShippingMethodTotals = { price: number tax_total: number total: number + subtotal: number original_total: number original_tax_total: number tax_lines: ShippingMethodTaxLine[] @@ -55,6 +60,7 @@ type LineItemTotals = { type LineItemTotalsOptions = { include_tax?: boolean use_tax_lines?: boolean + exclude_gift_cards?: boolean } type GetLineItemTotalOptions = { @@ -67,6 +73,7 @@ type TotalsServiceProps = { taxProviderService: TaxProviderService taxCalculationStrategy: ITaxCalculationStrategy manager: EntityManager + featureFlagRouter: FlagRouter } type GetTotalsOptions = { @@ -94,24 +101,29 @@ class TotalsService extends TransactionBaseService { protected manager_: EntityManager protected transactionManager_: EntityManager - private taxProviderService_: TaxProviderService - private taxCalculationStrategy_: ITaxCalculationStrategy + protected readonly taxProviderService_: TaxProviderService + protected readonly taxCalculationStrategy_: ITaxCalculationStrategy + protected readonly featureFlagRouter_: FlagRouter constructor({ manager, taxProviderService, taxCalculationStrategy, + featureFlagRouter, }: TotalsServiceProps) { super({ taxProviderService, taxCalculationStrategy, manager, + featureFlagRouter, }) this.manager_ = manager this.taxProviderService_ = taxProviderService this.taxCalculationStrategy_ = taxCalculationStrategy + this.manager_ = manager + this.featureFlagRouter_ = featureFlagRouter } /** @@ -124,14 +136,14 @@ class TotalsService extends TransactionBaseService { cartOrOrder: Cart | Order, options: GetTotalsOptions = {} ): Promise { - const subtotal = this.getSubtotal(cartOrOrder) + const subtotal = await this.getSubtotal(cartOrOrder) const taxTotal = (await this.getTaxTotal(cartOrOrder, options.force_taxes)) || 0 - const discountTotal = this.getDiscountTotal(cartOrOrder) + const discountTotal = await this.getDiscountTotal(cartOrOrder) const giftCardTotal = options.exclude_gift_cards ? { total: 0 } - : this.getGiftCardTotal(cartOrOrder) - const shippingTotal = this.getShippingTotal(cartOrOrder) + : await this.getGiftCardTotal(cartOrOrder) + const shippingTotal = await this.getShippingTotal(cartOrOrder) return ( subtotal + taxTotal + shippingTotal - discountTotal - giftCardTotal.total @@ -182,7 +194,7 @@ class TotalsService extends TransactionBaseService { cartOrOrder: Cart | Order, opts: GetShippingMethodTotalsOptions = {} ): Promise { - const calculationContext = this.getCalculationContext(cartOrOrder, { + const calculationContext = await this.getCalculationContext(cartOrOrder, { exclude_shipping: true, }) calculationContext.shipping_methods = [shippingMethod] @@ -191,64 +203,70 @@ class TotalsService extends TransactionBaseService { price: shippingMethod.price, original_total: shippingMethod.price, total: shippingMethod.price, + subtotal: shippingMethod.price, original_tax_total: 0, tax_total: 0, tax_lines: shippingMethod.tax_lines || [], } if (opts.include_tax) { - if (isOrder(cartOrOrder) && cartOrOrder.tax_rate !== null) { + if (isOrder(cartOrOrder) && cartOrOrder.tax_rate != null) { totals.original_tax_total = Math.round( - totals.original_tax_total * (cartOrOrder.tax_rate / 100) + totals.price * (cartOrOrder.tax_rate / 100) ) totals.tax_total = Math.round( - totals.original_tax_total * (cartOrOrder.tax_rate / 100) + totals.price * (cartOrOrder.tax_rate / 100) ) - } else { - let taxLines: ShippingMethodTaxLine[] - if (opts.use_tax_lines || isOrder(cartOrOrder)) { - if (typeof shippingMethod.tax_lines === "undefined") { - throw new MedusaError( - MedusaError.Types.UNEXPECTED_STATE, - "Tax Lines must be joined on shipping method to calculate taxes" - ) + } else if (totals.tax_lines.length === 0) { + const orderLines = await this.taxProviderService_ + .withTransaction(this.manager_) + .getTaxLines(cartOrOrder.items, calculationContext) + + totals.tax_lines = orderLines.filter((ol) => { + if ("shipping_method_id" in ol) { + return ol.shipping_method_id === shippingMethod.id } + return false + }) as ShippingMethodTaxLine[] - taxLines = shippingMethod.tax_lines - } else { - const orderLines = await this.taxProviderService_ - .withTransaction(this.manager_) - .getTaxLines(cartOrOrder.items, calculationContext) - - taxLines = orderLines.filter((ol) => { - if ("shipping_method_id" in ol) { - return ol.shipping_method_id === shippingMethod.id - } - return false - }) as ShippingMethodTaxLine[] + if (totals.tax_lines.length === 0 && isOrder(cartOrOrder)) { + throw new MedusaError( + MedusaError.Types.UNEXPECTED_STATE, + "Tax Lines must be joined on shipping method to calculate taxes" + ) } - totals.tax_lines = taxLines } if (totals.tax_lines.length > 0) { - totals.original_tax_total = - await this.taxCalculationStrategy_.calculate( - [], - totals.tax_lines, - calculationContext - ) + const includesTax = + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && shippingMethod.includes_tax + + totals.original_tax_total = await this.taxCalculationStrategy_.calculate( + [], + totals.tax_lines, + calculationContext + ) totals.tax_total = totals.original_tax_total - totals.original_total += totals.original_tax_total - totals.total += totals.tax_total + if (includesTax) { + totals.subtotal -= totals.tax_total + } else { + totals.original_total += totals.original_tax_total + totals.total += totals.tax_total + } } } - if (cartOrOrder.discounts) { - if (cartOrOrder.discounts.some((d) => d.rule.type === "free_shipping")) { - totals.total = 0 - totals.tax_total = 0 - } + const hasFreeShipping = cartOrOrder.discounts?.some( + (d) => d.rule.type === DiscountRuleType.FREE_SHIPPING + ) + + if (hasFreeShipping) { + totals.total = 0 + totals.subtotal = 0 + totals.tax_total = 0 } return totals @@ -260,21 +278,33 @@ class TotalsService extends TransactionBaseService { * @param opts - options * @return the calculated subtotal */ - getSubtotal(cartOrOrder: Cart | Order, opts: SubtotalOptions = {}): number { + async getSubtotal( + cartOrOrder: Cart | Order, + opts: SubtotalOptions = {} + ): Promise { let subtotal = 0 if (!cartOrOrder.items) { return subtotal } - cartOrOrder.items.map((item) => { + const getLineItemSubtotal = async (item: LineItem): Promise => { + const totals = await this.getLineItemTotals(item, cartOrOrder, { + include_tax: true, + exclude_gift_cards: true, + }) + return totals.subtotal + } + + for (const item of cartOrOrder.items) { if (opts.excludeNonDiscounts) { if (item.allow_discounts) { - subtotal += item.unit_price * item.quantity + subtotal += await getLineItemSubtotal(item) } - } else { - subtotal += item.unit_price * item.quantity + continue } - }) + + subtotal += await getLineItemSubtotal(item) + } return this.rounded(subtotal) } @@ -284,11 +314,23 @@ class TotalsService extends TransactionBaseService { * @param cartOrOrder - cart or order to calculate subtotal for * @return shipping total */ - getShippingTotal(cartOrOrder: Cart | Order): number { + async getShippingTotal(cartOrOrder: Cart | Order): Promise { const { shipping_methods } = cartOrOrder - return shipping_methods.reduce((acc, next) => { - return acc + next.price - }, 0) + + let total = 0 + for (const shippingMethod of shipping_methods) { + const totals = await this.getShippingMethodTotals( + shippingMethod, + cartOrOrder, + { + include_tax: true, + } + ) + + total += totals.subtotal + } + + return total } /** @@ -311,8 +353,8 @@ class TotalsService extends TransactionBaseService { return null } - const calculationContext = this.getCalculationContext(cartOrOrder) - const giftCardTotal = this.getGiftCardTotal(cartOrOrder) + const calculationContext = await this.getCalculationContext(cartOrOrder) + const giftCardTotal = await this.getGiftCardTotal(cartOrOrder) let taxLines: (ShippingMethodTaxLine | LineItemTaxLine)[] if (isOrder(cartOrOrder)) { @@ -335,9 +377,9 @@ class TotalsService extends TransactionBaseService { taxLines = taxLines.concat(shippingTaxLines) } else { - const subtotal = this.getSubtotal(cartOrOrder) - const shippingTotal = this.getShippingTotal(cartOrOrder) - const discountTotal = this.getDiscountTotal(cartOrOrder) + const subtotal = await this.getSubtotal(cartOrOrder) + const shippingTotal = await this.getShippingTotal(cartOrOrder) + const discountTotal = await this.getDiscountTotal(cartOrOrder) return this.rounded( (subtotal - discountTotal - giftCardTotal.total + shippingTotal) * (cartOrOrder.tax_rate / 100) @@ -388,17 +430,17 @@ class TotalsService extends TransactionBaseService { * @param options - controls what should be included in allocation map * @return the allocation map for the line items in the cart or order. */ - getAllocationMap( + async getAllocationMap( orderOrCart: Cart | Order, options: AllocationMapOptions = {} - ): LineAllocationsMap { + ): Promise { const allocationMap: LineAllocationsMap = {} if (!options.exclude_discounts) { let lineDiscounts: LineDiscountAmount[] = [] - const discount = orderOrCart.discounts.find( - ({ rule }) => rule.type !== "free_shipping" + const discount = orderOrCart.discounts?.find( + ({ rule }) => rule.type !== DiscountRuleType.FREE_SHIPPING ) if (discount) { lineDiscounts = this.getLineDiscounts(orderOrCart, discount) @@ -424,8 +466,8 @@ class TotalsService extends TransactionBaseService { if (!options.exclude_gift_cards) { let lineGiftCards: LineDiscountAmount[] = [] if (orderOrCart.gift_cards && orderOrCart.gift_cards.length) { - const subtotal = this.getSubtotal(orderOrCart) - const giftCardTotal = this.getGiftCardTotal(orderOrCart) + const subtotal = await this.getSubtotal(orderOrCart) + const giftCardTotal = await this.getGiftCardTotal(orderOrCart) // If the fixed discount exceeds the subtotal we should // calculate a 100% discount @@ -480,20 +522,38 @@ class TotalsService extends TransactionBaseService { * @param lineItem - the line item to calculate the refund amount for. * @return the line item refund amount. */ - getLineItemRefund(order: Order, lineItem: LineItem): number { - const allocationMap = this.getAllocationMap(order) + async getLineItemRefund(order: Order, lineItem: LineItem): Promise { + const allocationMap = await this.getAllocationMap(order) + + const includesTax = + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && lineItem.includes_tax const discountAmount = (allocationMap[lineItem.id]?.discount?.unit_amount || 0) * lineItem.quantity - const lineSubtotal = - lineItem.unit_price * lineItem.quantity - discountAmount + let lineSubtotal = lineItem.unit_price * lineItem.quantity - discountAmount /* * Used for backcompat with old tax system */ if (order.tax_rate !== null) { + const taxAmountIncludedInPrice = !includesTax + ? 0 + : Math.round( + calculatePriceTaxAmount({ + price: lineItem.unit_price, + taxRate: order.tax_rate / 100, + includesTax, + }) + ) + + lineSubtotal = + (lineItem.unit_price - taxAmountIncludedInPrice) * lineItem.quantity - + discountAmount + const taxRate = order.tax_rate / 100 return this.rounded(lineSubtotal * (1 + taxRate)) } @@ -508,6 +568,23 @@ class TotalsService extends TransactionBaseService { ) } + const taxRate = lineItem.tax_lines.reduce((acc, next) => { + return acc + next.rate / 100 + }, 0) + const taxAmountIncludedInPrice = !includesTax + ? 0 + : Math.round( + calculatePriceTaxAmount({ + price: lineItem.unit_price, + taxRate, + includesTax, + }) + ) + + lineSubtotal = + (lineItem.unit_price - taxAmountIncludedInPrice) * lineItem.quantity - + discountAmount + const taxTotal = lineItem.tax_lines.reduce((acc, next) => { const taxRate = next.rate / 100 return acc + this.rounded(lineSubtotal * taxRate) @@ -524,7 +601,7 @@ class TotalsService extends TransactionBaseService { * @param lineItems - the line items to calculate refund total for * @return the calculated subtotal */ - getRefundTotal(order: Order, lineItems: LineItem[]): number { + async getRefundTotal(order: Order, lineItems: LineItem[]): Promise { let itemIds = order.items.map((i) => i.id) // in case we swap a swap, we need to include swap items @@ -542,16 +619,18 @@ class TotalsService extends TransactionBaseService { } } - const refunds = lineItems.map((i) => { - if (!itemIds.includes(i.id)) { + const refunds: number[] = [] + for (const item of lineItems) { + if (!itemIds.includes(item.id)) { throw new MedusaError( MedusaError.Types.INVALID_DATA, "Line item does not exist on order" ) } - return this.getLineItemRefund(order, i) - }) + const refund = await this.getLineItemRefund(order, item) + refunds.push(refund) + } return this.rounded(refunds.reduce((acc, next) => acc + next, 0)) } @@ -624,7 +703,7 @@ class TotalsService extends TransactionBaseService { lineItem: LineItem, discount: Discount ): number { - const matchingDiscount = lineItem.adjustments.find( + const matchingDiscount = lineItem.adjustments?.find( (adjustment) => adjustment.discount_id === discount.id ) @@ -709,13 +788,24 @@ class TotalsService extends TransactionBaseService { cartOrOrder: Cart | Order, options: LineItemTotalsOptions = {} ): Promise { - const calculationContext = this.getCalculationContext(cartOrOrder, { + const calculationContext = await this.getCalculationContext(cartOrOrder, { exclude_shipping: true, + exclude_gift_cards: options.exclude_gift_cards, }) const lineItemAllocation = calculationContext.allocation_map[lineItem.id] || {} - const subtotal = lineItem.unit_price * lineItem.quantity + let subtotal = lineItem.unit_price * lineItem.quantity + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && + lineItem.includes_tax && + options.include_tax + ) { + subtotal = 0 // in that case we need to know the tax rate to compute it later + } + const gift_card_total = lineItemAllocation.gift_card?.amount || 0 const discount_total = (lineItemAllocation.discount?.unit_amount || 0) * lineItem.quantity @@ -735,14 +825,32 @@ class TotalsService extends TransactionBaseService { // Tax Information if (options.include_tax) { - // When we have an order with a null'ed tax rate we know that it is an + // When we have an order with a nulled or undefined tax rate we know that it is an // order from the old tax system. The following is a backward compat // calculation. - if (isOrder(cartOrOrder) && cartOrOrder.tax_rate !== null) { - lineItemTotals.original_tax_total = - subtotal * (cartOrOrder.tax_rate / 100) + if (isOrder(cartOrOrder) && cartOrOrder.tax_rate != null) { + const taxRate = cartOrOrder.tax_rate / 100 + + const includesTax = + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && lineItem.includes_tax + const taxIncludedInPrice = !lineItem.includes_tax + ? 0 + : Math.round( + calculatePriceTaxAmount({ + price: lineItem.unit_price, + taxRate: taxRate, + includesTax, + }) + ) + lineItemTotals.subtotal = + (lineItem.unit_price - taxIncludedInPrice) * lineItem.quantity + lineItemTotals.total = lineItemTotals.subtotal + + lineItemTotals.original_tax_total = lineItemTotals.subtotal * taxRate lineItemTotals.tax_total = - (subtotal - discount_total) * (cartOrOrder.tax_rate / 100) + (lineItemTotals.subtotal - discount_total) * taxRate lineItemTotals.total += lineItemTotals.tax_total lineItemTotals.original_total += lineItemTotals.original_tax_total @@ -795,15 +903,27 @@ class TotalsService extends TransactionBaseService { lineItemTotals.tax_lines, calculationContext ) - lineItemTotals.total += lineItemTotals.tax_total - calculationContext.allocation_map = {} // Don't account for discounts - lineItemTotals.original_tax_total = - await this.taxCalculationStrategy_.calculate( - [lineItem], - lineItemTotals.tax_lines, - calculationContext - ) + lineItemTotals.original_tax_total = await this.taxCalculationStrategy_.calculate( + [lineItem], + lineItemTotals.tax_lines, + calculationContext + ) + + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && + lineItem.includes_tax + ) { + lineItemTotals.subtotal += + lineItem.unit_price * lineItem.quantity - + lineItemTotals.original_tax_total + lineItemTotals.total += lineItemTotals.subtotal + lineItemTotals.original_total += lineItemTotals.subtotal + } + + lineItemTotals.total += lineItemTotals.tax_total lineItemTotals.original_total += lineItemTotals.original_tax_total } @@ -851,7 +971,9 @@ class TotalsService extends TransactionBaseService { */ async getGiftCardableAmount(cartOrOrder: Cart | Order): Promise { if (cartOrOrder.region?.gift_cards_taxable) { - return this.getSubtotal(cartOrOrder) - this.getDiscountTotal(cartOrOrder) + const subtotal = await this.getSubtotal(cartOrOrder) + const discountTotal = await this.getDiscountTotal(cartOrOrder) + return subtotal - discountTotal } return await this.getTotal(cartOrOrder, { @@ -864,12 +986,15 @@ class TotalsService extends TransactionBaseService { * @param cartOrOrder - the cart or order to get gift card amount for * @return the gift card amount applied to the cart or order */ - getGiftCardTotal(cartOrOrder: Cart | Order): { + async getGiftCardTotal( + cartOrOrder: Cart | Order + ): Promise<{ total: number tax_total: number - } { - const giftCardable = - this.getSubtotal(cartOrOrder) - this.getDiscountTotal(cartOrOrder) + }> { + const subtotal = await this.getSubtotal(cartOrOrder) + const discountTotal = await this.getDiscountTotal(cartOrOrder) + const giftCardable = subtotal - discountTotal if ("gift_card_transactions" in cartOrOrder) { // gift_card_transactions only exist on orders so we can @@ -936,19 +1061,15 @@ class TotalsService extends TransactionBaseService { * @param cartOrOrder - the cart or order to calculate discounts for * @return the total discounts amount */ - getDiscountTotal(cartOrOrder: Cart | Order): number { - const subtotal = this.getSubtotal(cartOrOrder, { + async getDiscountTotal(cartOrOrder: Cart | Order): Promise { + const subtotal = await this.getSubtotal(cartOrOrder, { excludeNonDiscounts: true, }) - if (!cartOrOrder.discounts || !cartOrOrder.discounts.length) { - return 0 - } - // we only support having free shipping and one other discount, so first // find the discount, which is not free shipping. - const discount = cartOrOrder.discounts.find( - ({ rule }) => rule.type !== "free_shipping" + const discount = cartOrOrder.discounts?.find( + ({ rule }) => rule.type !== DiscountRuleType.FREE_SHIPPING ) if (!discount) { @@ -970,11 +1091,11 @@ class TotalsService extends TransactionBaseService { * @param options - options to gather context by * @return the tax calculation context */ - getCalculationContext( + async getCalculationContext( cartOrOrder: Cart | Order, options: CalculationContextOptions = {} - ): TaxCalculationContext { - const allocationMap = this.getAllocationMap(cartOrOrder, { + ): Promise { + const allocationMap = await this.getAllocationMap(cartOrOrder, { exclude_gift_cards: options.exclude_gift_cards, exclude_discounts: options.exclude_discounts, }) diff --git a/packages/medusa/src/strategies/__tests__/price-selection.js b/packages/medusa/src/strategies/__tests__/price-selection.js index 56a834385f..be02940965 100644 --- a/packages/medusa/src/strategies/__tests__/price-selection.js +++ b/packages/medusa/src/strategies/__tests__/price-selection.js @@ -1,5 +1,249 @@ +import TaxInclusivePricingFeatureFlag from "../../loaders/feature-flags/tax-inclusive-pricing" +import { FlagRouter } from "../../utils/flag-router" import PriceSelectionStrategy from "../price-selection" +const executeTest = + (flagValue) => + async (title, { variant_id, context, validate, validateException }) => { + const mockMoneyAmountRepository = { + findManyForVariantInRegion: jest + .fn() + .mockImplementation( + async ( + variant_id, + region_id, + currency_code, + customer_id, + useDiscountPrices + ) => { + if (variant_id === "test-basic-variant") { + return [ + [ + { + amount: 100, + region_id, + currency_code, + price_list_id: null, + max_quantity: null, + min_quantity: null, + }, + ], + 1, + ] + } + if (variant_id === "test-basic-variant-tax-inclusive") { + return [ + [ + { + amount: 100, + region_id, + price_list_id: null, + max_quantity: null, + min_quantity: null, + region: { + includes_tax: true, + }, + }, + { + amount: 120, + currency_code, + price_list_id: null, + max_quantity: null, + min_quantity: null, + currency: { + includes_tax: true, + }, + }, + ], + 1, + ] + } + if (variant_id === "test-basic-variant-tax-inclusive-currency") { + return [ + [ + { + amount: 100, + region_id, + max_quantity: null, + min_quantity: null, + price_list_id: null, + }, + { + amount: 100, + currency_code, + price_list_id: null, + max_quantity: null, + min_quantity: null, + currency: { + includes_tax: true, + }, + }, + ], + 1, + ] + } + if (variant_id === "test-basic-variant-tax-inclusive-region") { + return [ + [ + { + amount: 100, + region_id, + max_quantity: null, + min_quantity: null, + price_list_id: null, + region: { + includes_tax: true, + }, + }, + { + amount: 100, + currency_code, + price_list_id: null, + max_quantity: null, + min_quantity: null, + }, + ], + 1, + ] + } + if (variant_id === "test-basic-variant-mixed") { + return [ + [ + { + amount: 100, + region_id, + max_quantity: null, + min_quantity: null, + price_list_id: null, + region: { + includes_tax: false, + }, + }, + { + amount: 95, + currency_code, + price_list_id: "pl_1", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale" }, + }, + { + amount: 110, + currency_code, + price_list_id: "pl_2", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale", includes_tax: true }, + }, + { + amount: 150, + currency_code, + price_list_id: "pl_3", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale" }, + }, + ], + 1, + ] + } + if (customer_id === "test-customer-1") { + return [ + [ + { + amount: 100, + region_id, + currency_code, + price_list_id: null, + max_quantity: null, + min_quantity: null, + }, + { + amount: 50, + region_id: region_id, + currency_code: currency_code, + price_list: { type: "sale" }, + max_quantity: null, + min_quantity: null, + }, + ], + 2, + ] + } + if (customer_id === "test-customer-2") { + return [ + [ + { + amount: 100, + region_id, + currency_code, + price_list_id: null, + max_quantity: null, + min_quantity: null, + }, + { + amount: 30, + min_quantity: 10, + max_quantity: 12, + price_list: { type: "sale" }, + region_id: region_id, + currency_code: currency_code, + }, + { + amount: 20, + min_quantity: 3, + max_quantity: 5, + price_list: { type: "sale" }, + region_id: region_id, + currency_code: currency_code, + }, + { + amount: 50, + min_quantity: 5, + max_quantity: 10, + price_list: { type: "sale" }, + region_id: region_id, + currency_code: currency_code, + }, + ], + 4, + ] + } + return [] + } + ), + } + + const mockEntityManager = { + getCustomRepository: (repotype) => mockMoneyAmountRepository, + } + + const featureFlagRouter = new FlagRouter({ + tax_inclusive_pricing: flagValue, + }) + + const selectionStrategy = new PriceSelectionStrategy({ + manager: mockEntityManager, + moneyAmountRepository: mockMoneyAmountRepository, + featureFlagRouter, + }) + + try { + const val = await selectionStrategy.calculateVariantPrice( + variant_id, + context + ) + + validate(val, { mockMoneyAmountRepository, featureFlagRouter }) + } catch (error) { + if (typeof validateException === "function") { + validateException(error, { mockMoneyAmountRepository }) + } else { + throw error + } + } + } + const toTest = [ [ "Variant with only default price", @@ -9,17 +253,41 @@ const toTest = [ region_id: "test-region", currency_code: "dkk", }, - validate: (value, { mockMoneyAmountRepository }) => { - expect( - mockMoneyAmountRepository.findManyForVariantInRegion - ).toHaveBeenCalledWith( - "test-basic-variant", - "test-region", - "dkk", - undefined, - undefined - ) + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + let ffFields = {} + if (featureFlagRouter.isFeatureEnabled("tax_inclusive_pricing")) { + ffFields = { + originalPriceIncludesTax: false, + calculatedPriceIncludesTax: false, + } + } + + if ( + featureFlagRouter.isFeatureEnabled(TaxInclusivePricingFeatureFlag.key) + ) { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-basic-variant", + "test-region", + "dkk", + undefined, + undefined, + true + ) + } else { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-basic-variant", + "test-region", + "dkk", + undefined, + undefined + ) + } expect(value).toEqual({ + ...ffFields, originalPrice: 100, calculatedPrice: 100, calculatedPriceType: "default", @@ -63,16 +331,31 @@ const toTest = [ currency_code: "dkk", customer_id: "test-customer-1", }, - validate: (value, { mockMoneyAmountRepository }) => { - expect( - mockMoneyAmountRepository.findManyForVariantInRegion - ).toHaveBeenCalledWith( - "test-variant", - "test-region", - "dkk", - "test-customer-1", - undefined - ) + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + if ( + featureFlagRouter.isFeatureEnabled(TaxInclusivePricingFeatureFlag.key) + ) { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-variant", + "test-region", + "dkk", + "test-customer-1", + undefined, + true + ) + } else { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-variant", + "test-region", + "dkk", + "test-customer-1", + undefined + ) + } }, }, ], @@ -85,8 +368,16 @@ const toTest = [ currency_code: "dkk", customer_id: "test-customer-1", }, - validate: (value, { mockMoneyAmountRepository }) => { + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + let ffFields = {} + if (featureFlagRouter.isFeatureEnabled("tax_inclusive_pricing")) { + ffFields = { + originalPriceIncludesTax: false, + calculatedPriceIncludesTax: false, + } + } expect(value).toEqual({ + ...ffFields, originalPrice: 100, calculatedPrice: 50, calculatedPriceType: "sale", @@ -121,8 +412,16 @@ const toTest = [ currency_code: "dkk", customer_id: "test-customer-2", }, - validate: (value, { mockMoneyAmountRepository }) => { + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + let ffFields = {} + if (featureFlagRouter.isFeatureEnabled("tax_inclusive_pricing")) { + ffFields = { + originalPriceIncludesTax: false, + calculatedPriceIncludesTax: false, + } + } expect(value).toEqual({ + ...ffFields, originalPrice: 100, calculatedPrice: 100, calculatedPriceType: "default", @@ -174,8 +473,16 @@ const toTest = [ customer_id: "test-customer-2", quantity: 7, }, - validate: (value, { mockMoneyAmountRepository }) => { + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + let ffFields = {} + if (featureFlagRouter.isFeatureEnabled("tax_inclusive_pricing")) { + ffFields = { + originalPriceIncludesTax: false, + calculatedPriceIncludesTax: false, + } + } expect(value).toEqual({ + ...ffFields, originalPrice: 100, calculatedPrice: 50, calculatedPriceType: "sale", @@ -226,8 +533,16 @@ const toTest = [ currency_code: "dkk", customer_id: "test-customer-2", }, - validate: (value, { mockMoneyAmountRepository }) => { + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + let ffFields = {} + if (featureFlagRouter.isFeatureEnabled("tax_inclusive_pricing")) { + ffFields = { + originalPriceIncludesTax: false, + calculatedPriceIncludesTax: false, + } + } expect(value).toEqual({ + ...ffFields, originalPrice: 100, calculatedPrice: 100, calculatedPriceType: "default", @@ -271,128 +586,281 @@ const toTest = [ ], ] +const taxInclusiveTesting = [ + [ + "Variant with tax inclusive prices", + { + variant_id: "test-basic-variant-tax-inclusive", + context: { + region_id: "test-region", + currency_code: "dkk", + }, + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-basic-variant-tax-inclusive", + "test-region", + "dkk", + undefined, + undefined, + true + ) + expect(value).toEqual({ + originalPrice: 100, + calculatedPrice: 100, + originalPriceIncludesTax: true, + calculatedPriceIncludesTax: true, + calculatedPriceType: "default", + prices: [ + { + amount: 100, + max_quantity: null, + min_quantity: null, + price_list_id: null, + region_id: "test-region", + }, + { + amount: 120, + currency_code: "dkk", + max_quantity: null, + min_quantity: null, + price_list_id: null, + }, + ], + }) + }, + }, + ], + [ + "Variant with mixed pricing tax inclusive prices currency", + { + variant_id: "test-basic-variant-tax-inclusive-currency", + context: { + region_id: "test-region", + currency_code: "dkk", + tax_rates: [{ rate: 25 }], + }, + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-basic-variant-tax-inclusive-currency", + "test-region", + "dkk", + undefined, + undefined, + true + ) + expect(value).toEqual({ + originalPrice: 100, + calculatedPrice: 100, + originalPriceIncludesTax: false, + calculatedPriceIncludesTax: true, + calculatedPriceType: "default", + prices: [ + { + amount: 100, + region_id: "test-region", + max_quantity: null, + min_quantity: null, + price_list_id: null, + }, + { + amount: 100, + currency_code: "dkk", + max_quantity: null, + min_quantity: null, + price_list_id: null, + }, + ], + }) + }, + }, + ], + [ + "Variant with mixed pricing tax inclusive prices region", + { + variant_id: "test-basic-variant-tax-inclusive-region", + context: { + region_id: "test-region", + currency_code: "dkk", + tax_rates: [{ rate: 25 }], + }, + validate: (value, { mockMoneyAmountRepository, featureFlagRouter }) => { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-basic-variant-tax-inclusive-region", + "test-region", + "dkk", + undefined, + undefined, + true + ) + expect(value).toEqual({ + originalPrice: 100, + calculatedPrice: 100, + originalPriceIncludesTax: true, + calculatedPriceIncludesTax: true, + calculatedPriceType: "default", + prices: [ + { + amount: 100, + region_id: "test-region", + max_quantity: null, + min_quantity: null, + price_list_id: null, + }, + { + amount: 100, + currency_code: "dkk", + max_quantity: null, + min_quantity: null, + price_list_id: null, + }, + ], + }) + }, + }, + ], + [ + "Variant with mixed tax prices (favoring tax inclusive)", + { + variant_id: "test-basic-variant-mixed", + context: { + region_id: "test-region", + currency_code: "dkk", + tax_rates: [{ rate: 25 }], + }, + validate: (value, { mockMoneyAmountRepository }) => { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-basic-variant-mixed", + "test-region", + "dkk", + undefined, + undefined, + true + ) + expect(value).toEqual({ + originalPrice: 100, + calculatedPrice: 110, + originalPriceIncludesTax: false, + calculatedPriceIncludesTax: true, + calculatedPriceType: "sale", + prices: [ + { + amount: 100, + region_id: "test-region", + max_quantity: null, + min_quantity: null, + price_list_id: null, + }, + { + amount: 95, + currency_code: "dkk", + price_list_id: "pl_1", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale" }, + }, + { + amount: 110, + currency_code: "dkk", + price_list_id: "pl_2", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale", includes_tax: true }, + }, + { + amount: 150, + currency_code: "dkk", + price_list_id: "pl_3", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale" }, + }, + ], + }) + }, + }, + ], + [ + "Variant with mixed tax price (favoring tax exclusive)", + { + variant_id: "test-basic-variant-mixed", + context: { + region_id: "test-region", + currency_code: "dkk", + tax_rate: 0.05, + }, + validate: (value, { mockMoneyAmountRepository }) => { + expect( + mockMoneyAmountRepository.findManyForVariantInRegion + ).toHaveBeenCalledWith( + "test-basic-variant-mixed", + "test-region", + "dkk", + undefined, + undefined, + true + ) + expect(value).toEqual({ + originalPrice: 100, + calculatedPrice: 95, + originalPriceIncludesTax: false, + calculatedPriceIncludesTax: false, + calculatedPriceType: "sale", + prices: [ + { + amount: 100, + region_id: "test-region", + max_quantity: null, + min_quantity: null, + price_list_id: null, + }, + { + amount: 95, + currency_code: "dkk", + price_list_id: "pl_1", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale" }, + }, + { + amount: 110, + currency_code: "dkk", + price_list_id: "pl_2", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale", includes_tax: true }, + }, + { + amount: 150, + currency_code: "dkk", + price_list_id: "pl_3", + max_quantity: null, + min_quantity: null, + price_list: { type: "sale" }, + }, + ], + }) + }, + }, + ], +] + describe("PriceSelectionStrategy", () => { describe("calculateVariantPrice", () => { - test.each(toTest)( - "%s", - async (title, { variant_id, context, validate, validateException }) => { - const mockMoneyAmountRepository = { - findManyForVariantInRegion: jest - .fn() - .mockImplementation( - async ( - variant_id, - region_id, - currency_code, - customer_id, - useDiscountPrices - ) => { - if (variant_id === "test-basic-variant") { - return [ - [ - { - amount: 100, - region_id, - currency_code, - price_list_id: null, - max_quantity: null, - min_quantity: null, - }, - ], - 1, - ] - } - if (customer_id === "test-customer-1") { - return [ - [ - { - amount: 100, - region_id, - currency_code, - price_list_id: null, - max_quantity: null, - min_quantity: null, - }, - { - amount: 50, - region_id: region_id, - currency_code: currency_code, - price_list: { type: "sale" }, - max_quantity: null, - min_quantity: null, - }, - ], - 2, - ] - } - if (customer_id === "test-customer-2") { - return [ - [ - { - amount: 100, - region_id, - currency_code, - price_list_id: null, - max_quantity: null, - min_quantity: null, - }, - { - amount: 30, - min_quantity: 10, - max_quantity: 12, - price_list: { type: "sale" }, - region_id: region_id, - currency_code: currency_code, - }, - { - amount: 20, - min_quantity: 3, - max_quantity: 5, - price_list: { type: "sale" }, - region_id: region_id, - currency_code: currency_code, - }, - { - amount: 50, - min_quantity: 5, - max_quantity: 10, - price_list: { type: "sale" }, - region_id: region_id, - currency_code: currency_code, - }, - ], - 4, - ] - } - return [] - } - ), - } - - const mockEntityManager = { - getCustomRepository: (repotype) => mockMoneyAmountRepository, - } - - const selectionStrategy = new PriceSelectionStrategy({ - manager: mockEntityManager, - moneyAmountRepository: mockMoneyAmountRepository, - }) - - try { - const val = await selectionStrategy.calculateVariantPrice( - variant_id, - context - ) - - validate(val, { mockMoneyAmountRepository }) - } catch (error) { - if (typeof validateException === "function") { - validateException(error, { mockMoneyAmountRepository }) - } else { - throw error - } - } - } - ) + ;[true, false].forEach((flagValue) => { + describe(`with tax inclusive pricing ${flagValue}`, () => { + test.each(toTest)(`%s`, executeTest(flagValue)) + }) + }) + describe("tax inclusive testing", () => { + test.each(taxInclusiveTesting)(`%s`, executeTest(true)) + }) }) }) diff --git a/packages/medusa/src/strategies/__tests__/tax-calculation.js b/packages/medusa/src/strategies/__tests__/tax-calculation.js index 1590a745b3..5ec84540f7 100644 --- a/packages/medusa/src/strategies/__tests__/tax-calculation.js +++ b/packages/medusa/src/strategies/__tests__/tax-calculation.js @@ -1,121 +1,282 @@ import TaxCalculationStrategy from "../tax-calculation" +import TaxInclusivePricingFeatureFlag from "../../loaders/feature-flags/tax-inclusive-pricing" +import { FlagRouter } from "../../utils/flag-router" const toTest = [ - { - title: "calculates correctly without gift card", - - /* - * Subtotal = 2 * 100 = 200 - * Taxable amount = 200 - 10 = 190 - * Taxline 1 = 190 * 0.0825 = 15.675 = 16 - * Taxline 2 = 190 * 0.125 = 13.75 = 14 - * Total tax = 40 - */ - expected: 40, - items: [ - { - id: "item_1", - unit_price: 100, - quantity: 2, - }, - ], - taxLines: [ - { - item_id: "item_1", - name: "Name 1", - rate: 8.25, - }, - { - item_id: "item_1", - name: "Name 2", - rate: 12.5, - }, - ], - context: { - shipping_address: null, - customer: { - email: "test@testson.com", - }, - shipping_methods: [], - region: { - gift_cards_taxable: false, - }, - allocation_map: { - item_1: { - discount: { - amount: 10, - unit_amount: 5, - }, - gift_card: { - amount: 10, - unit_amount: 5, + [ + "calculates correctly without gift card", + { + /* + * Subtotal = 2 * 100 = 200 + * Taxable amount = 200 - 10 = 190 + * Taxline 1 = 190 * 0.0825 = 15.675 = 16 + * Taxline 2 = 190 * 0.125 = 13.75 = 14 + * Total tax = 40 + */ + expected: 40, + items: [ + { + id: "item_1", + unit_price: 100, + quantity: 2, + }, + ], + taxLines: [ + { + item_id: "item_1", + name: "Name 1", + rate: 8.25, + }, + { + item_id: "item_1", + name: "Name 2", + rate: 12.5, + }, + ], + context: { + shipping_address: null, + customer: { + email: "test@testson.com", + }, + shipping_methods: [], + region: { + gift_cards_taxable: false, + }, + allocation_map: { + item_1: { + discount: { + amount: 10, + unit_amount: 5, + }, + gift_card: { + amount: 10, + unit_amount: 5, + }, }, }, }, }, - }, - { - title: "calculates correctly with gift card", - - /* - * Subtotal = 2 * 100 = 200 - * Taxable amount = 200 - 10 = 180 - * Taxline 1 = 180 * 0.0825 = 15 - * Taxline 2 = 180 * 0.125 = 23 - * Total tax = 38 - */ - expected: 40, - items: [ - { - id: "item_1", - unit_price: 100, - quantity: 2, - }, - ], - taxLines: [ - { - item_id: "item_1", - name: "Name 1", - rate: 8.25, - }, - { - item_id: "item_1", - name: "Name 2", - rate: 12.5, - }, - ], - context: { - shipping_address: null, - customer: { - email: "test@testson.com", - }, - region: { - gift_cards_taxable: true, - }, - shipping_methods: [], - allocation_map: { - item_1: { - discount: { - amount: 10, - unit_amount: 5, - }, - gift_card: { - amount: 10, - unit_amount: 5, + ], + [ + "calculates correctly with gift card", + { + /* + * Subtotal = 2 * 100 = 200 + * Taxable amount = 200 - 10 = 180 + * Taxline 1 = 180 * 0.0825 = 15 + * Taxline 2 = 180 * 0.125 = 23 + * Total tax = 38 + */ + expected: 40, + items: [ + { + id: "item_1", + unit_price: 100, + quantity: 2, + }, + ], + taxLines: [ + { + item_id: "item_1", + name: "Name 1", + rate: 8.25, + }, + { + item_id: "item_1", + name: "Name 2", + rate: 12.5, + }, + ], + context: { + shipping_address: null, + customer: { + email: "test@testson.com", + }, + region: { + gift_cards_taxable: true, + }, + shipping_methods: [], + allocation_map: { + item_1: { + discount: { + amount: 10, + unit_amount: 5, + }, + gift_card: { + amount: 10, + unit_amount: 5, + }, }, }, }, }, - }, + ], + [ + "calculates correctly with tax inclusive pricing", + { + /* + * Subtotal = 3 * 100 = 100 + * Taxable amount = 300 + * Taxline 1 = 100 * 0.2 * 2 = 40 + * Taxline 2 = 100 * 0.2 * 1 = 20 + * Total tax = 60 + */ + expected: 60, + flags: { [TaxInclusivePricingFeatureFlag.key]: true }, + items: [ + { + id: "item_1", + unit_price: 120, + quantity: 2, + includes_tax: true, + }, + { + id: "item_2", + unit_price: 100, + quantity: 1, + includes_tax: false, + }, + ], + taxLines: [ + { + item_id: "item_1", + name: "Name 1", + rate: 20, + }, + { + item_id: "item_2", + name: "Name 2", + rate: 20, + }, + ], + context: { + shipping_address: null, + customer: { + email: "test@testson.com", + }, + region: { + gift_cards_taxable: true, + }, + shipping_methods: [], + allocation_map: {}, + }, + }, + ], + [ + "calculates correctly with tax inclusive shipping", + { + expected: 40, + flags: { [TaxInclusivePricingFeatureFlag.key]: true }, + items: [ + { + id: "item_1", + unit_price: 120, + quantity: 1, + includes_tax: true, + }, + ], + taxLines: [ + { + shipping_method_id: "shipping_method_1", + name: "Name 1", + rate: 15, + }, + { + shipping_method_id: "shipping_method_2", + name: "Name 2", + rate: 5, + }, + { + item_id: "item_1", + name: "Name 1", + rate: 20, + }, + ], + context: { + shipping_address: null, + customer: { + email: "test@testson.com", + }, + region: { + gift_cards_taxable: true, + }, + shipping_methods: [ + { id: "shipping_method_1", price: 115, includes_tax: true }, + { id: "shipping_method_2", price: 105, includes_tax: true }, + ], + allocation_map: {}, + }, + }, + ], + [ + "calculates correctly with tax inclusive pricing and shipping", + { + expected: 85, + flags: { [TaxInclusivePricingFeatureFlag.key]: true }, + items: [ + { + id: "item_1", + unit_price: 120, + quantity: 2, + includes_tax: true, + }, + { + id: "item_2", + unit_price: 100, + quantity: 1, + includes_tax: false, + }, + ], + taxLines: [ + { + shipping_method_id: "shipping_method_1", + name: "Name 1", + rate: 15, + }, + { + shipping_method_id: "shipping_method_2", + name: "Name 2", + rate: 10, + }, + { + item_id: "item_1", + name: "Name 1", + rate: 20, + }, + { + item_id: "item_2", + name: "Name 2", + rate: 20, + }, + ], + context: { + shipping_address: null, + customer: { + email: "test@testson.com", + }, + region: { + gift_cards_taxable: true, + }, + shipping_methods: [ + { id: "shipping_method_1", price: 115, includes_tax: true }, + { id: "shipping_method_2", price: 100, includes_tax: false }, + ], + allocation_map: {}, + }, + }, + ], ] describe("TaxCalculationStrategy", () => { describe("calculate", () => { - const calcStrat = new TaxCalculationStrategy() - test.each(toTest)( - "$title", - async ({ items, taxLines, context, expected }) => { + "%s", + async (title, { items, taxLines, context, expected, flags }) => { + const featureFlagRouter = new FlagRouter(flags ?? {}) + const calcStrat = new TaxCalculationStrategy({ + featureFlagRouter, + }) + const val = await calcStrat.calculate(items, taxLines, context) expect(val).toEqual(expected) } diff --git a/packages/medusa/src/strategies/price-selection.ts b/packages/medusa/src/strategies/price-selection.ts index 68e635928b..65b4ea54cb 100644 --- a/packages/medusa/src/strategies/price-selection.ts +++ b/packages/medusa/src/strategies/price-selection.ts @@ -1,3 +1,4 @@ +import { EntityManager } from "typeorm" import { AbstractPriceSelectionStrategy, IPriceSelectionStrategy, @@ -5,18 +6,22 @@ import { PriceSelectionResult, PriceType, } from "../interfaces/price-selection-strategy" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" import { MoneyAmountRepository } from "../repositories/money-amount" -import { EntityManager } from "typeorm" -import { isDefined } from "../utils" +import { TaxServiceRate } from "../types/tax-service" +import { FlagRouter } from "../utils/flag-router" +import { isDefined } from "../utils/is-defined" class PriceSelectionStrategy extends AbstractPriceSelectionStrategy { private moneyAmountRepository_: typeof MoneyAmountRepository + private featureFlagRouter_: FlagRouter private manager_: EntityManager - constructor({ manager, moneyAmountRepository }) { + constructor({ manager, featureFlagRouter, moneyAmountRepository }) { super() this.manager_ = manager this.moneyAmountRepository_ = moneyAmountRepository + this.featureFlagRouter_ = featureFlagRouter } withTransaction(manager: EntityManager): IPriceSelectionStrategy { @@ -27,12 +32,116 @@ class PriceSelectionStrategy extends AbstractPriceSelectionStrategy { return new PriceSelectionStrategy({ manager: manager, moneyAmountRepository: this.moneyAmountRepository_, + featureFlagRouter: this.featureFlagRouter_, }) } async calculateVariantPrice( variant_id: string, context: PriceSelectionContext + ): Promise { + if ( + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + ) { + return this.calculateVariantPrice_new(variant_id, context) + } + return this.calculateVariantPrice_old(variant_id, context) + } + + private async calculateVariantPrice_new( + variant_id: string, + context: PriceSelectionContext + ): Promise { + const moneyRepo = this.manager_.getCustomRepository( + this.moneyAmountRepository_ + ) + + const [prices, count] = await moneyRepo.findManyForVariantInRegion( + variant_id, + context.region_id, + context.currency_code, + context.customer_id, + context.include_discount_prices, + true + ) + + const result: PriceSelectionResult = { + originalPrice: null, + calculatedPrice: null, + prices, + originalPriceIncludesTax: null, + calculatedPriceIncludesTax: null, + } + + if (!count || !context) { + return result + } + + const taxRate = context.tax_rates?.reduce( + (accRate: number, nextTaxRate: TaxServiceRate) => { + return accRate + (nextTaxRate.rate || 0) / 100 + }, + 0 + ) + + for (const ma of prices) { + let isTaxInclusive = ma.currency?.includes_tax || false + + if (ma.price_list?.includes_tax) { + // PriceList specific price so use the PriceList tax setting + isTaxInclusive = ma.price_list.includes_tax + } else if (ma.region?.includes_tax) { + // Region specific price so use the Region tax setting + isTaxInclusive = ma.region.includes_tax + } + + delete ma.currency + delete ma.region + + if ( + context.region_id && + ma.region_id === context.region_id && + ma.price_list_id === null && + ma.min_quantity === null && + ma.max_quantity === null + ) { + result.originalPriceIncludesTax = isTaxInclusive + result.originalPrice = ma.amount + } + + if ( + context.currency_code && + ma.currency_code === context.currency_code && + ma.price_list_id === null && + ma.min_quantity === null && + ma.max_quantity === null && + result.originalPrice === null // region prices take precedence + ) { + result.originalPriceIncludesTax = isTaxInclusive + result.originalPrice = ma.amount + } + + if ( + isValidQuantity(ma, context.quantity) && + isValidAmount(ma.amount, result, isTaxInclusive, taxRate) && + ((context.currency_code && + ma.currency_code === context.currency_code) || + (context.region_id && ma.region_id === context.region_id)) + ) { + result.calculatedPrice = ma.amount + result.calculatedPriceType = ma.price_list?.type || PriceType.DEFAULT + result.calculatedPriceIncludesTax = isTaxInclusive + } + } + + return result + } + + private async calculateVariantPrice_old( + variant_id: string, + context: PriceSelectionContext ): Promise { const moneyRepo = this.manager_.getCustomRepository( this.moneyAmountRepository_ @@ -65,6 +174,9 @@ class PriceSelectionStrategy extends AbstractPriceSelectionStrategy { } for (const ma of prices) { + delete ma.currency + delete ma.region + if ( context.region_id && ma.region_id === context.region_id && @@ -103,6 +215,31 @@ class PriceSelectionStrategy extends AbstractPriceSelectionStrategy { } } +const isValidAmount = ( + amount: number, + result: PriceSelectionResult, + isTaxInclusive: boolean, + taxRate?: number +): boolean => { + if (result.calculatedPrice === null) { + return true + } + + if (isTaxInclusive === result.calculatedPriceIncludesTax) { + // if both or neither are tax inclusive compare equally + return amount < result.calculatedPrice + } + + if (typeof taxRate !== "undefined") { + return isTaxInclusive + ? amount < (1 + taxRate) * result.calculatedPrice + : (1 + taxRate) * amount < result.calculatedPrice + } + + // if we dont have a taxrate we can't compare mixed prices + return false +} + const isValidQuantity = (price, quantity): boolean => (isDefined(quantity) && isValidPriceWithQuantity(price, quantity)) || (typeof quantity === "undefined" && isValidPriceWithoutQuantity(price)) diff --git a/packages/medusa/src/strategies/tax-calculation.ts b/packages/medusa/src/strategies/tax-calculation.ts index 4e1cec1c9b..5780a43620 100644 --- a/packages/medusa/src/strategies/tax-calculation.ts +++ b/packages/medusa/src/strategies/tax-calculation.ts @@ -1,11 +1,21 @@ -import { LineItem } from "../models/line-item" -import { ShippingMethod } from "../models/shipping-method" -import { LineItemTaxLine } from "../models/line-item-tax-line" -import { ShippingMethodTaxLine } from "../models/shipping-method-tax-line" -import { TaxCalculationContext } from "../interfaces/tax-service" -import { ITaxCalculationStrategy } from "../interfaces/tax-calculation-strategy" +import { + LineItem, + LineItemTaxLine, + ShippingMethod, + ShippingMethodTaxLine, +} from "../models" +import { ITaxCalculationStrategy, TaxCalculationContext } from "../interfaces" +import { calculatePriceTaxAmount } from "../utils" +import { FlagRouter } from "../utils/flag-router" +import TaxInclusivePricingFeatureFlag from "../loaders/feature-flags/tax-inclusive-pricing" class TaxCalculationStrategy implements ITaxCalculationStrategy { + protected readonly featureFlagRouter_: FlagRouter + + constructor({ featureFlagRouter }) { + this.featureFlagRouter_ = featureFlagRouter + } + async calculate( items: LineItem[], taxLines: (ShippingMethodTaxLine | LineItemTaxLine)[], @@ -33,18 +43,47 @@ class TaxCalculationStrategy implements ITaxCalculationStrategy { context: TaxCalculationContext ): number { let taxTotal = 0 - for (const i of items) { - const allocations = context.allocation_map[i.id] || {} - let taxableAmount = i.quantity * i.unit_price + for (const item of items) { + const allocations = context.allocation_map[item.id] || {} + + const filteredTaxLines = taxLines.filter((tl) => tl.item_id === item.id) + const includesTax = + this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) && item.includes_tax + + let taxableAmount + if (includesTax) { + const taxRate = filteredTaxLines.reduce( + (accRate: number, nextLineItemTaxLine: LineItemTaxLine) => { + return accRate + (nextLineItemTaxLine.rate || 0) / 100 + }, + 0 + ) + const taxIncludedInPrice = Math.round( + calculatePriceTaxAmount({ + price: item.unit_price, + taxRate, + includesTax, + }) + ) + taxableAmount = (item.unit_price - taxIncludedInPrice) * item.quantity + } else { + taxableAmount = item.unit_price * item.quantity + } taxableAmount -= ((allocations.discount && allocations.discount.unit_amount) || 0) * - i.quantity + item.quantity - const lineRates = taxLines.filter((tl) => tl.item_id === i.id) - for (const lineRate of lineRates) { - taxTotal += Math.round(taxableAmount * (lineRate.rate / 100)) + for (const filteredTaxLine of filteredTaxLines) { + taxTotal += Math.round( + calculatePriceTaxAmount({ + price: taxableAmount, + taxRate: filteredTaxLine.rate / 100, + }) + ) } } return taxTotal @@ -54,12 +93,19 @@ class TaxCalculationStrategy implements ITaxCalculationStrategy { shipping_methods: ShippingMethod[], taxLines: ShippingMethodTaxLine[] ): number { + const taxInclusiveEnabled = this.featureFlagRouter_.isFeatureEnabled( + TaxInclusivePricingFeatureFlag.key + ) + let taxTotal = 0 for (const sm of shipping_methods) { - const amount = sm.price const lineRates = taxLines.filter((tl) => tl.shipping_method_id === sm.id) for (const lineRate of lineRates) { - taxTotal += Math.round(amount * (lineRate.rate / 100)) + taxTotal += calculatePriceTaxAmount({ + price: sm.price, + taxRate: lineRate.rate / 100, + includesTax: taxInclusiveEnabled && sm.includes_tax, + }) } } return taxTotal diff --git a/packages/medusa/src/types/currency.ts b/packages/medusa/src/types/currency.ts new file mode 100644 index 0000000000..20ad33361f --- /dev/null +++ b/packages/medusa/src/types/currency.ts @@ -0,0 +1,3 @@ +export type UpdateCurrencyInput = { + includes_tax?: boolean +} diff --git a/packages/medusa/src/types/price-list.ts b/packages/medusa/src/types/price-list.ts index 0a009d831f..3ff5f32640 100644 --- a/packages/medusa/src/types/price-list.ts +++ b/packages/medusa/src/types/price-list.ts @@ -126,12 +126,19 @@ export type CreatePriceListInput = { customer_groups?: { id: string }[] starts_at?: Date ends_at?: Date + includes_tax?: boolean } export type UpdatePriceListInput = Partial< Pick< PriceList, - "name" | "description" | "starts_at" | "ends_at" | "status" | "type" + | "name" + | "description" + | "starts_at" + | "ends_at" + | "status" + | "type" + | "includes_tax" > > & { prices?: AdminPriceListPricesUpdateReq[] diff --git a/packages/medusa/src/types/pricing.ts b/packages/medusa/src/types/pricing.ts index 77118975e6..f3456e397c 100644 --- a/packages/medusa/src/types/pricing.ts +++ b/packages/medusa/src/types/pricing.ts @@ -1,11 +1,13 @@ -import { MoneyAmount, ProductVariant, Product, ShippingOption } from "../models" -import { TaxServiceRate } from "./tax-service" import { PriceSelectionContext } from "../interfaces/price-selection-strategy" +import { MoneyAmount, Product, ProductVariant, ShippingOption } from "../models" +import { TaxServiceRate } from "./tax-service" export type ProductVariantPricing = { prices: MoneyAmount[] original_price: number | null calculated_price: number | null + original_price_includes_tax?: boolean | null + calculated_price_includes_tax?: boolean | null calculated_price_type?: string | null } & TaxedPricing @@ -26,6 +28,7 @@ export type PricingContext = { export type ShippingOptionPricing = { price_incl_tax: number | null tax_rates: TaxServiceRate[] | null + tax_amount: number } export type PricedShippingOption = Partial & diff --git a/packages/medusa/src/types/region.ts b/packages/medusa/src/types/region.ts index 89a2803356..c665634316 100644 --- a/packages/medusa/src/types/region.ts +++ b/packages/medusa/src/types/region.ts @@ -1,6 +1,3 @@ -import { FindConfig } from "./common" -import { Region } from "../models" - export type UpdateRegionInput = { name?: string currency_code?: string @@ -12,6 +9,7 @@ export type UpdateRegionInput = { payment_providers?: string[] fulfillment_providers?: string[] countries?: string[] + includes_tax?: boolean metadata?: Record } @@ -23,5 +21,6 @@ export type CreateRegionInput = { payment_providers: string[] fulfillment_providers: string[] countries: string[] + includes_tax?: boolean metadata?: Record } diff --git a/packages/medusa/src/types/shipping-options.ts b/packages/medusa/src/types/shipping-options.ts index 18c7b8f673..44c29c3856 100644 --- a/packages/medusa/src/types/shipping-options.ts +++ b/packages/medusa/src/types/shipping-options.ts @@ -44,6 +44,7 @@ export type CreateShippingOptionInput = { profile_id: string provider_id: string data: Record + includes_tax?: boolean amount?: number is_return?: boolean @@ -71,4 +72,5 @@ export type UpdateShippingOptionInput = { provider_id?: string profile_id?: string data?: string + includes_tax?: boolean } diff --git a/packages/medusa/src/utils/__tests__/calculate-price-tax-amount.ts b/packages/medusa/src/utils/__tests__/calculate-price-tax-amount.ts new file mode 100644 index 0000000000..06fbd1503c --- /dev/null +++ b/packages/medusa/src/utils/__tests__/calculate-price-tax-amount.ts @@ -0,0 +1,45 @@ +import { calculatePriceTaxAmount } from "../calculate-price-tax-amount" +import { FlagRouter } from "../../utils/flag-router" + +describe("calculatePriceTaxAmount", () => { + describe("Calculate taxes from a given price", () => { + beforeAll(() => { + jest.spyOn(FlagRouter.prototype, "isFeatureEnabled").mockReturnValue(true) + }) + + it("Tax NOT included", () => { + const tax = calculatePriceTaxAmount({ + price: 150, + taxRate: 0.19, + includesTax: false, + }) + + expect(tax).toBeCloseTo(28.5, 2) + + const tax2 = calculatePriceTaxAmount({ + price: 120, + taxRate: 0.17, + }) + + expect(tax2).toBeCloseTo(20.4, 2) + }) + + it("Tax included", () => { + const tax = calculatePriceTaxAmount({ + price: 115, + taxRate: 0.15, + includesTax: true, + }) + + expect(tax).toBeCloseTo(15, 2) + + const tax2 = calculatePriceTaxAmount({ + price: 2150, + taxRate: 0.17, + includesTax: true, + }) + + expect(tax2).toBeCloseTo(312.39, 2) + }) + }) +}) diff --git a/packages/medusa/src/utils/calculate-price-tax-amount.ts b/packages/medusa/src/utils/calculate-price-tax-amount.ts new file mode 100644 index 0000000000..abc9cbb199 --- /dev/null +++ b/packages/medusa/src/utils/calculate-price-tax-amount.ts @@ -0,0 +1,23 @@ +/** + * Return the tax amount that + * - is includes in the price if it is tax inclusive + * - will be applied on to the price if it is tax exclusive + * @param price + * @param includesTax + * @param taxRate + */ +export function calculatePriceTaxAmount({ + price, + includesTax, + taxRate, +}: { + price: number + includesTax?: boolean + taxRate: number +}): number { + if (includesTax) { + return (taxRate * price) / (1 + taxRate) + } + + return price * taxRate +} diff --git a/packages/medusa/src/utils/feature-flag-decorators.ts b/packages/medusa/src/utils/feature-flag-decorators.ts index 7f8a882a5e..6caa082a8c 100644 --- a/packages/medusa/src/utils/feature-flag-decorators.ts +++ b/packages/medusa/src/utils/feature-flag-decorators.ts @@ -1,9 +1,5 @@ - import { getConfigFile } from "medusa-core-utils" import { Column, ColumnOptions, Entity, EntityOptions } from "typeorm" -import featureFlagsLoader from "../loaders/feature-flags" -import path from "path" -import { ConfigModule } from "../types/global" -import { FlagRouter } from "./flag-router" +import { featureFlagRouter } from "../loaders/feature-flags" /** * If that file is required in a non node environment then the setImmediate timer does not exists. @@ -28,8 +24,6 @@ export function FeatureFlagColumn( ): PropertyDecorator { return function (target, propertyName) { setImmediate_((): any => { - const featureFlagRouter = getFeatureFlagRouter() - if (!featureFlagRouter.isFeatureEnabled(featureFlag)) { return } @@ -45,8 +39,6 @@ export function FeatureFlagDecorators( ): PropertyDecorator { return function (target, propertyName) { setImmediate_((): any => { - const featureFlagRouter = getFeatureFlagRouter() - if (!featureFlagRouter.isFeatureEnabled(featureFlag)) { return } @@ -65,19 +57,8 @@ export function FeatureFlagEntity( ): ClassDecorator { return function (target: Function): void { target["isFeatureEnabled"] = function (): boolean { - const featureFlagRouter = getFeatureFlagRouter() - return featureFlagRouter.isFeatureEnabled(featureFlag) } Entity(name, options)(target) } } - -function getFeatureFlagRouter(): FlagRouter { - const { configModule } = getConfigFile( - path.resolve("."), - `medusa-config` - ) as { configModule: ConfigModule } - - return featureFlagsLoader(configModule) -} diff --git a/packages/medusa/src/utils/flag-router.ts b/packages/medusa/src/utils/flag-router.ts index 12bb8f0229..974377221a 100644 --- a/packages/medusa/src/utils/flag-router.ts +++ b/packages/medusa/src/utils/flag-router.ts @@ -1,7 +1,7 @@ import { FeatureFlagsResponse, IFlagRouter } from "../types/feature-flags" export class FlagRouter implements IFlagRouter { - private flags: Record = {} + private readonly flags: Record = {} constructor(flags: Record) { this.flags = flags diff --git a/packages/medusa/src/utils/index.ts b/packages/medusa/src/utils/index.ts index d56c80bac0..e353050941 100644 --- a/packages/medusa/src/utils/index.ts +++ b/packages/medusa/src/utils/index.ts @@ -4,3 +4,4 @@ export * from "./validate-id" export * from "./generate-entity-id" export * from "./remove-undefined-properties" export * from "./is-defined" +export * from "./calculate-price-tax-amount"