diff --git a/packages/modules/product/integration-tests/__tests__/product-module-service/product-categories.spec.ts b/packages/modules/product/integration-tests/__tests__/product-module-service/product-categories.spec.ts index 2ae79f2252..6c42faee1a 100644 --- a/packages/modules/product/integration-tests/__tests__/product-module-service/product-categories.spec.ts +++ b/packages/modules/product/integration-tests/__tests__/product-module-service/product-categories.spec.ts @@ -133,6 +133,118 @@ moduleIntegrationTestRunner({ }), ]) }) + + describe("with tree inclusion", () => { + let root, child1, child2, child1a, child2a, child2a1 + + beforeEach(async () => { + root = await service.createProductCategories({ + name: "Root", + }) + + child1 = await service.createProductCategories({ + name: "Child 1", + parent_category_id: root.id, + }) + + child1a = await service.createProductCategories({ + name: "Child 1 a", + parent_category_id: child1.id, + }) + + child2 = await service.createProductCategories({ + name: "Child 2", + parent_category_id: root.id, + }) + + child2a = await service.createProductCategories({ + name: "Child 2 a", + parent_category_id: child2.id, + is_internal: true, + }) + + child2a1 = await service.createProductCategories({ + name: "Child 2 a 1", + parent_category_id: child2a.id, + }) + }) + + it("should return all descendants of a category", async () => { + const results = await service.listProductCategories( + { + id: root.id, + include_descendants_tree: true, + is_internal: false, + }, + { + select: ["id"], + take: 1, + } + ) + + expect(results).toEqual([ + expect.objectContaining({ + id: root.id, + category_children: [ + expect.objectContaining({ + id: child1.id, + category_children: [ + expect.objectContaining({ id: child1a.id }), + ], + }), + expect.objectContaining({ + id: child2.id, + // child2a & child2a1 should not show up as we're scoping by internal + category_children: [], + }), + ], + }), + ]) + }) + + it("should return all ancestors of a category", async () => { + const results = await service.listProductCategories( + { + id: child1a.id, + include_ancestors_tree: true, + is_internal: false, + }, + { + select: ["id"], + take: 1, + } + ) + + expect(results).toEqual([ + expect.objectContaining({ + id: child1a.id, + parent_category: expect.objectContaining({ + id: child1.id, + parent_category: expect.objectContaining({ id: root.id }), + }), + }), + ]) + + const results2 = await service.listProductCategories( + { + id: child2a1.id, + include_ancestors_tree: true, + is_internal: false, + }, + { + select: ["id"], + take: 1, + } + ) + // If the where query includes scoped categories, we hide from the tree + expect(results2).toEqual([ + expect.objectContaining({ + id: child2a1.id, + parent_category: undefined, + }), + ]) + }) + }) }) describe("listAndCountCategories", () => { diff --git a/packages/modules/product/src/repositories/product-category.ts b/packages/modules/product/src/repositories/product-category.ts index 9d4101dd0d..43cde671c4 100644 --- a/packages/modules/product/src/repositories/product-category.ts +++ b/packages/modules/product/src/repositories/product-category.ts @@ -6,9 +6,9 @@ import { } from "@medusajs/types" import { DALUtils, isDefined, MedusaError } from "@medusajs/utils" import { + LoadStrategy, FilterQuery as MikroFilterQuery, FindOptions as MikroOptions, - LoadStrategy, } from "@mikro-orm/core" import { SqlEntityManager } from "@mikro-orm/postgresql" import { ProductCategory } from "@models" @@ -72,7 +72,6 @@ export class ProductCategoryRepository extends DALUtils.MikroOrmBaseTreeReposito context: Context = {} ): Promise { const manager = super.getActiveManager(context) - const findOptions_ = this.buildFindOptions(findOptions, transformOptions) const productCategories = await manager.find( @@ -136,12 +135,14 @@ export class ProductCategoryRepository extends DALUtils.MikroOrmBaseTreeReposito relationIndex = findOptions.options?.populate?.indexOf("category_children") const shouldPopulateChildren = relationIndex !== -1 + if (shouldPopulateChildren && include.descendants) { findOptions.options!.populate!.splice(relationIndex as number, 1) } const mpaths: any[] = [] const parentMpaths = new Set() + for (const cat of productCategories) { if (include.descendants) { mpaths.push({ mpath: { $like: `${cat.mpath}%` } }) @@ -158,37 +159,34 @@ export class ProductCategoryRepository extends DALUtils.MikroOrmBaseTreeReposito mpaths.push({ mpath: Array.from(parentMpaths) }) - const whereOptions = { - ...findOptions.where, - $or: mpaths, - } + const where = { ...findOptions.where, $or: mpaths } + const options = { + ...findOptions.options, + limit: undefined, + offset: 0, + } as MikroOptions - if ("parent_category_id" in whereOptions) { - delete whereOptions.parent_category_id - } + delete where.id + delete where.mpath + delete where.parent_category_id - if ("id" in whereOptions) { - delete whereOptions.id - } - - let allCategories = await manager.find( - ProductCategory, - whereOptions as MikroFilterQuery, - findOptions.options as MikroOptions + const categoriesInTree = await this.serialize( + await manager.find(ProductCategory, where, options) ) - allCategories = JSON.parse(JSON.stringify(allCategories)) + const categoriesById = new Map(categoriesInTree.map((cat) => [cat.id, cat])) - const categoriesById = new Map(allCategories.map((cat) => [cat.id, cat])) - - allCategories.forEach((cat: any) => { + categoriesInTree.forEach((cat: any) => { if (cat.parent_category_id && include.ancestors) { cat.parent_category = categoriesById.get(cat.parent_category_id) + cat.parent_category_id = categoriesById.get(cat.parent_category_id)?.[ + "id" + ] } }) const populateChildren = (category, level = 0) => { - const categories = allCategories.filter( + const categories = categoriesInTree.filter( (child) => child.parent_category_id === category.id ) @@ -231,7 +229,6 @@ export class ProductCategoryRepository extends DALUtils.MikroOrmBaseTreeReposito context: Context = {} ): Promise<[ProductCategory[], number]> { const manager = super.getActiveManager(context) - const findOptions_ = this.buildFindOptions(findOptions, transformOptions) const [productCategories, count] = await manager.findAndCount(