feat(medusa): Attach or update cart sales channel (#1873)

What

Allow to create a cart with a sales channel, otherwise the default one is attached.
Also allow to update the sales channel on an existing cart and in that case the line items that does not belongs to the new sales channel attached are removed

How

Updating existing end points and service method to integrate the new requirements

Tests

Add new integration tests

Fixes CORE-270
Fixes CORE-272

Co-authored-by: Oliver Windall Juhl <59018053+olivermrbl@users.noreply.github.com>
This commit is contained in:
Adrien de Peretti
2022-07-27 18:54:05 +02:00
committed by GitHub
parent 8dd85e5f03
commit df66378535
17 changed files with 595 additions and 169 deletions

View File

@@ -0,0 +1,7 @@
---
"@medusajs/medusa": patch
---
Adds support for:
- Attaching Sales Channel to cart as part of creation
- Updating Sales Channel on a cart and removing inapplicable line items

View File

@@ -110,18 +110,6 @@ Object {
}
`;
exports[`sales channels GET /store/cart/:id with saleschannel returns cart with sales channel for single cart 1`] = `
Object {
"created_at": Any<String>,
"deleted_at": null,
"description": "test description",
"id": Any<String>,
"is_disabled": false,
"name": "test name",
"updated_at": Any<String>,
}
`;
exports[`sales channels POST /admin/sales-channels successfully creates a sales channel 1`] = `
Object {
"sales_channel": ObjectContaining {

View File

@@ -311,11 +311,8 @@ describe("sales channels", () => {
await simpleSalesChannelFactory(dbConnection, {
name: "Default channel",
id: "test-channel",
is_default: true,
})
await dbConnection.manager.query(
`UPDATE store SET default_sales_channel_id = 'test-channel'`
)
} catch (e) {
console.error(e)
}
@@ -620,45 +617,6 @@ describe("sales channels", () => {
})
})
describe("GET /store/cart/:id with saleschannel", () => {
let cart
beforeEach(async () => {
try {
await adminSeeder(dbConnection)
cart = await simpleCartFactory(dbConnection, {
sales_channel: {
name: "test name",
description: "test description",
},
})
} catch (err) {
console.log(err)
}
})
afterEach(async () => {
const db = useDb()
await db.teardown()
})
it("returns cart with sales channel for single cart", async () => {
const api = useApi()
const response = await api.get(`/store/carts/${cart.id}`, adminReqConfig)
expect(response.data.cart.sales_channel).toBeTruthy()
expect(response.data.cart.sales_channel).toMatchSnapshot({
id: expect.any(String),
name: "test name",
description: "test description",
is_disabled: false,
created_at: expect.any(String),
updated_at: expect.any(String),
})
})
})
describe("DELETE /admin/sales-channels/:id/products/batch", () => {
let salesChannel
let product

View File

@@ -0,0 +1,13 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`sales channels GET /store/cart/:id returns cart with sales channel for single cart 1`] = `
Object {
"created_at": Any<String>,
"deleted_at": null,
"description": "test description",
"id": Any<String>,
"is_disabled": false,
"name": "test name",
"updated_at": Any<String>,
}
`;

View File

@@ -83,7 +83,7 @@ Object {
"cart": Object {
"billing_address_id": "test-billing-address",
"completed_at": null,
"context": null,
"context": Object {},
"created_at": Any<String>,
"customer_id": "test-customer",
"deleted_at": null,
@@ -259,7 +259,7 @@ Object {
"cart": Object {
"billing_address_id": "test-billing-address",
"completed_at": null,
"context": null,
"context": Object {},
"created_at": Any<String>,
"customer_id": "test-customer",
"deleted_at": null,

View File

@@ -6,12 +6,7 @@ const {
ShippingProfile,
Product,
ProductVariant,
MoneyAmount,
LineItem,
Payment,
Cart,
ShippingMethod,
Swap,
} = require("@medusajs/medusa")
const setupServer = require("../../../helpers/setup-server")

View File

@@ -0,0 +1,298 @@
const path = require("path")
const { useApi } = require("../../../helpers/use-api")
const { useDb } = require("../../../helpers/use-db")
const adminSeeder = require("../../helpers/admin-seeder")
const {
simpleSalesChannelFactory,
simpleCartFactory, simpleRegionFactory, simpleProductFactory,
} = require("../../factories")
const startServerWithEnvironment =
require("../../../helpers/start-server-with-environment").default
const adminReqConfig = {
headers: {
Authorization: "Bearer test_token",
},
}
jest.setTimeout(50000)
describe("sales channels", () => {
let medusaProcess
let dbConnection
beforeAll(async () => {
const cwd = path.resolve(path.join(__dirname, "..", ".."))
const [process, connection] = await startServerWithEnvironment({
cwd,
env: { MEDUSA_FF_SALES_CHANNELS: true },
verbose: false,
})
dbConnection = connection
medusaProcess = process
})
afterAll(async () => {
const db = useDb()
await db.shutdown()
medusaProcess.kill()
})
describe("POST /store/cart/", () => {
let salesChannel
let disabledSalesChannel
beforeEach(async () => {
try {
await adminSeeder(dbConnection)
await simpleRegionFactory(dbConnection, {
name: "Test region",
tax_rate: 0,
})
await simpleSalesChannelFactory(dbConnection, {
name: "Default Sales Channel",
description: "Created by Medusa",
is_default: true
})
disabledSalesChannel = await simpleSalesChannelFactory(dbConnection, {
name: "disabled cart sales channel",
description: "disabled cart sales channel description",
is_disabled: true,
})
salesChannel = await simpleSalesChannelFactory(dbConnection, {
name: "cart sales channel",
description: "cart sales channel description",
})
} catch (err) {
console.log(err)
}
})
afterEach(async () => {
const db = useDb()
await db.teardown()
})
it("returns a cart with the default sales channel", async () => {
const api = useApi()
const response = await api.post(`/store/carts`, {}, adminReqConfig)
expect(response.data.cart.sales_channel).toBeTruthy()
expect(response.data.cart.sales_channel).toEqual(
expect.objectContaining({
name: "Default Sales Channel",
description: "Created by Medusa",
})
)
})
it("returns a cart with the given sales channel", async () => {
const api = useApi()
const response = await api.post(`/store/carts`, { sales_channel_id: salesChannel.id }, adminReqConfig)
expect(response.data.cart.sales_channel).toBeTruthy()
expect(response.data.cart.sales_channel).toEqual(
expect.objectContaining({
name: salesChannel.name,
description: salesChannel.description,
})
)
})
it("throw if the given sales channel is disabled", async () => {
const api = useApi()
const err = await api.post(
`/store/carts`,
{ sales_channel_id: disabledSalesChannel.id },
adminReqConfig
).catch(err => err)
expect(err.response.status).toEqual(400)
expect(err.response.data.message).toBe(`Unable to assign the cart to a disabled Sales Channel "disabled cart sales channel"`)
})
})
describe("POST /store/cart/:id", () => {
let salesChannel1, salesChannel2, disabledSalesChannel
let product1, product2
let cart
beforeEach(async () => {
try {
await adminSeeder(dbConnection)
await simpleRegionFactory(dbConnection, {
name: "Test region",
currency_code: "usd",
tax_rate: 0,
})
salesChannel1 = await simpleSalesChannelFactory(dbConnection, {
name: "salesChannel1",
description: "salesChannel1",
})
salesChannel2 = await simpleSalesChannelFactory(dbConnection, {
name: "salesChannel2",
description: "salesChannel2",
})
disabledSalesChannel = await simpleSalesChannelFactory(dbConnection, {
name: "disabled cart sales channel",
description: "disabled cart sales channel description",
is_disabled: true,
})
product1 = await simpleProductFactory(
dbConnection,
{
title: "prod 1",
sales_channels: [salesChannel1],
variants: [
{
id: "test-variant",
prices: [
{
amount: 50,
currency: "usd",
variant_id: "test-variant",
},
],
},
],
},
)
product2 = await simpleProductFactory(
dbConnection,
{
sales_channels: [salesChannel2],
variants: [
{
id: "test-variant-2",
prices: [
{
amount: 100,
currency: "usd",
variant_id: "test-variant-2",
},
],
},
],
},
)
cart = await simpleCartFactory(
dbConnection,
{
sales_channel: salesChannel1,
line_items: [
{
variant_id: "test-variant",
unit_price: 50,
},
],
},
)
} catch (err) {
console.log(err)
}
})
afterEach(async () => {
const db = useDb()
await db.teardown()
})
it(
"updates a cart sales channels should remove the items that does not belongs to the new sales channel",
async () => {
const api = useApi()
let response = await api.get(`/store/carts/${cart.id}`, adminReqConfig)
expect(response.data.cart.sales_channel).toBeTruthy()
expect(response.data.cart.sales_channel).toEqual(
expect.objectContaining({
name: salesChannel1.name,
description: salesChannel1.description,
})
)
expect(response.data.cart.items.length).toBe(1)
expect(response.data.cart.items[0].variant.product).toEqual(
expect.objectContaining({
id: product1.id,
title: product1.title,
})
)
response = await api.post(`/store/carts/${cart.id}`, { sales_channel_id: salesChannel2.id }, adminReqConfig)
expect(response.data.cart.sales_channel).toBeTruthy()
expect(response.data.cart.sales_channel).toEqual(
expect.objectContaining({
name: salesChannel2.name,
description: salesChannel2.description,
})
)
expect(response.data.cart.items.length).toBe(0)
}
)
it("throw if the given sales channel is disabled", async () => {
const api = useApi()
const err = await api.post(
`/store/carts/${cart.id}`,
{ sales_channel_id: disabledSalesChannel.id },
adminReqConfig
).catch(err => err)
expect(err.response.status).toEqual(400)
expect(err.response.data.message).toBe("Unable to assign the cart to a disabled Sales Channel \"disabled cart sales channel\"")
})
})
describe("GET /store/cart/:id", () => {
let cart
beforeEach(async () => {
try {
await adminSeeder(dbConnection)
cart = await simpleCartFactory(dbConnection, {
sales_channel: {
name: "test name",
description: "test description",
},
})
} catch (err) {
console.log(err)
}
})
afterEach(async () => {
const db = useDb()
await db.teardown()
})
it("returns cart with sales channel for single cart", async () => {
const api = useApi()
const response = await api.get(`/store/carts/${cart.id}`, adminReqConfig)
expect(response.data.cart.sales_channel).toBeTruthy()
expect(response.data.cart.sales_channel).toMatchSnapshot({
id: expect.any(String),
name: "test name",
description: "test description",
is_disabled: false,
created_at: expect.any(String),
updated_at: expect.any(String),
})
})
})
})

View File

@@ -34,7 +34,7 @@ export type CartFactoryData = {
export const simpleCartFactory = async (
connection: Connection,
data: CartFactoryData = {},
seed: number
seed?: number
): Promise<Cart> => {
if (typeof seed !== "undefined") {
faker.seed(seed)

View File

@@ -8,6 +8,7 @@ export type SalesChannelFactoryData = {
is_disabled?: boolean
id?: string
products?: Product[],
is_default?: boolean
}
export const simpleSalesChannelFactory = async (
@@ -36,12 +37,19 @@ export const simpleSalesChannelFactory = async (
for (const product of data.products) {
promises.push(
manager.query(`
INSERT INTO product_sales_channel (product_id, sales_channel_id) VALUES ('${product.id}', '${salesChannel.id}');
INSERT INTO product_sales_channel (product_id, sales_channel_id)
VALUES ('${product.id}', '${salesChannel.id}');
`)
)
}
await Promise.all(promises)
}
if (data.is_default) {
await manager.query(
`UPDATE store SET default_sales_channel_id = '${salesChannel.id}'`
)
}
return salesChannel
}

View File

@@ -25,12 +25,12 @@ describe("POST /store/carts", () => {
it("calls CartService create", () => {
expect(CartServiceMock.create).toHaveBeenCalledTimes(1)
expect(CartServiceMock.create).toHaveBeenCalledWith({
region_id: IdMap.getId("testRegion"),
context: {
ip: "::ffff:127.0.0.1",
user_agent: "node-superagent/3.8.3",
clientId: "test",
},
region_id: IdMap.getId("testRegion"),
})
})

View File

@@ -11,11 +11,11 @@ import { MedusaError } from "medusa-core-utils"
import reqIp from "request-ip"
import { EntityManager } from "typeorm"
import { defaultStoreCartFields, defaultStoreCartRelations } from "."
import { CartService, LineItemService } from "../../../../services"
import { validator } from "../../../../utils/validator"
import { AddressPayload } from "../../../../types/common"
import { defaultStoreCartFields, defaultStoreCartRelations, } from "."
import { CartService, LineItemService, RegionService } from "../../../../services"
import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals"
import SalesChannelFeatureFlag from "../../../../loaders/feature-flags/sales-channels";
import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators";
/**
* @oas [post] /carts
@@ -33,6 +33,9 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals"
* region_id:
* type: string
* description: The id of the Region to create the Cart in.
* sales_channel_id:
* type: string
* description: [EXPERIMENTAL] The id of the Sales channel to create the Cart in.
* country_code:
* type: string
* description: "The 2 character ISO country code to create the Cart in."
@@ -63,7 +66,7 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals"
* $ref: "#/components/schemas/cart"
*/
export default async (req, res) => {
const validated = await validator(StorePostCartReq, req.body)
const validated = req.validatedBody as StorePostCartReq
const reqContext = {
ip: reqIp.getClientIp(req),
@@ -72,18 +75,17 @@ export default async (req, res) => {
const lineItemService: LineItemService = req.scope.resolve("lineItemService")
const cartService: CartService = req.scope.resolve("cartService")
const regionService: RegionService = req.scope.resolve("regionService")
const entityManager: EntityManager = req.scope.resolve("manager")
await entityManager.transaction(async (manager) => {
// Add a default region if no region has been specified
let regionId: string
if (typeof validated.region_id !== "undefined") {
regionId = validated.region_id
} else {
const regionService = req.scope.resolve("regionService")
const regions = await regionService.withTransaction(manager).list({})
const regions = await regionService
.withTransaction(manager)
.list({})
if (!regions?.length) {
throw new MedusaError(
@@ -95,36 +97,15 @@ export default async (req, res) => {
regionId = regions[0].id
}
const toCreate: {
region_id: string
context: object
customer_id?: string
email?: string
shipping_address?: Partial<AddressPayload>
} = {
region_id: regionId,
let cart = await cartService.withTransaction(manager).create({
...validated,
context: {
...reqContext,
...validated.context,
},
}
region_id: regionId,
})
if (req.user && req.user.customer_id) {
const customerService = req.scope.resolve("customerService")
const customer = await customerService
.withTransaction(manager)
.retrieve(req.user.customer_id)
toCreate["customer_id"] = customer.id
toCreate["email"] = customer.email
}
if (validated.country_code) {
toCreate["shipping_address"] = {
country_code: validated.country_code.toLowerCase(),
}
}
let cart = await cartService.withTransaction(manager).create(toCreate)
if (validated.items) {
await Promise.all(
validated.items.map(async (i) => {
@@ -160,6 +141,7 @@ export class Item {
@IsInt()
quantity: number
}
export class StorePostCartReq {
@IsOptional()
@IsString()
@@ -177,4 +159,10 @@ export class StorePostCartReq {
@IsOptional()
context?: object
@FeatureFlagDecorators(SalesChannelFeatureFlag.key, [
IsString(),
IsOptional(),
])
sales_channel_id?: string
}

View File

@@ -2,7 +2,9 @@ import { Router } from "express"
import "reflect-metadata"
import { Cart, Order, Swap } from "../../../../"
import { DeleteResponse, EmptyQueryParams } from "../../../../types/common"
import middlewares, { transformQuery } from "../../../middlewares"
import middlewares, { transformBody, transformQuery } from "../../../middlewares"
import { StorePostCartsCartReq } from "./update-cart";
import { StorePostCartReq } from "./create-cart";
const route = Router()
export default (app, container) => {
@@ -11,9 +13,8 @@ export default (app, container) => {
app.use("/carts", route)
const relations = [...defaultStoreCartRelations]
if (featureFlagRouter.isFeatureEnabled("sales_channels")) {
relations.push("sales_channel")
defaultStoreCartRelations.push("sales_channel")
}
// Inject plugin routes
@@ -25,7 +26,7 @@ export default (app, container) => {
route.get(
"/:id",
transformQuery(EmptyQueryParams, {
defaultRelations: relations,
defaultRelations: defaultStoreCartRelations,
defaultFields: defaultStoreCartFields,
isList: false,
}),
@@ -35,10 +36,15 @@ export default (app, container) => {
route.post(
"/",
middlewareService.usePreCartCreation(),
transformBody(StorePostCartReq),
middlewares.wrap(require("./create-cart").default)
)
route.post("/:id", middlewares.wrap(require("./update-cart").default))
route.post(
"/:id",
transformBody(StorePostCartsCartReq),
middlewares.wrap(require("./update-cart").default)
)
route.post(
"/:id/complete",

View File

@@ -8,11 +8,11 @@ import {
} from "class-validator"
import { defaultStoreCartFields, defaultStoreCartRelations } from "."
import { CartService } from "../../../../services"
import { CartUpdateProps } from "../../../../types/cart"
import { AddressPayload } from "../../../../types/common"
import { validator } from "../../../../utils/validator"
import { IsType } from "../../../../utils/validators/is-type"
import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals"
import { FeatureFlagDecorators } from "../../../../utils/feature-flag-decorators";
import SalesChannelFeatureFlag from "../../../../loaders/feature-flags/sales-channels";
/**
* @oas [post] /store/carts/{id}
@@ -35,6 +35,9 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals"
* email:
* type: string
* description: "An email to be used on the Cart."
* sales_channel_id:
* type: string
* description: The id of the Sales channel to update the Cart with.
* billing_address:
* description: "The Address to be used for billing purposes."
* anyOf:
@@ -83,30 +86,11 @@ import { decorateLineItemsWithTotals } from "./decorate-line-items-with-totals"
*/
export default async (req, res) => {
const { id } = req.params
const validated = await validator(StorePostCartsCartReq, req.body)
const validated = req.validatedBody as StorePostCartsCartReq
const cartService: CartService = req.scope.resolve("cartService")
await cartService.update(id, validated)
// Update the cart
const { shipping_address, billing_address, ...rest } = validated
const cartDataToUpdate: CartUpdateProps = { ...rest }
if (typeof shipping_address === "string") {
cartDataToUpdate.shipping_address_id = shipping_address
} else {
cartDataToUpdate.shipping_address = shipping_address
}
if (typeof billing_address === "string") {
cartDataToUpdate.billing_address_id = billing_address
} else {
cartDataToUpdate.billing_address = billing_address
}
await cartService.update(id, cartDataToUpdate)
// If the cart has payment sessions update these
const updated = await cartService.retrieve(id, {
relations: ["payment_sessions", "shipping_methods"],
})
@@ -173,4 +157,10 @@ export class StorePostCartsCartReq {
@IsOptional()
context?: object
@FeatureFlagDecorators(SalesChannelFeatureFlag.key, [
IsString(),
IsOptional(),
])
sales_channel_id?: string
}

View File

@@ -4,6 +4,7 @@ import { IdMap, MockManager, MockRepository } from "medusa-test-utils"
import CartService from "../cart"
import { InventoryServiceMock } from "../__mocks__/inventory"
import { LineItemAdjustmentServiceMock } from "../__mocks__/line-item-adjustment"
import { FlagRouter } from "../../utils/flag-router";
const eventBusService = {
emit: jest.fn(),
@@ -46,6 +47,7 @@ describe("CartService", () => {
manager: MockManager,
totalsService,
cartRepository,
featureFlagRouter: new FlagRouter({}),
})
result = await cartService.retrieve(IdMap.getId("emptyCart"))
})
@@ -76,6 +78,7 @@ describe("CartService", () => {
totalsService,
cartRepository,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -136,6 +139,7 @@ describe("CartService", () => {
totalsService,
cartRepository,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -239,6 +243,7 @@ describe("CartService", () => {
customerService,
regionService,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -271,6 +276,7 @@ describe("CartService", () => {
customer_id: IdMap.getId("customer"),
email: "email@test.com",
customer: expect.any(Object),
context: expect.any(Object),
})
expect(cartRepository.save).toHaveBeenCalledTimes(1)
@@ -315,6 +321,7 @@ describe("CartService", () => {
expect(cartRepository.create).toHaveBeenCalledTimes(1)
expect(cartRepository.create).toHaveBeenCalledWith({
context: {},
region_id: IdMap.getId("testRegion"),
shipping_address: {
first_name: "LeBron",
@@ -400,6 +407,7 @@ describe("CartService", () => {
shippingOptionService,
inventoryService,
lineItemAdjustmentService: LineItemAdjustmentServiceMock,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -614,6 +622,7 @@ describe("CartService", () => {
shippingOptionService,
eventBusService,
lineItemAdjustmentService: LineItemAdjustmentServiceMock,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -719,6 +728,7 @@ describe("CartService", () => {
cartRepository,
totalsService,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -806,6 +816,7 @@ describe("CartService", () => {
eventBusService,
inventoryService,
lineItemAdjustmentService: LineItemAdjustmentServiceMock,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -887,6 +898,7 @@ describe("CartService", () => {
cartRepository,
eventBusService,
customerService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -967,6 +979,7 @@ describe("CartService", () => {
cartRepository,
addressRepository,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -1028,6 +1041,7 @@ describe("CartService", () => {
totalsService,
cartRepository,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -1182,6 +1196,7 @@ describe("CartService", () => {
eventBusService,
paymentSessionRepository: MockRepository(),
priceSelectionStrategy: priceSelectionStrat,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -1269,6 +1284,7 @@ describe("CartService", () => {
totalsService,
cartRepository,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -1383,6 +1399,7 @@ describe("CartService", () => {
cartRepository,
paymentProviderService,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -1573,6 +1590,7 @@ describe("CartService", () => {
lineItemService,
eventBusService,
customShippingOptionService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(() => {
@@ -1927,6 +1945,7 @@ describe("CartService", () => {
discountService,
eventBusService,
lineItemAdjustmentService: LineItemAdjustmentServiceMock,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(async () => {
@@ -2214,6 +2233,7 @@ describe("CartService", () => {
totalsService,
cartRepository,
eventBusService,
featureFlagRouter: new FlagRouter({}),
})
beforeEach(async () => {

View File

@@ -3,14 +3,18 @@ import { MedusaError, Validator } from "medusa-core-utils"
import { DeepPartial, EntityManager, In } from "typeorm"
import { TransactionBaseService } from "../interfaces"
import { IPriceSelectionStrategy } from "../interfaces/price-selection-strategy"
import { DiscountRuleType } from "../models"
import { Address } from "../models/address"
import { Cart } from "../models/cart"
import { CustomShippingOption } from "../models/custom-shipping-option"
import { Customer } from "../models/customer"
import { Discount } from "../models/discount"
import { LineItem } from "../models/line-item"
import { ShippingMethod } from "../models/shipping-method"
import {
DiscountRuleType,
Address,
Cart,
CustomShippingOption,
Customer,
Discount,
LineItem,
ShippingMethod,
User,
SalesChannel,
} from "../models"
import { AddressRepository } from "../repositories/address"
import { CartRepository } from "../repositories/cart"
import { LineItemRepository } from "../repositories/line-item"
@@ -39,6 +43,10 @@ import RegionService from "./region"
import ShippingOptionService from "./shipping-option"
import TaxProviderService from "./tax-provider"
import TotalsService from "./totals"
import SalesChannelFeatureFlag from "../loaders/feature-flags/sales-channels"
import { FlagRouter } from "../utils/flag-router"
import SalesChannelService from "./sales-channel"
import StoreService from "./store"
type InjectedDependencies = {
manager: EntityManager
@@ -48,9 +56,12 @@ type InjectedDependencies = {
paymentSessionRepository: typeof PaymentSessionRepository
lineItemRepository: typeof LineItemRepository
eventBusService: EventBusService
salesChannelService: SalesChannelService
taxProviderService: TaxProviderService
paymentProviderService: PaymentProviderService
productService: ProductService
storeService: StoreService
featureFlagRouter: FlagRouter
productVariantService: ProductVariantService
regionService: RegionService
lineItemService: LineItemService
@@ -90,6 +101,9 @@ class CartService extends TransactionBaseService<CartService> {
protected readonly eventBus_: EventBusService
protected readonly productVariantService_: ProductVariantService
protected readonly productService_: ProductService
protected readonly featureFlagRouter_: FlagRouter
protected readonly storeService_: StoreService
protected readonly salesChannelService_: SalesChannelService
protected readonly regionService_: RegionService
protected readonly lineItemService_: LineItemService
protected readonly paymentProviderService_: PaymentProviderService
@@ -127,6 +141,9 @@ class CartService extends TransactionBaseService<CartService> {
customShippingOptionService,
lineItemAdjustmentService,
priceSelectionStrategy,
salesChannelService,
featureFlagRouter,
storeService,
}: InjectedDependencies) {
// eslint-disable-next-line prefer-rest-params
super(arguments[0])
@@ -153,6 +170,9 @@ class CartService extends TransactionBaseService<CartService> {
this.taxProviderService_ = taxProviderService
this.lineItemAdjustmentService_ = lineItemAdjustmentService
this.priceSelectionStrategy_ = priceSelectionStrategy
this.salesChannelService_ = salesChannelService
this.featureFlagRouter_ = featureFlagRouter
this.storeService_ = storeService
}
protected transformQueryForTotals_(
@@ -331,15 +351,17 @@ class CartService extends TransactionBaseService<CartService> {
this.addressRepository_
)
const { region_id } = data
if (!region_id) {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`A region_id must be provided when creating a cart`
)
const rawCart: DeepPartial<Cart> = {
context: data.context ?? {},
}
const rawCart: DeepPartial<Cart> = {}
if (
this.featureFlagRouter_.isFeatureEnabled(SalesChannelFeatureFlag.key)
) {
rawCart.sales_channel_id = (
await this.getValidatedSalesChannel(data.sales_channel_id)
).id
}
if (data.email) {
const customer = await this.createOrFetchUserFromEmail_(data.email)
@@ -348,15 +370,21 @@ class CartService extends TransactionBaseService<CartService> {
rawCart.email = customer.email
}
if (!data.region_id) {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`A region_id must be provided when creating a cart`
)
}
rawCart.region_id = data.region_id
const region = await this.regionService_
.withTransaction(transactionManager)
.retrieve(region_id, {
.retrieve(data.region_id, {
relations: ["countries"],
})
const regCountries = region.countries.map(({ iso_2 }) => iso_2)
rawCart.region_id = region.id
if (!data.shipping_address && !data.shipping_address_id) {
if (region.countries.length === 1) {
rawCart.shipping_address = addressRepo.create({
@@ -399,10 +427,8 @@ class CartService extends TransactionBaseService<CartService> {
typeof data[remainingField] !== "undefined" &&
remainingField !== "object"
) {
/* TODO: See how to fix the error TS2590 properly while keeping the DeepPartial type */
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
rawCart[remainingField] = data[remainingField]
const key = remainingField as string
rawCart[key] = data[remainingField]
}
}
@@ -418,6 +444,32 @@ class CartService extends TransactionBaseService<CartService> {
)
}
protected async getValidatedSalesChannel(
salesChannelId?: string
): Promise<SalesChannel | never> {
let salesChannel: SalesChannel
if (typeof salesChannelId !== "undefined") {
salesChannel = await this.salesChannelService_
.withTransaction(this.manager_)
.retrieve(salesChannelId)
} else {
salesChannel = (
await this.storeService_.withTransaction(this.manager_).retrieve({
relations: ["default_sales_channel"],
})
).default_sales_channel
}
if (salesChannel.is_disabled) {
throw new MedusaError(
MedusaError.Types.INVALID_DATA,
`Unable to assign the cart to a disabled Sales Channel "${salesChannel.name}"`
)
}
return salesChannel
}
/**
* Removes a line item from the cart.
* @param cartId - the id of the cart that we will remove from
@@ -721,6 +773,30 @@ class CartService extends TransactionBaseService<CartService> {
const cartRepo = transactionManager.getCustomRepository(
this.cartRepository_
)
const relations = [
"items",
"shipping_methods",
"shipping_address",
"billing_address",
"gift_cards",
"customer",
"region",
"payment_sessions",
"region.countries",
"discounts",
"discounts.rule",
"discounts.regions",
]
if (
this.featureFlagRouter_.isFeatureEnabled(
SalesChannelFeatureFlag.key
) &&
data.sales_channel_id
) {
relations.push("items.variant", "items.variant.product")
}
const cart = await this.retrieve(cartId, {
select: [
"subtotal",
@@ -729,20 +805,7 @@ class CartService extends TransactionBaseService<CartService> {
"discount_total",
"total",
],
relations: [
"items",
"shipping_methods",
"shipping_address",
"billing_address",
"gift_cards",
"customer",
"region",
"payment_sessions",
"region.countries",
"discounts",
"discounts.rule",
"discounts.regions",
],
relations,
})
if (data.customer_id) {
@@ -764,8 +827,12 @@ class CartService extends TransactionBaseService<CartService> {
}
if (typeof data.region_id !== "undefined") {
const shippingAddress =
typeof data.shipping_address !== "string"
? data.shipping_address
: {}
const countryCode =
(data.country_code || data.shipping_address?.country_code) ?? null
(data.country_code || shippingAddress?.country_code) ?? null
await this.setRegion_(cart, data.region_id, countryCode)
}
@@ -784,6 +851,18 @@ class CartService extends TransactionBaseService<CartService> {
await this.updateShippingAddress_(cart, shippingAddress, addrRepo)
}
if (
this.featureFlagRouter_.isFeatureEnabled(SalesChannelFeatureFlag.key)
) {
if (
typeof data.sales_channel_id !== "undefined" &&
data.sales_channel_id != cart.sales_channel_id
) {
await this.onSalesChannelChange(cart, data.sales_channel_id)
cart.sales_channel_id = data.sales_channel_id
}
}
if (typeof data.discounts !== "undefined") {
const previousDiscounts = [...cart.discounts]
cart.discounts.length = 0
@@ -861,6 +940,42 @@ class CartService extends TransactionBaseService<CartService> {
)
}
/**
* Remove the cart line item that does not belongs to the newly assigned sales channel
* @param cart - The cart being updated
* @param newSalesChannelId - The new sales channel being assigned to the cart
* @protected
*/
protected async onSalesChannelChange(
cart: Cart,
newSalesChannelId: string
): Promise<void> {
await this.getValidatedSalesChannel(newSalesChannelId)
const productIds = cart.items.map((item) => item.variant.product_id)
const productsToKeep = await this.productService_
.withTransaction(this.manager_)
.filterProductsBySalesChannel(productIds, newSalesChannelId, {
select: ["id", "sales_channels"],
take: productIds.length,
})
const productIdsToKeep = new Set<string>(
productsToKeep.map((product) => product.id)
)
const itemsToRemove = cart.items.filter((item) => {
return !productIdsToKeep.has(item.variant.product_id)
})
if (itemsToRemove.length) {
const results = await Promise.all(
itemsToRemove.map((item) => {
return this.removeLineItem(cart.id, item.id)
})
)
cart.items = results.pop()?.items ?? []
}
}
/**
* Sets the customer id of a cart
* @param cart - the cart to add email to

View File

@@ -2,7 +2,13 @@ import { MedusaError } from "medusa-core-utils"
import { EntityManager } from "typeorm"
import { SearchService } from "."
import { TransactionBaseService } from "../interfaces"
import { Product, ProductTag, ProductType, ProductVariant } from "../models"
import {
Product,
ProductTag,
ProductType,
ProductVariant,
SalesChannel,
} from "../models"
import { ImageRepository } from "../repositories/image"
import {
FindWithoutRelationsOptions,
@@ -284,6 +290,37 @@ class ProductService extends TransactionBaseService<ProductService> {
return product.variants
}
async filterProductsBySalesChannel(
productIds: string[],
salesChannelId: string,
config: FindProductConfig = {
skip: 0,
take: 50,
}
): Promise<Product[]> {
const givenRelations = config.relations ?? []
const requiredRelations = ["sales_channels"]
const relationsSet = new Set([...givenRelations, ...requiredRelations])
const products = await this.list(
{
id: productIds,
},
{
...config,
relations: [...relationsSet],
}
)
const productSalesChannelsMap = new Map<string, SalesChannel[]>(
products.map((product) => [product.id, product.sales_channels])
)
return products.filter((product) => {
return productSalesChannelsMap
.get(product.id)
?.some((sc) => sc.id === salesChannelId)
})
}
async listTypes(): Promise<ProductType[]> {
const manager = this.manager_
const productTypeRepository = manager.getCustomRepository(

View File

@@ -43,7 +43,7 @@ class Discount {
}
export type CartCreateProps = {
region_id: string
region_id?: string
email?: string
billing_address_id?: string
billing_address?: Partial<AddressPayload>
@@ -55,6 +55,8 @@ export type CartCreateProps = {
type?: CartType
context?: object
metadata?: object
sales_channel_id?: string
country_code?: string
}
export type CartUpdateProps = {
@@ -63,8 +65,8 @@ export type CartUpdateProps = {
email?: string
shipping_address_id?: string
billing_address_id?: string
billing_address?: AddressPayload
shipping_address?: AddressPayload
billing_address?: AddressPayload | string
shipping_address?: AddressPayload | string
completed_at?: Date
payment_authorized_at?: Date
gift_cards?: GiftCard[]
@@ -72,4 +74,5 @@ export type CartUpdateProps = {
customer_id?: string
context?: object
metadata?: Record<string, unknown>
sales_channel_id?: string
}