From df66378535727152bb329c71c38d614e5b642599 Mon Sep 17 00:00:00 2001 From: Adrien de Peretti Date: Wed, 27 Jul 2022 18:54:05 +0200 Subject: [PATCH] feat(medusa): Attach or update cart sales channel (#1873) What Allow to create a cart with a sales channel, otherwise the default one is attached. Also allow to update the sales channel on an existing cart and in that case the line items that does not belongs to the new sales channel attached are removed How Updating existing end points and service method to integrate the new requirements Tests Add new integration tests Fixes CORE-270 Fixes CORE-272 Co-authored-by: Oliver Windall Juhl <59018053+olivermrbl@users.noreply.github.com> --- .changeset/lovely-news-attend.md | 7 + .../__snapshots__/sales-channels.js.snap | 12 - .../api/__tests__/admin/sales-channels.js | 44 +-- .../__snapshots__/sales-channels.js.snap | 13 + .../store/__snapshots__/swaps.js.snap | 4 +- .../api/__tests__/store/orders.js | 5 - .../api/__tests__/store/sales-channels.js | 298 ++++++++++++++++++ .../api/factories/simple-cart-factory.ts | 2 +- .../factories/simple-sales-channel-factory.ts | 10 +- .../store/carts/__tests__/create-cart.js | 2 +- .../src/api/routes/store/carts/create-cart.ts | 58 ++-- .../src/api/routes/store/carts/index.ts | 16 +- .../src/api/routes/store/carts/update-cart.ts | 36 +-- .../medusa/src/services/__tests__/cart.js | 20 ++ packages/medusa/src/services/cart.ts | 189 ++++++++--- packages/medusa/src/services/product.ts | 39 ++- packages/medusa/src/types/cart.ts | 9 +- 17 files changed, 595 insertions(+), 169 deletions(-) create mode 100644 .changeset/lovely-news-attend.md create mode 100644 integration-tests/api/__tests__/store/__snapshots__/sales-channels.js.snap create mode 100644 integration-tests/api/__tests__/store/sales-channels.js diff --git a/.changeset/lovely-news-attend.md b/.changeset/lovely-news-attend.md new file mode 100644 index 0000000000..43c9b7dd05 --- /dev/null +++ b/.changeset/lovely-news-attend.md @@ -0,0 +1,7 @@ +--- +"@medusajs/medusa": patch +--- + +Adds support for: +- Attaching Sales Channel to cart as part of creation +- Updating Sales Channel on a cart and removing inapplicable line items diff --git a/integration-tests/api/__tests__/admin/__snapshots__/sales-channels.js.snap b/integration-tests/api/__tests__/admin/__snapshots__/sales-channels.js.snap index 38b5e2546c..0beb306585 100644 --- a/integration-tests/api/__tests__/admin/__snapshots__/sales-channels.js.snap +++ b/integration-tests/api/__tests__/admin/__snapshots__/sales-channels.js.snap @@ -110,18 +110,6 @@ Object { } `; -exports[`sales channels GET /store/cart/:id with saleschannel returns cart with sales channel for single cart 1`] = ` -Object { - "created_at": Any, - "deleted_at": null, - "description": "test description", - "id": Any, - "is_disabled": false, - "name": "test name", - "updated_at": Any, -} -`; - exports[`sales channels POST /admin/sales-channels successfully creates a sales channel 1`] = ` Object { "sales_channel": ObjectContaining { diff --git a/integration-tests/api/__tests__/admin/sales-channels.js b/integration-tests/api/__tests__/admin/sales-channels.js index 8c40fdc192..b6b42a45e2 100644 --- a/integration-tests/api/__tests__/admin/sales-channels.js +++ b/integration-tests/api/__tests__/admin/sales-channels.js @@ -311,11 +311,8 @@ describe("sales channels", () => { await simpleSalesChannelFactory(dbConnection, { name: "Default channel", id: "test-channel", + is_default: true, }) - - await dbConnection.manager.query( - `UPDATE store SET default_sales_channel_id = 'test-channel'` - ) } catch (e) { console.error(e) } @@ -620,45 +617,6 @@ describe("sales channels", () => { }) }) - describe("GET /store/cart/:id with saleschannel", () => { - let cart - beforeEach(async () => { - try { - await adminSeeder(dbConnection) - - cart = await simpleCartFactory(dbConnection, { - sales_channel: { - name: "test name", - description: "test description", - }, - }) - } catch (err) { - console.log(err) - } - }) - - afterEach(async () => { - const db = useDb() - await db.teardown() - }) - - it("returns cart with sales channel for single cart", async () => { - const api = useApi() - - const response = await api.get(`/store/carts/${cart.id}`, adminReqConfig) - - expect(response.data.cart.sales_channel).toBeTruthy() - expect(response.data.cart.sales_channel).toMatchSnapshot({ - id: expect.any(String), - name: "test name", - description: "test description", - is_disabled: false, - created_at: expect.any(String), - updated_at: expect.any(String), - }) - }) - }) - describe("DELETE /admin/sales-channels/:id/products/batch", () => { let salesChannel let product diff --git a/integration-tests/api/__tests__/store/__snapshots__/sales-channels.js.snap b/integration-tests/api/__tests__/store/__snapshots__/sales-channels.js.snap new file mode 100644 index 0000000000..24ecef7c7f --- /dev/null +++ b/integration-tests/api/__tests__/store/__snapshots__/sales-channels.js.snap @@ -0,0 +1,13 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`sales channels GET /store/cart/:id returns cart with sales channel for single cart 1`] = ` +Object { + "created_at": Any, + "deleted_at": null, + "description": "test description", + "id": Any, + "is_disabled": false, + "name": "test name", + "updated_at": Any, +} +`; diff --git a/integration-tests/api/__tests__/store/__snapshots__/swaps.js.snap b/integration-tests/api/__tests__/store/__snapshots__/swaps.js.snap index 654c95ac5e..f1397e5121 100644 --- a/integration-tests/api/__tests__/store/__snapshots__/swaps.js.snap +++ b/integration-tests/api/__tests__/store/__snapshots__/swaps.js.snap @@ -83,7 +83,7 @@ Object { "cart": Object { "billing_address_id": "test-billing-address", "completed_at": null, - "context": null, + "context": Object {}, "created_at": Any, "customer_id": "test-customer", "deleted_at": null, @@ -259,7 +259,7 @@ Object { "cart": Object { "billing_address_id": "test-billing-address", "completed_at": null, - "context": null, + "context": Object {}, "created_at": Any, "customer_id": "test-customer", "deleted_at": null, diff --git a/integration-tests/api/__tests__/store/orders.js b/integration-tests/api/__tests__/store/orders.js index 12112307f9..c6b6922d32 100644 --- a/integration-tests/api/__tests__/store/orders.js +++ b/integration-tests/api/__tests__/store/orders.js @@ -6,12 +6,7 @@ const { ShippingProfile, Product, ProductVariant, - MoneyAmount, LineItem, - Payment, - Cart, - ShippingMethod, - Swap, } = require("@medusajs/medusa") const setupServer = require("../../../helpers/setup-server") diff --git a/integration-tests/api/__tests__/store/sales-channels.js b/integration-tests/api/__tests__/store/sales-channels.js new file mode 100644 index 0000000000..3c641f202c --- /dev/null +++ b/integration-tests/api/__tests__/store/sales-channels.js @@ -0,0 +1,298 @@ +const path = require("path") + +const { useApi } = require("../../../helpers/use-api") +const { useDb } = require("../../../helpers/use-db") + +const adminSeeder = require("../../helpers/admin-seeder") +const { + simpleSalesChannelFactory, + simpleCartFactory, simpleRegionFactory, simpleProductFactory, +} = require("../../factories") + +const startServerWithEnvironment = + require("../../../helpers/start-server-with-environment").default + +const adminReqConfig = { + headers: { + Authorization: "Bearer test_token", + }, +} + +jest.setTimeout(50000) + +describe("sales channels", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_SALES_CHANNELS: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("POST /store/cart/", () => { + let salesChannel + let disabledSalesChannel + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + await simpleRegionFactory(dbConnection, { + name: "Test region", + tax_rate: 0, + }) + await simpleSalesChannelFactory(dbConnection, { + name: "Default Sales Channel", + description: "Created by Medusa", + is_default: true + }) + disabledSalesChannel = await simpleSalesChannelFactory(dbConnection, { + name: "disabled cart sales channel", + description: "disabled cart sales channel description", + is_disabled: true, + }) + salesChannel = await simpleSalesChannelFactory(dbConnection, { + name: "cart sales channel", + description: "cart sales channel description", + }) + } catch (err) { + console.log(err) + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("returns a cart with the default sales channel", async () => { + const api = useApi() + + const response = await api.post(`/store/carts`, {}, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toEqual( + expect.objectContaining({ + name: "Default Sales Channel", + description: "Created by Medusa", + }) + ) + }) + + it("returns a cart with the given sales channel", async () => { + const api = useApi() + + const response = await api.post(`/store/carts`, { sales_channel_id: salesChannel.id }, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toEqual( + expect.objectContaining({ + name: salesChannel.name, + description: salesChannel.description, + }) + ) + }) + + it("throw if the given sales channel is disabled", async () => { + const api = useApi() + + const err = await api.post( + `/store/carts`, + { sales_channel_id: disabledSalesChannel.id }, + adminReqConfig + ).catch(err => err) + + expect(err.response.status).toEqual(400) + expect(err.response.data.message).toBe(`Unable to assign the cart to a disabled Sales Channel "disabled cart sales channel"`) + }) + }) + + describe("POST /store/cart/:id", () => { + let salesChannel1, salesChannel2, disabledSalesChannel + let product1, product2 + let cart + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + await simpleRegionFactory(dbConnection, { + name: "Test region", + currency_code: "usd", + tax_rate: 0, + }) + + salesChannel1 = await simpleSalesChannelFactory(dbConnection, { + name: "salesChannel1", + description: "salesChannel1", + }) + salesChannel2 = await simpleSalesChannelFactory(dbConnection, { + name: "salesChannel2", + description: "salesChannel2", + }) + disabledSalesChannel = await simpleSalesChannelFactory(dbConnection, { + name: "disabled cart sales channel", + description: "disabled cart sales channel description", + is_disabled: true, + }) + + product1 = await simpleProductFactory( + dbConnection, + { + title: "prod 1", + sales_channels: [salesChannel1], + variants: [ + { + id: "test-variant", + prices: [ + { + amount: 50, + currency: "usd", + variant_id: "test-variant", + }, + ], + }, + ], + }, + ) + product2 = await simpleProductFactory( + dbConnection, + { + sales_channels: [salesChannel2], + variants: [ + { + id: "test-variant-2", + prices: [ + { + amount: 100, + currency: "usd", + variant_id: "test-variant-2", + }, + ], + }, + ], + }, + ) + + cart = await simpleCartFactory( + dbConnection, + { + sales_channel: salesChannel1, + line_items: [ + { + variant_id: "test-variant", + unit_price: 50, + }, + ], + }, + ) + } catch (err) { + console.log(err) + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it( + "updates a cart sales channels should remove the items that does not belongs to the new sales channel", + async () => { + const api = useApi() + + let response = await api.get(`/store/carts/${cart.id}`, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toEqual( + expect.objectContaining({ + name: salesChannel1.name, + description: salesChannel1.description, + }) + ) + expect(response.data.cart.items.length).toBe(1) + expect(response.data.cart.items[0].variant.product).toEqual( + expect.objectContaining({ + id: product1.id, + title: product1.title, + }) + ) + + response = await api.post(`/store/carts/${cart.id}`, { sales_channel_id: salesChannel2.id }, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toEqual( + expect.objectContaining({ + name: salesChannel2.name, + description: salesChannel2.description, + }) + ) + expect(response.data.cart.items.length).toBe(0) + } + ) + + it("throw if the given sales channel is disabled", async () => { + const api = useApi() + + const err = await api.post( + `/store/carts/${cart.id}`, + { sales_channel_id: disabledSalesChannel.id }, + adminReqConfig + ).catch(err => err) + + expect(err.response.status).toEqual(400) + expect(err.response.data.message).toBe("Unable to assign the cart to a disabled Sales Channel \"disabled cart sales channel\"") + }) + }) + + describe("GET /store/cart/:id", () => { + let cart + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + + cart = await simpleCartFactory(dbConnection, { + sales_channel: { + name: "test name", + description: "test description", + }, + }) + } catch (err) { + console.log(err) + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("returns cart with sales channel for single cart", async () => { + const api = useApi() + + const response = await api.get(`/store/carts/${cart.id}`, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toMatchSnapshot({ + id: expect.any(String), + name: "test name", + description: "test description", + is_disabled: false, + created_at: expect.any(String), + updated_at: expect.any(String), + }) + }) + }) +}) diff --git a/integration-tests/api/factories/simple-cart-factory.ts b/integration-tests/api/factories/simple-cart-factory.ts index 49d353e772..742d4c84b4 100644 --- a/integration-tests/api/factories/simple-cart-factory.ts +++ b/integration-tests/api/factories/simple-cart-factory.ts @@ -34,7 +34,7 @@ export type CartFactoryData = { export const simpleCartFactory = async ( connection: Connection, data: CartFactoryData = {}, - seed: number + seed?: number ): Promise => { if (typeof seed !== "undefined") { faker.seed(seed) diff --git a/integration-tests/api/factories/simple-sales-channel-factory.ts b/integration-tests/api/factories/simple-sales-channel-factory.ts index cf7f0f88c3..095efc7c8e 100644 --- a/integration-tests/api/factories/simple-sales-channel-factory.ts +++ b/integration-tests/api/factories/simple-sales-channel-factory.ts @@ -8,6 +8,7 @@ export type SalesChannelFactoryData = { is_disabled?: boolean id?: string products?: Product[], + is_default?: boolean } export const simpleSalesChannelFactory = async ( @@ -36,12 +37,19 @@ export const simpleSalesChannelFactory = async ( for (const product of data.products) { promises.push( manager.query(` - INSERT INTO product_sales_channel (product_id, sales_channel_id) VALUES ('${product.id}', '${salesChannel.id}'); + INSERT INTO product_sales_channel (product_id, sales_channel_id) + VALUES ('${product.id}', '${salesChannel.id}'); `) ) } await Promise.all(promises) } + if (data.is_default) { + await manager.query( + `UPDATE store SET default_sales_channel_id = '${salesChannel.id}'` + ) + } + return salesChannel } diff --git a/packages/medusa/src/api/routes/store/carts/__tests__/create-cart.js b/packages/medusa/src/api/routes/store/carts/__tests__/create-cart.js index 203bebcb51..4b0735e710 100644 --- a/packages/medusa/src/api/routes/store/carts/__tests__/create-cart.js +++ b/packages/medusa/src/api/routes/store/carts/__tests__/create-cart.js @@ -25,12 +25,12 @@ describe("POST /store/carts", () => { it("calls CartService create", () => { expect(CartServiceMock.create).toHaveBeenCalledTimes(1) expect(CartServiceMock.create).toHaveBeenCalledWith({ - region_id: IdMap.getId("testRegion"), context: { ip: "::ffff:127.0.0.1", user_agent: "node-superagent/3.8.3", clientId: "test", }, + region_id: IdMap.getId("testRegion"), }) }) diff --git a/packages/medusa/src/api/routes/store/carts/create-cart.ts b/packages/medusa/src/api/routes/store/carts/create-cart.ts index 4fdaa9d9c8..1bb8e54f8c 100644 --- a/packages/medusa/src/api/routes/store/carts/create-cart.ts +++ b/packages/medusa/src/api/routes/store/carts/create-cart.ts @@ -11,11 +11,11 @@ import { MedusaError } from "medusa-core-utils" import reqIp from "request-ip" import { EntityManager } from "typeorm" -import { defaultStoreCartFields, defaultStoreCartRelations } from "." -import { CartService, LineItemService } from "../../../../services" -import { validator } from "../../../../utils/validator" -import { AddressPayload } from "../../../../types/common" +import { defaultStoreCartFields, defaultStoreCartRelations, } from "." +import { CartService, LineItemService, RegionService } from "../../../../services" import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" +import SalesChannelFeatureFlag from "../../../../loaders/feature-flags/sales-channels"; +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators"; /** * @oas [post] /carts @@ -33,6 +33,9 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" * region_id: * type: string * description: The id of the Region to create the Cart in. + * sales_channel_id: + * type: string + * description: [EXPERIMENTAL] The id of the Sales channel to create the Cart in. * country_code: * type: string * description: "The 2 character ISO country code to create the Cart in." @@ -63,7 +66,7 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" * $ref: "#/components/schemas/cart" */ export default async (req, res) => { - const validated = await validator(StorePostCartReq, req.body) + const validated = req.validatedBody as StorePostCartReq const reqContext = { ip: reqIp.getClientIp(req), @@ -72,18 +75,17 @@ export default async (req, res) => { const lineItemService: LineItemService = req.scope.resolve("lineItemService") const cartService: CartService = req.scope.resolve("cartService") - + const regionService: RegionService = req.scope.resolve("regionService") const entityManager: EntityManager = req.scope.resolve("manager") await entityManager.transaction(async (manager) => { - // Add a default region if no region has been specified let regionId: string - if (typeof validated.region_id !== "undefined") { regionId = validated.region_id } else { - const regionService = req.scope.resolve("regionService") - const regions = await regionService.withTransaction(manager).list({}) + const regions = await regionService + .withTransaction(manager) + .list({}) if (!regions?.length) { throw new MedusaError( @@ -95,36 +97,15 @@ export default async (req, res) => { regionId = regions[0].id } - const toCreate: { - region_id: string - context: object - customer_id?: string - email?: string - shipping_address?: Partial - } = { - region_id: regionId, + let cart = await cartService.withTransaction(manager).create({ + ...validated, context: { ...reqContext, ...validated.context, }, - } + region_id: regionId, + }) - if (req.user && req.user.customer_id) { - const customerService = req.scope.resolve("customerService") - const customer = await customerService - .withTransaction(manager) - .retrieve(req.user.customer_id) - toCreate["customer_id"] = customer.id - toCreate["email"] = customer.email - } - - if (validated.country_code) { - toCreate["shipping_address"] = { - country_code: validated.country_code.toLowerCase(), - } - } - - let cart = await cartService.withTransaction(manager).create(toCreate) if (validated.items) { await Promise.all( validated.items.map(async (i) => { @@ -160,6 +141,7 @@ export class Item { @IsInt() quantity: number } + export class StorePostCartReq { @IsOptional() @IsString() @@ -177,4 +159,10 @@ export class StorePostCartReq { @IsOptional() context?: object + + @FeatureFlagDecorators(SalesChannelFeatureFlag.key, [ + IsString(), + IsOptional(), + ]) + sales_channel_id?: string } diff --git a/packages/medusa/src/api/routes/store/carts/index.ts b/packages/medusa/src/api/routes/store/carts/index.ts index ccecd4db8f..f92727b911 100644 --- a/packages/medusa/src/api/routes/store/carts/index.ts +++ b/packages/medusa/src/api/routes/store/carts/index.ts @@ -2,7 +2,9 @@ import { Router } from "express" import "reflect-metadata" import { Cart, Order, Swap } from "../../../../" import { DeleteResponse, EmptyQueryParams } from "../../../../types/common" -import middlewares, { transformQuery } from "../../../middlewares" +import middlewares, { transformBody, transformQuery } from "../../../middlewares" +import { StorePostCartsCartReq } from "./update-cart"; +import { StorePostCartReq } from "./create-cart"; const route = Router() export default (app, container) => { @@ -11,9 +13,8 @@ export default (app, container) => { app.use("/carts", route) - const relations = [...defaultStoreCartRelations] if (featureFlagRouter.isFeatureEnabled("sales_channels")) { - relations.push("sales_channel") + defaultStoreCartRelations.push("sales_channel") } // Inject plugin routes @@ -25,7 +26,7 @@ export default (app, container) => { route.get( "/:id", transformQuery(EmptyQueryParams, { - defaultRelations: relations, + defaultRelations: defaultStoreCartRelations, defaultFields: defaultStoreCartFields, isList: false, }), @@ -35,10 +36,15 @@ export default (app, container) => { route.post( "/", middlewareService.usePreCartCreation(), + transformBody(StorePostCartReq), middlewares.wrap(require("./create-cart").default) ) - route.post("/:id", middlewares.wrap(require("./update-cart").default)) + route.post( + "/:id", + transformBody(StorePostCartsCartReq), + middlewares.wrap(require("./update-cart").default) + ) route.post( "/:id/complete", diff --git a/packages/medusa/src/api/routes/store/carts/update-cart.ts b/packages/medusa/src/api/routes/store/carts/update-cart.ts index 9459d712da..8803681436 100644 --- a/packages/medusa/src/api/routes/store/carts/update-cart.ts +++ b/packages/medusa/src/api/routes/store/carts/update-cart.ts @@ -8,11 +8,11 @@ import { } from "class-validator" import { defaultStoreCartFields, defaultStoreCartRelations } from "." import { CartService } from "../../../../services" -import { CartUpdateProps } from "../../../../types/cart" import { AddressPayload } from "../../../../types/common" -import { validator } from "../../../../utils/validator" import { IsType } from "../../../../utils/validators/is-type" import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators"; +import SalesChannelFeatureFlag from "../../../../loaders/feature-flags/sales-channels"; /** * @oas [post] /store/carts/{id} @@ -35,6 +35,9 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" * email: * type: string * description: "An email to be used on the Cart." + * sales_channel_id: + * type: string + * description: The id of the Sales channel to update the Cart with. * billing_address: * description: "The Address to be used for billing purposes." * anyOf: @@ -83,30 +86,11 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" */ export default async (req, res) => { const { id } = req.params - - const validated = await validator(StorePostCartsCartReq, req.body) + const validated = req.validatedBody as StorePostCartsCartReq const cartService: CartService = req.scope.resolve("cartService") + await cartService.update(id, validated) - // Update the cart - const { shipping_address, billing_address, ...rest } = validated - - const cartDataToUpdate: CartUpdateProps = { ...rest } - if (typeof shipping_address === "string") { - cartDataToUpdate.shipping_address_id = shipping_address - } else { - cartDataToUpdate.shipping_address = shipping_address - } - - if (typeof billing_address === "string") { - cartDataToUpdate.billing_address_id = billing_address - } else { - cartDataToUpdate.billing_address = billing_address - } - - await cartService.update(id, cartDataToUpdate) - - // If the cart has payment sessions update these const updated = await cartService.retrieve(id, { relations: ["payment_sessions", "shipping_methods"], }) @@ -173,4 +157,10 @@ export class StorePostCartsCartReq { @IsOptional() context?: object + + @FeatureFlagDecorators(SalesChannelFeatureFlag.key, [ + IsString(), + IsOptional(), + ]) + sales_channel_id?: string } diff --git a/packages/medusa/src/services/__tests__/cart.js b/packages/medusa/src/services/__tests__/cart.js index a8b99b5d03..f134cb98aa 100644 --- a/packages/medusa/src/services/__tests__/cart.js +++ b/packages/medusa/src/services/__tests__/cart.js @@ -4,6 +4,7 @@ import { IdMap, MockManager, MockRepository } from "medusa-test-utils" import CartService from "../cart" import { InventoryServiceMock } from "../__mocks__/inventory" import { LineItemAdjustmentServiceMock } from "../__mocks__/line-item-adjustment" +import { FlagRouter } from "../../utils/flag-router"; const eventBusService = { emit: jest.fn(), @@ -46,6 +47,7 @@ describe("CartService", () => { manager: MockManager, totalsService, cartRepository, + featureFlagRouter: new FlagRouter({}), }) result = await cartService.retrieve(IdMap.getId("emptyCart")) }) @@ -76,6 +78,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -136,6 +139,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -239,6 +243,7 @@ describe("CartService", () => { customerService, regionService, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -271,6 +276,7 @@ describe("CartService", () => { customer_id: IdMap.getId("customer"), email: "email@test.com", customer: expect.any(Object), + context: expect.any(Object), }) expect(cartRepository.save).toHaveBeenCalledTimes(1) @@ -315,6 +321,7 @@ describe("CartService", () => { expect(cartRepository.create).toHaveBeenCalledTimes(1) expect(cartRepository.create).toHaveBeenCalledWith({ + context: {}, region_id: IdMap.getId("testRegion"), shipping_address: { first_name: "LeBron", @@ -400,6 +407,7 @@ describe("CartService", () => { shippingOptionService, inventoryService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -614,6 +622,7 @@ describe("CartService", () => { shippingOptionService, eventBusService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -719,6 +728,7 @@ describe("CartService", () => { cartRepository, totalsService, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -806,6 +816,7 @@ describe("CartService", () => { eventBusService, inventoryService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -887,6 +898,7 @@ describe("CartService", () => { cartRepository, eventBusService, customerService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -967,6 +979,7 @@ describe("CartService", () => { cartRepository, addressRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1028,6 +1041,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1182,6 +1196,7 @@ describe("CartService", () => { eventBusService, paymentSessionRepository: MockRepository(), priceSelectionStrategy: priceSelectionStrat, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1269,6 +1284,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1383,6 +1399,7 @@ describe("CartService", () => { cartRepository, paymentProviderService, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1573,6 +1590,7 @@ describe("CartService", () => { lineItemService, eventBusService, customShippingOptionService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1927,6 +1945,7 @@ describe("CartService", () => { discountService, eventBusService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, + featureFlagRouter: new FlagRouter({}), }) beforeEach(async () => { @@ -2214,6 +2233,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(async () => { diff --git a/packages/medusa/src/services/cart.ts b/packages/medusa/src/services/cart.ts index 9ef89f349d..24da409701 100644 --- a/packages/medusa/src/services/cart.ts +++ b/packages/medusa/src/services/cart.ts @@ -3,14 +3,18 @@ import { MedusaError, Validator } from "medusa-core-utils" import { DeepPartial, EntityManager, In } from "typeorm" import { TransactionBaseService } from "../interfaces" import { IPriceSelectionStrategy } from "../interfaces/price-selection-strategy" -import { DiscountRuleType } from "../models" -import { Address } from "../models/address" -import { Cart } from "../models/cart" -import { CustomShippingOption } from "../models/custom-shipping-option" -import { Customer } from "../models/customer" -import { Discount } from "../models/discount" -import { LineItem } from "../models/line-item" -import { ShippingMethod } from "../models/shipping-method" +import { + DiscountRuleType, + Address, + Cart, + CustomShippingOption, + Customer, + Discount, + LineItem, + ShippingMethod, + User, + SalesChannel, +} from "../models" import { AddressRepository } from "../repositories/address" import { CartRepository } from "../repositories/cart" import { LineItemRepository } from "../repositories/line-item" @@ -39,6 +43,10 @@ import RegionService from "./region" import ShippingOptionService from "./shipping-option" import TaxProviderService from "./tax-provider" import TotalsService from "./totals" +import SalesChannelFeatureFlag from "../loaders/feature-flags/sales-channels" +import { FlagRouter } from "../utils/flag-router" +import SalesChannelService from "./sales-channel" +import StoreService from "./store" type InjectedDependencies = { manager: EntityManager @@ -48,9 +56,12 @@ type InjectedDependencies = { paymentSessionRepository: typeof PaymentSessionRepository lineItemRepository: typeof LineItemRepository eventBusService: EventBusService + salesChannelService: SalesChannelService taxProviderService: TaxProviderService paymentProviderService: PaymentProviderService productService: ProductService + storeService: StoreService + featureFlagRouter: FlagRouter productVariantService: ProductVariantService regionService: RegionService lineItemService: LineItemService @@ -90,6 +101,9 @@ class CartService extends TransactionBaseService { protected readonly eventBus_: EventBusService protected readonly productVariantService_: ProductVariantService protected readonly productService_: ProductService + protected readonly featureFlagRouter_: FlagRouter + protected readonly storeService_: StoreService + protected readonly salesChannelService_: SalesChannelService protected readonly regionService_: RegionService protected readonly lineItemService_: LineItemService protected readonly paymentProviderService_: PaymentProviderService @@ -127,6 +141,9 @@ class CartService extends TransactionBaseService { customShippingOptionService, lineItemAdjustmentService, priceSelectionStrategy, + salesChannelService, + featureFlagRouter, + storeService, }: InjectedDependencies) { // eslint-disable-next-line prefer-rest-params super(arguments[0]) @@ -153,6 +170,9 @@ class CartService extends TransactionBaseService { this.taxProviderService_ = taxProviderService this.lineItemAdjustmentService_ = lineItemAdjustmentService this.priceSelectionStrategy_ = priceSelectionStrategy + this.salesChannelService_ = salesChannelService + this.featureFlagRouter_ = featureFlagRouter + this.storeService_ = storeService } protected transformQueryForTotals_( @@ -331,15 +351,17 @@ class CartService extends TransactionBaseService { this.addressRepository_ ) - const { region_id } = data - if (!region_id) { - throw new MedusaError( - MedusaError.Types.INVALID_DATA, - `A region_id must be provided when creating a cart` - ) + const rawCart: DeepPartial = { + context: data.context ?? {}, } - const rawCart: DeepPartial = {} + if ( + this.featureFlagRouter_.isFeatureEnabled(SalesChannelFeatureFlag.key) + ) { + rawCart.sales_channel_id = ( + await this.getValidatedSalesChannel(data.sales_channel_id) + ).id + } if (data.email) { const customer = await this.createOrFetchUserFromEmail_(data.email) @@ -348,15 +370,21 @@ class CartService extends TransactionBaseService { rawCart.email = customer.email } + if (!data.region_id) { + throw new MedusaError( + MedusaError.Types.INVALID_DATA, + `A region_id must be provided when creating a cart` + ) + } + + rawCart.region_id = data.region_id const region = await this.regionService_ .withTransaction(transactionManager) - .retrieve(region_id, { + .retrieve(data.region_id, { relations: ["countries"], }) const regCountries = region.countries.map(({ iso_2 }) => iso_2) - rawCart.region_id = region.id - if (!data.shipping_address && !data.shipping_address_id) { if (region.countries.length === 1) { rawCart.shipping_address = addressRepo.create({ @@ -399,10 +427,8 @@ class CartService extends TransactionBaseService { typeof data[remainingField] !== "undefined" && remainingField !== "object" ) { - /* TODO: See how to fix the error TS2590 properly while keeping the DeepPartial type */ - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore - rawCart[remainingField] = data[remainingField] + const key = remainingField as string + rawCart[key] = data[remainingField] } } @@ -418,6 +444,32 @@ class CartService extends TransactionBaseService { ) } + protected async getValidatedSalesChannel( + salesChannelId?: string + ): Promise { + let salesChannel: SalesChannel + if (typeof salesChannelId !== "undefined") { + salesChannel = await this.salesChannelService_ + .withTransaction(this.manager_) + .retrieve(salesChannelId) + } else { + salesChannel = ( + await this.storeService_.withTransaction(this.manager_).retrieve({ + relations: ["default_sales_channel"], + }) + ).default_sales_channel + } + + if (salesChannel.is_disabled) { + throw new MedusaError( + MedusaError.Types.INVALID_DATA, + `Unable to assign the cart to a disabled Sales Channel "${salesChannel.name}"` + ) + } + + return salesChannel + } + /** * Removes a line item from the cart. * @param cartId - the id of the cart that we will remove from @@ -721,6 +773,30 @@ class CartService extends TransactionBaseService { const cartRepo = transactionManager.getCustomRepository( this.cartRepository_ ) + const relations = [ + "items", + "shipping_methods", + "shipping_address", + "billing_address", + "gift_cards", + "customer", + "region", + "payment_sessions", + "region.countries", + "discounts", + "discounts.rule", + "discounts.regions", + ] + + if ( + this.featureFlagRouter_.isFeatureEnabled( + SalesChannelFeatureFlag.key + ) && + data.sales_channel_id + ) { + relations.push("items.variant", "items.variant.product") + } + const cart = await this.retrieve(cartId, { select: [ "subtotal", @@ -729,20 +805,7 @@ class CartService extends TransactionBaseService { "discount_total", "total", ], - relations: [ - "items", - "shipping_methods", - "shipping_address", - "billing_address", - "gift_cards", - "customer", - "region", - "payment_sessions", - "region.countries", - "discounts", - "discounts.rule", - "discounts.regions", - ], + relations, }) if (data.customer_id) { @@ -764,8 +827,12 @@ class CartService extends TransactionBaseService { } if (typeof data.region_id !== "undefined") { + const shippingAddress = + typeof data.shipping_address !== "string" + ? data.shipping_address + : {} const countryCode = - (data.country_code || data.shipping_address?.country_code) ?? null + (data.country_code || shippingAddress?.country_code) ?? null await this.setRegion_(cart, data.region_id, countryCode) } @@ -784,6 +851,18 @@ class CartService extends TransactionBaseService { await this.updateShippingAddress_(cart, shippingAddress, addrRepo) } + if ( + this.featureFlagRouter_.isFeatureEnabled(SalesChannelFeatureFlag.key) + ) { + if ( + typeof data.sales_channel_id !== "undefined" && + data.sales_channel_id != cart.sales_channel_id + ) { + await this.onSalesChannelChange(cart, data.sales_channel_id) + cart.sales_channel_id = data.sales_channel_id + } + } + if (typeof data.discounts !== "undefined") { const previousDiscounts = [...cart.discounts] cart.discounts.length = 0 @@ -861,6 +940,42 @@ class CartService extends TransactionBaseService { ) } + /** + * Remove the cart line item that does not belongs to the newly assigned sales channel + * @param cart - The cart being updated + * @param newSalesChannelId - The new sales channel being assigned to the cart + * @protected + */ + protected async onSalesChannelChange( + cart: Cart, + newSalesChannelId: string + ): Promise { + await this.getValidatedSalesChannel(newSalesChannelId) + + const productIds = cart.items.map((item) => item.variant.product_id) + const productsToKeep = await this.productService_ + .withTransaction(this.manager_) + .filterProductsBySalesChannel(productIds, newSalesChannelId, { + select: ["id", "sales_channels"], + take: productIds.length, + }) + const productIdsToKeep = new Set( + productsToKeep.map((product) => product.id) + ) + const itemsToRemove = cart.items.filter((item) => { + return !productIdsToKeep.has(item.variant.product_id) + }) + + if (itemsToRemove.length) { + const results = await Promise.all( + itemsToRemove.map((item) => { + return this.removeLineItem(cart.id, item.id) + }) + ) + cart.items = results.pop()?.items ?? [] + } + } + /** * Sets the customer id of a cart * @param cart - the cart to add email to diff --git a/packages/medusa/src/services/product.ts b/packages/medusa/src/services/product.ts index 33b1b2e467..1674632c20 100644 --- a/packages/medusa/src/services/product.ts +++ b/packages/medusa/src/services/product.ts @@ -2,7 +2,13 @@ import { MedusaError } from "medusa-core-utils" import { EntityManager } from "typeorm" import { SearchService } from "." import { TransactionBaseService } from "../interfaces" -import { Product, ProductTag, ProductType, ProductVariant } from "../models" +import { + Product, + ProductTag, + ProductType, + ProductVariant, + SalesChannel, +} from "../models" import { ImageRepository } from "../repositories/image" import { FindWithoutRelationsOptions, @@ -284,6 +290,37 @@ class ProductService extends TransactionBaseService { return product.variants } + async filterProductsBySalesChannel( + productIds: string[], + salesChannelId: string, + config: FindProductConfig = { + skip: 0, + take: 50, + } + ): Promise { + const givenRelations = config.relations ?? [] + const requiredRelations = ["sales_channels"] + const relationsSet = new Set([...givenRelations, ...requiredRelations]) + + const products = await this.list( + { + id: productIds, + }, + { + ...config, + relations: [...relationsSet], + } + ) + const productSalesChannelsMap = new Map( + products.map((product) => [product.id, product.sales_channels]) + ) + return products.filter((product) => { + return productSalesChannelsMap + .get(product.id) + ?.some((sc) => sc.id === salesChannelId) + }) + } + async listTypes(): Promise { const manager = this.manager_ const productTypeRepository = manager.getCustomRepository( diff --git a/packages/medusa/src/types/cart.ts b/packages/medusa/src/types/cart.ts index 624ea605f4..099eff88f0 100644 --- a/packages/medusa/src/types/cart.ts +++ b/packages/medusa/src/types/cart.ts @@ -43,7 +43,7 @@ class Discount { } export type CartCreateProps = { - region_id: string + region_id?: string email?: string billing_address_id?: string billing_address?: Partial @@ -55,6 +55,8 @@ export type CartCreateProps = { type?: CartType context?: object metadata?: object + sales_channel_id?: string + country_code?: string } export type CartUpdateProps = { @@ -63,8 +65,8 @@ export type CartUpdateProps = { email?: string shipping_address_id?: string billing_address_id?: string - billing_address?: AddressPayload - shipping_address?: AddressPayload + billing_address?: AddressPayload | string + shipping_address?: AddressPayload | string completed_at?: Date payment_authorized_at?: Date gift_cards?: GiftCard[] @@ -72,4 +74,5 @@ export type CartUpdateProps = { customer_id?: string context?: object metadata?: Record + sales_channel_id?: string }