diff --git a/.changeset/lovely-news-attend.md b/.changeset/lovely-news-attend.md new file mode 100644 index 0000000000..43c9b7dd05 --- /dev/null +++ b/.changeset/lovely-news-attend.md @@ -0,0 +1,7 @@ +--- +"@medusajs/medusa": patch +--- + +Adds support for: +- Attaching Sales Channel to cart as part of creation +- Updating Sales Channel on a cart and removing inapplicable line items diff --git a/integration-tests/api/__tests__/admin/__snapshots__/sales-channels.js.snap b/integration-tests/api/__tests__/admin/__snapshots__/sales-channels.js.snap index 38b5e2546c..0beb306585 100644 --- a/integration-tests/api/__tests__/admin/__snapshots__/sales-channels.js.snap +++ b/integration-tests/api/__tests__/admin/__snapshots__/sales-channels.js.snap @@ -110,18 +110,6 @@ Object { } `; -exports[`sales channels GET /store/cart/:id with saleschannel returns cart with sales channel for single cart 1`] = ` -Object { - "created_at": Any, - "deleted_at": null, - "description": "test description", - "id": Any, - "is_disabled": false, - "name": "test name", - "updated_at": Any, -} -`; - exports[`sales channels POST /admin/sales-channels successfully creates a sales channel 1`] = ` Object { "sales_channel": ObjectContaining { diff --git a/integration-tests/api/__tests__/admin/sales-channels.js b/integration-tests/api/__tests__/admin/sales-channels.js index 8c40fdc192..b6b42a45e2 100644 --- a/integration-tests/api/__tests__/admin/sales-channels.js +++ b/integration-tests/api/__tests__/admin/sales-channels.js @@ -311,11 +311,8 @@ describe("sales channels", () => { await simpleSalesChannelFactory(dbConnection, { name: "Default channel", id: "test-channel", + is_default: true, }) - - await dbConnection.manager.query( - `UPDATE store SET default_sales_channel_id = 'test-channel'` - ) } catch (e) { console.error(e) } @@ -620,45 +617,6 @@ describe("sales channels", () => { }) }) - describe("GET /store/cart/:id with saleschannel", () => { - let cart - beforeEach(async () => { - try { - await adminSeeder(dbConnection) - - cart = await simpleCartFactory(dbConnection, { - sales_channel: { - name: "test name", - description: "test description", - }, - }) - } catch (err) { - console.log(err) - } - }) - - afterEach(async () => { - const db = useDb() - await db.teardown() - }) - - it("returns cart with sales channel for single cart", async () => { - const api = useApi() - - const response = await api.get(`/store/carts/${cart.id}`, adminReqConfig) - - expect(response.data.cart.sales_channel).toBeTruthy() - expect(response.data.cart.sales_channel).toMatchSnapshot({ - id: expect.any(String), - name: "test name", - description: "test description", - is_disabled: false, - created_at: expect.any(String), - updated_at: expect.any(String), - }) - }) - }) - describe("DELETE /admin/sales-channels/:id/products/batch", () => { let salesChannel let product diff --git a/integration-tests/api/__tests__/store/__snapshots__/sales-channels.js.snap b/integration-tests/api/__tests__/store/__snapshots__/sales-channels.js.snap new file mode 100644 index 0000000000..24ecef7c7f --- /dev/null +++ b/integration-tests/api/__tests__/store/__snapshots__/sales-channels.js.snap @@ -0,0 +1,13 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`sales channels GET /store/cart/:id returns cart with sales channel for single cart 1`] = ` +Object { + "created_at": Any, + "deleted_at": null, + "description": "test description", + "id": Any, + "is_disabled": false, + "name": "test name", + "updated_at": Any, +} +`; diff --git a/integration-tests/api/__tests__/store/__snapshots__/swaps.js.snap b/integration-tests/api/__tests__/store/__snapshots__/swaps.js.snap index 654c95ac5e..f1397e5121 100644 --- a/integration-tests/api/__tests__/store/__snapshots__/swaps.js.snap +++ b/integration-tests/api/__tests__/store/__snapshots__/swaps.js.snap @@ -83,7 +83,7 @@ Object { "cart": Object { "billing_address_id": "test-billing-address", "completed_at": null, - "context": null, + "context": Object {}, "created_at": Any, "customer_id": "test-customer", "deleted_at": null, @@ -259,7 +259,7 @@ Object { "cart": Object { "billing_address_id": "test-billing-address", "completed_at": null, - "context": null, + "context": Object {}, "created_at": Any, "customer_id": "test-customer", "deleted_at": null, diff --git a/integration-tests/api/__tests__/store/orders.js b/integration-tests/api/__tests__/store/orders.js index 12112307f9..c6b6922d32 100644 --- a/integration-tests/api/__tests__/store/orders.js +++ b/integration-tests/api/__tests__/store/orders.js @@ -6,12 +6,7 @@ const { ShippingProfile, Product, ProductVariant, - MoneyAmount, LineItem, - Payment, - Cart, - ShippingMethod, - Swap, } = require("@medusajs/medusa") const setupServer = require("../../../helpers/setup-server") diff --git a/integration-tests/api/__tests__/store/sales-channels.js b/integration-tests/api/__tests__/store/sales-channels.js new file mode 100644 index 0000000000..3c641f202c --- /dev/null +++ b/integration-tests/api/__tests__/store/sales-channels.js @@ -0,0 +1,298 @@ +const path = require("path") + +const { useApi } = require("../../../helpers/use-api") +const { useDb } = require("../../../helpers/use-db") + +const adminSeeder = require("../../helpers/admin-seeder") +const { + simpleSalesChannelFactory, + simpleCartFactory, simpleRegionFactory, simpleProductFactory, +} = require("../../factories") + +const startServerWithEnvironment = + require("../../../helpers/start-server-with-environment").default + +const adminReqConfig = { + headers: { + Authorization: "Bearer test_token", + }, +} + +jest.setTimeout(50000) + +describe("sales channels", () => { + let medusaProcess + let dbConnection + + beforeAll(async () => { + const cwd = path.resolve(path.join(__dirname, "..", "..")) + const [process, connection] = await startServerWithEnvironment({ + cwd, + env: { MEDUSA_FF_SALES_CHANNELS: true }, + verbose: false, + }) + dbConnection = connection + medusaProcess = process + }) + + afterAll(async () => { + const db = useDb() + await db.shutdown() + + medusaProcess.kill() + }) + + describe("POST /store/cart/", () => { + let salesChannel + let disabledSalesChannel + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + await simpleRegionFactory(dbConnection, { + name: "Test region", + tax_rate: 0, + }) + await simpleSalesChannelFactory(dbConnection, { + name: "Default Sales Channel", + description: "Created by Medusa", + is_default: true + }) + disabledSalesChannel = await simpleSalesChannelFactory(dbConnection, { + name: "disabled cart sales channel", + description: "disabled cart sales channel description", + is_disabled: true, + }) + salesChannel = await simpleSalesChannelFactory(dbConnection, { + name: "cart sales channel", + description: "cart sales channel description", + }) + } catch (err) { + console.log(err) + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("returns a cart with the default sales channel", async () => { + const api = useApi() + + const response = await api.post(`/store/carts`, {}, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toEqual( + expect.objectContaining({ + name: "Default Sales Channel", + description: "Created by Medusa", + }) + ) + }) + + it("returns a cart with the given sales channel", async () => { + const api = useApi() + + const response = await api.post(`/store/carts`, { sales_channel_id: salesChannel.id }, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toEqual( + expect.objectContaining({ + name: salesChannel.name, + description: salesChannel.description, + }) + ) + }) + + it("throw if the given sales channel is disabled", async () => { + const api = useApi() + + const err = await api.post( + `/store/carts`, + { sales_channel_id: disabledSalesChannel.id }, + adminReqConfig + ).catch(err => err) + + expect(err.response.status).toEqual(400) + expect(err.response.data.message).toBe(`Unable to assign the cart to a disabled Sales Channel "disabled cart sales channel"`) + }) + }) + + describe("POST /store/cart/:id", () => { + let salesChannel1, salesChannel2, disabledSalesChannel + let product1, product2 + let cart + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + await simpleRegionFactory(dbConnection, { + name: "Test region", + currency_code: "usd", + tax_rate: 0, + }) + + salesChannel1 = await simpleSalesChannelFactory(dbConnection, { + name: "salesChannel1", + description: "salesChannel1", + }) + salesChannel2 = await simpleSalesChannelFactory(dbConnection, { + name: "salesChannel2", + description: "salesChannel2", + }) + disabledSalesChannel = await simpleSalesChannelFactory(dbConnection, { + name: "disabled cart sales channel", + description: "disabled cart sales channel description", + is_disabled: true, + }) + + product1 = await simpleProductFactory( + dbConnection, + { + title: "prod 1", + sales_channels: [salesChannel1], + variants: [ + { + id: "test-variant", + prices: [ + { + amount: 50, + currency: "usd", + variant_id: "test-variant", + }, + ], + }, + ], + }, + ) + product2 = await simpleProductFactory( + dbConnection, + { + sales_channels: [salesChannel2], + variants: [ + { + id: "test-variant-2", + prices: [ + { + amount: 100, + currency: "usd", + variant_id: "test-variant-2", + }, + ], + }, + ], + }, + ) + + cart = await simpleCartFactory( + dbConnection, + { + sales_channel: salesChannel1, + line_items: [ + { + variant_id: "test-variant", + unit_price: 50, + }, + ], + }, + ) + } catch (err) { + console.log(err) + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it( + "updates a cart sales channels should remove the items that does not belongs to the new sales channel", + async () => { + const api = useApi() + + let response = await api.get(`/store/carts/${cart.id}`, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toEqual( + expect.objectContaining({ + name: salesChannel1.name, + description: salesChannel1.description, + }) + ) + expect(response.data.cart.items.length).toBe(1) + expect(response.data.cart.items[0].variant.product).toEqual( + expect.objectContaining({ + id: product1.id, + title: product1.title, + }) + ) + + response = await api.post(`/store/carts/${cart.id}`, { sales_channel_id: salesChannel2.id }, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toEqual( + expect.objectContaining({ + name: salesChannel2.name, + description: salesChannel2.description, + }) + ) + expect(response.data.cart.items.length).toBe(0) + } + ) + + it("throw if the given sales channel is disabled", async () => { + const api = useApi() + + const err = await api.post( + `/store/carts/${cart.id}`, + { sales_channel_id: disabledSalesChannel.id }, + adminReqConfig + ).catch(err => err) + + expect(err.response.status).toEqual(400) + expect(err.response.data.message).toBe("Unable to assign the cart to a disabled Sales Channel \"disabled cart sales channel\"") + }) + }) + + describe("GET /store/cart/:id", () => { + let cart + + beforeEach(async () => { + try { + await adminSeeder(dbConnection) + + cart = await simpleCartFactory(dbConnection, { + sales_channel: { + name: "test name", + description: "test description", + }, + }) + } catch (err) { + console.log(err) + } + }) + + afterEach(async () => { + const db = useDb() + await db.teardown() + }) + + it("returns cart with sales channel for single cart", async () => { + const api = useApi() + + const response = await api.get(`/store/carts/${cart.id}`, adminReqConfig) + + expect(response.data.cart.sales_channel).toBeTruthy() + expect(response.data.cart.sales_channel).toMatchSnapshot({ + id: expect.any(String), + name: "test name", + description: "test description", + is_disabled: false, + created_at: expect.any(String), + updated_at: expect.any(String), + }) + }) + }) +}) diff --git a/integration-tests/api/factories/simple-cart-factory.ts b/integration-tests/api/factories/simple-cart-factory.ts index 49d353e772..742d4c84b4 100644 --- a/integration-tests/api/factories/simple-cart-factory.ts +++ b/integration-tests/api/factories/simple-cart-factory.ts @@ -34,7 +34,7 @@ export type CartFactoryData = { export const simpleCartFactory = async ( connection: Connection, data: CartFactoryData = {}, - seed: number + seed?: number ): Promise => { if (typeof seed !== "undefined") { faker.seed(seed) diff --git a/integration-tests/api/factories/simple-sales-channel-factory.ts b/integration-tests/api/factories/simple-sales-channel-factory.ts index cf7f0f88c3..095efc7c8e 100644 --- a/integration-tests/api/factories/simple-sales-channel-factory.ts +++ b/integration-tests/api/factories/simple-sales-channel-factory.ts @@ -8,6 +8,7 @@ export type SalesChannelFactoryData = { is_disabled?: boolean id?: string products?: Product[], + is_default?: boolean } export const simpleSalesChannelFactory = async ( @@ -36,12 +37,19 @@ export const simpleSalesChannelFactory = async ( for (const product of data.products) { promises.push( manager.query(` - INSERT INTO product_sales_channel (product_id, sales_channel_id) VALUES ('${product.id}', '${salesChannel.id}'); + INSERT INTO product_sales_channel (product_id, sales_channel_id) + VALUES ('${product.id}', '${salesChannel.id}'); `) ) } await Promise.all(promises) } + if (data.is_default) { + await manager.query( + `UPDATE store SET default_sales_channel_id = '${salesChannel.id}'` + ) + } + return salesChannel } diff --git a/packages/medusa/src/api/routes/store/carts/__tests__/create-cart.js b/packages/medusa/src/api/routes/store/carts/__tests__/create-cart.js index 203bebcb51..4b0735e710 100644 --- a/packages/medusa/src/api/routes/store/carts/__tests__/create-cart.js +++ b/packages/medusa/src/api/routes/store/carts/__tests__/create-cart.js @@ -25,12 +25,12 @@ describe("POST /store/carts", () => { it("calls CartService create", () => { expect(CartServiceMock.create).toHaveBeenCalledTimes(1) expect(CartServiceMock.create).toHaveBeenCalledWith({ - region_id: IdMap.getId("testRegion"), context: { ip: "::ffff:127.0.0.1", user_agent: "node-superagent/3.8.3", clientId: "test", }, + region_id: IdMap.getId("testRegion"), }) }) diff --git a/packages/medusa/src/api/routes/store/carts/create-cart.ts b/packages/medusa/src/api/routes/store/carts/create-cart.ts index 4fdaa9d9c8..1bb8e54f8c 100644 --- a/packages/medusa/src/api/routes/store/carts/create-cart.ts +++ b/packages/medusa/src/api/routes/store/carts/create-cart.ts @@ -11,11 +11,11 @@ import { MedusaError } from "medusa-core-utils" import reqIp from "request-ip" import { EntityManager } from "typeorm" -import { defaultStoreCartFields, defaultStoreCartRelations } from "." -import { CartService, LineItemService } from "../../../../services" -import { validator } from "../../../../utils/validator" -import { AddressPayload } from "../../../../types/common" +import { defaultStoreCartFields, defaultStoreCartRelations, } from "." +import { CartService, LineItemService, RegionService } from "../../../../services" import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" +import SalesChannelFeatureFlag from "../../../../loaders/feature-flags/sales-channels"; +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators"; /** * @oas [post] /carts @@ -33,6 +33,9 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" * region_id: * type: string * description: The id of the Region to create the Cart in. + * sales_channel_id: + * type: string + * description: [EXPERIMENTAL] The id of the Sales channel to create the Cart in. * country_code: * type: string * description: "The 2 character ISO country code to create the Cart in." @@ -63,7 +66,7 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" * $ref: "#/components/schemas/cart" */ export default async (req, res) => { - const validated = await validator(StorePostCartReq, req.body) + const validated = req.validatedBody as StorePostCartReq const reqContext = { ip: reqIp.getClientIp(req), @@ -72,18 +75,17 @@ export default async (req, res) => { const lineItemService: LineItemService = req.scope.resolve("lineItemService") const cartService: CartService = req.scope.resolve("cartService") - + const regionService: RegionService = req.scope.resolve("regionService") const entityManager: EntityManager = req.scope.resolve("manager") await entityManager.transaction(async (manager) => { - // Add a default region if no region has been specified let regionId: string - if (typeof validated.region_id !== "undefined") { regionId = validated.region_id } else { - const regionService = req.scope.resolve("regionService") - const regions = await regionService.withTransaction(manager).list({}) + const regions = await regionService + .withTransaction(manager) + .list({}) if (!regions?.length) { throw new MedusaError( @@ -95,36 +97,15 @@ export default async (req, res) => { regionId = regions[0].id } - const toCreate: { - region_id: string - context: object - customer_id?: string - email?: string - shipping_address?: Partial - } = { - region_id: regionId, + let cart = await cartService.withTransaction(manager).create({ + ...validated, context: { ...reqContext, ...validated.context, }, - } + region_id: regionId, + }) - if (req.user && req.user.customer_id) { - const customerService = req.scope.resolve("customerService") - const customer = await customerService - .withTransaction(manager) - .retrieve(req.user.customer_id) - toCreate["customer_id"] = customer.id - toCreate["email"] = customer.email - } - - if (validated.country_code) { - toCreate["shipping_address"] = { - country_code: validated.country_code.toLowerCase(), - } - } - - let cart = await cartService.withTransaction(manager).create(toCreate) if (validated.items) { await Promise.all( validated.items.map(async (i) => { @@ -160,6 +141,7 @@ export class Item { @IsInt() quantity: number } + export class StorePostCartReq { @IsOptional() @IsString() @@ -177,4 +159,10 @@ export class StorePostCartReq { @IsOptional() context?: object + + @FeatureFlagDecorators(SalesChannelFeatureFlag.key, [ + IsString(), + IsOptional(), + ]) + sales_channel_id?: string } diff --git a/packages/medusa/src/api/routes/store/carts/index.ts b/packages/medusa/src/api/routes/store/carts/index.ts index ccecd4db8f..f92727b911 100644 --- a/packages/medusa/src/api/routes/store/carts/index.ts +++ b/packages/medusa/src/api/routes/store/carts/index.ts @@ -2,7 +2,9 @@ import { Router } from "express" import "reflect-metadata" import { Cart, Order, Swap } from "../../../../" import { DeleteResponse, EmptyQueryParams } from "../../../../types/common" -import middlewares, { transformQuery } from "../../../middlewares" +import middlewares, { transformBody, transformQuery } from "../../../middlewares" +import { StorePostCartsCartReq } from "./update-cart"; +import { StorePostCartReq } from "./create-cart"; const route = Router() export default (app, container) => { @@ -11,9 +13,8 @@ export default (app, container) => { app.use("/carts", route) - const relations = [...defaultStoreCartRelations] if (featureFlagRouter.isFeatureEnabled("sales_channels")) { - relations.push("sales_channel") + defaultStoreCartRelations.push("sales_channel") } // Inject plugin routes @@ -25,7 +26,7 @@ export default (app, container) => { route.get( "/:id", transformQuery(EmptyQueryParams, { - defaultRelations: relations, + defaultRelations: defaultStoreCartRelations, defaultFields: defaultStoreCartFields, isList: false, }), @@ -35,10 +36,15 @@ export default (app, container) => { route.post( "/", middlewareService.usePreCartCreation(), + transformBody(StorePostCartReq), middlewares.wrap(require("./create-cart").default) ) - route.post("/:id", middlewares.wrap(require("./update-cart").default)) + route.post( + "/:id", + transformBody(StorePostCartsCartReq), + middlewares.wrap(require("./update-cart").default) + ) route.post( "/:id/complete", diff --git a/packages/medusa/src/api/routes/store/carts/update-cart.ts b/packages/medusa/src/api/routes/store/carts/update-cart.ts index 9459d712da..8803681436 100644 --- a/packages/medusa/src/api/routes/store/carts/update-cart.ts +++ b/packages/medusa/src/api/routes/store/carts/update-cart.ts @@ -8,11 +8,11 @@ import { } from "class-validator" import { defaultStoreCartFields, defaultStoreCartRelations } from "." import { CartService } from "../../../../services" -import { CartUpdateProps } from "../../../../types/cart" import { AddressPayload } from "../../../../types/common" -import { validator } from "../../../../utils/validator" import { IsType } from "../../../../utils/validators/is-type" import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" +import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators"; +import SalesChannelFeatureFlag from "../../../../loaders/feature-flags/sales-channels"; /** * @oas [post] /store/carts/{id} @@ -35,6 +35,9 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" * email: * type: string * description: "An email to be used on the Cart." + * sales_channel_id: + * type: string + * description: The id of the Sales channel to update the Cart with. * billing_address: * description: "The Address to be used for billing purposes." * anyOf: @@ -83,30 +86,11 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals" */ export default async (req, res) => { const { id } = req.params - - const validated = await validator(StorePostCartsCartReq, req.body) + const validated = req.validatedBody as StorePostCartsCartReq const cartService: CartService = req.scope.resolve("cartService") + await cartService.update(id, validated) - // Update the cart - const { shipping_address, billing_address, ...rest } = validated - - const cartDataToUpdate: CartUpdateProps = { ...rest } - if (typeof shipping_address === "string") { - cartDataToUpdate.shipping_address_id = shipping_address - } else { - cartDataToUpdate.shipping_address = shipping_address - } - - if (typeof billing_address === "string") { - cartDataToUpdate.billing_address_id = billing_address - } else { - cartDataToUpdate.billing_address = billing_address - } - - await cartService.update(id, cartDataToUpdate) - - // If the cart has payment sessions update these const updated = await cartService.retrieve(id, { relations: ["payment_sessions", "shipping_methods"], }) @@ -173,4 +157,10 @@ export class StorePostCartsCartReq { @IsOptional() context?: object + + @FeatureFlagDecorators(SalesChannelFeatureFlag.key, [ + IsString(), + IsOptional(), + ]) + sales_channel_id?: string } diff --git a/packages/medusa/src/services/__tests__/cart.js b/packages/medusa/src/services/__tests__/cart.js index a8b99b5d03..f134cb98aa 100644 --- a/packages/medusa/src/services/__tests__/cart.js +++ b/packages/medusa/src/services/__tests__/cart.js @@ -4,6 +4,7 @@ import { IdMap, MockManager, MockRepository } from "medusa-test-utils" import CartService from "../cart" import { InventoryServiceMock } from "../__mocks__/inventory" import { LineItemAdjustmentServiceMock } from "../__mocks__/line-item-adjustment" +import { FlagRouter } from "../../utils/flag-router"; const eventBusService = { emit: jest.fn(), @@ -46,6 +47,7 @@ describe("CartService", () => { manager: MockManager, totalsService, cartRepository, + featureFlagRouter: new FlagRouter({}), }) result = await cartService.retrieve(IdMap.getId("emptyCart")) }) @@ -76,6 +78,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -136,6 +139,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -239,6 +243,7 @@ describe("CartService", () => { customerService, regionService, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -271,6 +276,7 @@ describe("CartService", () => { customer_id: IdMap.getId("customer"), email: "email@test.com", customer: expect.any(Object), + context: expect.any(Object), }) expect(cartRepository.save).toHaveBeenCalledTimes(1) @@ -315,6 +321,7 @@ describe("CartService", () => { expect(cartRepository.create).toHaveBeenCalledTimes(1) expect(cartRepository.create).toHaveBeenCalledWith({ + context: {}, region_id: IdMap.getId("testRegion"), shipping_address: { first_name: "LeBron", @@ -400,6 +407,7 @@ describe("CartService", () => { shippingOptionService, inventoryService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -614,6 +622,7 @@ describe("CartService", () => { shippingOptionService, eventBusService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -719,6 +728,7 @@ describe("CartService", () => { cartRepository, totalsService, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -806,6 +816,7 @@ describe("CartService", () => { eventBusService, inventoryService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -887,6 +898,7 @@ describe("CartService", () => { cartRepository, eventBusService, customerService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -967,6 +979,7 @@ describe("CartService", () => { cartRepository, addressRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1028,6 +1041,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1182,6 +1196,7 @@ describe("CartService", () => { eventBusService, paymentSessionRepository: MockRepository(), priceSelectionStrategy: priceSelectionStrat, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1269,6 +1284,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1383,6 +1399,7 @@ describe("CartService", () => { cartRepository, paymentProviderService, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1573,6 +1590,7 @@ describe("CartService", () => { lineItemService, eventBusService, customShippingOptionService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(() => { @@ -1927,6 +1945,7 @@ describe("CartService", () => { discountService, eventBusService, lineItemAdjustmentService: LineItemAdjustmentServiceMock, + featureFlagRouter: new FlagRouter({}), }) beforeEach(async () => { @@ -2214,6 +2233,7 @@ describe("CartService", () => { totalsService, cartRepository, eventBusService, + featureFlagRouter: new FlagRouter({}), }) beforeEach(async () => { diff --git a/packages/medusa/src/services/cart.ts b/packages/medusa/src/services/cart.ts index 9ef89f349d..24da409701 100644 --- a/packages/medusa/src/services/cart.ts +++ b/packages/medusa/src/services/cart.ts @@ -3,14 +3,18 @@ import { MedusaError, Validator } from "medusa-core-utils" import { DeepPartial, EntityManager, In } from "typeorm" import { TransactionBaseService } from "../interfaces" import { IPriceSelectionStrategy } from "../interfaces/price-selection-strategy" -import { DiscountRuleType } from "../models" -import { Address } from "../models/address" -import { Cart } from "../models/cart" -import { CustomShippingOption } from "../models/custom-shipping-option" -import { Customer } from "../models/customer" -import { Discount } from "../models/discount" -import { LineItem } from "../models/line-item" -import { ShippingMethod } from "../models/shipping-method" +import { + DiscountRuleType, + Address, + Cart, + CustomShippingOption, + Customer, + Discount, + LineItem, + ShippingMethod, + User, + SalesChannel, +} from "../models" import { AddressRepository } from "../repositories/address" import { CartRepository } from "../repositories/cart" import { LineItemRepository } from "../repositories/line-item" @@ -39,6 +43,10 @@ import RegionService from "./region" import ShippingOptionService from "./shipping-option" import TaxProviderService from "./tax-provider" import TotalsService from "./totals" +import SalesChannelFeatureFlag from "../loaders/feature-flags/sales-channels" +import { FlagRouter } from "../utils/flag-router" +import SalesChannelService from "./sales-channel" +import StoreService from "./store" type InjectedDependencies = { manager: EntityManager @@ -48,9 +56,12 @@ type InjectedDependencies = { paymentSessionRepository: typeof PaymentSessionRepository lineItemRepository: typeof LineItemRepository eventBusService: EventBusService + salesChannelService: SalesChannelService taxProviderService: TaxProviderService paymentProviderService: PaymentProviderService productService: ProductService + storeService: StoreService + featureFlagRouter: FlagRouter productVariantService: ProductVariantService regionService: RegionService lineItemService: LineItemService @@ -90,6 +101,9 @@ class CartService extends TransactionBaseService { protected readonly eventBus_: EventBusService protected readonly productVariantService_: ProductVariantService protected readonly productService_: ProductService + protected readonly featureFlagRouter_: FlagRouter + protected readonly storeService_: StoreService + protected readonly salesChannelService_: SalesChannelService protected readonly regionService_: RegionService protected readonly lineItemService_: LineItemService protected readonly paymentProviderService_: PaymentProviderService @@ -127,6 +141,9 @@ class CartService extends TransactionBaseService { customShippingOptionService, lineItemAdjustmentService, priceSelectionStrategy, + salesChannelService, + featureFlagRouter, + storeService, }: InjectedDependencies) { // eslint-disable-next-line prefer-rest-params super(arguments[0]) @@ -153,6 +170,9 @@ class CartService extends TransactionBaseService { this.taxProviderService_ = taxProviderService this.lineItemAdjustmentService_ = lineItemAdjustmentService this.priceSelectionStrategy_ = priceSelectionStrategy + this.salesChannelService_ = salesChannelService + this.featureFlagRouter_ = featureFlagRouter + this.storeService_ = storeService } protected transformQueryForTotals_( @@ -331,15 +351,17 @@ class CartService extends TransactionBaseService { this.addressRepository_ ) - const { region_id } = data - if (!region_id) { - throw new MedusaError( - MedusaError.Types.INVALID_DATA, - `A region_id must be provided when creating a cart` - ) + const rawCart: DeepPartial = { + context: data.context ?? {}, } - const rawCart: DeepPartial = {} + if ( + this.featureFlagRouter_.isFeatureEnabled(SalesChannelFeatureFlag.key) + ) { + rawCart.sales_channel_id = ( + await this.getValidatedSalesChannel(data.sales_channel_id) + ).id + } if (data.email) { const customer = await this.createOrFetchUserFromEmail_(data.email) @@ -348,15 +370,21 @@ class CartService extends TransactionBaseService { rawCart.email = customer.email } + if (!data.region_id) { + throw new MedusaError( + MedusaError.Types.INVALID_DATA, + `A region_id must be provided when creating a cart` + ) + } + + rawCart.region_id = data.region_id const region = await this.regionService_ .withTransaction(transactionManager) - .retrieve(region_id, { + .retrieve(data.region_id, { relations: ["countries"], }) const regCountries = region.countries.map(({ iso_2 }) => iso_2) - rawCart.region_id = region.id - if (!data.shipping_address && !data.shipping_address_id) { if (region.countries.length === 1) { rawCart.shipping_address = addressRepo.create({ @@ -399,10 +427,8 @@ class CartService extends TransactionBaseService { typeof data[remainingField] !== "undefined" && remainingField !== "object" ) { - /* TODO: See how to fix the error TS2590 properly while keeping the DeepPartial type */ - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore - rawCart[remainingField] = data[remainingField] + const key = remainingField as string + rawCart[key] = data[remainingField] } } @@ -418,6 +444,32 @@ class CartService extends TransactionBaseService { ) } + protected async getValidatedSalesChannel( + salesChannelId?: string + ): Promise { + let salesChannel: SalesChannel + if (typeof salesChannelId !== "undefined") { + salesChannel = await this.salesChannelService_ + .withTransaction(this.manager_) + .retrieve(salesChannelId) + } else { + salesChannel = ( + await this.storeService_.withTransaction(this.manager_).retrieve({ + relations: ["default_sales_channel"], + }) + ).default_sales_channel + } + + if (salesChannel.is_disabled) { + throw new MedusaError( + MedusaError.Types.INVALID_DATA, + `Unable to assign the cart to a disabled Sales Channel "${salesChannel.name}"` + ) + } + + return salesChannel + } + /** * Removes a line item from the cart. * @param cartId - the id of the cart that we will remove from @@ -721,6 +773,30 @@ class CartService extends TransactionBaseService { const cartRepo = transactionManager.getCustomRepository( this.cartRepository_ ) + const relations = [ + "items", + "shipping_methods", + "shipping_address", + "billing_address", + "gift_cards", + "customer", + "region", + "payment_sessions", + "region.countries", + "discounts", + "discounts.rule", + "discounts.regions", + ] + + if ( + this.featureFlagRouter_.isFeatureEnabled( + SalesChannelFeatureFlag.key + ) && + data.sales_channel_id + ) { + relations.push("items.variant", "items.variant.product") + } + const cart = await this.retrieve(cartId, { select: [ "subtotal", @@ -729,20 +805,7 @@ class CartService extends TransactionBaseService { "discount_total", "total", ], - relations: [ - "items", - "shipping_methods", - "shipping_address", - "billing_address", - "gift_cards", - "customer", - "region", - "payment_sessions", - "region.countries", - "discounts", - "discounts.rule", - "discounts.regions", - ], + relations, }) if (data.customer_id) { @@ -764,8 +827,12 @@ class CartService extends TransactionBaseService { } if (typeof data.region_id !== "undefined") { + const shippingAddress = + typeof data.shipping_address !== "string" + ? data.shipping_address + : {} const countryCode = - (data.country_code || data.shipping_address?.country_code) ?? null + (data.country_code || shippingAddress?.country_code) ?? null await this.setRegion_(cart, data.region_id, countryCode) } @@ -784,6 +851,18 @@ class CartService extends TransactionBaseService { await this.updateShippingAddress_(cart, shippingAddress, addrRepo) } + if ( + this.featureFlagRouter_.isFeatureEnabled(SalesChannelFeatureFlag.key) + ) { + if ( + typeof data.sales_channel_id !== "undefined" && + data.sales_channel_id != cart.sales_channel_id + ) { + await this.onSalesChannelChange(cart, data.sales_channel_id) + cart.sales_channel_id = data.sales_channel_id + } + } + if (typeof data.discounts !== "undefined") { const previousDiscounts = [...cart.discounts] cart.discounts.length = 0 @@ -861,6 +940,42 @@ class CartService extends TransactionBaseService { ) } + /** + * Remove the cart line item that does not belongs to the newly assigned sales channel + * @param cart - The cart being updated + * @param newSalesChannelId - The new sales channel being assigned to the cart + * @protected + */ + protected async onSalesChannelChange( + cart: Cart, + newSalesChannelId: string + ): Promise { + await this.getValidatedSalesChannel(newSalesChannelId) + + const productIds = cart.items.map((item) => item.variant.product_id) + const productsToKeep = await this.productService_ + .withTransaction(this.manager_) + .filterProductsBySalesChannel(productIds, newSalesChannelId, { + select: ["id", "sales_channels"], + take: productIds.length, + }) + const productIdsToKeep = new Set( + productsToKeep.map((product) => product.id) + ) + const itemsToRemove = cart.items.filter((item) => { + return !productIdsToKeep.has(item.variant.product_id) + }) + + if (itemsToRemove.length) { + const results = await Promise.all( + itemsToRemove.map((item) => { + return this.removeLineItem(cart.id, item.id) + }) + ) + cart.items = results.pop()?.items ?? [] + } + } + /** * Sets the customer id of a cart * @param cart - the cart to add email to diff --git a/packages/medusa/src/services/product.ts b/packages/medusa/src/services/product.ts index 33b1b2e467..1674632c20 100644 --- a/packages/medusa/src/services/product.ts +++ b/packages/medusa/src/services/product.ts @@ -2,7 +2,13 @@ import { MedusaError } from "medusa-core-utils" import { EntityManager } from "typeorm" import { SearchService } from "." import { TransactionBaseService } from "../interfaces" -import { Product, ProductTag, ProductType, ProductVariant } from "../models" +import { + Product, + ProductTag, + ProductType, + ProductVariant, + SalesChannel, +} from "../models" import { ImageRepository } from "../repositories/image" import { FindWithoutRelationsOptions, @@ -284,6 +290,37 @@ class ProductService extends TransactionBaseService { return product.variants } + async filterProductsBySalesChannel( + productIds: string[], + salesChannelId: string, + config: FindProductConfig = { + skip: 0, + take: 50, + } + ): Promise { + const givenRelations = config.relations ?? [] + const requiredRelations = ["sales_channels"] + const relationsSet = new Set([...givenRelations, ...requiredRelations]) + + const products = await this.list( + { + id: productIds, + }, + { + ...config, + relations: [...relationsSet], + } + ) + const productSalesChannelsMap = new Map( + products.map((product) => [product.id, product.sales_channels]) + ) + return products.filter((product) => { + return productSalesChannelsMap + .get(product.id) + ?.some((sc) => sc.id === salesChannelId) + }) + } + async listTypes(): Promise { const manager = this.manager_ const productTypeRepository = manager.getCustomRepository( diff --git a/packages/medusa/src/types/cart.ts b/packages/medusa/src/types/cart.ts index 624ea605f4..099eff88f0 100644 --- a/packages/medusa/src/types/cart.ts +++ b/packages/medusa/src/types/cart.ts @@ -43,7 +43,7 @@ class Discount { } export type CartCreateProps = { - region_id: string + region_id?: string email?: string billing_address_id?: string billing_address?: Partial @@ -55,6 +55,8 @@ export type CartCreateProps = { type?: CartType context?: object metadata?: object + sales_channel_id?: string + country_code?: string } export type CartUpdateProps = { @@ -63,8 +65,8 @@ export type CartUpdateProps = { email?: string shipping_address_id?: string billing_address_id?: string - billing_address?: AddressPayload - shipping_address?: AddressPayload + billing_address?: AddressPayload | string + shipping_address?: AddressPayload | string completed_at?: Date payment_authorized_at?: Date gift_cards?: GiftCard[] @@ -72,4 +74,5 @@ export type CartUpdateProps = { customer_id?: string context?: object metadata?: Record + sales_channel_id?: string }