feat(medusa): Rollout index engine behind feature flag (#11431)

**What**
- Add index engine feature flag
- apply it to the `store/products` end point as well as `admin/products`
- Query builder various fixes
- search capabilities on full data of every entities. The `q` search will be applied to all involved joined table for selection/where clauses

Co-authored-by: Carlos R. L. Rodrigues <37986729+carlos-r-l-rodrigues@users.noreply.github.com>
This commit is contained in:
Adrien de Peretti
2025-02-18 14:49:57 +01:00
committed by GitHub
parent 3b69f5a105
commit 448dbcb596
27 changed files with 881 additions and 135 deletions

View File

@@ -2,6 +2,8 @@ export const schema = `
type Product @Listeners(values: ["product.created", "product.updated", "product.deleted"]) {
id: String
title: String
created_at: DateTime
deep: InternalNested
variants: [ProductVariant]
}

View File

@@ -133,7 +133,7 @@ describe("IndexModuleService syncIndexConfig", function () {
afterEach(afterEach_)
it("should full sync all entities when the config has changed", async () => {
it.only("should full sync all entities when the config has changed", async () => {
await setTimeout(1000)
const currentMetadata = await indexMetadataService.list()
@@ -148,7 +148,7 @@ describe("IndexModuleService syncIndexConfig", function () {
}),
expect.objectContaining({
entity: "Product",
fields: "id,title",
fields: "created_at,id,title",
status: "done",
}),
expect.objectContaining({

View File

@@ -199,7 +199,7 @@ describe("DataSynchronizer", () => {
filters: {
id: [testProductId],
},
fields: ["id", "title"],
fields: ["id", "created_at", "title"],
})
// Second loop fetching products
@@ -225,7 +225,7 @@ describe("DataSynchronizer", () => {
filters: {
id: [testProductId2],
},
fields: ["id", "title"],
fields: ["id", "created_at", "title"],
})
expect(ackMock).toHaveBeenNthCalledWith(1, {

View File

@@ -30,29 +30,34 @@ const dbUtils = TestDatabaseUtils.dbTestUtilFactory()
jest.setTimeout(300000)
const productId = "prod_1"
const productId2 = "prod_2"
const variantId = "var_1"
const variantId2 = "var_2"
const priceSetId = "price_set_1"
const priceId = "money_amount_1"
const linkId = "link_id_1"
const sendEvents = async (eventDataToEmit) => {
let a = 0
let productCounter = 0
let variantCounter = 0
queryMock.graph = jest.fn().mockImplementation((query) => {
const entity = query.entity
if (entity === "product") {
return {
data: {
id: a++ > 0 ? "aaaa" : productId,
id: productCounter++ > 0 ? productId2 : productId,
title: "Test Product " + productCounter,
},
}
} else if (entity === "product_variant") {
const counter = variantCounter++
return {
data: {
id: variantId,
id: counter > 0 ? variantId2 : variantId,
sku: "aaa test aaa",
product: {
id: productId,
id: counter > 0 ? productId2 : productId,
},
},
}
@@ -374,7 +379,16 @@ describe("IndexModuleService", function () {
{
name: "product.created",
data: {
id: "PRODUCTASDASDAS",
id: productId2,
},
},
{
name: "variant.created",
data: {
id: variantId2,
product: {
id: productId2,
},
},
},
{
@@ -426,14 +440,46 @@ describe("IndexModuleService", function () {
})
expect(productIndexEntries).toHaveLength(2)
expect(productIndexEntries[0].id).toEqual(productId)
expect(productIndexEntries).toEqual(
expect.arrayContaining([
expect.objectContaining({
id: productId,
data: expect.objectContaining({
id: productId,
title: expect.stringContaining("Test Product"),
}),
}),
expect.objectContaining({
id: productId2,
data: expect.objectContaining({
id: productId2,
title: expect.stringContaining("Test Product"),
}),
}),
])
)
const variantIndexEntries = indexEntries.filter((entry) => {
return entry.name === "ProductVariant"
})
expect(variantIndexEntries).toHaveLength(1)
expect(variantIndexEntries[0].id).toEqual(variantId)
expect(variantIndexEntries).toHaveLength(2)
expect(variantIndexEntries).toEqual(
expect.arrayContaining([
expect.objectContaining({
id: variantId,
data: expect.objectContaining({
id: variantId,
}),
}),
expect.objectContaining({
id: variantId2,
data: expect.objectContaining({
id: variantId2,
}),
}),
])
)
const priceSetIndexEntries = indexEntries.filter((entry) => {
return entry.name === "PriceSet"
@@ -461,7 +507,7 @@ describe("IndexModuleService", function () {
{}
)
expect(indexRelationEntries).toHaveLength(4)
expect(indexRelationEntries).toHaveLength(5)
const productVariantIndexRelationEntries = indexRelationEntries.filter(
(entry) => {

View File

@@ -414,7 +414,19 @@ describe("IndexModuleService query", function () {
},
})
// NULLS LAST (DESC = first)
expect(data).toEqual([
{
id: "prod_2",
title: "Product 2 title",
deep: {
a: 1,
obj: {
b: 15,
},
},
variants: [],
},
{
id: "prod_1",
variants: [
@@ -440,17 +452,6 @@ describe("IndexModuleService query", function () {
},
],
},
{
id: "prod_2",
title: "Product 2 title",
deep: {
a: 1,
obj: {
b: 15,
},
},
variants: [],
},
])
const { data: dataAsc } = await module.query({
@@ -469,17 +470,6 @@ describe("IndexModuleService query", function () {
})
expect(dataAsc).toEqual([
{
id: "prod_2",
title: "Product 2 title",
deep: {
a: 1,
obj: {
b: 15,
},
},
variants: [],
},
{
id: "prod_1",
variants: [
@@ -505,6 +495,17 @@ describe("IndexModuleService query", function () {
},
],
},
{
id: "prod_2",
title: "Product 2 title",
deep: {
a: 1,
obj: {
b: 15,
},
},
variants: [],
},
])
})
@@ -565,6 +566,11 @@ describe("IndexModuleService query", function () {
pagination: {
take: 100,
skip: 0,
order: {
product: {
created_at: "ASC",
},
},
},
})
@@ -596,7 +602,7 @@ describe("IndexModuleService query", function () {
product: {
variants: {
prices: {
amount: "DESC",
amount: "ASC",
},
},
},
@@ -608,14 +614,14 @@ describe("IndexModuleService query", function () {
{
id: "prod_1",
variants: [
{
id: "var_1",
sku: "aaa test aaa",
},
{
id: "var_2",
sku: "sku 123",
},
{
id: "var_1",
sku: "aaa test aaa",
},
],
},
{

View File

@@ -0,0 +1,53 @@
import { Migration } from "@mikro-orm/migrations"
export class Migration20250218132404 extends Migration {
override async up(): Promise<void> {
this.addSql(
`
ALTER TABLE index_data
ADD COLUMN document_tsv tsvector;
`
)
this.addSql(
`
UPDATE index_data
SET document_tsv = to_tsvector('simple', (
SELECT string_agg(value, ' ')
FROM jsonb_each_text(data)
));
`
)
this.addSql(
`
CREATE INDEX idx_documents_document_tsv
ON index_data
USING gin(document_tsv);
`
)
this.addSql(
`
CREATE OR REPLACE FUNCTION update_document_tsv() RETURNS trigger AS $$
BEGIN
NEW.document_tsv := to_tsvector('simple', (
SELECT string_agg(value, ' ')
FROM jsonb_each_text(NEW.data)
));
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trg_update_document_tsv
BEFORE INSERT OR UPDATE ON index_data
FOR EACH ROW
EXECUTE FUNCTION update_document_tsv();
`
)
}
override async down(): Promise<void> {
this.addSql(`DROP TRIGGER IF EXISTS trg_update_document_tsv ON index_data;`)
this.addSql(`DROP FUNCTION IF EXISTS update_document_tsv;`)
this.addSql(`DROP INDEX IF EXISTS idx_documents_document_tsv;`)
this.addSql(`ALTER TABLE index_data DROP COLUMN IF EXISTS document_tsv;`)
}
}

View File

@@ -5,6 +5,7 @@ const IndexData = model.define("IndexData", {
name: model.text().primaryKey(),
data: model.json().default({}),
staled_at: model.dateTime().nullable(),
// document_tsv: model.tsvector(), NOTE: This is not supported and it is here for reference of its counter part in the migration
})
export default IndexData

View File

@@ -178,8 +178,14 @@ export default class IndexModuleService
return this.schemaObjectRepresentation_
}
const baseSchema = `
scalar DateTime
scalar Date
scalar Time
scalar JSON
`
const [objectRepresentation, entityMap] = buildSchemaObjectRepresentation(
this.moduleOptions_.schema ?? defaultSchema
baseSchema + (this.moduleOptions_.schema ?? defaultSchema)
)
this.schemaObjectRepresentation_ = objectRepresentation

View File

@@ -13,8 +13,8 @@ import {
InjectTransactionManager,
isDefined,
MedusaContext,
promiseAll,
toMikroORMEntity,
unflattenObjectKeys,
} from "@medusajs/framework/utils"
import {
EntityManager,
@@ -250,10 +250,11 @@ export class PostgresProvider implements IndexTypes.StorageProvider {
const { take, skip, order: inputOrderBy = {} } = config.pagination ?? {}
const select = normalizeFieldsSelection(fields)
const where = flattenObjectKeys(filters)
const where = flattenObjectKeys(unflattenObjectKeys(filters))
const joinWhere = flattenObjectKeys(joinFilters)
const orderBy = flattenObjectKeys(inputOrderBy)
const inputOrderByObj = unflattenObjectKeys(inputOrderBy)
const joinWhere = flattenObjectKeys(unflattenObjectKeys(joinFilters))
const orderBy = flattenObjectKeys(inputOrderByObj)
const { manager } = sharedContext as { manager: SqlEntityManager }
let hasPagination = false
@@ -266,7 +267,10 @@ export class PostgresProvider implements IndexTypes.StorageProvider {
}
}
const requestedFields = deepMerge(deepMerge(select, filters), inputOrderBy)
const requestedFields = deepMerge(
deepMerge(select, filters),
inputOrderByObj
)
const connection = manager.getConnection()
const qb = new QueryBuilder({
@@ -288,26 +292,20 @@ export class PostgresProvider implements IndexTypes.StorageProvider {
requestedFields,
})
const [sql, sqlCount] = qb.buildQuery({
const sql = qb.buildQuery({
hasPagination,
returnIdOnly: !!keepFilteredEntities,
hasCount,
})
const promises: Promise<any>[] = []
promises.push(manager.execute(sql))
if (hasCount && sqlCount) {
promises.push(manager.execute(sqlCount))
}
let [resultSet, count] = await promiseAll(promises)
const resultSet = await manager.execute(sql)
const resultMetadata: IndexTypes.QueryFunctionReturnPagination | undefined =
hasPagination
? {
count: hasCount ? parseInt(count[0].count) : undefined,
count: hasCount
? parseInt(resultSet[0]?.count_total ?? 0)
: undefined,
skip,
take,
}
@@ -436,7 +434,7 @@ export class PostgresProvider implements IndexTypes.StorageProvider {
{
onConflictAction: "merge",
onConflictFields: ["id", "name"],
onConflictMergeFields: ["data", "staled_at"],
onConflictMergeFields: ["staled_at"],
}
)

View File

@@ -1,6 +1,10 @@
export const schemaObjectRepresentationPropertiesToOmit = [
"_schemaPropertiesMap",
"_serviceNameModuleConfigMap",
"JSON",
"DateTime",
"Date",
"Time",
]
export type Select = {

View File

@@ -21,7 +21,13 @@ export const CustomDirectives = {
export function makeSchemaExecutable(inputSchema: string) {
const { schema: cleanedSchema } = GraphQLUtils.cleanGraphQLSchema(inputSchema)
return GraphQLUtils.makeExecutableSchema({ typeDefs: cleanedSchema })
if (!cleanedSchema) {
return
}
return GraphQLUtils.makeExecutableSchema({
typeDefs: cleanedSchema,
})
}
function extractNameFromAlias(
@@ -68,9 +74,9 @@ function retrieveModuleAndAlias(entityName, moduleJoinerConfigs) {
if (moduleSchema) {
const executableSchema = makeSchemaExecutable(moduleSchema)
const entitiesMap = executableSchema.getTypeMap()
const entitiesMap = executableSchema?.getTypeMap()
if (entitiesMap[entityName]) {
if (entitiesMap?.[entityName]) {
relatedModule = moduleJoinerConfig
}
}
@@ -191,6 +197,10 @@ function retrieveLinkModuleAndAlias({
const executableSchema = makeSchemaExecutable(
foreignModuleConfig.schema
)
if (!executableSchema) {
continue
}
const entitiesMap = executableSchema.getTypeMap()
let intermediateEntities: string[] = []
@@ -704,7 +714,7 @@ export function buildSchemaObjectRepresentation(
): [IndexTypes.SchemaObjectRepresentation, Record<string, any>] {
const moduleJoinerConfigs = MedusaModule.getAllJoinerConfigs()
const augmentedSchema = CustomDirectives.Listeners.definition + schema
const executableSchema = makeSchemaExecutable(augmentedSchema)
const executableSchema = makeSchemaExecutable(augmentedSchema)!
const entitiesMap = executableSchema.getTypeMap()
const objectRepresentation = {

View File

@@ -4,14 +4,24 @@ export const defaultSchema = `
type Product @Listeners(values: ["${Modules.PRODUCT}.product.created", "${Modules.PRODUCT}.product.updated", "${Modules.PRODUCT}.product.deleted"]) {
id: String
title: String
handle: String
status: String
type_id: String
collection_id: String
is_giftcard: String
external_id: String
created_at: DateTime
updated_at: DateTime
variants: [ProductVariant]
sales_channels: [SalesChannel]
}
type ProductVariant @Listeners(values: ["${Modules.PRODUCT}.product-variant.created", "${Modules.PRODUCT}.product-variant.updated", "${Modules.PRODUCT}.product-variant.deleted"]) {
id: String
product_id: String
sku: String
prices: [Price]
}

View File

@@ -1,15 +1,15 @@
import { join } from "path"
import { CustomDirectives, makeSchemaExecutable } from "./build-config"
import { MedusaModule } from "@medusajs/framework/modules-sdk"
import {
FileSystem,
gqlSchemaToTypes as ModulesSdkGqlSchemaToTypes,
} from "@medusajs/framework/utils"
import { join } from "path"
import * as process from "process"
import { CustomDirectives, makeSchemaExecutable } from "./build-config"
export async function gqlSchemaToTypes(schema: string) {
const augmentedSchema = CustomDirectives.Listeners.definition + schema
const executableSchema = makeSchemaExecutable(augmentedSchema)
const executableSchema = makeSchemaExecutable(augmentedSchema)!
const filename = "index-service-entry-points"
const filenameWithExt = filename + ".d.ts"
const dir = join(process.cwd(), ".medusa")

View File

@@ -4,6 +4,7 @@ import {
isObject,
isPresent,
isString,
unflattenObjectKeys,
} from "@medusajs/framework/utils"
import { Knex } from "@mikro-orm/knex"
import { OrderBy, QueryFormat, QueryOptions, Select } from "@types"
@@ -22,6 +23,8 @@ export const OPERATOR_MAP = {
}
export class QueryBuilder {
#searchVectorColumnName = "document_tsv"
private readonly structure: Select
private readonly entityMap: Record<string, any>
private readonly knex: Knex
@@ -82,6 +85,7 @@ export class QueryBuilder {
private getGraphQLType(path, field) {
const entity = this.getEntity(path)?.ref?.entity!
const fieldRef = this.entityMap[entity]._fields[field]
if (!fieldRef) {
throw new Error(`Field ${field} is not indexed.`)
}
@@ -111,6 +115,7 @@ export class QueryBuilder {
Boolean: (val) => Boolean(val),
ID: (val) => String(val),
Date: (val) => new Date(val).toISOString(),
DateTime: (val) => new Date(val).toISOString(),
Time: (val) => new Date(`1970-01-01T${val}Z`).toISOString(),
}
@@ -132,6 +137,7 @@ export class QueryBuilder {
Float: "::double precision",
Boolean: "::boolean",
Date: "::timestamp",
DateTime: "::timestamp",
Time: "::time",
"": "",
}
@@ -141,6 +147,7 @@ export class QueryBuilder {
Float: "0",
Boolean: "false",
Date: "1970-01-01 00:00:00",
DateTime: "1970-01-01 00:00:00",
Time: "00:00:00",
"": "",
}
@@ -560,9 +567,10 @@ export class QueryBuilder {
hasPagination?: boolean
hasCount?: boolean
returnIdOnly?: boolean
}): [string, string | null] {
}): string {
const queryBuilder = this.knex.queryBuilder()
const selectOnlyStructure = this.selector.select
const structure = this.requestedFields
const filter = this.selector.where ?? {}
@@ -579,6 +587,16 @@ export class QueryBuilder {
const rootEntity = entity.toLowerCase()
const aliasMapping: { [path: string]: string } = {}
let hasTextSearch: boolean = false
let textSearchQuery: string | null = null
const searchQueryFilterProp = `${rootEntity}.q`
if (filter[searchQueryFilterProp]) {
hasTextSearch = true
textSearchQuery = filter[searchQueryFilterProp]
delete filter[searchQueryFilterProp]
}
const joinParts = this.buildQueryParts(
rootStructure,
"",
@@ -591,7 +609,11 @@ export class QueryBuilder {
const rootAlias = aliasMapping[rootKey]
const selectParts = !returnIdOnly
? this.buildSelectParts(rootStructure, rootKey, aliasMapping)
? this.buildSelectParts(
selectOnlyStructure[rootKey] as Select,
rootKey,
aliasMapping
)
: { [rootKey + ".id"]: `${rootAlias}.id` }
queryBuilder.select(selectParts)
@@ -604,6 +626,36 @@ export class QueryBuilder {
queryBuilder.joinRaw(joinPart)
})
let searchWhereParts: string[] = []
if (hasTextSearch) {
/**
* Build the search where parts for the query,.
* Apply the search query to the search vector column for every joined tabled except
* the pivot joined table.
*/
searchWhereParts = [
`${this.getShortAlias(aliasMapping, rootEntity)}.${
this.#searchVectorColumnName
} @@ plainto_tsquery('simple', '${textSearchQuery}')`,
...joinParts.flatMap((part) => {
const aliases = part
.split(" as ")
.flatMap((chunk) => chunk.split(" on "))
.filter(
(alias) => alias.startsWith('"t_') && !alias.includes("_ref")
)
return aliases.map(
(alias) =>
`${alias}.${
this.#searchVectorColumnName
} @@ plainto_tsquery('simple', '${textSearchQuery}')`
)
}),
]
queryBuilder.whereRaw(`(${searchWhereParts.join(" OR ")})`)
}
// WHERE clause
this.parseWhere(aliasMapping, filter, queryBuilder)
@@ -618,49 +670,60 @@ export class QueryBuilder {
const direction = orderBy[aliasPath]
queryBuilder.orderByRaw(
pgType.coalesce(`${alias}.data->>'${field}'`) + " " + direction
`(${alias}.data->>'${field}')${pgType.cast}` + " " + direction
)
}
let distinctQueryBuilder = queryBuilder.clone()
let take_ = !isNaN(+take!) ? +take! : 15
let skip_ = !isNaN(+skip!) ? +skip! : 0
let sql = ""
let cte = ""
if (hasPagination) {
const idColumn = `${this.getShortAlias(aliasMapping, rootEntity)}.id`
distinctQueryBuilder.clearSelect()
distinctQueryBuilder.select(
this.knex.raw(`DISTINCT ON (${idColumn}) ${idColumn} as "id"`)
)
distinctQueryBuilder.limit(take_)
distinctQueryBuilder.offset(skip_)
cte = this.buildCTEData({
hasCount,
searchWhereParts,
take: take_,
skip: skip_,
orderBy,
})
sql += `WITH paginated_data AS (${distinctQueryBuilder.toQuery()}),`
if (hasCount) {
queryBuilder.select(this.knex.raw("pd.count_total"))
}
queryBuilder.andWhere(
this.knex.raw(`${idColumn} IN (SELECT id FROM "paginated_data")`)
queryBuilder.joinRaw(
`JOIN paginated_data AS pd ON ${rootAlias}.id = pd.id`
)
}
sql += `${hasPagination ? " " : "WITH"} data AS (${queryBuilder.toQuery()})
SELECT *
FROM data`
let sqlCount = ""
if (hasCount) {
sqlCount = this.buildQueryCount()
}
return [sql, hasCount ? sqlCount : null]
return cte + queryBuilder.toQuery()
}
public buildQueryCount(): string {
public buildCTEData({
hasCount,
searchWhereParts = [],
skip,
take,
orderBy,
}: {
hasCount: boolean
searchWhereParts: string[]
skip?: number
take: number
orderBy: OrderBy
}): string {
const queryBuilder = this.knex.queryBuilder()
const hasWhere = isPresent(this.rawConfig?.filters)
const structure = hasWhere ? this.rawConfig?.filters! : this.requestedFields
const hasWhere = isPresent(this.rawConfig?.filters) || isPresent(orderBy)
const structure =
hasWhere && !searchWhereParts.length
? unflattenObjectKeys({
...(this.rawConfig?.filters
? unflattenObjectKeys(this.rawConfig?.filters)
: {}),
...orderBy,
})
: this.requestedFields
const rootKey = this.getStructureKeys(structure)[0]
@@ -682,9 +745,7 @@ export class QueryBuilder {
const rootAlias = aliasMapping[rootKey]
queryBuilder.select(
this.knex.raw(`COUNT(DISTINCT ${rootAlias}.id) as count`)
)
queryBuilder.select(this.knex.raw(`${rootAlias}.id as id`))
queryBuilder.from(
`cat_${rootEntity} AS ${this.getShortAlias(aliasMapping, rootEntity)}`
@@ -695,10 +756,58 @@ export class QueryBuilder {
queryBuilder.joinRaw(joinPart)
})
if (searchWhereParts.length) {
queryBuilder.whereRaw(`(${searchWhereParts.join(" OR ")})`)
}
this.parseWhere(aliasMapping, this.selector.where!, queryBuilder)
}
return queryBuilder.toQuery()
// ORDER BY clause
const orderAliases: string[] = []
for (const aliasPath in orderBy) {
const path = aliasPath.split(".")
const field = path.pop()
const attr = path.join(".")
const pgType = this.getPostgresCastType(attr, [field])
const alias = aliasMapping[attr]
const direction = orderBy[aliasPath]
const orderAlias = `"${alias}.data->>'${field}'"`
orderAliases.push(orderAlias + " " + direction)
// transform the order by clause to a select MIN/MAX
queryBuilder.select(
direction === "ASC"
? this.knex.raw(
`MIN((${alias}.data->>'${field}')${pgType.cast}) as ${orderAlias}`
)
: this.knex.raw(
`MAX((${alias}.data->>'${field}')${pgType.cast}) as ${orderAlias}`
)
)
}
queryBuilder.groupByRaw(`${rootAlias}.id`)
const countSubQuery = hasCount
? `, (SELECT count(id) FROM data_select) as count_total`
: ""
return `
WITH data_select AS (
${queryBuilder.toQuery()}
),
paginated_data AS (
SELECT id ${countSubQuery}
FROM data_select
${orderAliases.length ? "ORDER BY " + orderAliases.join(", ") : ""}
LIMIT ${take}
${skip ? `OFFSET ${skip}` : ""}
)
`
}
// NOTE: We are keeping the bellow code for now as reference to alternative implementation for us. DO NOT REMOVE