diff --git a/.changeset/curvy-ravens-turn.md b/.changeset/curvy-ravens-turn.md new file mode 100644 index 0000000000..47b2ec8d5a --- /dev/null +++ b/.changeset/curvy-ravens-turn.md @@ -0,0 +1,5 @@ +--- +"@medusajs/framework": patch +--- + +fix: apply additional data validator using a global middleware diff --git a/packages/core/framework/src/http/__fixtures__/mocks/index.ts b/packages/core/framework/src/http/__fixtures__/mocks/index.ts index 0d615a7682..197a3026ad 100644 --- a/packages/core/framework/src/http/__fixtures__/mocks/index.ts +++ b/packages/core/framework/src/http/__fixtures__/mocks/index.ts @@ -2,6 +2,7 @@ import { ConfigModule } from "@medusajs/types" export const customersGlobalMiddlewareMock = jest.fn() export const customersCreateMiddlewareMock = jest.fn() +export const customersCreateMiddlewareValidatorMock = jest.fn() export const storeGlobalMiddlewareMock = jest.fn() export const config = { diff --git a/packages/core/framework/src/http/__fixtures__/routers-middleware/middlewares.ts b/packages/core/framework/src/http/__fixtures__/routers-middleware/middlewares.ts index 47d3632fd4..43bb62e219 100644 --- a/packages/core/framework/src/http/__fixtures__/routers-middleware/middlewares.ts +++ b/packages/core/framework/src/http/__fixtures__/routers-middleware/middlewares.ts @@ -1,35 +1,45 @@ -import { NextFunction, raw, Request, Response } from "express" +import { raw } from "express" +import { MedusaRequest, MedusaResponse, MedusaNextFunction } from "../../types" import { customersCreateMiddlewareMock, customersGlobalMiddlewareMock, + customersCreateMiddlewareValidatorMock, storeGlobalMiddlewareMock, } from "../mocks" +import z from "zod" import { defineMiddlewares } from "../../utils/define-middlewares" const customersGlobalMiddleware = ( - req: Request, - res: Response, - next: NextFunction + req: MedusaRequest, + res: MedusaResponse, + next: MedusaNextFunction ) => { customersGlobalMiddlewareMock() next() } const customersCreateMiddleware = ( - req: Request, - res: Response, - next: NextFunction + req: MedusaRequest, + res: MedusaResponse, + next: MedusaNextFunction ) => { + if (req.additionalDataValidator) { + customersCreateMiddlewareValidatorMock() + } customersCreateMiddlewareMock() next() } -const storeGlobal = (req: Request, res: Response, next: NextFunction) => { +const storeGlobal = ( + req: MedusaRequest, + res: MedusaResponse, + next: MedusaNextFunction +) => { storeGlobalMiddlewareMock() next() } -export default defineMiddlewares([ +const middlewares = defineMiddlewares([ { matcher: "/customers", middlewares: [customersGlobalMiddleware], @@ -37,6 +47,9 @@ export default defineMiddlewares([ { method: "POST", matcher: "/customers", + additionalDataValidator: { + group_id: z.string(), + }, middlewares: [customersCreateMiddleware], }, { @@ -56,3 +69,5 @@ export default defineMiddlewares([ middlewares: [raw({ type: "application/json" })], }, ]) + +export default middlewares diff --git a/packages/core/framework/src/http/__tests__/index.spec.ts b/packages/core/framework/src/http/__tests__/index.spec.ts index ec7c835e7b..5aed855a91 100644 --- a/packages/core/framework/src/http/__tests__/index.spec.ts +++ b/packages/core/framework/src/http/__tests__/index.spec.ts @@ -2,6 +2,7 @@ import express from "express" import { resolve } from "path" import { customersCreateMiddlewareMock, + customersCreateMiddlewareValidatorMock, customersGlobalMiddlewareMock, storeGlobalMiddlewareMock, } from "../__fixtures__/mocks" @@ -203,6 +204,16 @@ describe("RoutesLoader", function () { expect(customersCreateMiddlewareMock).toHaveBeenCalled() }) + it("should assign the req.additionalDataValidator when the method and route matches", async function () { + const res = await request("POST", "/customers") + + expect(res.status).toBe(200) + expect(res.text).toBe("create customer") + expect(customersGlobalMiddlewareMock).toHaveBeenCalled() + expect(customersCreateMiddlewareMock).toHaveBeenCalled() + expect(customersCreateMiddlewareValidatorMock).toHaveBeenCalled() + }) + it("should call store global middleware on `/store/*` routes", async function () { const res = await request("POST", "/store/products/1000/sync") diff --git a/packages/core/framework/src/http/middleware-file-loader.ts b/packages/core/framework/src/http/middleware-file-loader.ts index 8f139f4b87..f0ad17d3d1 100644 --- a/packages/core/framework/src/http/middleware-file-loader.ts +++ b/packages/core/framework/src/http/middleware-file-loader.ts @@ -1,3 +1,4 @@ +import zod from "zod" import { join } from "path" import { dynamicImport, FileSystem } from "@medusajs/utils" @@ -7,6 +8,7 @@ import { type BodyParserConfigRoute, type MiddlewareDescriptor, type MedusaErrorHandlerFunction, + type AdditionalDataValidatorRoute, HTTP_METHODS, } from "./types" @@ -31,6 +33,12 @@ export class MiddlewareFileLoader { */ #middleware: MiddlewareDescriptor[] = [] + /** + * Route matchers on which a custom additional data validator is + * defined + */ + #additionalDataValidatorRoutes: AdditionalDataValidatorRoute[] = [] + /** * Route matchers on which a custom body parser config is used */ @@ -61,6 +69,7 @@ export class MiddlewareFileLoader { const result = routes.reduce<{ bodyParserConfigRoutes: BodyParserConfigRoute[] + additionalDataValidatorRoutes: AdditionalDataValidatorRoute[] middleware: MiddlewareDescriptor[] }>( (result, route) => { @@ -76,7 +85,7 @@ export class MiddlewareFileLoader { const matcher = String(route.matcher) - if ("bodyParser" in route && route.bodyParser !== undefined) { + if (route.bodyParser !== undefined) { const methods = route.methods || [...HTTP_METHODS] logger.debug( @@ -90,6 +99,21 @@ export class MiddlewareFileLoader { }) } + if (route.additionalDataValidator !== undefined) { + const methods = route.methods || [...HTTP_METHODS] + + logger.debug( + `assigning additionalData validator on matcher ${methods}:${route.matcher}` + ) + + result.additionalDataValidatorRoutes.push({ + matcher: matcher, + methods, + schema: route.additionalDataValidator, + validator: zod.object(route.additionalDataValidator).nullish(), + }) + } + if (route.middlewares) { route.middlewares.forEach((middleware) => { result.middleware.push({ @@ -103,6 +127,7 @@ export class MiddlewareFileLoader { }, { bodyParserConfigRoutes: [], + additionalDataValidatorRoutes: [], middleware: [], } ) @@ -117,6 +142,10 @@ export class MiddlewareFileLoader { this.#bodyParserConfigRoutes = this.#bodyParserConfigRoutes.concat( result.bodyParserConfigRoutes ) + this.#additionalDataValidatorRoutes = + this.#additionalDataValidatorRoutes.concat( + result.additionalDataValidatorRoutes + ) } /** @@ -157,4 +186,12 @@ export class MiddlewareFileLoader { getBodyParserConfigRoutes() { return this.#bodyParserConfigRoutes } + + /** + * Returns routes that have additional validator configured + * on them + */ + getAdditionalDataValidatorRoutes() { + return this.#additionalDataValidatorRoutes + } } diff --git a/packages/core/framework/src/http/router.ts b/packages/core/framework/src/http/router.ts index 02b57ff1b8..0f714eb068 100644 --- a/packages/core/framework/src/http/router.ts +++ b/packages/core/framework/src/http/router.ts @@ -12,6 +12,7 @@ import type { MiddlewareDescriptor, BodyParserConfigRoute, RouteHandler, + AdditionalDataValidatorRoute, } from "./types" import { RoutesLoader } from "./routes-loader" @@ -92,6 +93,8 @@ export class ApiLoader { | ErrorRequestHandler | undefined, bodyParserConfigRoutes: middlewareLoader.getBodyParserConfigRoutes(), + additionalDataValidatorRoutes: + middlewareLoader.getAdditionalDataValidatorRoutes(), } } @@ -281,6 +284,46 @@ export class ApiLoader { ) } + /** + * Applies the route middleware on a route. Encapsulates the logic + * needed to pass the middleware via the trace calls + */ + #assignAdditionalDataValidator( + namespace: string, + routesFinder: RoutesFinder + ) { + logger.debug( + `Registering assignAdditionalDataValidator middleware for prefix ${namespace}` + ) + + const additionalDataValidator = function additionalDataValidator( + req: MedusaRequest, + _: MedusaResponse, + next: MedusaNextFunction + ) { + const matchingRoute = routesFinder.find( + req.path, + req.method as MiddlewareVerb + ) + if (matchingRoute && matchingRoute.validator) { + logger.debug( + `Using validator to validate additional data on ${req.method} ${req.path}` + ) + req.additionalDataValidator = matchingRoute.validator + } + return next() + } + + this.#app.use( + namespace, + ApiLoader.traceMiddleware + ? (ApiLoader.traceMiddleware(additionalDataValidator, { + route: namespace, + }) as RequestHandler) + : (additionalDataValidator as RequestHandler) + ) + } + /** * Applies the middleware to authenticate the headers to contain * a `x-publishable-key` header @@ -305,6 +348,7 @@ export class ApiLoader { routes, routesFinder, bodyParserConfigRoutes, + additionalDataValidatorRoutes, } = await this.#loadHttpResources() /** @@ -322,6 +366,27 @@ export class ApiLoader { ) this.#applyBodyParserMiddleware("/", bodyParserRoutesFinder) + /** + * Use the routes finder to pick the additional data validator + * to be applied on the current request + */ + if (additionalDataValidatorRoutes.length) { + const additionalDataValidatorRoutesFinder = + new RoutesFinder( + new RoutesSorter(additionalDataValidatorRoutes).sort([ + "static", + "params", + "regex", + "wildcard", + "global", + ]) + ) + this.#assignAdditionalDataValidator( + "/", + additionalDataValidatorRoutesFinder + ) + } + /** * CORS and Auth setup for admin routes */ diff --git a/packages/core/framework/src/http/types.ts b/packages/core/framework/src/http/types.ts index 9480322aeb..990c81812a 100644 --- a/packages/core/framework/src/http/types.ts +++ b/packages/core/framework/src/http/types.ts @@ -1,5 +1,5 @@ import type { NextFunction, Request, Response } from "express" -import { ZodNullable, ZodObject, ZodOptional } from "zod" +import type { ZodNullable, ZodObject, ZodOptional, ZodRawShape } from "zod" import { FindConfig, @@ -60,6 +60,7 @@ export type MiddlewareRoute = { methods?: MiddlewareVerb[] matcher: string | RegExp bodyParser?: ParserConfig + additionalDataValidator?: ZodRawShape middlewares?: MiddlewareFunction[] } @@ -102,6 +103,13 @@ export type BodyParserConfigRoute = { config: ParserConfig } +export type AdditionalDataValidatorRoute = { + matcher: string + methods: MiddlewareVerb | MiddlewareVerb[] + schema: ZodRawShape + validator: ZodOptional>> +} + export type GlobalMiddlewareDescriptor = { config?: MiddlewaresConfig } diff --git a/packages/core/framework/src/http/utils/define-middlewares.ts b/packages/core/framework/src/http/utils/define-middlewares.ts index ed2a2c230d..851157c405 100644 --- a/packages/core/framework/src/http/utils/define-middlewares.ts +++ b/packages/core/framework/src/http/utils/define-middlewares.ts @@ -1,13 +1,12 @@ import { MedusaNextFunction, MedusaRequest, - MedusaRequestHandler, MedusaResponse, MiddlewaresConfig, MiddlewareVerb, ParserConfig, } from "../types" -import zod, { ZodRawShape } from "zod" +import { ZodRawShape } from "zod" /** * A helper function to configure the routes by defining custom middleware, @@ -42,23 +41,7 @@ export function defineMiddlewares< return { errorHandler, routes: routes.map((route) => { - let { middlewares, method, methods, additionalDataValidator, ...rest } = - route - const customMiddleware: MedusaRequestHandler[] = [] - - /** - * Define a custom validator when a zod schema is provided via - * "additionalDataValidator" property - */ - if (additionalDataValidator) { - customMiddleware.push((req, _, next) => { - req.additionalDataValidator = zod - .object(additionalDataValidator) - .nullish() - next() - }) - } - + let { middlewares, method, methods, ...rest } = route if (!methods) { methods = Array.isArray(method) ? method : method ? [method] : method } @@ -66,7 +49,7 @@ export function defineMiddlewares< return { ...rest, methods, - middlewares: customMiddleware.concat(middlewares || []), + middlewares: [...(middlewares ?? [])], } }), } diff --git a/packages/medusa/src/utils/__tests__/define-routes-config.spec.ts b/packages/medusa/src/utils/__tests__/define-routes-config.spec.ts index f94318c4f9..6ec22b59ff 100644 --- a/packages/medusa/src/utils/__tests__/define-routes-config.spec.ts +++ b/packages/medusa/src/utils/__tests__/define-routes-config.spec.ts @@ -1,6 +1,4 @@ -import zod from "zod" import { defineMiddlewares } from "../define-middlewares" -import { MedusaRequest, MedusaResponse } from "@medusajs/framework/http" describe("defineMiddlewares", function () { test("define custom middleware for a route", () => { @@ -20,38 +18,4 @@ describe("defineMiddlewares", function () { ], }) }) - - test("should wrap additionalDataValidator to middleware", () => { - const req = { - body: {}, - } as MedusaRequest - const res = {} as MedusaResponse - const nextFn = jest.fn() - const schema = { - brand_id: zod.string(), - } - - const config = defineMiddlewares([ - { - matcher: "/admin/products", - additionalDataValidator: schema, - }, - ]) - - expect(config).toMatchObject({ - routes: [ - { - matcher: "/admin/products", - middlewares: [expect.any(Function)], - }, - ], - }) - - config.routes?.[0].middlewares?.[0](req, res, nextFn) - expect(req.additionalDataValidator!.parse({ brand_id: "1" })).toMatchObject( - { - brand_id: "1", - } - ) - }) })