diff --git a/packages/orm/package.json b/packages/orm/package.json index 170062927..ad6f1c57e 100644 --- a/packages/orm/package.json +++ b/packages/orm/package.json @@ -95,6 +95,7 @@ "@paralleldrive/cuid2": "^2.2.2", "@zenstackhq/common-helpers": "workspace:*", "@zenstackhq/schema": "workspace:*", + "@zenstackhq/zod": "workspace:*", "cuid": "^3.0.0", "decimal.js": "catalog:", "json-stable-stringify": "^1.3.0", diff --git a/packages/orm/src/client/client-impl.ts b/packages/orm/src/client/client-impl.ts index 8eec17a13..1aa289aa4 100644 --- a/packages/orm/src/client/client-impl.ts +++ b/packages/orm/src/client/client-impl.ts @@ -133,6 +133,10 @@ export class ClientImpl { return this.kyselyRaw; } + get $zod() { + return this.inputValidator.zodFactory; + } + get isTransaction() { return this.kysely.isTransaction; } diff --git a/packages/orm/src/client/constants.ts b/packages/orm/src/client/constants.ts index bf62faff6..a945b7da2 100644 --- a/packages/orm/src/client/constants.ts +++ b/packages/orm/src/client/constants.ts @@ -1,7 +1,7 @@ -/** - * The comment prefix for annotation generated Kysely queries with context information. - */ -export const CONTEXT_COMMENT_PREFIX = '-- $$context:'; +// /** +// * The comment prefix for annotation generated Kysely queries with context information. +// */ +// export const CONTEXT_COMMENT_PREFIX = '-- $$context:'; /** * The types of fields that are numeric. diff --git a/packages/orm/src/client/contract.ts b/packages/orm/src/client/contract.ts index 9b038722f..c6b772aa4 100644 --- a/packages/orm/src/client/contract.ts +++ b/packages/orm/src/client/contract.ts @@ -39,45 +39,15 @@ import type { UpdateManyArgs, UpsertArgs, } from './crud-types'; -import type { - CoreCreateOperations, - CoreCrudOperations, - CoreDeleteOperations, - CoreReadOperations, - CoreUpdateOperations, -} from './crud/operations/base'; import type { ClientOptions, QueryOptions } from './options'; import type { ExtClientMembersBase, ExtQueryArgsBase, RuntimePlugin } from './plugin'; import type { ZenStackPromise } from './promise'; import type { ToKysely } from './query-builder'; import type { GetSlicedModels, GetSlicedOperations, GetSlicedProcedures } from './type-utils'; +import type { ZodSchemaFactory } from './zod/factory'; type TransactionUnsupportedMethods = (typeof TRANSACTION_UNSUPPORTED_METHODS)[number]; -/** - * Extracts extended query args for a specific operation. - */ -type ExtractExtQueryArgs = (Operation extends keyof ExtQueryArgs - ? ExtQueryArgs[Operation] - : {}) & - ('$create' extends keyof ExtQueryArgs - ? Operation extends CoreCreateOperations - ? ExtQueryArgs['$create'] - : {} - : {}) & - ('$read' extends keyof ExtQueryArgs ? (Operation extends CoreReadOperations ? ExtQueryArgs['$read'] : {}) : {}) & - ('$update' extends keyof ExtQueryArgs - ? Operation extends CoreUpdateOperations - ? ExtQueryArgs['$update'] - : {} - : {}) & - ('$delete' extends keyof ExtQueryArgs - ? Operation extends CoreDeleteOperations - ? ExtQueryArgs['$delete'] - : {} - : {}) & - ('$all' extends keyof ExtQueryArgs ? ExtQueryArgs['$all'] : {}); - /** * Transaction isolation levels. */ @@ -232,6 +202,11 @@ export type ClientContract< */ $disconnect(): Promise; + /** + * Factory for creating zod schemas to validate query args. + */ + get $zod(): ZodSchemaFactory; + /** * Pushes the schema to the database. For testing purposes only. * @private @@ -317,7 +292,7 @@ export type AllModelOperations< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions, - ExtQueryArgs, + ExtQueryArgs extends ExtQueryArgsBase, > = CommonModelOperations & // provider-specific operations (Schema['provider']['type'] extends 'mysql' @@ -341,15 +316,8 @@ export type AllModelOperations< * }); * ``` */ - createManyAndReturn< - T extends CreateManyAndReturnArgs & - ExtractExtQueryArgs, - >( - args?: SelectSubset< - T, - CreateManyAndReturnArgs & - ExtractExtQueryArgs - >, + createManyAndReturn>( + args?: SelectSubset>, ): ZenStackPromise[]>; /** @@ -374,15 +342,8 @@ export type AllModelOperations< * }); * ``` */ - updateManyAndReturn< - T extends UpdateManyAndReturnArgs & - ExtractExtQueryArgs, - >( - args: Subset< - T, - UpdateManyAndReturnArgs & - ExtractExtQueryArgs - >, + updateManyAndReturn>( + args: Subset>, ): ZenStackPromise[]>; }); @@ -390,7 +351,7 @@ type CommonModelOperations< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions, - ExtQueryArgs, + ExtQueryArgs extends ExtQueryArgsBase, > = { /** * Returns a list of entities. @@ -473,8 +434,8 @@ type CommonModelOperations< * }); // result: `{ _count: { posts: number } }` * ``` */ - findMany & ExtractExtQueryArgs>( - args?: SelectSubset & ExtractExtQueryArgs>, + findMany>( + args?: SelectSubset>, ): ZenStackPromise[]>; /** @@ -483,8 +444,8 @@ type CommonModelOperations< * @returns a single entity or null if not found * @see {@link findMany} */ - findUnique & ExtractExtQueryArgs>( - args: SelectSubset & ExtractExtQueryArgs>, + findUnique>( + args: SelectSubset>, ): ZenStackPromise | null>; /** @@ -493,10 +454,8 @@ type CommonModelOperations< * @returns a single entity * @see {@link findMany} */ - findUniqueOrThrow< - T extends FindUniqueArgs & ExtractExtQueryArgs, - >( - args: SelectSubset & ExtractExtQueryArgs>, + findUniqueOrThrow>( + args: SelectSubset>, ): ZenStackPromise>; /** @@ -505,8 +464,8 @@ type CommonModelOperations< * @returns a single entity or null if not found * @see {@link findMany} */ - findFirst & ExtractExtQueryArgs>( - args?: SelectSubset & ExtractExtQueryArgs>, + findFirst>( + args?: SelectSubset>, ): ZenStackPromise | null>; /** @@ -515,8 +474,8 @@ type CommonModelOperations< * @returns a single entity * @see {@link findMany} */ - findFirstOrThrow & ExtractExtQueryArgs>( - args?: SelectSubset & ExtractExtQueryArgs>, + findFirstOrThrow>( + args?: SelectSubset>, ): ZenStackPromise>; /** @@ -571,8 +530,8 @@ type CommonModelOperations< * }); * ``` */ - create & ExtractExtQueryArgs>( - args: SelectSubset & ExtractExtQueryArgs>, + create>( + args: SelectSubset>, ): ZenStackPromise>; /** @@ -600,8 +559,8 @@ type CommonModelOperations< * }); * ``` */ - createMany & ExtractExtQueryArgs>( - args?: SelectSubset & ExtractExtQueryArgs>, + createMany>( + args?: SelectSubset>, ): ZenStackPromise; /** @@ -721,8 +680,8 @@ type CommonModelOperations< * }); * ``` */ - update & ExtractExtQueryArgs>( - args: SelectSubset & ExtractExtQueryArgs>, + update>( + args: SelectSubset>, ): ZenStackPromise>; /** @@ -745,8 +704,8 @@ type CommonModelOperations< * limit: 10 * }); */ - updateMany & ExtractExtQueryArgs>( - args: Subset & ExtractExtQueryArgs>, + updateMany>( + args: Subset>, ): ZenStackPromise; /** @@ -769,8 +728,8 @@ type CommonModelOperations< * }); * ``` */ - upsert & ExtractExtQueryArgs>( - args: SelectSubset & ExtractExtQueryArgs>, + upsert>( + args: SelectSubset>, ): ZenStackPromise>; /** @@ -792,8 +751,8 @@ type CommonModelOperations< * }); // result: `{ id: string; email: string }` * ``` */ - delete & ExtractExtQueryArgs>( - args: SelectSubset & ExtractExtQueryArgs>, + delete>( + args: SelectSubset>, ): ZenStackPromise>; /** @@ -815,8 +774,8 @@ type CommonModelOperations< * }); * ``` */ - deleteMany & ExtractExtQueryArgs>( - args?: Subset & ExtractExtQueryArgs>, + deleteMany>( + args?: Subset>, ): ZenStackPromise; /** @@ -837,8 +796,8 @@ type CommonModelOperations< * select: { _all: true, email: true } * }); // result: `{ _all: number, email: number }` */ - count & ExtractExtQueryArgs>( - args?: Subset & ExtractExtQueryArgs>, + count>( + args?: Subset>, ): ZenStackPromise>>; /** @@ -858,8 +817,8 @@ type CommonModelOperations< * _max: { age: true } * }); // result: `{ _count: number, _avg: { age: number }, ... }` */ - aggregate & ExtractExtQueryArgs>( - args: Subset & ExtractExtQueryArgs>, + aggregate>( + args: Subset>, ): ZenStackPromise>>; /** @@ -895,8 +854,8 @@ type CommonModelOperations< * having: { country: 'US', age: { _avg: { gte: 18 } } } * }); */ - groupBy & ExtractExtQueryArgs>( - args: Subset & ExtractExtQueryArgs>, + groupBy>( + args: Subset>, ): ZenStackPromise>>; /** @@ -916,8 +875,8 @@ type CommonModelOperations< * where: { posts: { some: { published: true } } }, * }); // result: `boolean` */ - exists & ExtractExtQueryArgs>( - args?: Subset & ExtractExtQueryArgs>, + exists>( + args?: Subset>, ): ZenStackPromise; }; @@ -927,7 +886,7 @@ export type ModelOperations< Schema extends SchemaDef, Model extends GetModels, Options extends ClientOptions = ClientOptions, - ExtQueryArgs = {}, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = SliceOperations, Schema, Model, Options>; //#endregion diff --git a/packages/orm/src/client/crud-types.ts b/packages/orm/src/client/crud-types.ts index 5b050f56a..329c55b8e 100644 --- a/packages/orm/src/client/crud-types.ts +++ b/packages/orm/src/client/crud-types.ts @@ -50,7 +50,15 @@ import type { XOR, } from '../utils/type-utils'; import type { ClientContract } from './contract'; +import type { + CoreCreateOperations, + CoreCrudOperations, + CoreDeleteOperations, + CoreReadOperations, + CoreUpdateOperations, +} from './crud/operations/base'; import type { FilterKind, QueryOptions } from './options'; +import type { ExtQueryArgsBase } from './plugin'; import type { ToKyselySchema } from './query-builder'; import type { GetSlicedFilterKindsForField, GetSlicedModels } from './type-utils'; @@ -1213,27 +1221,32 @@ export type FindManyArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, -> = FindArgs; + ExtQueryArgs extends ExtQueryArgsBase = {}, +> = FindArgs & ExtractExtQueryArgs; export type FindFirstArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, -> = FindArgs; + ExtQueryArgs extends ExtQueryArgsBase = {}, +> = FindArgs & ExtractExtQueryArgs; export type ExistsArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, -> = FilterArgs; + ExtQueryArgs extends ExtQueryArgsBase = {}, +> = FilterArgs & ExtractExtQueryArgs; export type FindUniqueArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = { where: WhereUniqueInput; -} & SelectIncludeOmit; +} & SelectIncludeOmit & + ExtractExtQueryArgs; //#endregion @@ -1243,17 +1256,27 @@ export type CreateArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = { data: CreateInput; -} & SelectIncludeOmit; +} & SelectIncludeOmit & + ExtractExtQueryArgs; -export type CreateManyArgs> = CreateManyInput; +export type CreateManyArgs< + Schema extends SchemaDef, + Model extends GetModels, + _Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, +> = CreateManyInput & ExtractExtQueryArgs; export type CreateManyAndReturnArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, -> = CreateManyInput & SelectIncludeOmit; + ExtQueryArgs extends ExtQueryArgsBase = {}, +> = CreateManyInput & + SelectIncludeOmit & + ExtractExtQueryArgs; type OptionalWrap, T extends object> = Optional< T, @@ -1460,6 +1483,7 @@ export type UpdateArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = { /** * The data to update the record with. @@ -1470,19 +1494,24 @@ export type UpdateArgs< * The unique filter to find the record to update. */ where: WhereUniqueInput; -} & SelectIncludeOmit; +} & SelectIncludeOmit & + ExtractExtQueryArgs; export type UpdateManyArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, -> = UpdateManyPayload; + ExtQueryArgs extends ExtQueryArgsBase = {}, +> = UpdateManyPayload & ExtractExtQueryArgs; export type UpdateManyAndReturnArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, -> = UpdateManyPayload & SelectIncludeOmit; + ExtQueryArgs extends ExtQueryArgsBase = {}, +> = UpdateManyPayload & + SelectIncludeOmit & + ExtractExtQueryArgs; type UpdateManyPayload< Schema extends SchemaDef, @@ -1510,6 +1539,7 @@ export type UpsertArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = { /** * The data to create the record if it doesn't exist. @@ -1525,7 +1555,8 @@ export type UpsertArgs< * The unique filter to find the record to update. */ where: WhereUniqueInput; -} & SelectIncludeOmit; +} & SelectIncludeOmit & + ExtractExtQueryArgs; type UpdateScalarInput< Schema extends SchemaDef, @@ -1745,17 +1776,20 @@ export type DeleteArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = { /** * The unique filter to find the record to delete. */ where: WhereUniqueInput; -} & SelectIncludeOmit; +} & SelectIncludeOmit & + ExtractExtQueryArgs; export type DeleteManyArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = { /** * Filter to select records to delete. @@ -1766,7 +1800,7 @@ export type DeleteManyArgs< * Limits the number of records to delete. */ limit?: number; -}; +} & ExtractExtQueryArgs; // #endregion @@ -1776,12 +1810,13 @@ export type CountArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = Omit, 'select' | 'include' | 'distinct' | 'omit'> & { /** * Selects fields to count */ select?: CountAggregateInput | true; -}; +} & ExtractExtQueryArgs; type CountAggregateInput> = { [Key in NonRelationFields]?: true; @@ -1805,6 +1840,7 @@ export type AggregateArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = { /** * Filter conditions @@ -1851,7 +1887,8 @@ export type AggregateArgs< * Performs sum value aggregation. */ _sum?: SumAvgInput; - }); + }) & + ExtractExtQueryArgs; type NumericFields> = keyof { [Key in GetModelFields as GetModelFieldType extends @@ -1942,6 +1979,7 @@ export type GroupByArgs< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, > = { /** * Filter conditions @@ -1999,7 +2037,8 @@ export type GroupByArgs< * Performs sum value aggregation. */ _sum?: SumAvgInput; - }); + }) & + ExtractExtQueryArgs; export type GroupByResult< Schema extends SchemaDef, @@ -2346,4 +2385,28 @@ type ProviderSupportsDistinct = Schema['provider']['ty ? true : false; +/** + * Extracts extended query args for a specific operation. + */ +type ExtractExtQueryArgs = (Operation extends keyof ExtQueryArgs + ? ExtQueryArgs[Operation] + : {}) & + ('$create' extends keyof ExtQueryArgs + ? Operation extends CoreCreateOperations + ? ExtQueryArgs['$create'] + : {} + : {}) & + ('$read' extends keyof ExtQueryArgs ? (Operation extends CoreReadOperations ? ExtQueryArgs['$read'] : {}) : {}) & + ('$update' extends keyof ExtQueryArgs + ? Operation extends CoreUpdateOperations + ? ExtQueryArgs['$update'] + : {} + : {}) & + ('$delete' extends keyof ExtQueryArgs + ? Operation extends CoreDeleteOperations + ? ExtQueryArgs['$delete'] + : {} + : {}) & + ('$all' extends keyof ExtQueryArgs ? ExtQueryArgs['$all'] : {}); + // #endregion diff --git a/packages/orm/src/client/crud/operations/find.ts b/packages/orm/src/client/crud/operations/find.ts index 7bf56b8f5..197c1c643 100644 --- a/packages/orm/src/client/crud/operations/find.ts +++ b/packages/orm/src/client/crud/operations/find.ts @@ -10,7 +10,11 @@ export class FindOperationHandler extends BaseOperatio // parse args let parsedArgs = validateArgs - ? this.inputValidator.validateFindArgs(this.model, normalizedArgs, operation) + ? this.inputValidator.validateFindArgs( + this.model, + normalizedArgs, + operation as 'findFirst' | 'findUnique' | 'findMany', + ) : (normalizedArgs as any); if (findOne) { diff --git a/packages/orm/src/client/crud/validator/index.ts b/packages/orm/src/client/crud/validator/index.ts index 88bf0cfd1..cfd78eece 100644 --- a/packages/orm/src/client/crud/validator/index.ts +++ b/packages/orm/src/client/crud/validator/index.ts @@ -1,2254 +1 @@ -import { enumerate, invariant, lowerCaseFirst } from '@zenstackhq/common-helpers'; -import Decimal from 'decimal.js'; -import { match, P } from 'ts-pattern'; -import { z, ZodObject, ZodType } from 'zod'; -import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types'; -import { - type AttributeApplication, - type BuiltinType, - type FieldDef, - type GetModels, - type ProcedureDef, - type SchemaDef, -} from '../../../schema'; -import { extractFields } from '../../../utils/object-utils'; -import { formatError } from '../../../utils/zod-utils'; -import { AggregateOperators, FILTER_PROPERTY_TO_KIND, LOGICAL_COMBINATORS, NUMERIC_FIELD_TYPES } from '../../constants'; -import type { ClientContract } from '../../contract'; -import { - type AggregateArgs, - type CountArgs, - type CreateArgs, - type CreateManyAndReturnArgs, - type CreateManyArgs, - type DeleteArgs, - type DeleteManyArgs, - type ExistsArgs, - type FindArgs, - type GroupByArgs, - type UpdateArgs, - type UpdateManyAndReturnArgs, - type UpdateManyArgs, - type UpsertArgs, -} from '../../crud-types'; -import { createInternalError, createInvalidInputError } from '../../errors'; -import type { AnyPlugin } from '../../plugin'; -import { - fieldHasDefaultValue, - getDiscriminatorField, - getEnum, - getTypeDef, - getUniqueFields, - isEnum, - isTypeDef, - requireField, - requireModel, -} from '../../query-utils'; -import { - CoreCreateOperations, - CoreDeleteOperations, - CoreReadOperations, - CoreUpdateOperations, - type CoreCrudOperations, -} from '../operations/base'; -import { cache } from './cache-decorator'; -import { - addBigIntValidation, - addCustomValidation, - addDecimalValidation, - addListValidation, - addNumberValidation, - addStringValidation, -} from './utils'; - -type GetSchemaFunc = (model: GetModels) => ZodType; - -/** - * Minimal field information needed for filter schema generation. - */ -type FieldInfo = { - name: string; - type: string; - optional?: boolean; - array?: boolean; -}; - -export class InputValidator { - private readonly schemaCache = new Map(); - private readonly allFilterKinds = [...new Set(Object.values(FILTER_PROPERTY_TO_KIND))]; - - constructor(private readonly client: ClientContract) {} - - private get schema() { - return this.client.$schema; - } - - private get options() { - return this.client.$options; - } - - private get extraValidationsEnabled() { - return this.client.$options.validateInput !== false; - } - - // #region Entry points - - validateFindArgs( - model: GetModels, - args: unknown, - operation: CoreCrudOperations, - ): FindArgs, any, true> | undefined { - return this.validate, any, true> | undefined>( - model, - operation, - (model) => this.makeFindSchema(model, operation), - args, - ); - } - - validateExistsArgs( - model: GetModels, - args: unknown, - ): ExistsArgs, any> | undefined { - return this.validate, any> | undefined>( - model, - 'exists', - (model) => this.makeExistsSchema(model), - args, - ); - } - - validateCreateArgs(model: GetModels, args: unknown): CreateArgs, any> { - return this.validate, any>>( - model, - 'create', - (model) => this.makeCreateSchema(model), - args, - ); - } - - validateCreateManyArgs(model: GetModels, args: unknown): CreateManyArgs> { - return this.validate>>( - model, - 'createMany', - (model) => this.makeCreateManySchema(model), - args, - ); - } - - validateCreateManyAndReturnArgs( - model: GetModels, - args: unknown, - ): CreateManyAndReturnArgs, any> | undefined { - return this.validate, any> | undefined>( - model, - 'createManyAndReturn', - (model) => this.makeCreateManyAndReturnSchema(model), - args, - ); - } - - validateUpdateArgs(model: GetModels, args: unknown): UpdateArgs, any> { - return this.validate, any>>( - model, - 'update', - (model) => this.makeUpdateSchema(model), - args, - ); - } - - validateUpdateManyArgs(model: GetModels, args: unknown): UpdateManyArgs, any> { - return this.validate, any>>( - model, - 'updateMany', - (model) => this.makeUpdateManySchema(model), - args, - ); - } - - validateUpdateManyAndReturnArgs( - model: GetModels, - args: unknown, - ): UpdateManyAndReturnArgs, any> { - return this.validate, any>>( - model, - 'updateManyAndReturn', - (model) => this.makeUpdateManyAndReturnSchema(model), - args, - ); - } - - validateUpsertArgs(model: GetModels, args: unknown): UpsertArgs, any> { - return this.validate, any>>( - model, - 'upsert', - (model) => this.makeUpsertSchema(model), - args, - ); - } - - validateDeleteArgs(model: GetModels, args: unknown): DeleteArgs, any> { - return this.validate, any>>( - model, - 'delete', - (model) => this.makeDeleteSchema(model), - args, - ); - } - - validateDeleteManyArgs( - model: GetModels, - args: unknown, - ): DeleteManyArgs, any> | undefined { - return this.validate, any> | undefined>( - model, - 'deleteMany', - (model) => this.makeDeleteManySchema(model), - args, - ); - } - - validateCountArgs(model: GetModels, args: unknown): CountArgs, any> | undefined { - return this.validate, any> | undefined>( - model, - 'count', - (model) => this.makeCountSchema(model), - args, - ); - } - - validateAggregateArgs(model: GetModels, args: unknown): AggregateArgs, any> { - return this.validate, any>>( - model, - 'aggregate', - (model) => this.makeAggregateSchema(model), - args, - ); - } - - validateGroupByArgs(model: GetModels, args: unknown): GroupByArgs, any> { - return this.validate, any>>( - model, - 'groupBy', - (model) => this.makeGroupBySchema(model), - args, - ); - } - - // TODO: turn it into a Zod schema and cache - validateProcedureInput(proc: string, input: unknown): unknown { - const procDef = (this.schema.procedures ?? {})[proc] as ProcedureDef | undefined; - invariant(procDef, `Procedure "${proc}" not found in schema`); - - const params = Object.values(procDef.params ?? {}); - - // For procedures where every parameter is optional, allow omitting the input entirely. - if (typeof input === 'undefined') { - if (params.length === 0) { - return undefined; - } - if (params.every((p) => p.optional)) { - return undefined; - } - throw createInvalidInputError('Missing procedure arguments', `$procs.${proc}`); - } - - if (typeof input !== 'object' || input === null || Array.isArray(input)) { - throw createInvalidInputError('Procedure input must be an object', `$procs.${proc}`); - } - - const envelope = input as Record; - const argsPayload = Object.prototype.hasOwnProperty.call(envelope, 'args') ? (envelope as any).args : undefined; - - if (params.length === 0) { - if (typeof argsPayload === 'undefined') { - return input; - } - if (!argsPayload || typeof argsPayload !== 'object' || Array.isArray(argsPayload)) { - throw createInvalidInputError('Procedure `args` must be an object', `$procs.${proc}`); - } - if (Object.keys(argsPayload as any).length === 0) { - return input; - } - throw createInvalidInputError('Procedure does not accept arguments', `$procs.${proc}`); - } - - if (typeof argsPayload === 'undefined') { - if (params.every((p) => p.optional)) { - return input; - } - throw createInvalidInputError('Missing procedure arguments', `$procs.${proc}`); - } - - if (!argsPayload || typeof argsPayload !== 'object' || Array.isArray(argsPayload)) { - throw createInvalidInputError('Procedure `args` must be an object', `$procs.${proc}`); - } - - const obj = argsPayload as Record; - - for (const param of params) { - const value = (obj as any)[param.name]; - - if (!Object.prototype.hasOwnProperty.call(obj, param.name)) { - if (param.optional) { - continue; - } - throw createInvalidInputError(`Missing procedure argument: ${param.name}`, `$procs.${proc}`); - } - - if (typeof value === 'undefined') { - if (param.optional) { - continue; - } - throw createInvalidInputError( - `Invalid procedure argument: ${param.name} is required`, - `$procs.${proc}`, - ); - } - - const schema = this.makeProcedureParamSchema(param); - const parsed = schema.safeParse(value); - if (!parsed.success) { - throw createInvalidInputError( - `Invalid procedure argument: ${param.name}: ${formatError(parsed.error)}`, - `$procs.${proc}`, - ); - } - } - - return input; - } - - // #endregion - - // #region Validation helpers - - private validate(model: GetModels, operation: string, getSchema: GetSchemaFunc, args: unknown) { - const schema = getSchema(model); - const { error, data } = schema.safeParse(args); - if (error) { - throw createInvalidInputError( - `Invalid ${operation} args for model "${model}": ${formatError(error)}`, - model, - { - cause: error, - }, - ); - } - return data as T; - } - - private mergePluginArgsSchema(schema: ZodObject, operation: CoreCrudOperations) { - let result = schema; - for (const plugin of this.options.plugins ?? []) { - if (plugin.queryArgs) { - const pluginSchema = this.getPluginExtQueryArgsSchema(plugin, operation); - if (pluginSchema) { - result = result.extend(pluginSchema.shape); - } - } - } - return result.strict(); - } - - private getPluginExtQueryArgsSchema(plugin: AnyPlugin, operation: string): ZodObject | undefined { - if (!plugin.queryArgs) { - return undefined; - } - - let result: ZodType | undefined; - - if (operation in plugin.queryArgs && plugin.queryArgs[operation]) { - // most specific operation takes highest precedence - result = plugin.queryArgs[operation]; - } else if (operation === 'upsert') { - // upsert is special: it's in both CoreCreateOperations and CoreUpdateOperations - // so we need to merge both $create and $update schemas to match the type system - const createSchema = - '$create' in plugin.queryArgs && plugin.queryArgs['$create'] ? plugin.queryArgs['$create'] : undefined; - const updateSchema = - '$update' in plugin.queryArgs && plugin.queryArgs['$update'] ? plugin.queryArgs['$update'] : undefined; - - if (createSchema && updateSchema) { - invariant( - createSchema instanceof z.ZodObject, - 'Plugin extended query args schema must be a Zod object', - ); - invariant( - updateSchema instanceof z.ZodObject, - 'Plugin extended query args schema must be a Zod object', - ); - // merge both schemas (combines their properties) - result = createSchema.extend(updateSchema.shape); - } else if (createSchema) { - result = createSchema; - } else if (updateSchema) { - result = updateSchema; - } - } else if ( - // then comes grouped operations: $create, $read, $update, $delete - CoreCreateOperations.includes(operation as CoreCreateOperations) && - '$create' in plugin.queryArgs && - plugin.queryArgs['$create'] - ) { - result = plugin.queryArgs['$create']; - } else if ( - CoreReadOperations.includes(operation as CoreReadOperations) && - '$read' in plugin.queryArgs && - plugin.queryArgs['$read'] - ) { - result = plugin.queryArgs['$read']; - } else if ( - CoreUpdateOperations.includes(operation as CoreUpdateOperations) && - '$update' in plugin.queryArgs && - plugin.queryArgs['$update'] - ) { - result = plugin.queryArgs['$update']; - } else if ( - CoreDeleteOperations.includes(operation as CoreDeleteOperations) && - '$delete' in plugin.queryArgs && - plugin.queryArgs['$delete'] - ) { - result = plugin.queryArgs['$delete']; - } else if ('$all' in plugin.queryArgs && plugin.queryArgs['$all']) { - // finally comes $all - result = plugin.queryArgs['$all']; - } - - invariant( - result === undefined || result instanceof z.ZodObject, - 'Plugin extended query args schema must be a Zod object', - ); - return result; - } - - // #endregion - - // #region Find - - @cache() - private makeFindSchema(model: string, operation: CoreCrudOperations) { - const fields: Record = {}; - const unique = operation === 'findUnique'; - const findOne = operation === 'findUnique' || operation === 'findFirst'; - const where = this.makeWhereSchema(model, unique); - if (unique) { - fields['where'] = where; - } else { - fields['where'] = where.optional(); - } - - fields['select'] = this.makeSelectSchema(model).optional().nullable(); - fields['include'] = this.makeIncludeSchema(model).optional().nullable(); - fields['omit'] = this.makeOmitSchema(model).optional().nullable(); - - if (!unique) { - fields['skip'] = this.makeSkipSchema().optional(); - if (findOne) { - fields['take'] = z.literal(1).optional(); - } else { - fields['take'] = this.makeTakeSchema().optional(); - } - fields['orderBy'] = this.orArray(this.makeOrderBySchema(model, true, false), true).optional(); - fields['cursor'] = this.makeCursorSchema(model).optional(); - fields['distinct'] = this.makeDistinctSchema(model).optional(); - } - - const baseSchema = z.strictObject(fields); - let result: ZodType = this.mergePluginArgsSchema(baseSchema, operation); - result = this.refineForSelectIncludeMutuallyExclusive(result); - result = this.refineForSelectOmitMutuallyExclusive(result); - - if (!unique) { - result = result.optional(); - } - return result; - } - - @cache() - private makeExistsSchema(model: string) { - const baseSchema = z.strictObject({ - where: this.makeWhereSchema(model, false).optional(), - }); - return this.mergePluginArgsSchema(baseSchema, 'exists').optional(); - } - - private makeScalarSchema(type: string, attributes?: readonly AttributeApplication[]) { - if (this.schema.typeDefs && type in this.schema.typeDefs) { - return this.makeTypeDefSchema(type); - } else if (this.schema.enums && type in this.schema.enums) { - return this.makeEnumSchema(type); - } else { - return match(type) - .with('String', () => - this.extraValidationsEnabled ? addStringValidation(z.string(), attributes) : z.string(), - ) - .with('Int', () => - this.extraValidationsEnabled ? addNumberValidation(z.number().int(), attributes) : z.number().int(), - ) - .with('Float', () => - this.extraValidationsEnabled ? addNumberValidation(z.number(), attributes) : z.number(), - ) - .with('Boolean', () => z.boolean()) - .with('BigInt', () => - z.union([ - this.extraValidationsEnabled - ? addNumberValidation(z.number().int(), attributes) - : z.number().int(), - this.extraValidationsEnabled ? addBigIntValidation(z.bigint(), attributes) : z.bigint(), - ]), - ) - .with('Decimal', () => { - return z.union([ - this.extraValidationsEnabled ? addNumberValidation(z.number(), attributes) : z.number(), - addDecimalValidation(z.instanceof(Decimal), attributes, this.extraValidationsEnabled), - addDecimalValidation(z.string(), attributes, this.extraValidationsEnabled), - ]); - }) - .with('DateTime', () => z.union([z.date(), z.iso.datetime()])) - .with('Bytes', () => z.instanceof(Uint8Array)) - .with('Json', () => this.makeJsonValueSchema(false, false)) - .otherwise(() => z.unknown()); - } - } - - @cache() - private makeEnumSchema(type: string) { - const enumDef = getEnum(this.schema, type); - invariant(enumDef, `Enum "${type}" not found in schema`); - return z.enum(Object.keys(enumDef.values) as [string, ...string[]]); - } - - @cache() - private makeTypeDefSchema(type: string): z.ZodType { - const typeDef = getTypeDef(this.schema, type); - invariant(typeDef, `Type definition "${type}" not found in schema`); - const schema = z.looseObject( - Object.fromEntries( - Object.entries(typeDef.fields).map(([field, def]) => { - let fieldSchema = this.makeScalarSchema(def.type); - if (def.array) { - fieldSchema = fieldSchema.array(); - } - if (def.optional) { - fieldSchema = fieldSchema.nullish(); - } - return [field, fieldSchema]; - }), - ), - ); - - // zod doesn't preserve object field order after parsing, here we use a - // validation-only custom schema and use the original data if parsing - // is successful - const finalSchema = z.any().superRefine((value, ctx) => { - const parseResult = schema.safeParse(value); - if (!parseResult.success) { - parseResult.error.issues.forEach((issue) => ctx.addIssue(issue as any)); - } - }); - - return finalSchema; - } - - @cache() - private makeWhereSchema( - model: string, - unique: boolean, - withoutRelationFields = false, - withAggregations = false, - ): ZodType { - const modelDef = requireModel(this.schema, model); - - // unique field used in unique filters bypass filter slicing - const uniqueFieldNames = unique - ? getUniqueFields(this.schema, model) - .filter( - (uf): uf is { name: string; def: FieldDef } => - // single-field unique - 'def' in uf, - ) - .map((uf) => uf.name) - : undefined; - - const fields: Record = {}; - for (const field of Object.keys(modelDef.fields)) { - const fieldDef = requireField(this.schema, model, field); - let fieldSchema: ZodType | undefined; - - if (fieldDef.relation) { - if (withoutRelationFields) { - continue; - } - - // Check if Relation filter kind is allowed - const allowedFilterKinds = this.getEffectiveFilterKinds(model, field); - if (allowedFilterKinds && !allowedFilterKinds.includes('Relation')) { - // Relation filters are not allowed for this field - use z.never() - fieldSchema = z.never(); - } else { - fieldSchema = z.lazy(() => this.makeWhereSchema(fieldDef.type, false).optional()); - - // optional to-one relation allows null - fieldSchema = this.nullableIf(fieldSchema, !fieldDef.array && !!fieldDef.optional); - - if (fieldDef.array) { - // to-many relation - fieldSchema = z.union([ - fieldSchema, - z.strictObject({ - some: fieldSchema.optional(), - every: fieldSchema.optional(), - none: fieldSchema.optional(), - }), - ]); - } else { - // to-one relation - fieldSchema = z.union([ - fieldSchema, - z.strictObject({ - is: fieldSchema.optional(), - isNot: fieldSchema.optional(), - }), - ]); - } - } - } else { - const ignoreSlicing = !!uniqueFieldNames?.includes(field); - - const enumDef = getEnum(this.schema, fieldDef.type); - if (enumDef) { - // enum - if (Object.keys(enumDef.values).length > 0) { - fieldSchema = this.makeEnumFilterSchema(model, fieldDef, withAggregations, ignoreSlicing); - } - } else if (fieldDef.array) { - // array field - fieldSchema = this.makeArrayFilterSchema(model, fieldDef); - } else if (this.isTypeDefType(fieldDef.type)) { - fieldSchema = this.makeTypedJsonFilterSchema(model, fieldDef); - } else { - // primitive field - fieldSchema = this.makePrimitiveFilterSchema(model, fieldDef, withAggregations, ignoreSlicing); - } - } - - if (fieldSchema) { - fields[field] = fieldSchema.optional(); - } - } - - if (unique) { - // add compound unique fields, e.g. `{ id1_id2: { id1: 1, id2: 1 } }` - // compound-field filters are not affected by slicing - const uniqueFields = getUniqueFields(this.schema, model); - for (const uniqueField of uniqueFields) { - if ('defs' in uniqueField) { - fields[uniqueField.name] = z - .object( - Object.fromEntries( - Object.entries(uniqueField.defs).map(([key, def]) => { - invariant(!def.relation, 'unique field cannot be a relation'); - let fieldSchema: ZodType; - const enumDef = getEnum(this.schema, def.type); - if (enumDef) { - // enum - if (Object.keys(enumDef.values).length > 0) { - fieldSchema = this.makeEnumFilterSchema(model, def, false, true); - } else { - fieldSchema = z.never(); - } - } else { - fieldSchema = this.makePrimitiveFilterSchema(model, def, false, true); - } - return [key, fieldSchema]; - }), - ), - ) - .optional(); - } - } - } - - // expression builder - fields['$expr'] = z.custom((v) => typeof v === 'function', { error: '"$expr" must be a function' }).optional(); - - // logical operators - fields['AND'] = this.orArray( - z.lazy(() => this.makeWhereSchema(model, false, withoutRelationFields)), - true, - ).optional(); - fields['OR'] = z - .lazy(() => this.makeWhereSchema(model, false, withoutRelationFields)) - .array() - .optional(); - fields['NOT'] = this.orArray( - z.lazy(() => this.makeWhereSchema(model, false, withoutRelationFields)), - true, - ).optional(); - - const baseWhere = z.strictObject(fields); - let result: ZodType = baseWhere; - - if (unique) { - // requires at least one unique field (field set) is required - const uniqueFields = getUniqueFields(this.schema, model); - if (uniqueFields.length === 0) { - throw createInternalError(`Model "${model}" has no unique fields`); - } - - if (uniqueFields.length === 1) { - // only one unique field (set), mark the field(s) required - result = baseWhere.required({ - [uniqueFields[0]!.name]: true, - } as any); - } else { - result = baseWhere.refine((value) => { - // check that at least one unique field is set - return uniqueFields.some(({ name }) => value[name] !== undefined); - }, `At least one unique field or field set must be set`); - } - } - - return result; - } - - @cache() - private makeTypedJsonFilterSchema(contextModel: string | undefined, fieldInfo: FieldInfo) { - const field = fieldInfo.name; - const type = fieldInfo.type; - const optional = !!fieldInfo.optional; - const array = !!fieldInfo.array; - - const typeDef = getTypeDef(this.schema, type); - invariant(typeDef, `Type definition "${type}" not found in schema`); - - const candidates: z.ZodType[] = []; - - if (!array) { - // fields filter - const fieldSchemas: Record = {}; - for (const [fieldName, fieldDef] of Object.entries(typeDef.fields)) { - if (this.isTypeDefType(fieldDef.type)) { - // recursive typed JSON - use same model/field for nested typed JSON - fieldSchemas[fieldName] = this.makeTypedJsonFilterSchema(contextModel, fieldDef).optional(); - } else { - // enum, array, primitives - const enumDef = getEnum(this.schema, fieldDef.type); - if (enumDef) { - fieldSchemas[fieldName] = this.makeEnumFilterSchema(contextModel, fieldDef, false).optional(); - } else if (fieldDef.array) { - fieldSchemas[fieldName] = this.makeArrayFilterSchema(contextModel, fieldDef).optional(); - } else { - fieldSchemas[fieldName] = this.makePrimitiveFilterSchema( - contextModel, - fieldDef, - false, - ).optional(); - } - } - } - - candidates.push(z.strictObject(fieldSchemas)); - } - - const recursiveSchema = z - .lazy(() => this.makeTypedJsonFilterSchema(contextModel, { name: field, type, optional, array: false })) - .optional(); - if (array) { - // array filter - candidates.push( - z.strictObject({ - some: recursiveSchema, - every: recursiveSchema, - none: recursiveSchema, - }), - ); - } else { - // is / isNot filter - candidates.push( - z.strictObject({ - is: recursiveSchema, - isNot: recursiveSchema, - }), - ); - } - - // plain json filter - candidates.push(this.makeJsonFilterSchema(contextModel, field, optional)); - - if (optional) { - // allow null as well - candidates.push(z.null()); - } - - // either plain json filter or field filters - return z.union(candidates); - } - - private isTypeDefType(type: string) { - return this.schema.typeDefs && type in this.schema.typeDefs; - } - - @cache() - private makeEnumFilterSchema( - model: string | undefined, - fieldInfo: FieldInfo, - withAggregations: boolean, - ignoreSlicing: boolean = false, - ) { - const enumName = fieldInfo.type; - const optional = !!fieldInfo.optional; - const array = !!fieldInfo.array; - - const enumDef = getEnum(this.schema, enumName); - invariant(enumDef, `Enum "${enumName}" not found in schema`); - const baseSchema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]); - if (array) { - return this.internalMakeArrayFilterSchema(model, fieldInfo.name, baseSchema); - } - const allowedFilterKinds = ignoreSlicing ? undefined : this.getEffectiveFilterKinds(model, fieldInfo.name); - const components = this.makeCommonPrimitiveFilterComponents( - baseSchema, - optional, - () => z.lazy(() => this.makeEnumFilterSchema(model, fieldInfo, withAggregations)), - ['equals', 'in', 'notIn', 'not'], - withAggregations ? ['_count', '_min', '_max'] : undefined, - allowedFilterKinds, - ); - - return this.createUnionFilterSchema(baseSchema, optional, components, allowedFilterKinds); - } - - @cache() - private makeArrayFilterSchema(model: string | undefined, fieldInfo: FieldInfo) { - return this.internalMakeArrayFilterSchema( - model, - fieldInfo.name, - this.makeScalarSchema(fieldInfo.type as BuiltinType), - ); - } - - private internalMakeArrayFilterSchema(contextModel: string | undefined, field: string, elementSchema: ZodType) { - const allowedFilterKinds = this.getEffectiveFilterKinds(contextModel, field); - const operators = { - equals: elementSchema.array().optional(), - has: elementSchema.optional(), - hasEvery: elementSchema.array().optional(), - hasSome: elementSchema.array().optional(), - isEmpty: z.boolean().optional(), - }; - - // Filter operators based on allowed filter kinds - const filteredOperators = this.trimFilterOperators(operators, allowedFilterKinds); - - return z.strictObject(filteredOperators); - } - - @cache() - private makePrimitiveFilterSchema( - contextModel: string | undefined, - fieldInfo: FieldInfo, - withAggregations: boolean, - ignoreSlicing = false, - ) { - const allowedFilterKinds = ignoreSlicing - ? undefined - : this.getEffectiveFilterKinds(contextModel, fieldInfo.name); - const type = fieldInfo.type as BuiltinType; - const optional = !!fieldInfo.optional; - return match(type) - .with('String', () => this.makeStringFilterSchema(optional, withAggregations, allowedFilterKinds)) - .with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) => - this.makeNumberFilterSchema( - this.makeScalarSchema(type), - optional, - withAggregations, - allowedFilterKinds, - ), - ) - .with('Boolean', () => this.makeBooleanFilterSchema(optional, withAggregations, allowedFilterKinds)) - .with('DateTime', () => this.makeDateTimeFilterSchema(optional, withAggregations, allowedFilterKinds)) - .with('Bytes', () => this.makeBytesFilterSchema(optional, withAggregations, allowedFilterKinds)) - .with('Json', () => this.makeJsonFilterSchema(contextModel, fieldInfo.name, optional)) - .with('Unsupported', () => z.never()) - .exhaustive(); - } - - private makeJsonValueSchema(nullable: boolean, forFilter: boolean): z.ZodType { - const options: z.ZodType[] = [z.string(), z.number(), z.boolean(), z.instanceof(JsonNullClass)]; - - if (forFilter) { - options.push(z.instanceof(DbNullClass)); - } else { - if (nullable) { - // for mutation, allow DbNull only if nullable - options.push(z.instanceof(DbNullClass)); - } - } - - if (forFilter) { - options.push(z.instanceof(AnyNullClass)); - } - - const schema = z.union([ - ...options, - z.lazy(() => z.union([this.makeJsonValueSchema(false, false), z.null()]).array()), - z.record( - z.string(), - z.lazy(() => z.union([this.makeJsonValueSchema(false, false), z.null()])), - ), - ]); - return this.nullableIf(schema, nullable); - } - - @cache() - private makeJsonFilterSchema(contextModel: string | undefined, field: string, optional: boolean) { - const allowedFilterKinds = this.getEffectiveFilterKinds(contextModel, field); - - // Check if Json filter kind is allowed - if (allowedFilterKinds && !allowedFilterKinds.includes('Json')) { - // Return a never schema if Json filters are not allowed - return z.never(); - } - - const valueSchema = this.makeJsonValueSchema(optional, true); - return z.strictObject({ - path: z.string().optional(), - equals: valueSchema.optional(), - not: valueSchema.optional(), - string_contains: z.string().optional(), - string_starts_with: z.string().optional(), - string_ends_with: z.string().optional(), - mode: this.makeStringModeSchema().optional(), - array_contains: valueSchema.optional(), - array_starts_with: valueSchema.optional(), - array_ends_with: valueSchema.optional(), - }); - } - - @cache() - private makeDateTimeFilterSchema( - optional: boolean, - withAggregations: boolean, - allowedFilterKinds: string[] | undefined, - ): ZodType { - return this.makeCommonPrimitiveFilterSchema( - z.union([z.iso.datetime(), z.date()]), - optional, - () => z.lazy(() => this.makeDateTimeFilterSchema(optional, withAggregations, allowedFilterKinds)), - withAggregations ? ['_count', '_min', '_max'] : undefined, - allowedFilterKinds, - ); - } - - @cache() - private makeBooleanFilterSchema( - optional: boolean, - withAggregations: boolean, - allowedFilterKinds: string[] | undefined, - ): ZodType { - const components = this.makeCommonPrimitiveFilterComponents( - z.boolean(), - optional, - () => z.lazy(() => this.makeBooleanFilterSchema(optional, withAggregations, allowedFilterKinds)), - ['equals', 'not'], - withAggregations ? ['_count', '_min', '_max'] : undefined, - allowedFilterKinds, - ); - - return this.createUnionFilterSchema(z.boolean(), optional, components, allowedFilterKinds); - } - - @cache() - private makeBytesFilterSchema( - optional: boolean, - withAggregations: boolean, - allowedFilterKinds: string[] | undefined, - ): ZodType { - const baseSchema = z.instanceof(Uint8Array); - const components = this.makeCommonPrimitiveFilterComponents( - baseSchema, - optional, - () => z.instanceof(Uint8Array), - ['equals', 'in', 'notIn', 'not'], - withAggregations ? ['_count', '_min', '_max'] : undefined, - allowedFilterKinds, - ); - - return this.createUnionFilterSchema(baseSchema, optional, components, allowedFilterKinds); - } - - private makeCommonPrimitiveFilterComponents( - baseSchema: ZodType, - optional: boolean, - makeThis: () => ZodType, - supportedOperators: string[] | undefined = undefined, - withAggregations: Array<'_count' | '_avg' | '_sum' | '_min' | '_max'> | undefined = undefined, - allowedFilterKinds: string[] | undefined = undefined, - ) { - const commonAggSchema = () => - this.makeCommonPrimitiveFilterSchema(baseSchema, false, makeThis, undefined, allowedFilterKinds).optional(); - let result = { - equals: this.nullableIf(baseSchema.optional(), optional), - in: baseSchema.array().optional(), - notIn: baseSchema.array().optional(), - lt: baseSchema.optional(), - lte: baseSchema.optional(), - gt: baseSchema.optional(), - gte: baseSchema.optional(), - between: baseSchema.array().length(2).optional(), - not: makeThis().optional(), - ...(withAggregations?.includes('_count') - ? { _count: this.makeNumberFilterSchema(z.number().int(), false, false, undefined).optional() } - : {}), - ...(withAggregations?.includes('_avg') ? { _avg: commonAggSchema() } : {}), - ...(withAggregations?.includes('_sum') ? { _sum: commonAggSchema() } : {}), - ...(withAggregations?.includes('_min') ? { _min: commonAggSchema() } : {}), - ...(withAggregations?.includes('_max') ? { _max: commonAggSchema() } : {}), - }; - if (supportedOperators) { - const keys = [...supportedOperators, ...(withAggregations ?? [])]; - result = extractFields(result, keys) as typeof result; - } - - // Filter operators based on allowed filter kinds - result = this.trimFilterOperators(result, allowedFilterKinds) as typeof result; - - return result; - } - - private makeCommonPrimitiveFilterSchema( - baseSchema: ZodType, - optional: boolean, - makeThis: () => ZodType, - withAggregations: Array | undefined = undefined, - allowedFilterKinds: string[] | undefined = undefined, - ): z.ZodType { - const components = this.makeCommonPrimitiveFilterComponents( - baseSchema, - optional, - makeThis, - undefined, - withAggregations, - allowedFilterKinds, - ); - - return this.createUnionFilterSchema(baseSchema, optional, components, allowedFilterKinds); - } - - private makeNumberFilterSchema( - baseSchema: ZodType, - optional: boolean, - withAggregations: boolean, - allowedFilterKinds: string[] | undefined, - ): ZodType { - return this.makeCommonPrimitiveFilterSchema( - baseSchema, - optional, - () => z.lazy(() => this.makeNumberFilterSchema(baseSchema, optional, withAggregations, allowedFilterKinds)), - withAggregations ? ['_count', '_avg', '_sum', '_min', '_max'] : undefined, - allowedFilterKinds, - ); - } - - private makeStringFilterSchema( - optional: boolean, - withAggregations: boolean, - allowedFilterKinds: string[] | undefined, - ): ZodType { - const baseComponents = this.makeCommonPrimitiveFilterComponents( - z.string(), - optional, - () => z.lazy(() => this.makeStringFilterSchema(optional, withAggregations, allowedFilterKinds)), - undefined, - withAggregations ? ['_count', '_min', '_max'] : undefined, - allowedFilterKinds, - ); - - const stringSpecificOperators = { - startsWith: z.string().optional(), - endsWith: z.string().optional(), - contains: z.string().optional(), - ...(this.providerSupportsCaseSensitivity - ? { - mode: this.makeStringModeSchema().optional(), - } - : {}), - }; - - // Filter string-specific operators based on allowed filter kinds - const filteredStringOperators = this.trimFilterOperators(stringSpecificOperators, allowedFilterKinds); - - const allComponents = { - ...baseComponents, - ...filteredStringOperators, - }; - - return this.createUnionFilterSchema(z.string(), optional, allComponents, allowedFilterKinds); - } - - private makeStringModeSchema() { - return z.union([z.literal('default'), z.literal('insensitive')]); - } - - @cache() - private makeSelectSchema(model: string) { - const modelDef = requireModel(this.schema, model); - const fields: Record = {}; - for (const field of Object.keys(modelDef.fields)) { - const fieldDef = requireField(this.schema, model, field); - if (fieldDef.relation) { - // Check if the target model is allowed by slicing configuration - if (this.isModelAllowed(fieldDef.type)) { - fields[field] = this.makeRelationSelectIncludeSchema(model, field).optional(); - } - } else { - fields[field] = z.boolean().optional(); - } - } - - const _countSchema = this.makeCountSelectionSchema(model); - if (!(_countSchema instanceof z.ZodNever)) { - fields['_count'] = _countSchema; - } - - return z.strictObject(fields); - } - - @cache() - private makeCountSelectionSchema(model: string) { - const modelDef = requireModel(this.schema, model); - const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array); - if (toManyRelations.length > 0) { - return z - .union([ - z.literal(true), - z.strictObject({ - select: z.strictObject( - toManyRelations.reduce( - (acc, fieldDef) => ({ - ...acc, - [fieldDef.name]: z - .union([ - z.boolean(), - z.strictObject({ - where: this.makeWhereSchema(fieldDef.type, false, false), - }), - ]) - .optional(), - }), - {} as Record, - ), - ), - }), - ]) - .optional(); - } else { - return z.never(); - } - } - - @cache() - private makeRelationSelectIncludeSchema(model: string, field: string) { - const fieldDef = requireField(this.schema, model, field); - let objSchema: z.ZodType = z.strictObject({ - ...(fieldDef.array || fieldDef.optional - ? { - // to-many relations and optional to-one relations are filterable - where: z.lazy(() => this.makeWhereSchema(fieldDef.type, false)).optional(), - } - : {}), - select: z - .lazy(() => this.makeSelectSchema(fieldDef.type)) - .optional() - .nullable(), - include: z - .lazy(() => this.makeIncludeSchema(fieldDef.type)) - .optional() - .nullable(), - omit: z - .lazy(() => this.makeOmitSchema(fieldDef.type)) - .optional() - .nullable(), - ...(fieldDef.array - ? { - // to-many relations can be ordered, skipped, taken, and cursor-located - orderBy: z - .lazy(() => this.orArray(this.makeOrderBySchema(fieldDef.type, true, false), true)) - .optional(), - skip: this.makeSkipSchema().optional(), - take: this.makeTakeSchema().optional(), - cursor: this.makeCursorSchema(fieldDef.type).optional(), - distinct: this.makeDistinctSchema(fieldDef.type).optional(), - } - : {}), - }); - - objSchema = this.refineForSelectIncludeMutuallyExclusive(objSchema); - objSchema = this.refineForSelectOmitMutuallyExclusive(objSchema); - - return z.union([z.boolean(), objSchema]); - } - - @cache() - private makeOmitSchema(model: string) { - const modelDef = requireModel(this.schema, model); - const fields: Record = {}; - for (const field of Object.keys(modelDef.fields)) { - const fieldDef = requireField(this.schema, model, field); - if (!fieldDef.relation) { - if (this.options.allowQueryTimeOmitOverride !== false) { - // if override is allowed, use boolean - fields[field] = z.boolean().optional(); - } else { - // otherwise only allow true - fields[field] = z.literal(true).optional(); - } - } - } - return z.strictObject(fields); - } - - @cache() - private makeIncludeSchema(model: string) { - const modelDef = requireModel(this.schema, model); - const fields: Record = {}; - for (const field of Object.keys(modelDef.fields)) { - const fieldDef = requireField(this.schema, model, field); - if (fieldDef.relation) { - // Check if the target model is allowed by slicing configuration - if (this.isModelAllowed(fieldDef.type)) { - fields[field] = this.makeRelationSelectIncludeSchema(model, field).optional(); - } - } - } - - const _countSchema = this.makeCountSelectionSchema(model); - if (!(_countSchema instanceof z.ZodNever)) { - fields['_count'] = _countSchema; - } - - return z.strictObject(fields); - } - - @cache() - private makeOrderBySchema(model: string, withRelation: boolean, WithAggregation: boolean) { - const modelDef = requireModel(this.schema, model); - const fields: Record = {}; - const sort = z.union([z.literal('asc'), z.literal('desc')]); - for (const field of Object.keys(modelDef.fields)) { - const fieldDef = requireField(this.schema, model, field); - if (fieldDef.relation) { - // relations - if (withRelation) { - fields[field] = z.lazy(() => { - let relationOrderBy = this.makeOrderBySchema(fieldDef.type, withRelation, WithAggregation); - if (fieldDef.array) { - relationOrderBy = relationOrderBy.extend({ - _count: sort, - }); - } - return relationOrderBy.optional(); - }); - } - } else { - // scalars - if (fieldDef.optional) { - fields[field] = z - .union([ - sort, - z.strictObject({ - sort, - nulls: z.union([z.literal('first'), z.literal('last')]), - }), - ]) - .optional(); - } else { - fields[field] = sort.optional(); - } - } - } - - // aggregations - if (WithAggregation) { - const aggregationFields = ['_count', '_avg', '_sum', '_min', '_max']; - for (const agg of aggregationFields) { - fields[agg] = z.lazy(() => this.makeOrderBySchema(model, true, false).optional()); - } - } - - return z.strictObject(fields); - } - - @cache() - private makeDistinctSchema(model: string) { - const modelDef = requireModel(this.schema, model); - const nonRelationFields = Object.keys(modelDef.fields).filter((field) => !modelDef.fields[field]?.relation); - return this.orArray(z.enum(nonRelationFields as any), true); - } - - private makeCursorSchema(model: string) { - // `makeWhereSchema` is already cached - return this.makeWhereSchema(model, true, true).optional(); - } - - // #endregion - - // #region Create - - @cache() - private makeCreateSchema(model: string) { - const dataSchema = this.makeCreateDataSchema(model, false); - const baseSchema = z.strictObject({ - data: dataSchema, - select: this.makeSelectSchema(model).optional().nullable(), - include: this.makeIncludeSchema(model).optional().nullable(), - omit: this.makeOmitSchema(model).optional().nullable(), - }); - let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'create'); - schema = this.refineForSelectIncludeMutuallyExclusive(schema); - schema = this.refineForSelectOmitMutuallyExclusive(schema); - return schema; - } - - @cache() - private makeCreateManySchema(model: string) { - return this.mergePluginArgsSchema(this.makeCreateManyDataSchema(model, []), 'createMany').optional(); - } - - @cache() - private makeCreateManyAndReturnSchema(model: string) { - const base = this.makeCreateManyDataSchema(model, []); - let result: ZodObject = base.extend({ - select: this.makeSelectSchema(model).optional().nullable(), - omit: this.makeOmitSchema(model).optional().nullable(), - }); - result = this.mergePluginArgsSchema(result, 'createManyAndReturn'); - return this.refineForSelectOmitMutuallyExclusive(result).optional(); - } - - @cache() - private makeCreateDataSchema( - model: string, - canBeArray: boolean, - withoutFields: string[] = [], - withoutRelationFields = false, - ) { - const uncheckedVariantFields: Record = {}; - const checkedVariantFields: Record = {}; - const modelDef = requireModel(this.schema, model); - const hasRelation = - !withoutRelationFields && - Object.entries(modelDef.fields).some(([f, def]) => !withoutFields.includes(f) && def.relation); - - Object.keys(modelDef.fields).forEach((field) => { - if (withoutFields.includes(field)) { - return; - } - const fieldDef = requireField(this.schema, model, field); - if (fieldDef.computed) { - return; - } - - if (this.isDelegateDiscriminator(fieldDef)) { - // discriminator field is auto-assigned - return; - } - - if (fieldDef.relation) { - if (withoutRelationFields) { - return; - } - // Check if the target model is allowed by slicing configuration - if (!this.isModelAllowed(fieldDef.type)) { - return; - } - const excludeFields: string[] = []; - const oppositeField = fieldDef.relation.opposite; - if (oppositeField) { - excludeFields.push(oppositeField); - const oppositeFieldDef = requireField(this.schema, fieldDef.type, oppositeField); - if (oppositeFieldDef.relation?.fields) { - excludeFields.push(...oppositeFieldDef.relation.fields); - } - } - - let fieldSchema: ZodType = z.lazy(() => - this.makeRelationManipulationSchema(model, field, excludeFields, 'create'), - ); - - if (fieldDef.optional || fieldDef.array) { - // optional or array relations are optional - fieldSchema = fieldSchema.optional(); - } else { - // if all fk fields are optional, the relation is optional - let allFksOptional = false; - if (fieldDef.relation.fields) { - allFksOptional = fieldDef.relation.fields.every((f) => { - const fkDef = requireField(this.schema, model, f); - return fkDef.optional || fieldHasDefaultValue(fkDef); - }); - } - if (allFksOptional) { - fieldSchema = fieldSchema.optional(); - } - } - - // optional to-one relation can be null - if (fieldDef.optional && !fieldDef.array) { - fieldSchema = fieldSchema.nullable(); - } - checkedVariantFields[field] = fieldSchema; - if (fieldDef.array || !fieldDef.relation.references) { - // non-owned relation - uncheckedVariantFields[field] = fieldSchema; - } - } else { - let fieldSchema = this.makeScalarSchema(fieldDef.type, fieldDef.attributes); - - if (fieldDef.array) { - fieldSchema = addListValidation(fieldSchema.array(), fieldDef.attributes); - fieldSchema = z - .union([ - fieldSchema, - z.strictObject({ - set: fieldSchema, - }), - ]) - .optional(); - } - - if (fieldDef.optional || fieldHasDefaultValue(fieldDef)) { - fieldSchema = fieldSchema.optional(); - } - - if (fieldDef.optional) { - if (fieldDef.type === 'Json') { - // DbNull for Json fields - fieldSchema = z.union([fieldSchema, z.instanceof(DbNullClass)]); - } else { - fieldSchema = fieldSchema.nullable(); - } - } - - uncheckedVariantFields[field] = fieldSchema; - if (!fieldDef.foreignKeyFor) { - // non-fk field - checkedVariantFields[field] = fieldSchema; - } - } - }); - - const uncheckedCreateSchema = this.extraValidationsEnabled - ? addCustomValidation(z.strictObject(uncheckedVariantFields), modelDef.attributes) - : z.strictObject(uncheckedVariantFields); - const checkedCreateSchema = this.extraValidationsEnabled - ? addCustomValidation(z.strictObject(checkedVariantFields), modelDef.attributes) - : z.strictObject(checkedVariantFields); - - if (!hasRelation) { - return this.orArray(uncheckedCreateSchema, canBeArray); - } else { - return z.union([ - uncheckedCreateSchema, - checkedCreateSchema, - ...(canBeArray ? [z.array(uncheckedCreateSchema)] : []), - ...(canBeArray ? [z.array(checkedCreateSchema)] : []), - ]); - } - } - - private isDelegateDiscriminator(fieldDef: FieldDef) { - if (!fieldDef.originModel) { - // not inherited from a delegate - return false; - } - const discriminatorField = getDiscriminatorField(this.schema, fieldDef.originModel); - return discriminatorField === fieldDef.name; - } - - @cache() - private makeRelationManipulationSchema( - model: string, - field: string, - withoutFields: string[], - mode: 'create' | 'update', - ) { - const fieldDef = requireField(this.schema, model, field); - const fieldType = fieldDef.type; - const array = !!fieldDef.array; - const fields: Record = { - create: this.makeCreateDataSchema(fieldDef.type, !!fieldDef.array, withoutFields).optional(), - - connect: this.makeConnectDataSchema(fieldType, array).optional(), - - connectOrCreate: this.makeConnectOrCreateDataSchema(fieldType, array, withoutFields).optional(), - }; - - if (array) { - fields['createMany'] = this.makeCreateManyDataSchema(fieldType, withoutFields).optional(); - } - - if (mode === 'update') { - if (fieldDef.optional || fieldDef.array) { - // disconnect and delete are only available for optional/to-many relations - fields['disconnect'] = this.makeDisconnectDataSchema(fieldType, array).optional(); - - fields['delete'] = this.makeDeleteRelationDataSchema(fieldType, array, true).optional(); - } - - fields['update'] = array - ? this.orArray( - z.strictObject({ - where: this.makeWhereSchema(fieldType, true), - data: this.makeUpdateDataSchema(fieldType, withoutFields), - }), - true, - ).optional() - : z - .union([ - z.strictObject({ - where: this.makeWhereSchema(fieldType, false).optional(), - data: this.makeUpdateDataSchema(fieldType, withoutFields), - }), - this.makeUpdateDataSchema(fieldType, withoutFields), - ]) - .optional(); - - let upsertWhere = this.makeWhereSchema(fieldType, true); - if (!fieldDef.array) { - // to-one relation, can upsert without where clause - upsertWhere = upsertWhere.optional(); - } - fields['upsert'] = this.orArray( - z.strictObject({ - where: upsertWhere, - create: this.makeCreateDataSchema(fieldType, false, withoutFields), - update: this.makeUpdateDataSchema(fieldType, withoutFields), - }), - true, - ).optional(); - - if (array) { - // to-many relation specifics - fields['set'] = this.makeSetDataSchema(fieldType, true).optional(); - - fields['updateMany'] = this.orArray( - z.strictObject({ - where: this.makeWhereSchema(fieldType, false, true), - data: this.makeUpdateDataSchema(fieldType, withoutFields), - }), - true, - ).optional(); - - fields['deleteMany'] = this.makeDeleteRelationDataSchema(fieldType, true, false).optional(); - } - } - - return z.strictObject(fields); - } - - @cache() - private makeSetDataSchema(model: string, canBeArray: boolean) { - return this.orArray(this.makeWhereSchema(model, true), canBeArray); - } - - @cache() - private makeConnectDataSchema(model: string, canBeArray: boolean) { - return this.orArray(this.makeWhereSchema(model, true), canBeArray); - } - - @cache() - private makeDisconnectDataSchema(model: string, canBeArray: boolean) { - if (canBeArray) { - // to-many relation, must be unique filters - return this.orArray(this.makeWhereSchema(model, true), canBeArray); - } else { - // to-one relation, can be boolean or a regular filter - the entity - // being disconnected is already uniquely identified by its parent - return z.union([z.boolean(), this.makeWhereSchema(model, false)]); - } - } - - @cache() - private makeDeleteRelationDataSchema(model: string, toManyRelation: boolean, uniqueFilter: boolean) { - return toManyRelation - ? this.orArray(this.makeWhereSchema(model, uniqueFilter), true) - : z.union([z.boolean(), this.makeWhereSchema(model, uniqueFilter)]); - } - - @cache() - private makeConnectOrCreateDataSchema(model: string, canBeArray: boolean, withoutFields: string[]) { - const whereSchema = this.makeWhereSchema(model, true); - const createSchema = this.makeCreateDataSchema(model, false, withoutFields); - return this.orArray( - z.strictObject({ - where: whereSchema, - create: createSchema, - }), - canBeArray, - ); - } - - @cache() - private makeCreateManyDataSchema(model: string, withoutFields: string[]) { - return z.strictObject({ - data: this.makeCreateDataSchema(model, true, withoutFields, true), - skipDuplicates: z.boolean().optional(), - }); - } - - // #endregion - - // #region Update - - @cache() - private makeUpdateSchema(model: string) { - const baseSchema = z.strictObject({ - where: this.makeWhereSchema(model, true), - data: this.makeUpdateDataSchema(model), - select: this.makeSelectSchema(model).optional().nullable(), - include: this.makeIncludeSchema(model).optional().nullable(), - omit: this.makeOmitSchema(model).optional().nullable(), - }); - let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'update'); - schema = this.refineForSelectIncludeMutuallyExclusive(schema); - schema = this.refineForSelectOmitMutuallyExclusive(schema); - return schema; - } - - @cache() - private makeUpdateManySchema(model: string) { - return this.mergePluginArgsSchema( - z.strictObject({ - where: this.makeWhereSchema(model, false).optional(), - data: this.makeUpdateDataSchema(model, [], true), - limit: z.number().int().nonnegative().optional(), - }), - 'updateMany', - ); - } - - @cache() - private makeUpdateManyAndReturnSchema(model: string) { - // plugin extended args schema is merged in `makeUpdateManySchema` - const baseSchema: ZodObject = this.makeUpdateManySchema(model); - let schema: ZodType = baseSchema.extend({ - select: this.makeSelectSchema(model).optional().nullable(), - omit: this.makeOmitSchema(model).optional().nullable(), - }); - schema = this.refineForSelectOmitMutuallyExclusive(schema); - return schema; - } - - @cache() - private makeUpsertSchema(model: string) { - const baseSchema = z.strictObject({ - where: this.makeWhereSchema(model, true), - create: this.makeCreateDataSchema(model, false), - update: this.makeUpdateDataSchema(model), - select: this.makeSelectSchema(model).optional().nullable(), - include: this.makeIncludeSchema(model).optional().nullable(), - omit: this.makeOmitSchema(model).optional().nullable(), - }); - let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'upsert'); - schema = this.refineForSelectIncludeMutuallyExclusive(schema); - schema = this.refineForSelectOmitMutuallyExclusive(schema); - return schema; - } - - @cache() - private makeUpdateDataSchema(model: string, withoutFields: string[] = [], withoutRelationFields = false) { - const uncheckedVariantFields: Record = {}; - const checkedVariantFields: Record = {}; - const modelDef = requireModel(this.schema, model); - const hasRelation = Object.entries(modelDef.fields).some( - ([key, value]) => value.relation && !withoutFields.includes(key), - ); - - Object.keys(modelDef.fields).forEach((field) => { - if (withoutFields.includes(field)) { - return; - } - const fieldDef = requireField(this.schema, model, field); - - if (fieldDef.relation) { - if (withoutRelationFields) { - return; - } - // Check if the target model is allowed by slicing configuration - if (!this.isModelAllowed(fieldDef.type)) { - return; - } - const excludeFields: string[] = []; - const oppositeField = fieldDef.relation.opposite; - if (oppositeField) { - excludeFields.push(oppositeField); - const oppositeFieldDef = requireField(this.schema, fieldDef.type, oppositeField); - if (oppositeFieldDef.relation?.fields) { - excludeFields.push(...oppositeFieldDef.relation.fields); - } - } - let fieldSchema: ZodType = z - .lazy(() => this.makeRelationManipulationSchema(model, field, excludeFields, 'update')) - .optional(); - // optional to-one relation can be null - if (fieldDef.optional && !fieldDef.array) { - fieldSchema = fieldSchema.nullable(); - } - checkedVariantFields[field] = fieldSchema; - if (fieldDef.array || !fieldDef.relation.references) { - // non-owned relation - uncheckedVariantFields[field] = fieldSchema; - } - } else { - let fieldSchema = this.makeScalarSchema(fieldDef.type, fieldDef.attributes); - - if (this.isNumericField(fieldDef)) { - fieldSchema = z.union([ - fieldSchema, - z - .object({ - set: this.nullableIf(z.number().optional(), !!fieldDef.optional).optional(), - increment: z.number().optional(), - decrement: z.number().optional(), - multiply: z.number().optional(), - divide: z.number().optional(), - }) - .refine( - (v) => Object.keys(v).length === 1, - 'Only one of "set", "increment", "decrement", "multiply", or "divide" can be provided', - ), - ]); - } - - if (fieldDef.array) { - const arraySchema = addListValidation(fieldSchema.array(), fieldDef.attributes); - fieldSchema = z.union([ - arraySchema, - z - .object({ - set: arraySchema.optional(), - push: z.union([fieldSchema, fieldSchema.array()]).optional(), - }) - .refine((v) => Object.keys(v).length === 1, 'Only one of "set", "push" can be provided'), - ]); - } - - if (fieldDef.optional) { - if (fieldDef.type === 'Json') { - // DbNull for Json fields - fieldSchema = z.union([fieldSchema, z.instanceof(DbNullClass)]); - } else { - fieldSchema = fieldSchema.nullable(); - } - } - - // all fields are optional in update - fieldSchema = fieldSchema.optional(); - - uncheckedVariantFields[field] = fieldSchema; - if (!fieldDef.foreignKeyFor) { - // non-fk field - checkedVariantFields[field] = fieldSchema; - } - } - }); - - const uncheckedUpdateSchema = this.extraValidationsEnabled - ? addCustomValidation(z.strictObject(uncheckedVariantFields), modelDef.attributes) - : z.strictObject(uncheckedVariantFields); - const checkedUpdateSchema = this.extraValidationsEnabled - ? addCustomValidation(z.strictObject(checkedVariantFields), modelDef.attributes) - : z.strictObject(checkedVariantFields); - if (!hasRelation) { - return uncheckedUpdateSchema; - } else { - return z.union([uncheckedUpdateSchema, checkedUpdateSchema]); - } - } - - // #endregion - - // #region Delete - - @cache() - private makeDeleteSchema(model: string) { - const baseSchema = z.strictObject({ - where: this.makeWhereSchema(model, true), - select: this.makeSelectSchema(model).optional().nullable(), - include: this.makeIncludeSchema(model).optional().nullable(), - omit: this.makeOmitSchema(model).optional().nullable(), - }); - let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'delete'); - schema = this.refineForSelectIncludeMutuallyExclusive(schema); - schema = this.refineForSelectOmitMutuallyExclusive(schema); - return schema; - } - - @cache() - private makeDeleteManySchema(model: string) { - return this.mergePluginArgsSchema( - z.strictObject({ - where: this.makeWhereSchema(model, false).optional(), - limit: z.number().int().nonnegative().optional(), - }), - 'deleteMany', - ).optional(); - } - - // #endregion - - // #region Count - - @cache() - makeCountSchema(model: string) { - return this.mergePluginArgsSchema( - z.strictObject({ - where: this.makeWhereSchema(model, false).optional(), - skip: this.makeSkipSchema().optional(), - take: this.makeTakeSchema().optional(), - orderBy: this.orArray(this.makeOrderBySchema(model, true, false), true).optional(), - select: this.makeCountAggregateInputSchema(model).optional(), - }), - 'count', - ).optional(); - } - - @cache() - private makeCountAggregateInputSchema(model: string) { - const modelDef = requireModel(this.schema, model); - return z.union([ - z.literal(true), - z.strictObject({ - _all: z.literal(true).optional(), - ...Object.keys(modelDef.fields).reduce( - (acc, field) => { - acc[field] = z.literal(true).optional(); - return acc; - }, - {} as Record, - ), - }), - ]); - } - - // #endregion - - // #region Aggregate - - @cache() - makeAggregateSchema(model: string) { - return this.mergePluginArgsSchema( - z.strictObject({ - where: this.makeWhereSchema(model, false).optional(), - skip: this.makeSkipSchema().optional(), - take: this.makeTakeSchema().optional(), - orderBy: this.orArray(this.makeOrderBySchema(model, true, false), true).optional(), - _count: this.makeCountAggregateInputSchema(model).optional(), - _avg: this.makeSumAvgInputSchema(model).optional(), - _sum: this.makeSumAvgInputSchema(model).optional(), - _min: this.makeMinMaxInputSchema(model).optional(), - _max: this.makeMinMaxInputSchema(model).optional(), - }), - 'aggregate', - ).optional(); - } - - @cache() - makeSumAvgInputSchema(model: string) { - const modelDef = requireModel(this.schema, model); - return z.strictObject( - Object.keys(modelDef.fields).reduce( - (acc, field) => { - const fieldDef = requireField(this.schema, model, field); - if (this.isNumericField(fieldDef)) { - acc[field] = z.literal(true).optional(); - } - return acc; - }, - {} as Record, - ), - ); - } - - @cache() - makeMinMaxInputSchema(model: string) { - const modelDef = requireModel(this.schema, model); - return z.strictObject( - Object.keys(modelDef.fields).reduce( - (acc, field) => { - const fieldDef = requireField(this.schema, model, field); - if (!fieldDef.relation && !fieldDef.array) { - acc[field] = z.literal(true).optional(); - } - return acc; - }, - {} as Record, - ), - ); - } - - @cache() - private makeGroupBySchema(model: string) { - const modelDef = requireModel(this.schema, model); - const nonRelationFields = Object.keys(modelDef.fields).filter((field) => !modelDef.fields[field]?.relation); - const bySchema = - nonRelationFields.length > 0 - ? this.orArray(z.enum(nonRelationFields as [string, ...string[]]), true) - : z.never(); - - const baseSchema = z.strictObject({ - where: this.makeWhereSchema(model, false).optional(), - orderBy: this.orArray(this.makeOrderBySchema(model, false, true), true).optional(), - by: bySchema, - having: this.makeHavingSchema(model).optional(), - skip: this.makeSkipSchema().optional(), - take: this.makeTakeSchema().optional(), - _count: this.makeCountAggregateInputSchema(model).optional(), - _avg: this.makeSumAvgInputSchema(model).optional(), - _sum: this.makeSumAvgInputSchema(model).optional(), - _min: this.makeMinMaxInputSchema(model).optional(), - _max: this.makeMinMaxInputSchema(model).optional(), - }); - - let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'groupBy'); - - // fields used in `having` must be either in the `by` list, or aggregations - schema = schema.refine((value: any) => { - const bys = typeof value.by === 'string' ? [value.by] : value.by; - if (value.having && typeof value.having === 'object') { - for (const [key, val] of Object.entries(value.having)) { - if (AggregateOperators.includes(key as any)) { - continue; - } - if (bys.includes(key)) { - continue; - } - // we have a key not mentioned in `by`, in this case it must only use - // aggregations in the condition - - // 1. payload must be an object - if (!val || typeof val !== 'object') { - return false; - } - // 2. payload must only contain aggregations - if (!this.onlyAggregationFields(val)) { - return false; - } - } - } - return true; - }, 'fields in "having" must be in "by"'); - - // fields used in `orderBy` must be either in the `by` list, or aggregations - schema = schema.refine((value: any) => { - const bys = typeof value.by === 'string' ? [value.by] : value.by; - if ( - value.orderBy && - Object.keys(value.orderBy) - .filter((f) => !AggregateOperators.includes(f as AggregateOperators)) - .some((key) => !bys.includes(key)) - ) { - return false; - } else { - return true; - } - }, 'fields in "orderBy" must be in "by"'); - - return schema; - } - - private onlyAggregationFields(val: object) { - for (const [key, value] of Object.entries(val)) { - if (AggregateOperators.includes(key as any)) { - // aggregation field - continue; - } - if (LOGICAL_COMBINATORS.includes(key as any)) { - // logical operators - if (enumerate(value).every((v) => this.onlyAggregationFields(v))) { - continue; - } - } - return false; - } - return true; - } - - private makeHavingSchema(model: string) { - // `makeWhereSchema` is cached - return this.makeWhereSchema(model, false, true, true); - } - - // #endregion - - // #region Procedures - - @cache() - private makeProcedureParamSchema(param: { type: string; array?: boolean; optional?: boolean }): z.ZodType { - let schema: z.ZodType; - - if (isTypeDef(this.schema, param.type)) { - schema = this.makeTypeDefSchema(param.type); - } else if (isEnum(this.schema, param.type)) { - schema = this.makeEnumSchema(param.type); - } else if (param.type in (this.schema.models ?? {})) { - // For model-typed values, accept any object (no deep shape validation). - schema = z.record(z.string(), z.unknown()); - } else { - // Builtin scalar types. - schema = this.makeScalarSchema(param.type as BuiltinType); - - // If a type isn't recognized by any of the above branches, `makeScalarSchema` returns `unknown`. - // Treat it as configuration/schema error. - if (schema instanceof z.ZodUnknown) { - throw createInternalError(`Unsupported procedure parameter type: ${param.type}`); - } - } - - if (param.array) { - schema = schema.array(); - } - if (param.optional) { - schema = schema.optional(); - } - - return schema; - } - - // #endregion - - // #region Cache Management - - getCache(cacheKey: string) { - return this.schemaCache.get(cacheKey); - } - - setCache(cacheKey: string, schema: ZodType) { - return this.schemaCache.set(cacheKey, schema); - } - - // @ts-ignore - private printCacheStats(detailed = false) { - console.log('Schema cache size:', this.schemaCache.size); - if (detailed) { - for (const key of this.schemaCache.keys()) { - console.log(`\t${key}`); - } - } - } - - // #endregion - - // #region Helpers - - @cache() - private makeSkipSchema() { - return z.number().int().nonnegative(); - } - - @cache() - private makeTakeSchema() { - return z.number().int(); - } - - private refineForSelectIncludeMutuallyExclusive(schema: ZodType) { - return schema.refine( - (value: any) => !(value['select'] && value['include']), - '"select" and "include" cannot be used together', - ); - } - - private refineForSelectOmitMutuallyExclusive(schema: ZodType) { - return schema.refine( - (value: any) => !(value['select'] && value['omit']), - '"select" and "omit" cannot be used together', - ); - } - - private nullableIf(schema: ZodType, nullable: boolean) { - return nullable ? schema.nullable() : schema; - } - - private orArray(schema: T, canBeArray: boolean) { - return canBeArray ? z.union([schema, z.array(schema)]) : schema; - } - - private isNumericField(fieldDef: FieldDef) { - return NUMERIC_FIELD_TYPES.includes(fieldDef.type) && !fieldDef.array; - } - - private get providerSupportsCaseSensitivity() { - return this.schema.provider.type === 'postgresql'; - } - - /** - * Gets the effective set of allowed FilterKind values for a specific model and field. - * Respects the precedence: model[field] > model.$all > $all[field] > $all.$all. - */ - private getEffectiveFilterKinds(model: string | undefined, field: string): string[] | undefined { - if (!model) { - // no restrictions - return undefined; - } - - const slicing = this.options.slicing; - if (!slicing?.models) { - // no slicing or no model-specific slicing, no restrictions - return undefined; - } - - // A string-indexed view of slicing.models that avoids unsafe 'as any' while still - // allowing runtime access by model name. The value shape matches FieldSlicingOptions. - type FieldConfig = { includedFilterKinds?: readonly string[]; excludedFilterKinds?: readonly string[] }; - type FieldsRecord = { $all?: FieldConfig } & Record; - type ModelConfig = { fields?: FieldsRecord }; - const modelsRecord = slicing.models as Record; - - // Check field-level settings for the specific model - const modelConfig = modelsRecord[lowerCaseFirst(model)]; - if (modelConfig?.fields) { - const fieldConfig = modelConfig.fields[field]; - if (fieldConfig) { - return this.computeFilterKinds(fieldConfig.includedFilterKinds, fieldConfig.excludedFilterKinds); - } - - // Fallback to field-level $all for the specific model - const allFieldsConfig = modelConfig.fields['$all']; - if (allFieldsConfig) { - return this.computeFilterKinds( - allFieldsConfig.includedFilterKinds, - allFieldsConfig.excludedFilterKinds, - ); - } - } - - // Fallback to model-level $all - const allModelsConfig = modelsRecord['$all']; - if (allModelsConfig?.fields) { - // Check specific field in $all model config before falling back to $all.$all - const allModelsFieldConfig = allModelsConfig.fields[field]; - if (allModelsFieldConfig) { - return this.computeFilterKinds( - allModelsFieldConfig.includedFilterKinds, - allModelsFieldConfig.excludedFilterKinds, - ); - } - - // Fallback to $all.$all - const allModelsAllFieldsConfig = allModelsConfig.fields['$all']; - if (allModelsAllFieldsConfig) { - return this.computeFilterKinds( - allModelsAllFieldsConfig.includedFilterKinds, - allModelsAllFieldsConfig.excludedFilterKinds, - ); - } - } - - return undefined; // No restrictions - } - - /** - * Computes the effective set of filter kinds based on inclusion and exclusion lists. - */ - private computeFilterKinds(included: readonly string[] | undefined, excluded: readonly string[] | undefined) { - let result: string[] | undefined; - - if (included !== undefined) { - // Start with the included set - result = [...included]; - } - - if (excluded !== undefined) { - if (!result) { - // If no inclusion list, start with all filter kinds - result = [...this.allFilterKinds]; - } - // Remove excluded kinds - for (const kind of excluded) { - result = result.filter((k) => k !== kind); - } - } - - return result; - } - - /** - * Filters operators based on allowed filter kinds. - */ - private trimFilterOperators>( - operators: T, - allowedKinds: string[] | undefined, - ): Partial { - if (!allowedKinds) { - return operators; // No restrictions - } - - return Object.fromEntries( - Object.entries(operators).filter(([key, _]) => { - return ( - !(key in FILTER_PROPERTY_TO_KIND) || - allowedKinds.includes(FILTER_PROPERTY_TO_KIND[key as keyof typeof FILTER_PROPERTY_TO_KIND]) - ); - }), - ) as Partial; - } - - private createUnionFilterSchema( - valueSchema: ZodType, - optional: boolean, - components: Record, - allowedFilterKinds: string[] | undefined, - ) { - // If all filter operators are excluded - if (Object.keys(components).length === 0) { - // if equality filters are allowed, allow direct value - if (!allowedFilterKinds || allowedFilterKinds.includes('Equality')) { - return this.nullableIf(valueSchema, optional); - } - // otherwise nothing is allowed - return z.never(); - } - - if (!allowedFilterKinds || allowedFilterKinds.includes('Equality')) { - // direct value or filter operators - return z.union([this.nullableIf(valueSchema, optional), z.strictObject(components)]); - } else { - // filter operators - return z.strictObject(components); - } - } - - /** - * Checks if a model is included in the slicing configuration. - * Returns true if the model is allowed, false if it's excluded. - */ - private isModelAllowed(targetModel: string): boolean { - const slicing = this.options.slicing; - if (!slicing) { - return true; // No slicing, all models allowed - } - - const { includedModels, excludedModels } = slicing; - - // If includedModels is specified, only those models are allowed - if (includedModels !== undefined) { - if (!includedModels.includes(targetModel as any)) { - return false; - } - } - - // If excludedModels is specified, those models are not allowed - if (excludedModels !== undefined) { - if (excludedModels.includes(targetModel as any)) { - return false; - } - } - - return true; - } - - // #endregion -} +export { InputValidator } from './validator'; diff --git a/packages/orm/src/client/crud/validator/validator.ts b/packages/orm/src/client/crud/validator/validator.ts new file mode 100644 index 000000000..1acba6f15 --- /dev/null +++ b/packages/orm/src/client/crud/validator/validator.ts @@ -0,0 +1,288 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import { match } from 'ts-pattern'; +import { ZodType } from 'zod'; +import { type GetModels, type ProcedureDef, type SchemaDef } from '../../../schema'; +import { formatError } from '../../../utils/zod-utils'; +import type { ClientContract } from '../../contract'; +import { + type AggregateArgs, + type CountArgs, + type CreateArgs, + type CreateManyAndReturnArgs, + type CreateManyArgs, + type DeleteArgs, + type DeleteManyArgs, + type ExistsArgs, + type FindArgs, + type GroupByArgs, + type UpdateArgs, + type UpdateManyAndReturnArgs, + type UpdateManyArgs, + type UpsertArgs, +} from '../../crud-types'; +import { createInvalidInputError } from '../../errors'; +import { ZodSchemaFactory } from '../../zod/factory'; + +type GetSchemaFunc = (model: GetModels) => ZodType; + +export class InputValidator { + readonly zodFactory: ZodSchemaFactory; + + constructor(private readonly client: ClientContract) { + this.zodFactory = new ZodSchemaFactory(client); + } + + // #region Entry points + + validateFindArgs( + model: GetModels, + args: unknown, + operation: 'findFirst' | 'findUnique' | 'findMany', + ): FindArgs, any, true> | undefined { + return this.validate, any, true> | undefined>( + model, + operation, + (model) => + match(operation) + .with('findFirst', () => this.zodFactory.makeFindFirstSchema(model)) + .with('findUnique', () => this.zodFactory.makeFindUniqueSchema(model)) + .with('findMany', () => this.zodFactory.makeFindManySchema(model)) + .exhaustive(), + args, + ); + } + + validateExistsArgs( + model: GetModels, + args: unknown, + ): ExistsArgs, any> | undefined { + return this.validate, any> | undefined>( + model, + 'exists', + (model) => this.zodFactory.makeExistsSchema(model), + args, + ); + } + + validateCreateArgs(model: GetModels, args: unknown): CreateArgs, any> { + return this.validate, any>>( + model, + 'create', + (model) => this.zodFactory.makeCreateSchema(model), + args, + ); + } + + validateCreateManyArgs(model: GetModels, args: unknown): CreateManyArgs> { + return this.validate>>( + model, + 'createMany', + (model) => this.zodFactory.makeCreateManySchema(model), + args, + ); + } + + validateCreateManyAndReturnArgs( + model: GetModels, + args: unknown, + ): CreateManyAndReturnArgs, any> | undefined { + return this.validate, any> | undefined>( + model, + 'createManyAndReturn', + (model) => this.zodFactory.makeCreateManyAndReturnSchema(model), + args, + ); + } + + validateUpdateArgs(model: GetModels, args: unknown): UpdateArgs, any> { + return this.validate, any>>( + model, + 'update', + (model) => this.zodFactory.makeUpdateSchema(model), + args, + ); + } + + validateUpdateManyArgs(model: GetModels, args: unknown): UpdateManyArgs, any> { + return this.validate, any>>( + model, + 'updateMany', + (model) => this.zodFactory.makeUpdateManySchema(model), + args, + ); + } + + validateUpdateManyAndReturnArgs( + model: GetModels, + args: unknown, + ): UpdateManyAndReturnArgs, any> { + return this.validate, any>>( + model, + 'updateManyAndReturn', + (model) => this.zodFactory.makeUpdateManyAndReturnSchema(model), + args, + ); + } + + validateUpsertArgs(model: GetModels, args: unknown): UpsertArgs, any> { + return this.validate, any>>( + model, + 'upsert', + (model) => this.zodFactory.makeUpsertSchema(model), + args, + ); + } + + validateDeleteArgs(model: GetModels, args: unknown): DeleteArgs, any> { + return this.validate, any>>( + model, + 'delete', + (model) => this.zodFactory.makeDeleteSchema(model), + args, + ); + } + + validateDeleteManyArgs( + model: GetModels, + args: unknown, + ): DeleteManyArgs, any> | undefined { + return this.validate, any> | undefined>( + model, + 'deleteMany', + (model) => this.zodFactory.makeDeleteManySchema(model), + args, + ); + } + + validateCountArgs(model: GetModels, args: unknown): CountArgs, any> | undefined { + return this.validate, any> | undefined>( + model, + 'count', + (model) => this.zodFactory.makeCountSchema(model), + args, + ); + } + + validateAggregateArgs(model: GetModels, args: unknown): AggregateArgs, any> { + return this.validate, any>>( + model, + 'aggregate', + (model) => this.zodFactory.makeAggregateSchema(model), + args, + ); + } + + validateGroupByArgs(model: GetModels, args: unknown): GroupByArgs, any> { + return this.validate, any>>( + model, + 'groupBy', + (model) => this.zodFactory.makeGroupBySchema(model), + args, + ); + } + + // TODO: turn it into a Zod schema and cache + validateProcedureInput(proc: string, input: unknown): unknown { + const procDef = (this.client.$schema.procedures ?? {})[proc] as ProcedureDef | undefined; + invariant(procDef, `Procedure "${proc}" not found in schema`); + + const params = Object.values(procDef.params ?? {}); + + // For procedures where every parameter is optional, allow omitting the input entirely. + if (typeof input === 'undefined') { + if (params.length === 0) { + return undefined; + } + if (params.every((p) => p.optional)) { + return undefined; + } + throw createInvalidInputError('Missing procedure arguments', `$procs.${proc}`); + } + + if (typeof input !== 'object' || input === null || Array.isArray(input)) { + throw createInvalidInputError('Procedure input must be an object', `$procs.${proc}`); + } + + const envelope = input as Record; + const argsPayload = Object.prototype.hasOwnProperty.call(envelope, 'args') ? (envelope as any).args : undefined; + + if (params.length === 0) { + if (typeof argsPayload === 'undefined') { + return input; + } + if (!argsPayload || typeof argsPayload !== 'object' || Array.isArray(argsPayload)) { + throw createInvalidInputError('Procedure `args` must be an object', `$procs.${proc}`); + } + if (Object.keys(argsPayload as any).length === 0) { + return input; + } + throw createInvalidInputError('Procedure does not accept arguments', `$procs.${proc}`); + } + + if (typeof argsPayload === 'undefined') { + if (params.every((p) => p.optional)) { + return input; + } + throw createInvalidInputError('Missing procedure arguments', `$procs.${proc}`); + } + + if (!argsPayload || typeof argsPayload !== 'object' || Array.isArray(argsPayload)) { + throw createInvalidInputError('Procedure `args` must be an object', `$procs.${proc}`); + } + + const obj = argsPayload as Record; + + for (const param of params) { + const value = (obj as any)[param.name]; + + if (!Object.prototype.hasOwnProperty.call(obj, param.name)) { + if (param.optional) { + continue; + } + throw createInvalidInputError(`Missing procedure argument: ${param.name}`, `$procs.${proc}`); + } + + if (typeof value === 'undefined') { + if (param.optional) { + continue; + } + throw createInvalidInputError( + `Invalid procedure argument: ${param.name} is required`, + `$procs.${proc}`, + ); + } + + const schema = this.zodFactory.makeProcedureParamSchema(param); + const parsed = schema.safeParse(value); + if (!parsed.success) { + throw createInvalidInputError( + `Invalid procedure argument: ${param.name}: ${formatError(parsed.error)}`, + `$procs.${proc}`, + ); + } + } + + return input; + } + + // #endregion + + // #region Validation helpers + + private validate(model: GetModels, operation: string, getSchema: GetSchemaFunc, args: unknown) { + const schema = getSchema(model); + const { error, data } = schema.safeParse(args); + if (error) { + throw createInvalidInputError( + `Invalid ${operation} args for model "${model}": ${formatError(error)}`, + model, + { + cause: error, + }, + ); + } + return data as T; + } + + // #endregion +} diff --git a/packages/orm/src/client/index.ts b/packages/orm/src/client/index.ts index 25015c353..41e313faa 100644 --- a/packages/orm/src/client/index.ts +++ b/packages/orm/src/client/index.ts @@ -21,3 +21,4 @@ export type { ZenStackPromise } from './promise'; export type { ToKysely } from './query-builder'; export * as QueryUtils from './query-utils'; export type * from './type-utils'; +export * from './zod'; diff --git a/packages/orm/src/client/crud/validator/cache-decorator.ts b/packages/orm/src/client/zod/cache-decorator.ts similarity index 100% rename from packages/orm/src/client/crud/validator/cache-decorator.ts rename to packages/orm/src/client/zod/cache-decorator.ts diff --git a/packages/orm/src/client/zod/factory.ts b/packages/orm/src/client/zod/factory.ts new file mode 100644 index 000000000..8ad67de6b --- /dev/null +++ b/packages/orm/src/client/zod/factory.ts @@ -0,0 +1,2106 @@ +import { enumerate, invariant, lowerCaseFirst } from '@zenstackhq/common-helpers'; +import { ZodUtils } from '@zenstackhq/zod'; +import Decimal from 'decimal.js'; +import { match, P } from 'ts-pattern'; +import { z, ZodObject, ZodType } from 'zod'; +import { AnyNullClass, DbNullClass, JsonNullClass } from '../../common-types'; +import { + type AttributeApplication, + type BuiltinType, + type FieldDef, + type GetModels, + type SchemaDef, +} from '../../schema'; +import { extractFields } from '../../utils/object-utils'; +import { AggregateOperators, FILTER_PROPERTY_TO_KIND, LOGICAL_COMBINATORS, NUMERIC_FIELD_TYPES } from '../constants'; +import type { ClientContract } from '../contract'; +import type { + AggregateArgs, + CountArgs, + CreateArgs, + CreateManyAndReturnArgs, + CreateManyArgs, + DeleteArgs, + DeleteManyArgs, + ExistsArgs, + FindFirstArgs, + FindManyArgs, + FindUniqueArgs, + GroupByArgs, + UpdateArgs, + UpdateManyAndReturnArgs, + UpdateManyArgs, + UpsertArgs, +} from '../crud-types'; +import { + CoreCreateOperations, + CoreDeleteOperations, + CoreReadOperations, + CoreUpdateOperations, + type CoreCrudOperations, +} from '../crud/operations/base'; +import { createInternalError } from '../errors'; +import type { ClientOptions } from '../options'; +import type { AnyPlugin, ExtQueryArgsBase, RuntimePlugin } from '../plugin'; +import { + fieldHasDefaultValue, + getDiscriminatorField, + getEnum, + getTypeDef, + getUniqueFields, + isEnum, + isTypeDef, + requireField, + requireModel, +} from '../query-utils'; +import { cache } from './cache-decorator'; + +/** + * Minimal field information needed for filter schema generation. + */ +type FieldInfo = { + name: string; + type: string; + optional?: boolean; + array?: boolean; +}; + +/** + * Create a factory for generating Zod schemas to validate ORM query inputs. + */ +export function createQuerySchemaFactory< + Schema extends SchemaDef, + Options extends ClientOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, +>(client: ClientContract): ZodSchemaFactory; + +/** + * Create a factory for generating Zod schemas to validate ORM query inputs. + */ +export function createQuerySchemaFactory< + Schema extends SchemaDef, + Options extends ClientOptions = ClientOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, +>(schema: Schema, options?: Options): ZodSchemaFactory; + +export function createQuerySchemaFactory(clientOrSchema: any, options?: any) { + return new ZodSchemaFactory(clientOrSchema, options); +} + +/** + * Factory class responsible for creating and caching Zod schemas for ORM input validation. + */ +export class ZodSchemaFactory< + Schema extends SchemaDef, + Options extends ClientOptions = ClientOptions, + ExtQueryArgs extends ExtQueryArgsBase = {}, +> { + private readonly schemaCache = new Map(); + private readonly allFilterKinds = [...new Set(Object.values(FILTER_PROPERTY_TO_KIND))]; + private readonly schema: Schema; + private readonly options: Options; + + constructor(client: ClientContract); + constructor(schema: Schema, options?: Options); + constructor(clientOrSchema: any, options?: Options) { + if ('$schema' in clientOrSchema) { + this.schema = clientOrSchema.$schema; + this.options = clientOrSchema.$options; + } else { + this.schema = clientOrSchema; + this.options = options || ({} as Options); + } + } + + private get plugins(): RuntimePlugin[] { + return this.options.plugins ?? []; + } + + private get extraValidationsEnabled() { + return this.options.validateInput !== false; + } + + // #region Cache Management + + // @ts-ignore + private getCache(cacheKey: string) { + return this.schemaCache.get(cacheKey); + } + + // @ts-ignore + private setCache(cacheKey: string, schema: ZodType) { + return this.schemaCache.set(cacheKey, schema); + } + + // @ts-ignore + private printCacheStats(detailed = false) { + console.log('Schema cache size:', this.schemaCache.size); + if (detailed) { + for (const key of this.schemaCache.keys()) { + console.log(`\t${key}`); + } + } + } + + // #endregion + + // #region Find + + makeFindUniqueSchema>( + model: Model, + ): ZodType> { + return this.makeFindSchema(model, 'findUnique') as ZodType< + FindUniqueArgs + >; + } + + makeFindFirstSchema>( + model: Model, + ): ZodType | undefined> { + return this.makeFindSchema(model, 'findFirst') as ZodType< + FindFirstArgs | undefined + >; + } + + makeFindManySchema>( + model: Model, + ): ZodType | undefined> { + return this.makeFindSchema(model, 'findMany') as ZodType< + FindManyArgs | undefined + >; + } + + @cache() + private makeFindSchema(model: string, operation: CoreCrudOperations) { + const fields: Record = {}; + const unique = operation === 'findUnique'; + const findOne = operation === 'findUnique' || operation === 'findFirst'; + const where = this.makeWhereSchema(model, unique); + if (unique) { + fields['where'] = where; + } else { + fields['where'] = where.optional(); + } + + fields['select'] = this.makeSelectSchema(model).optional().nullable(); + fields['include'] = this.makeIncludeSchema(model).optional().nullable(); + fields['omit'] = this.makeOmitSchema(model).optional().nullable(); + + if (!unique) { + fields['skip'] = this.makeSkipSchema().optional(); + if (findOne) { + fields['take'] = z.literal(1).optional(); + } else { + fields['take'] = this.makeTakeSchema().optional(); + } + fields['orderBy'] = this.orArray(this.makeOrderBySchema(model, true, false), true).optional(); + fields['cursor'] = this.makeCursorSchema(model).optional(); + fields['distinct'] = this.makeDistinctSchema(model).optional(); + } + + const baseSchema = z.strictObject(fields); + let result: ZodType = this.mergePluginArgsSchema(baseSchema, operation); + result = this.refineForSelectIncludeMutuallyExclusive(result); + result = this.refineForSelectOmitMutuallyExclusive(result); + + if (!unique) { + result = result.optional(); + } + return result; + } + + @cache() + makeExistsSchema>( + model: Model, + ): ZodType | undefined> { + const baseSchema = z.strictObject({ + where: this.makeWhereSchema(model, false).optional(), + }); + return this.mergePluginArgsSchema(baseSchema, 'exists').optional() as ZodType< + ExistsArgs | undefined + >; + } + + private makeScalarSchema(type: string, attributes?: readonly AttributeApplication[]) { + if (this.schema.typeDefs && type in this.schema.typeDefs) { + return this.makeTypeDefSchema(type); + } else if (this.schema.enums && type in this.schema.enums) { + return this.makeEnumSchema(type); + } else { + return match(type) + .with('String', () => + this.extraValidationsEnabled ? ZodUtils.addStringValidation(z.string(), attributes) : z.string(), + ) + .with('Int', () => + this.extraValidationsEnabled + ? ZodUtils.addNumberValidation(z.number().int(), attributes) + : z.number().int(), + ) + .with('Float', () => + this.extraValidationsEnabled ? ZodUtils.addNumberValidation(z.number(), attributes) : z.number(), + ) + .with('Boolean', () => z.boolean()) + .with('BigInt', () => + z.union([ + this.extraValidationsEnabled + ? ZodUtils.addNumberValidation(z.number().int(), attributes) + : z.number().int(), + this.extraValidationsEnabled + ? ZodUtils.addBigIntValidation(z.bigint(), attributes) + : z.bigint(), + ]), + ) + .with('Decimal', () => { + return z.union([ + this.extraValidationsEnabled + ? ZodUtils.addNumberValidation(z.number(), attributes) + : z.number(), + ZodUtils.addDecimalValidation(z.instanceof(Decimal), attributes, this.extraValidationsEnabled), + ZodUtils.addDecimalValidation(z.string(), attributes, this.extraValidationsEnabled), + ]); + }) + .with('DateTime', () => z.union([z.date(), z.iso.datetime()])) + .with('Bytes', () => z.instanceof(Uint8Array)) + .with('Json', () => this.makeJsonValueSchema(false, false)) + .otherwise(() => z.unknown()); + } + } + + @cache() + private makeEnumSchema(_enum: string) { + const enumDef = getEnum(this.schema, _enum); + invariant(enumDef, `Enum "${_enum}" not found in schema`); + return z.enum(Object.keys(enumDef.values) as [string, ...string[]]); + } + + @cache() + private makeTypeDefSchema(type: string): ZodType { + const typeDef = getTypeDef(this.schema, type); + invariant(typeDef, `Type definition "${type}" not found in schema`); + const schema = z.looseObject( + Object.fromEntries( + Object.entries(typeDef.fields).map(([field, def]) => { + let fieldSchema = this.makeScalarSchema(def.type); + if (def.array) { + fieldSchema = fieldSchema.array(); + } + if (def.optional) { + fieldSchema = fieldSchema.nullish(); + } + return [field, fieldSchema]; + }), + ), + ); + + // zod doesn't preserve object field order after parsing, here we use a + // validation-only custom schema and use the original data if parsing + // is successful + const finalSchema = z.any().superRefine((value, ctx) => { + const parseResult = schema.safeParse(value); + if (!parseResult.success) { + parseResult.error.issues.forEach((issue) => ctx.addIssue(issue as any)); + } + }); + + return finalSchema; + } + + @cache() + makeWhereSchema(model: string, unique: boolean, withoutRelationFields = false, withAggregations = false): ZodType { + const modelDef = requireModel(this.schema, model); + + // unique field used in unique filters bypass filter slicing + const uniqueFieldNames = unique + ? getUniqueFields(this.schema, model) + .filter( + (uf): uf is { name: string; def: FieldDef } => + // single-field unique + 'def' in uf, + ) + .map((uf) => uf.name) + : undefined; + + const fields: Record = {}; + for (const field of Object.keys(modelDef.fields)) { + const fieldDef = requireField(this.schema, model, field); + let fieldSchema: ZodType | undefined; + + if (fieldDef.relation) { + if (withoutRelationFields) { + continue; + } + + // Check if Relation filter kind is allowed + const allowedFilterKinds = this.getEffectiveFilterKinds(model, field); + if (allowedFilterKinds && !allowedFilterKinds.includes('Relation')) { + // Relation filters are not allowed for this field - use z.never() + fieldSchema = z.never(); + } else { + fieldSchema = z.lazy(() => this.makeWhereSchema(fieldDef.type, false).optional()); + + // optional to-one relation allows null + fieldSchema = this.nullableIf(fieldSchema, !fieldDef.array && !!fieldDef.optional); + + if (fieldDef.array) { + // to-many relation + fieldSchema = z.union([ + fieldSchema, + z.strictObject({ + some: fieldSchema.optional(), + every: fieldSchema.optional(), + none: fieldSchema.optional(), + }), + ]); + } else { + // to-one relation + fieldSchema = z.union([ + fieldSchema, + z.strictObject({ + is: fieldSchema.optional(), + isNot: fieldSchema.optional(), + }), + ]); + } + } + } else { + const ignoreSlicing = !!uniqueFieldNames?.includes(field); + + const enumDef = getEnum(this.schema, fieldDef.type); + if (enumDef) { + // enum + if (Object.keys(enumDef.values).length > 0) { + fieldSchema = this.makeEnumFilterSchema(model, fieldDef, withAggregations, ignoreSlicing); + } + } else if (fieldDef.array) { + // array field + fieldSchema = this.makeArrayFilterSchema(model, fieldDef); + } else if (this.isTypeDefType(fieldDef.type)) { + fieldSchema = this.makeTypedJsonFilterSchema(model, fieldDef); + } else { + // primitive field + fieldSchema = this.makePrimitiveFilterSchema(model, fieldDef, withAggregations, ignoreSlicing); + } + } + + if (fieldSchema) { + fields[field] = fieldSchema.optional(); + } + } + + if (unique) { + // add compound unique fields, e.g. `{ id1_id2: { id1: 1, id2: 1 } }` + // compound-field filters are not affected by slicing + const uniqueFields = getUniqueFields(this.schema, model); + for (const uniqueField of uniqueFields) { + if ('defs' in uniqueField) { + fields[uniqueField.name] = z + .object( + Object.fromEntries( + Object.entries(uniqueField.defs).map(([key, def]) => { + invariant(!def.relation, 'unique field cannot be a relation'); + let fieldSchema: ZodType; + const enumDef = getEnum(this.schema, def.type); + if (enumDef) { + // enum + if (Object.keys(enumDef.values).length > 0) { + fieldSchema = this.makeEnumFilterSchema(model, def, false, true); + } else { + fieldSchema = z.never(); + } + } else { + fieldSchema = this.makePrimitiveFilterSchema(model, def, false, true); + } + return [key, fieldSchema]; + }), + ), + ) + .optional(); + } + } + } + + // expression builder + fields['$expr'] = z.custom((v) => typeof v === 'function', { error: '"$expr" must be a function' }).optional(); + + // logical operators + fields['AND'] = this.orArray( + z.lazy(() => this.makeWhereSchema(model, false, withoutRelationFields)), + true, + ).optional(); + fields['OR'] = z + .lazy(() => this.makeWhereSchema(model, false, withoutRelationFields)) + .array() + .optional(); + fields['NOT'] = this.orArray( + z.lazy(() => this.makeWhereSchema(model, false, withoutRelationFields)), + true, + ).optional(); + + const baseWhere = z.strictObject(fields); + let result: ZodType = baseWhere; + + if (unique) { + // requires at least one unique field (field set) is required + const uniqueFields = getUniqueFields(this.schema, model); + if (uniqueFields.length === 0) { + throw createInternalError(`Model "${model}" has no unique fields`); + } + + if (uniqueFields.length === 1) { + // only one unique field (set), mark the field(s) required + result = baseWhere.required({ + [uniqueFields[0]!.name]: true, + } as any); + } else { + result = baseWhere.refine((value) => { + // check that at least one unique field is set + return uniqueFields.some(({ name }) => value[name] !== undefined); + }, `At least one unique field or field set must be set`); + } + } + + return result; + } + + @cache() + private makeTypedJsonFilterSchema(contextModel: string | undefined, fieldInfo: FieldInfo) { + const field = fieldInfo.name; + const type = fieldInfo.type; + const optional = !!fieldInfo.optional; + const array = !!fieldInfo.array; + + const typeDef = getTypeDef(this.schema, type); + invariant(typeDef, `Type definition "${type}" not found in schema`); + + const candidates: ZodType[] = []; + + if (!array) { + // fields filter + const fieldSchemas: Record = {}; + for (const [fieldName, fieldDef] of Object.entries(typeDef.fields)) { + if (this.isTypeDefType(fieldDef.type)) { + // recursive typed JSON - use same model/field for nested typed JSON + fieldSchemas[fieldName] = this.makeTypedJsonFilterSchema(contextModel, fieldDef).optional(); + } else { + // enum, array, primitives + const enumDef = getEnum(this.schema, fieldDef.type); + if (enumDef) { + fieldSchemas[fieldName] = this.makeEnumFilterSchema(contextModel, fieldDef, false).optional(); + } else if (fieldDef.array) { + fieldSchemas[fieldName] = this.makeArrayFilterSchema(contextModel, fieldDef).optional(); + } else { + fieldSchemas[fieldName] = this.makePrimitiveFilterSchema( + contextModel, + fieldDef, + false, + ).optional(); + } + } + } + + candidates.push(z.strictObject(fieldSchemas)); + } + + const recursiveSchema = z + .lazy(() => this.makeTypedJsonFilterSchema(contextModel, { name: field, type, optional, array: false })) + .optional(); + if (array) { + // array filter + candidates.push( + z.strictObject({ + some: recursiveSchema, + every: recursiveSchema, + none: recursiveSchema, + }), + ); + } else { + // is / isNot filter + candidates.push( + z.strictObject({ + is: recursiveSchema, + isNot: recursiveSchema, + }), + ); + } + + // plain json filter + candidates.push(this.makeJsonFilterSchema(contextModel, field, optional)); + + if (optional) { + // allow null as well + candidates.push(z.null()); + } + + // either plain json filter or field filters + return z.union(candidates); + } + + private isTypeDefType(type: string) { + return this.schema.typeDefs && type in this.schema.typeDefs; + } + + @cache() + private makeEnumFilterSchema( + model: string | undefined, + fieldInfo: FieldInfo, + withAggregations: boolean, + ignoreSlicing: boolean = false, + ) { + const enumName = fieldInfo.type; + const optional = !!fieldInfo.optional; + const array = !!fieldInfo.array; + + const enumDef = getEnum(this.schema, enumName); + invariant(enumDef, `Enum "${enumName}" not found in schema`); + const baseSchema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]); + if (array) { + return this.internalMakeArrayFilterSchema(model, fieldInfo.name, baseSchema); + } + const allowedFilterKinds = ignoreSlicing ? undefined : this.getEffectiveFilterKinds(model, fieldInfo.name); + const components = this.makeCommonPrimitiveFilterComponents( + baseSchema, + optional, + () => z.lazy(() => this.makeEnumFilterSchema(model, fieldInfo, withAggregations)), + ['equals', 'in', 'notIn', 'not'], + withAggregations ? ['_count', '_min', '_max'] : undefined, + allowedFilterKinds, + ); + + return this.createUnionFilterSchema(baseSchema, optional, components, allowedFilterKinds); + } + + @cache() + private makeArrayFilterSchema(model: string | undefined, fieldInfo: FieldInfo) { + return this.internalMakeArrayFilterSchema( + model, + fieldInfo.name, + this.makeScalarSchema(fieldInfo.type as BuiltinType), + ); + } + + private internalMakeArrayFilterSchema(contextModel: string | undefined, field: string, elementSchema: ZodType) { + const allowedFilterKinds = this.getEffectiveFilterKinds(contextModel, field); + const operators = { + equals: elementSchema.array().optional(), + has: elementSchema.optional(), + hasEvery: elementSchema.array().optional(), + hasSome: elementSchema.array().optional(), + isEmpty: z.boolean().optional(), + }; + + // Filter operators based on allowed filter kinds + const filteredOperators = this.trimFilterOperators(operators, allowedFilterKinds); + + return z.strictObject(filteredOperators); + } + + @cache() + private makePrimitiveFilterSchema( + contextModel: string | undefined, + fieldInfo: FieldInfo, + withAggregations: boolean, + ignoreSlicing = false, + ) { + const allowedFilterKinds = ignoreSlicing + ? undefined + : this.getEffectiveFilterKinds(contextModel, fieldInfo.name); + const type = fieldInfo.type as BuiltinType; + const optional = !!fieldInfo.optional; + return match(type) + .with('String', () => this.makeStringFilterSchema(optional, withAggregations, allowedFilterKinds)) + .with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) => + this.makeNumberFilterSchema( + this.makeScalarSchema(type), + optional, + withAggregations, + allowedFilterKinds, + ), + ) + .with('Boolean', () => this.makeBooleanFilterSchema(optional, withAggregations, allowedFilterKinds)) + .with('DateTime', () => this.makeDateTimeFilterSchema(optional, withAggregations, allowedFilterKinds)) + .with('Bytes', () => this.makeBytesFilterSchema(optional, withAggregations, allowedFilterKinds)) + .with('Json', () => this.makeJsonFilterSchema(contextModel, fieldInfo.name, optional)) + .with('Unsupported', () => z.never()) + .exhaustive(); + } + + private makeJsonValueSchema(nullable: boolean, forFilter: boolean): ZodType { + const options: ZodType[] = [z.string(), z.number(), z.boolean(), z.instanceof(JsonNullClass)]; + + if (forFilter) { + options.push(z.instanceof(DbNullClass)); + } else { + if (nullable) { + // for mutation, allow DbNull only if nullable + options.push(z.instanceof(DbNullClass)); + } + } + + if (forFilter) { + options.push(z.instanceof(AnyNullClass)); + } + + const schema = z.union([ + ...options, + z.lazy(() => z.union([this.makeJsonValueSchema(false, false), z.null()]).array()), + z.record( + z.string(), + z.lazy(() => z.union([this.makeJsonValueSchema(false, false), z.null()])), + ), + ]); + return this.nullableIf(schema, nullable); + } + + @cache() + private makeJsonFilterSchema(contextModel: string | undefined, field: string, optional: boolean) { + const allowedFilterKinds = this.getEffectiveFilterKinds(contextModel, field); + + // Check if Json filter kind is allowed + if (allowedFilterKinds && !allowedFilterKinds.includes('Json')) { + // Return a never schema if Json filters are not allowed + return z.never(); + } + + const valueSchema = this.makeJsonValueSchema(optional, true); + return z.strictObject({ + path: z.string().optional(), + equals: valueSchema.optional(), + not: valueSchema.optional(), + string_contains: z.string().optional(), + string_starts_with: z.string().optional(), + string_ends_with: z.string().optional(), + mode: this.makeStringModeSchema().optional(), + array_contains: valueSchema.optional(), + array_starts_with: valueSchema.optional(), + array_ends_with: valueSchema.optional(), + }); + } + + @cache() + private makeDateTimeFilterSchema( + optional: boolean, + withAggregations: boolean, + allowedFilterKinds: string[] | undefined, + ): ZodType { + return this.makeCommonPrimitiveFilterSchema( + z.union([z.iso.datetime(), z.date()]), + optional, + () => z.lazy(() => this.makeDateTimeFilterSchema(optional, withAggregations, allowedFilterKinds)), + withAggregations ? ['_count', '_min', '_max'] : undefined, + allowedFilterKinds, + ); + } + + @cache() + private makeBooleanFilterSchema( + optional: boolean, + withAggregations: boolean, + allowedFilterKinds: string[] | undefined, + ): ZodType { + const components = this.makeCommonPrimitiveFilterComponents( + z.boolean(), + optional, + () => z.lazy(() => this.makeBooleanFilterSchema(optional, withAggregations, allowedFilterKinds)), + ['equals', 'not'], + withAggregations ? ['_count', '_min', '_max'] : undefined, + allowedFilterKinds, + ); + + return this.createUnionFilterSchema(z.boolean(), optional, components, allowedFilterKinds); + } + + @cache() + private makeBytesFilterSchema( + optional: boolean, + withAggregations: boolean, + allowedFilterKinds: string[] | undefined, + ): ZodType { + const baseSchema = z.instanceof(Uint8Array); + const components = this.makeCommonPrimitiveFilterComponents( + baseSchema, + optional, + () => z.instanceof(Uint8Array), + ['equals', 'in', 'notIn', 'not'], + withAggregations ? ['_count', '_min', '_max'] : undefined, + allowedFilterKinds, + ); + + return this.createUnionFilterSchema(baseSchema, optional, components, allowedFilterKinds); + } + + private makeCommonPrimitiveFilterComponents( + baseSchema: ZodType, + optional: boolean, + makeThis: () => ZodType, + supportedOperators: string[] | undefined = undefined, + withAggregations: Array<'_count' | '_avg' | '_sum' | '_min' | '_max'> | undefined = undefined, + allowedFilterKinds: string[] | undefined = undefined, + ) { + const commonAggSchema = () => + this.makeCommonPrimitiveFilterSchema(baseSchema, false, makeThis, undefined, allowedFilterKinds).optional(); + let result = { + equals: this.nullableIf(baseSchema.optional(), optional), + in: baseSchema.array().optional(), + notIn: baseSchema.array().optional(), + lt: baseSchema.optional(), + lte: baseSchema.optional(), + gt: baseSchema.optional(), + gte: baseSchema.optional(), + between: baseSchema.array().length(2).optional(), + not: makeThis().optional(), + ...(withAggregations?.includes('_count') + ? { _count: this.makeNumberFilterSchema(z.number().int(), false, false, undefined).optional() } + : {}), + ...(withAggregations?.includes('_avg') ? { _avg: commonAggSchema() } : {}), + ...(withAggregations?.includes('_sum') ? { _sum: commonAggSchema() } : {}), + ...(withAggregations?.includes('_min') ? { _min: commonAggSchema() } : {}), + ...(withAggregations?.includes('_max') ? { _max: commonAggSchema() } : {}), + }; + if (supportedOperators) { + const keys = [...supportedOperators, ...(withAggregations ?? [])]; + result = extractFields(result, keys) as typeof result; + } + + // Filter operators based on allowed filter kinds + result = this.trimFilterOperators(result, allowedFilterKinds) as typeof result; + + return result; + } + + private makeCommonPrimitiveFilterSchema( + baseSchema: ZodType, + optional: boolean, + makeThis: () => ZodType, + withAggregations: Array | undefined = undefined, + allowedFilterKinds: string[] | undefined = undefined, + ): ZodType { + const components = this.makeCommonPrimitiveFilterComponents( + baseSchema, + optional, + makeThis, + undefined, + withAggregations, + allowedFilterKinds, + ); + + return this.createUnionFilterSchema(baseSchema, optional, components, allowedFilterKinds); + } + + private makeNumberFilterSchema( + baseSchema: ZodType, + optional: boolean, + withAggregations: boolean, + allowedFilterKinds: string[] | undefined, + ): ZodType { + return this.makeCommonPrimitiveFilterSchema( + baseSchema, + optional, + () => z.lazy(() => this.makeNumberFilterSchema(baseSchema, optional, withAggregations, allowedFilterKinds)), + withAggregations ? ['_count', '_avg', '_sum', '_min', '_max'] : undefined, + allowedFilterKinds, + ); + } + + private makeStringFilterSchema( + optional: boolean, + withAggregations: boolean, + allowedFilterKinds: string[] | undefined, + ): ZodType { + const baseComponents = this.makeCommonPrimitiveFilterComponents( + z.string(), + optional, + () => z.lazy(() => this.makeStringFilterSchema(optional, withAggregations, allowedFilterKinds)), + undefined, + withAggregations ? ['_count', '_min', '_max'] : undefined, + allowedFilterKinds, + ); + + const stringSpecificOperators = { + startsWith: z.string().optional(), + endsWith: z.string().optional(), + contains: z.string().optional(), + ...(this.providerSupportsCaseSensitivity + ? { + mode: this.makeStringModeSchema().optional(), + } + : {}), + }; + + // Filter string-specific operators based on allowed filter kinds + const filteredStringOperators = this.trimFilterOperators(stringSpecificOperators, allowedFilterKinds); + + const allComponents = { + ...baseComponents, + ...filteredStringOperators, + }; + + return this.createUnionFilterSchema(z.string(), optional, allComponents, allowedFilterKinds); + } + + private makeStringModeSchema() { + return z.union([z.literal('default'), z.literal('insensitive')]); + } + + @cache() + private makeSelectSchema(model: string) { + const modelDef = requireModel(this.schema, model); + const fields: Record = {}; + for (const field of Object.keys(modelDef.fields)) { + const fieldDef = requireField(this.schema, model, field); + if (fieldDef.relation) { + // Check if the target model is allowed by slicing configuration + if (this.isModelAllowed(fieldDef.type)) { + fields[field] = this.makeRelationSelectIncludeSchema(model, field).optional(); + } + } else { + fields[field] = z.boolean().optional(); + } + } + + const _countSchema = this.makeCountSelectionSchema(model); + if (!(_countSchema instanceof z.ZodNever)) { + fields['_count'] = _countSchema; + } + + return z.strictObject(fields); + } + + @cache() + private makeCountSelectionSchema(model: string) { + const modelDef = requireModel(this.schema, model); + const toManyRelations = Object.values(modelDef.fields).filter((def) => def.relation && def.array); + if (toManyRelations.length > 0) { + return z + .union([ + z.literal(true), + z.strictObject({ + select: z.strictObject( + toManyRelations.reduce( + (acc, fieldDef) => ({ + ...acc, + [fieldDef.name]: z + .union([ + z.boolean(), + z.strictObject({ + where: this.makeWhereSchema(fieldDef.type, false, false), + }), + ]) + .optional(), + }), + {} as Record, + ), + ), + }), + ]) + .optional(); + } else { + return z.never(); + } + } + + @cache() + private makeRelationSelectIncludeSchema(model: string, field: string) { + const fieldDef = requireField(this.schema, model, field); + let objSchema: ZodType = z.strictObject({ + ...(fieldDef.array || fieldDef.optional + ? { + // to-many relations and optional to-one relations are filterable + where: z.lazy(() => this.makeWhereSchema(fieldDef.type, false)).optional(), + } + : {}), + select: z + .lazy(() => this.makeSelectSchema(fieldDef.type)) + .optional() + .nullable(), + include: z + .lazy(() => this.makeIncludeSchema(fieldDef.type)) + .optional() + .nullable(), + omit: z + .lazy(() => this.makeOmitSchema(fieldDef.type)) + .optional() + .nullable(), + ...(fieldDef.array + ? { + // to-many relations can be ordered, skipped, taken, and cursor-located + orderBy: z + .lazy(() => this.orArray(this.makeOrderBySchema(fieldDef.type, true, false), true)) + .optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), + cursor: this.makeCursorSchema(fieldDef.type).optional(), + distinct: this.makeDistinctSchema(fieldDef.type).optional(), + } + : {}), + }); + + objSchema = this.refineForSelectIncludeMutuallyExclusive(objSchema); + objSchema = this.refineForSelectOmitMutuallyExclusive(objSchema); + + return z.union([z.boolean(), objSchema]); + } + + @cache() + private makeOmitSchema(model: string) { + const modelDef = requireModel(this.schema, model); + const fields: Record = {}; + for (const field of Object.keys(modelDef.fields)) { + const fieldDef = requireField(this.schema, model, field); + if (!fieldDef.relation) { + if (this.options.allowQueryTimeOmitOverride !== false) { + // if override is allowed, use boolean + fields[field] = z.boolean().optional(); + } else { + // otherwise only allow true + fields[field] = z.literal(true).optional(); + } + } + } + return z.strictObject(fields); + } + + @cache() + private makeIncludeSchema(model: string) { + const modelDef = requireModel(this.schema, model); + const fields: Record = {}; + for (const field of Object.keys(modelDef.fields)) { + const fieldDef = requireField(this.schema, model, field); + if (fieldDef.relation) { + // Check if the target model is allowed by slicing configuration + if (this.isModelAllowed(fieldDef.type)) { + fields[field] = this.makeRelationSelectIncludeSchema(model, field).optional(); + } + } + } + + const _countSchema = this.makeCountSelectionSchema(model); + if (!(_countSchema instanceof z.ZodNever)) { + fields['_count'] = _countSchema; + } + + return z.strictObject(fields); + } + + @cache() + private makeOrderBySchema(model: string, withRelation: boolean, WithAggregation: boolean) { + const modelDef = requireModel(this.schema, model); + const fields: Record = {}; + const sort = z.union([z.literal('asc'), z.literal('desc')]); + for (const field of Object.keys(modelDef.fields)) { + const fieldDef = requireField(this.schema, model, field); + if (fieldDef.relation) { + // relations + if (withRelation) { + fields[field] = z.lazy(() => { + let relationOrderBy = this.makeOrderBySchema(fieldDef.type, withRelation, WithAggregation); + if (fieldDef.array) { + relationOrderBy = relationOrderBy.extend({ + _count: sort, + }); + } + return relationOrderBy.optional(); + }); + } + } else { + // scalars + if (fieldDef.optional) { + fields[field] = z + .union([ + sort, + z.strictObject({ + sort, + nulls: z.union([z.literal('first'), z.literal('last')]), + }), + ]) + .optional(); + } else { + fields[field] = sort.optional(); + } + } + } + + // aggregations + if (WithAggregation) { + const aggregationFields = ['_count', '_avg', '_sum', '_min', '_max']; + for (const agg of aggregationFields) { + fields[agg] = z.lazy(() => this.makeOrderBySchema(model, true, false).optional()); + } + } + + return z.strictObject(fields); + } + + @cache() + private makeDistinctSchema(model: string) { + const modelDef = requireModel(this.schema, model); + const nonRelationFields = Object.keys(modelDef.fields).filter((field) => !modelDef.fields[field]?.relation); + return nonRelationFields.length > 0 ? this.orArray(z.enum(nonRelationFields as any), true) : z.never(); + } + + private makeCursorSchema(model: string) { + // `makeWhereSchema` is already cached + return this.makeWhereSchema(model, true, true).optional(); + } + + // #endregion + + // #region Create + + @cache() + makeCreateSchema>( + model: Model, + ): ZodType> { + const dataSchema = this.makeCreateDataSchema(model, false); + const baseSchema = z.strictObject({ + data: dataSchema, + select: this.makeSelectSchema(model).optional().nullable(), + include: this.makeIncludeSchema(model).optional().nullable(), + omit: this.makeOmitSchema(model).optional().nullable(), + }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'create'); + schema = this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema as ZodType>; + } + + @cache() + makeCreateManySchema>( + model: Model, + ): ZodType> { + return this.mergePluginArgsSchema( + this.makeCreateManyPayloadSchema(model, []), + 'createMany', + ) as unknown as ZodType>; + } + + @cache() + makeCreateManyAndReturnSchema>( + model: Model, + ): ZodType> { + const base = this.makeCreateManyPayloadSchema(model, []); + let result: ZodObject = base.extend({ + select: this.makeSelectSchema(model).optional().nullable(), + omit: this.makeOmitSchema(model).optional().nullable(), + }); + result = this.mergePluginArgsSchema(result, 'createManyAndReturn'); + return this.refineForSelectOmitMutuallyExclusive(result).optional() as ZodType< + CreateManyAndReturnArgs + >; + } + + @cache() + private makeCreateDataSchema( + model: string, + canBeArray: boolean, + withoutFields: string[] = [], + withoutRelationFields = false, + ) { + const uncheckedVariantFields: Record = {}; + const checkedVariantFields: Record = {}; + const modelDef = requireModel(this.schema, model); + const hasRelation = + !withoutRelationFields && + Object.entries(modelDef.fields).some(([f, def]) => !withoutFields.includes(f) && def.relation); + + Object.keys(modelDef.fields).forEach((field) => { + if (withoutFields.includes(field)) { + return; + } + const fieldDef = requireField(this.schema, model, field); + if (fieldDef.computed) { + return; + } + + if (this.isDelegateDiscriminator(fieldDef)) { + // discriminator field is auto-assigned + return; + } + + if (fieldDef.relation) { + if (withoutRelationFields) { + return; + } + // Check if the target model is allowed by slicing configuration + if (!this.isModelAllowed(fieldDef.type)) { + return; + } + const excludeFields: string[] = []; + const oppositeField = fieldDef.relation.opposite; + if (oppositeField) { + excludeFields.push(oppositeField); + const oppositeFieldDef = requireField(this.schema, fieldDef.type, oppositeField); + if (oppositeFieldDef.relation?.fields) { + excludeFields.push(...oppositeFieldDef.relation.fields); + } + } + + let fieldSchema: ZodType = z.lazy(() => + this.makeRelationManipulationSchema(model, field, excludeFields, 'create'), + ); + + if (fieldDef.optional || fieldDef.array) { + // optional or array relations are optional + fieldSchema = fieldSchema.optional(); + } else { + // if all fk fields are optional, the relation is optional + let allFksOptional = false; + if (fieldDef.relation.fields) { + allFksOptional = fieldDef.relation.fields.every((f) => { + const fkDef = requireField(this.schema, model, f); + return fkDef.optional || fieldHasDefaultValue(fkDef); + }); + } + if (allFksOptional) { + fieldSchema = fieldSchema.optional(); + } + } + + // optional to-one relation can be null + if (fieldDef.optional && !fieldDef.array) { + fieldSchema = fieldSchema.nullable(); + } + checkedVariantFields[field] = fieldSchema; + if (fieldDef.array || !fieldDef.relation.references) { + // non-owned relation + uncheckedVariantFields[field] = fieldSchema; + } + } else { + let fieldSchema = this.makeScalarSchema(fieldDef.type, fieldDef.attributes); + + if (fieldDef.array) { + fieldSchema = ZodUtils.addListValidation(fieldSchema.array(), fieldDef.attributes); + fieldSchema = z + .union([ + fieldSchema, + z.strictObject({ + set: fieldSchema, + }), + ]) + .optional(); + } + + if (fieldDef.optional || fieldHasDefaultValue(fieldDef)) { + fieldSchema = fieldSchema.optional(); + } + + if (fieldDef.optional) { + if (fieldDef.type === 'Json') { + // DbNull for Json fields + fieldSchema = z.union([fieldSchema, z.instanceof(DbNullClass)]); + } else { + fieldSchema = fieldSchema.nullable(); + } + } + + uncheckedVariantFields[field] = fieldSchema; + if (!fieldDef.foreignKeyFor) { + // non-fk field + checkedVariantFields[field] = fieldSchema; + } + } + }); + + const uncheckedCreateSchema = this.extraValidationsEnabled + ? ZodUtils.addCustomValidation(z.strictObject(uncheckedVariantFields), modelDef.attributes) + : z.strictObject(uncheckedVariantFields); + const checkedCreateSchema = this.extraValidationsEnabled + ? ZodUtils.addCustomValidation(z.strictObject(checkedVariantFields), modelDef.attributes) + : z.strictObject(checkedVariantFields); + + if (!hasRelation) { + return this.orArray(uncheckedCreateSchema, canBeArray); + } else { + return z.union([ + uncheckedCreateSchema, + checkedCreateSchema, + ...(canBeArray ? [z.array(uncheckedCreateSchema)] : []), + ...(canBeArray ? [z.array(checkedCreateSchema)] : []), + ]); + } + } + + private isDelegateDiscriminator(fieldDef: FieldDef) { + if (!fieldDef.originModel) { + // not inherited from a delegate + return false; + } + const discriminatorField = getDiscriminatorField(this.schema, fieldDef.originModel); + return discriminatorField === fieldDef.name; + } + + @cache() + private makeRelationManipulationSchema( + model: string, + field: string, + withoutFields: string[], + mode: 'create' | 'update', + ) { + const fieldDef = requireField(this.schema, model, field); + const fieldType = fieldDef.type; + const array = !!fieldDef.array; + const fields: Record = { + create: this.makeCreateDataSchema(fieldDef.type, !!fieldDef.array, withoutFields).optional(), + + connect: this.makeConnectDataSchema(fieldType, array).optional(), + + connectOrCreate: this.makeConnectOrCreateDataSchema(fieldType, array, withoutFields).optional(), + }; + + if (array) { + fields['createMany'] = this.makeCreateManyPayloadSchema(fieldType, withoutFields).optional(); + } + + if (mode === 'update') { + if (fieldDef.optional || fieldDef.array) { + // disconnect and delete are only available for optional/to-many relations + fields['disconnect'] = this.makeDisconnectDataSchema(fieldType, array).optional(); + + fields['delete'] = this.makeDeleteRelationDataSchema(fieldType, array, true).optional(); + } + + fields['update'] = array + ? this.orArray( + z.strictObject({ + where: this.makeWhereSchema(fieldType, true), + data: this.makeUpdateDataSchema(fieldType, withoutFields), + }), + true, + ).optional() + : z + .union([ + z.strictObject({ + where: this.makeWhereSchema(fieldType, false).optional(), + data: this.makeUpdateDataSchema(fieldType, withoutFields), + }), + this.makeUpdateDataSchema(fieldType, withoutFields), + ]) + .optional(); + + let upsertWhere = this.makeWhereSchema(fieldType, true); + if (!fieldDef.array) { + // to-one relation, can upsert without where clause + upsertWhere = upsertWhere.optional(); + } + fields['upsert'] = this.orArray( + z.strictObject({ + where: upsertWhere, + create: this.makeCreateDataSchema(fieldType, false, withoutFields), + update: this.makeUpdateDataSchema(fieldType, withoutFields), + }), + true, + ).optional(); + + if (array) { + // to-many relation specifics + fields['set'] = this.makeSetDataSchema(fieldType, true).optional(); + + fields['updateMany'] = this.orArray( + z.strictObject({ + where: this.makeWhereSchema(fieldType, false, true), + data: this.makeUpdateDataSchema(fieldType, withoutFields), + }), + true, + ).optional(); + + fields['deleteMany'] = this.makeDeleteRelationDataSchema(fieldType, true, false).optional(); + } + } + + return z.strictObject(fields); + } + + @cache() + private makeSetDataSchema(model: string, canBeArray: boolean) { + return this.orArray(this.makeWhereSchema(model, true), canBeArray); + } + + @cache() + private makeConnectDataSchema(model: string, canBeArray: boolean) { + return this.orArray(this.makeWhereSchema(model, true), canBeArray); + } + + @cache() + private makeDisconnectDataSchema(model: string, canBeArray: boolean) { + if (canBeArray) { + // to-many relation, must be unique filters + return this.orArray(this.makeWhereSchema(model, true), canBeArray); + } else { + // to-one relation, can be boolean or a regular filter - the entity + // being disconnected is already uniquely identified by its parent + return z.union([z.boolean(), this.makeWhereSchema(model, false)]); + } + } + + @cache() + private makeDeleteRelationDataSchema(model: string, toManyRelation: boolean, uniqueFilter: boolean) { + return toManyRelation + ? this.orArray(this.makeWhereSchema(model, uniqueFilter), true) + : z.union([z.boolean(), this.makeWhereSchema(model, uniqueFilter)]); + } + + @cache() + private makeConnectOrCreateDataSchema(model: string, canBeArray: boolean, withoutFields: string[]) { + const whereSchema = this.makeWhereSchema(model, true); + const createSchema = this.makeCreateDataSchema(model, false, withoutFields); + return this.orArray( + z.strictObject({ + where: whereSchema, + create: createSchema, + }), + canBeArray, + ); + } + + @cache() + private makeCreateManyPayloadSchema(model: string, withoutFields: string[]) { + return z.strictObject({ + data: this.makeCreateDataSchema(model, true, withoutFields, true), + skipDuplicates: z.boolean().optional(), + }); + } + + // #endregion + + // #region Update + + @cache() + makeUpdateSchema>( + model: Model, + ): ZodType> { + const baseSchema = z.strictObject({ + where: this.makeWhereSchema(model, true), + data: this.makeUpdateDataSchema(model), + select: this.makeSelectSchema(model).optional().nullable(), + include: this.makeIncludeSchema(model).optional().nullable(), + omit: this.makeOmitSchema(model).optional().nullable(), + }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'update'); + schema = this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema as ZodType>; + } + + @cache() + makeUpdateManySchema>( + model: Model, + ): ZodType> { + return this.mergePluginArgsSchema( + z.strictObject({ + where: this.makeWhereSchema(model, false).optional(), + data: this.makeUpdateDataSchema(model, [], true), + limit: z.number().int().nonnegative().optional(), + }), + 'updateMany', + ) as unknown as ZodType>; + } + + @cache() + makeUpdateManyAndReturnSchema>( + model: Model, + ): ZodType> { + // plugin extended args schema is merged in `makeUpdateManySchema` + const baseSchema = this.makeUpdateManySchema(model) as unknown as ZodObject; + let schema: ZodType = baseSchema.extend({ + select: this.makeSelectSchema(model).optional().nullable(), + omit: this.makeOmitSchema(model).optional().nullable(), + }); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema as ZodType>; + } + + @cache() + makeUpsertSchema>( + model: Model, + ): ZodType> { + const baseSchema = z.strictObject({ + where: this.makeWhereSchema(model, true), + create: this.makeCreateDataSchema(model, false), + update: this.makeUpdateDataSchema(model), + select: this.makeSelectSchema(model).optional().nullable(), + include: this.makeIncludeSchema(model).optional().nullable(), + omit: this.makeOmitSchema(model).optional().nullable(), + }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'upsert'); + schema = this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema as ZodType>; + } + + @cache() + private makeUpdateDataSchema(model: string, withoutFields: string[] = [], withoutRelationFields = false) { + const uncheckedVariantFields: Record = {}; + const checkedVariantFields: Record = {}; + const modelDef = requireModel(this.schema, model); + const hasRelation = Object.entries(modelDef.fields).some( + ([key, value]) => value.relation && !withoutFields.includes(key), + ); + + Object.keys(modelDef.fields).forEach((field) => { + if (withoutFields.includes(field)) { + return; + } + const fieldDef = requireField(this.schema, model, field); + + if (fieldDef.relation) { + if (withoutRelationFields) { + return; + } + // Check if the target model is allowed by slicing configuration + if (!this.isModelAllowed(fieldDef.type)) { + return; + } + const excludeFields: string[] = []; + const oppositeField = fieldDef.relation.opposite; + if (oppositeField) { + excludeFields.push(oppositeField); + const oppositeFieldDef = requireField(this.schema, fieldDef.type, oppositeField); + if (oppositeFieldDef.relation?.fields) { + excludeFields.push(...oppositeFieldDef.relation.fields); + } + } + let fieldSchema: ZodType = z + .lazy(() => this.makeRelationManipulationSchema(model, field, excludeFields, 'update')) + .optional(); + // optional to-one relation can be null + if (fieldDef.optional && !fieldDef.array) { + fieldSchema = fieldSchema.nullable(); + } + checkedVariantFields[field] = fieldSchema; + if (fieldDef.array || !fieldDef.relation.references) { + // non-owned relation + uncheckedVariantFields[field] = fieldSchema; + } + } else { + let fieldSchema = this.makeScalarSchema(fieldDef.type, fieldDef.attributes); + + if (this.isNumericField(fieldDef)) { + fieldSchema = z.union([ + fieldSchema, + z + .object({ + // TODO: use Decimal/BigInt for incremental updates + set: this.nullableIf(z.number().optional(), !!fieldDef.optional).optional(), + increment: z.number().optional(), + decrement: z.number().optional(), + multiply: z.number().optional(), + divide: z.number().optional(), + }) + .refine( + (v) => Object.keys(v).length === 1, + 'Only one of "set", "increment", "decrement", "multiply", or "divide" can be provided', + ), + ]); + } + + if (fieldDef.array) { + const arraySchema = ZodUtils.addListValidation(fieldSchema.array(), fieldDef.attributes); + fieldSchema = z.union([ + arraySchema, + z + .object({ + set: arraySchema.optional(), + push: z.union([fieldSchema, fieldSchema.array()]).optional(), + }) + .refine((v) => Object.keys(v).length === 1, 'Only one of "set", "push" can be provided'), + ]); + } + + if (fieldDef.optional) { + if (fieldDef.type === 'Json') { + // DbNull for Json fields + fieldSchema = z.union([fieldSchema, z.instanceof(DbNullClass)]); + } else { + fieldSchema = fieldSchema.nullable(); + } + } + + // all fields are optional in update + fieldSchema = fieldSchema.optional(); + + uncheckedVariantFields[field] = fieldSchema; + if (!fieldDef.foreignKeyFor) { + // non-fk field + checkedVariantFields[field] = fieldSchema; + } + } + }); + + const uncheckedUpdateSchema = this.extraValidationsEnabled + ? ZodUtils.addCustomValidation(z.strictObject(uncheckedVariantFields), modelDef.attributes) + : z.strictObject(uncheckedVariantFields); + const checkedUpdateSchema = this.extraValidationsEnabled + ? ZodUtils.addCustomValidation(z.strictObject(checkedVariantFields), modelDef.attributes) + : z.strictObject(checkedVariantFields); + if (!hasRelation) { + return uncheckedUpdateSchema; + } else { + return z.union([uncheckedUpdateSchema, checkedUpdateSchema]); + } + } + + // #endregion + + // #region Delete + + @cache() + makeDeleteSchema>( + model: Model, + ): ZodType> { + const baseSchema = z.strictObject({ + where: this.makeWhereSchema(model, true), + select: this.makeSelectSchema(model).optional().nullable(), + include: this.makeIncludeSchema(model).optional().nullable(), + omit: this.makeOmitSchema(model).optional().nullable(), + }); + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'delete'); + schema = this.refineForSelectIncludeMutuallyExclusive(schema); + schema = this.refineForSelectOmitMutuallyExclusive(schema); + return schema as ZodType>; + } + + @cache() + makeDeleteManySchema>( + model: Model, + ): ZodType | undefined> { + return this.mergePluginArgsSchema( + z.strictObject({ + where: this.makeWhereSchema(model, false).optional(), + limit: z.number().int().nonnegative().optional(), + }), + 'deleteMany', + ).optional() as unknown as ZodType | undefined>; + } + + // #endregion + + // #region Count + + @cache() + makeCountSchema>( + model: Model, + ): ZodType | undefined> { + return this.mergePluginArgsSchema( + z.strictObject({ + where: this.makeWhereSchema(model, false).optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), + orderBy: this.orArray(this.makeOrderBySchema(model, true, false), true).optional(), + select: this.makeCountAggregateInputSchema(model).optional(), + }), + 'count', + ).optional() as ZodType | undefined>; + } + + @cache() + private makeCountAggregateInputSchema(model: string) { + const modelDef = requireModel(this.schema, model); + return z.union([ + z.literal(true), + z.strictObject({ + _all: z.literal(true).optional(), + ...Object.keys(modelDef.fields).reduce( + (acc, field) => { + acc[field] = z.literal(true).optional(); + return acc; + }, + {} as Record, + ), + }), + ]); + } + + // #endregion + + // #region Aggregate + + @cache() + makeAggregateSchema>( + model: Model, + ): ZodType | undefined> { + return this.mergePluginArgsSchema( + z.strictObject({ + where: this.makeWhereSchema(model, false).optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), + orderBy: this.orArray(this.makeOrderBySchema(model, true, false), true).optional(), + _count: this.makeCountAggregateInputSchema(model).optional(), + _avg: this.makeSumAvgInputSchema(model).optional(), + _sum: this.makeSumAvgInputSchema(model).optional(), + _min: this.makeMinMaxInputSchema(model).optional(), + _max: this.makeMinMaxInputSchema(model).optional(), + }), + 'aggregate', + ).optional() as ZodType | undefined>; + } + + @cache() + private makeSumAvgInputSchema(model: string) { + const modelDef = requireModel(this.schema, model); + return z.strictObject( + Object.keys(modelDef.fields).reduce( + (acc, field) => { + const fieldDef = requireField(this.schema, model, field); + if (this.isNumericField(fieldDef)) { + acc[field] = z.literal(true).optional(); + } + return acc; + }, + {} as Record, + ), + ); + } + + @cache() + private makeMinMaxInputSchema(model: string) { + const modelDef = requireModel(this.schema, model); + return z.strictObject( + Object.keys(modelDef.fields).reduce( + (acc, field) => { + const fieldDef = requireField(this.schema, model, field); + if (!fieldDef.relation && !fieldDef.array) { + acc[field] = z.literal(true).optional(); + } + return acc; + }, + {} as Record, + ), + ); + } + + // #endregion + + // #region Group By + + @cache() + makeGroupBySchema>( + model: Model, + ): ZodType> { + const modelDef = requireModel(this.schema, model); + const nonRelationFields = Object.keys(modelDef.fields).filter((field) => !modelDef.fields[field]?.relation); + const bySchema = + nonRelationFields.length > 0 + ? this.orArray(z.enum(nonRelationFields as [string, ...string[]]), true) + : z.never(); + + const baseSchema = z.strictObject({ + where: this.makeWhereSchema(model, false).optional(), + orderBy: this.orArray(this.makeOrderBySchema(model, false, true), true).optional(), + by: bySchema, + having: this.makeHavingSchema(model).optional(), + skip: this.makeSkipSchema().optional(), + take: this.makeTakeSchema().optional(), + _count: this.makeCountAggregateInputSchema(model).optional(), + _avg: this.makeSumAvgInputSchema(model).optional(), + _sum: this.makeSumAvgInputSchema(model).optional(), + _min: this.makeMinMaxInputSchema(model).optional(), + _max: this.makeMinMaxInputSchema(model).optional(), + }); + + let schema: ZodType = this.mergePluginArgsSchema(baseSchema, 'groupBy'); + + // fields used in `having` must be either in the `by` list, or aggregations + schema = schema.refine((value: any) => { + const bys = enumerate(value.by); + if (value.having && typeof value.having === 'object') { + for (const [key, val] of Object.entries(value.having)) { + if (AggregateOperators.includes(key as any)) { + continue; + } + if (bys.includes(key)) { + continue; + } + // we have a key not mentioned in `by`, in this case it must only use + // aggregations in the condition + + // 1. payload must be an object + if (!val || typeof val !== 'object') { + return false; + } + // 2. payload must only contain aggregations + if (!this.onlyAggregationFields(val)) { + return false; + } + } + } + return true; + }, 'fields in "having" must be in "by"'); + + // fields used in `orderBy` must be either in the `by` list, or aggregations + schema = schema.refine((value: any) => { + const bys = enumerate(value.by); + for (const orderBy of enumerate(value.orderBy)) { + if ( + orderBy && + Object.keys(orderBy) + .filter((f) => !AggregateOperators.includes(f as AggregateOperators)) + .some((key) => !bys.includes(key)) + ) { + return false; + } + } + return true; + }, 'fields in "orderBy" must be in "by"'); + + return schema as ZodType>; + } + + private onlyAggregationFields(val: object) { + for (const [key, value] of Object.entries(val)) { + if (AggregateOperators.includes(key as any)) { + // aggregation field + continue; + } + if (LOGICAL_COMBINATORS.includes(key as any)) { + // logical operators + if (enumerate(value).every((v) => this.onlyAggregationFields(v))) { + continue; + } + } + return false; + } + return true; + } + + private makeHavingSchema(model: string) { + // `makeWhereSchema` is cached + return this.makeWhereSchema(model, false, true, true); + } + + // #endregion + + // #region Procedures + + @cache() + makeProcedureParamSchema(param: { type: string; array?: boolean; optional?: boolean }): ZodType { + let schema: ZodType; + + if (isTypeDef(this.schema, param.type)) { + schema = this.makeTypeDefSchema(param.type); + } else if (isEnum(this.schema, param.type)) { + schema = this.makeEnumSchema(param.type); + } else if (param.type in (this.schema.models ?? {})) { + // For model-typed values, accept any object (no deep shape validation). + schema = z.record(z.string(), z.unknown()); + } else { + // Builtin scalar types. + schema = this.makeScalarSchema(param.type as BuiltinType); + + // If a type isn't recognized by any of the above branches, `makeScalarSchema` returns `unknown`. + // Treat it as configuration/schema error. + if (schema instanceof z.ZodUnknown) { + throw createInternalError(`Unsupported procedure parameter type: ${param.type}`); + } + } + + if (param.array) { + schema = schema.array(); + } + if (param.optional) { + schema = schema.optional(); + } + + return schema; + } + + // #endregion + + // #region Plugin Args + + private mergePluginArgsSchema(schema: ZodObject, operation: CoreCrudOperations) { + let result = schema; + for (const plugin of this.plugins ?? []) { + if (plugin.queryArgs) { + const pluginSchema = this.getPluginExtQueryArgsSchema(plugin, operation); + if (pluginSchema) { + result = result.extend(pluginSchema.shape); + } + } + } + return result.strict(); + } + + private getPluginExtQueryArgsSchema(plugin: AnyPlugin, operation: string): ZodObject | undefined { + if (!plugin.queryArgs) { + return undefined; + } + + let result: ZodType | undefined; + + if (operation in plugin.queryArgs && plugin.queryArgs[operation]) { + // most specific operation takes highest precedence + result = plugin.queryArgs[operation]; + } else if (operation === 'upsert') { + // upsert is special: it's in both CoreCreateOperations and CoreUpdateOperations + // so we need to merge both $create and $update schemas to match the type system + const createSchema = + '$create' in plugin.queryArgs && plugin.queryArgs['$create'] ? plugin.queryArgs['$create'] : undefined; + const updateSchema = + '$update' in plugin.queryArgs && plugin.queryArgs['$update'] ? plugin.queryArgs['$update'] : undefined; + + if (createSchema && updateSchema) { + invariant(createSchema instanceof ZodObject, 'Plugin extended query args schema must be a Zod object'); + invariant(updateSchema instanceof ZodObject, 'Plugin extended query args schema must be a Zod object'); + // merge both schemas (combines their properties) + result = createSchema.extend(updateSchema.shape); + } else if (createSchema) { + result = createSchema; + } else if (updateSchema) { + result = updateSchema; + } + } else if ( + // then comes grouped operations: $create, $read, $update, $delete + CoreCreateOperations.includes(operation as CoreCreateOperations) && + '$create' in plugin.queryArgs && + plugin.queryArgs['$create'] + ) { + result = plugin.queryArgs['$create']; + } else if ( + CoreReadOperations.includes(operation as CoreReadOperations) && + '$read' in plugin.queryArgs && + plugin.queryArgs['$read'] + ) { + result = plugin.queryArgs['$read']; + } else if ( + CoreUpdateOperations.includes(operation as CoreUpdateOperations) && + '$update' in plugin.queryArgs && + plugin.queryArgs['$update'] + ) { + result = plugin.queryArgs['$update']; + } else if ( + CoreDeleteOperations.includes(operation as CoreDeleteOperations) && + '$delete' in plugin.queryArgs && + plugin.queryArgs['$delete'] + ) { + result = plugin.queryArgs['$delete']; + } else if ('$all' in plugin.queryArgs && plugin.queryArgs['$all']) { + // finally comes $all + result = plugin.queryArgs['$all']; + } + + invariant( + result === undefined || result instanceof ZodObject, + 'Plugin extended query args schema must be a Zod object', + ); + return result; + } + + // #endregion + + // #region Helpers + + @cache() + private makeSkipSchema() { + return z.number().int().nonnegative(); + } + + @cache() + private makeTakeSchema() { + return z.number().int(); + } + + private refineForSelectIncludeMutuallyExclusive(schema: ZodType) { + return schema.refine( + (value: any) => !(value['select'] && value['include']), + '"select" and "include" cannot be used together', + ); + } + + private refineForSelectOmitMutuallyExclusive(schema: ZodType) { + return schema.refine( + (value: any) => !(value['select'] && value['omit']), + '"select" and "omit" cannot be used together', + ); + } + + private nullableIf(schema: ZodType, nullable: boolean) { + return nullable ? schema.nullable() : schema; + } + + private orArray(schema: T, canBeArray: boolean) { + return canBeArray ? z.union([schema, z.array(schema)]) : schema; + } + + private isNumericField(fieldDef: FieldDef) { + return NUMERIC_FIELD_TYPES.includes(fieldDef.type) && !fieldDef.array; + } + + private get providerSupportsCaseSensitivity() { + return this.schema.provider.type === 'postgresql'; + } + + /** + * Gets the effective set of allowed FilterKind values for a specific model and field. + * Respects the precedence: model[field] > model.$all > $all[field] > $all.$all. + */ + private getEffectiveFilterKinds(model: string | undefined, field: string): string[] | undefined { + if (!model) { + // no restrictions + return undefined; + } + + const slicing = this.options.slicing; + if (!slicing?.models) { + // no slicing or no model-specific slicing, no restrictions + return undefined; + } + + // A string-indexed view of slicing.models that avoids unsafe 'as any' while still + // allowing runtime access by model name. The value shape matches FieldSlicingOptions. + type FieldConfig = { includedFilterKinds?: readonly string[]; excludedFilterKinds?: readonly string[] }; + type FieldsRecord = { $all?: FieldConfig } & Record; + type ModelConfig = { fields?: FieldsRecord }; + const modelsRecord = slicing.models as Record; + + // Check field-level settings for the specific model + const modelConfig = modelsRecord[lowerCaseFirst(model)]; + if (modelConfig?.fields) { + const fieldConfig = modelConfig.fields[field]; + if (fieldConfig) { + return this.computeFilterKinds(fieldConfig.includedFilterKinds, fieldConfig.excludedFilterKinds); + } + + // Fallback to field-level $all for the specific model + const allFieldsConfig = modelConfig.fields['$all']; + if (allFieldsConfig) { + return this.computeFilterKinds( + allFieldsConfig.includedFilterKinds, + allFieldsConfig.excludedFilterKinds, + ); + } + } + + // Fallback to model-level $all + const allModelsConfig = modelsRecord['$all']; + if (allModelsConfig?.fields) { + // Check specific field in $all model config before falling back to $all.$all + const allModelsFieldConfig = allModelsConfig.fields[field]; + if (allModelsFieldConfig) { + return this.computeFilterKinds( + allModelsFieldConfig.includedFilterKinds, + allModelsFieldConfig.excludedFilterKinds, + ); + } + + // Fallback to $all.$all + const allModelsAllFieldsConfig = allModelsConfig.fields['$all']; + if (allModelsAllFieldsConfig) { + return this.computeFilterKinds( + allModelsAllFieldsConfig.includedFilterKinds, + allModelsAllFieldsConfig.excludedFilterKinds, + ); + } + } + + return undefined; // No restrictions + } + + /** + * Computes the effective set of filter kinds based on inclusion and exclusion lists. + */ + private computeFilterKinds(included: readonly string[] | undefined, excluded: readonly string[] | undefined) { + let result: string[] | undefined; + + if (included !== undefined) { + // Start with the included set + result = [...included]; + } + + if (excluded !== undefined) { + if (!result) { + // If no inclusion list, start with all filter kinds + result = [...this.allFilterKinds]; + } + // Remove excluded kinds + for (const kind of excluded) { + result = result.filter((k) => k !== kind); + } + } + + return result; + } + + /** + * Filters operators based on allowed filter kinds. + */ + private trimFilterOperators>( + operators: T, + allowedKinds: string[] | undefined, + ): Partial { + if (!allowedKinds) { + return operators; // No restrictions + } + + return Object.fromEntries( + Object.entries(operators).filter(([key, _]) => { + return ( + !(key in FILTER_PROPERTY_TO_KIND) || + allowedKinds.includes(FILTER_PROPERTY_TO_KIND[key as keyof typeof FILTER_PROPERTY_TO_KIND]) + ); + }), + ) as Partial; + } + + private createUnionFilterSchema( + valueSchema: ZodType, + optional: boolean, + components: Record, + allowedFilterKinds: string[] | undefined, + ) { + // If all filter operators are excluded + if (Object.keys(components).length === 0) { + // if equality filters are allowed, allow direct value + if (!allowedFilterKinds || allowedFilterKinds.includes('Equality')) { + return this.nullableIf(valueSchema, optional); + } + // otherwise nothing is allowed + return z.never(); + } + + if (!allowedFilterKinds || allowedFilterKinds.includes('Equality')) { + // direct value or filter operators + return z.union([this.nullableIf(valueSchema, optional), z.strictObject(components)]); + } else { + // filter operators + return z.strictObject(components); + } + } + + /** + * Checks if a model is included in the slicing configuration. + * Returns true if the model is allowed, false if it's excluded. + */ + private isModelAllowed(targetModel: string): boolean { + const slicing = this.options.slicing; + if (!slicing) { + return true; // No slicing, all models allowed + } + + const { includedModels, excludedModels } = slicing; + + // If includedModels is specified, only those models are allowed + if (includedModels !== undefined) { + if (!includedModels.includes(targetModel as any)) { + return false; + } + } + + // If excludedModels is specified, those models are not allowed + if (excludedModels !== undefined) { + if (excludedModels.includes(targetModel as any)) { + return false; + } + } + + return true; + } + + // #endregion +} + +export function createSchemaFactory>( + client: Client, +): Client extends ClientContract + ? ZodSchemaFactory + : never { + return new ZodSchemaFactory(client) as any; +} diff --git a/packages/orm/src/client/zod/index.ts b/packages/orm/src/client/zod/index.ts new file mode 100644 index 000000000..63e947e57 --- /dev/null +++ b/packages/orm/src/client/zod/index.ts @@ -0,0 +1 @@ +export { createQuerySchemaFactory } from './factory'; diff --git a/packages/schema/package.json b/packages/schema/package.json index 0af386bd7..2b8b09bec 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -7,7 +7,9 @@ "build": "tsc --noEmit && tsup-node", "watch": "tsup-node --watch", "lint": "eslint src --ext ts", - "pack": "pnpm pack" + "test": "vitest run", + "pack": "pnpm pack", + "test:generate": "tsx ../../scripts/test-generate.ts ." }, "keywords": [], "author": "ZenStack Team", @@ -32,6 +34,7 @@ }, "devDependencies": { "@zenstackhq/eslint-config": "workspace:*", - "@zenstackhq/typescript-config": "workspace:*" + "@zenstackhq/typescript-config": "workspace:*", + "@zenstackhq/vitest-config": "workspace:*" } } diff --git a/packages/schema/src/accessor.ts b/packages/schema/src/accessor.ts new file mode 100644 index 000000000..e732b54f1 --- /dev/null +++ b/packages/schema/src/accessor.ts @@ -0,0 +1,232 @@ +import { ExpressionUtils } from './expression-utils'; +import type { + DataSourceProviderType, + EnumDef, + FieldDef, + ModelDef, + ProcedureDef, + SchemaDef, + TypeDefDef, +} from './schema'; + +type Accessors = { + /** + * The data source provider type of the schema, e.g. "sqlite", "postgresql", etc. + */ + get providerType(): DataSourceProviderType; + + /** + * Gets a model definition by name. Returns `undefined` if the model is not found. + */ + getModel(name: string): ModelDef | undefined; + + /** + * Gets a model definition by name. Throws an error if the model is not found. + */ + requireModel(name: string): ModelDef; + + /** + * Gets a field definition by model/type and field name. Returns `undefined` if the field is not found. + */ + getField(modelOrType: string, field: string): FieldDef | undefined; + + /** + * Gets a field definition by model/type and field name. Throws an error if the field is not found. + */ + requireField(modelOrType: string, field: string): FieldDef; + + /*** + * Gets an enum definition by name. Returns `undefined` if the enum is not found. + */ + getEnum(name: string): EnumDef | undefined; + + /** + * Gets an enum definition by name. Throws an error if the enum is not found. + */ + requireEnum(name: string): EnumDef; + + /** + * Gets a type definition by name. Returns `undefined` if the type definition is not found. + * @param name + */ + getTypeDef(name: string): TypeDefDef | undefined; + + /** + * Gets a type definition by name. Throws an error if the type definition is not found. + */ + requireTypeDef(name: string): TypeDefDef; + + /** + * Gets a procedure definition by name. Returns `undefined` if the procedure is not found. + */ + getProcedure(name: string): ProcedureDef | undefined; + + /** + * Gets a procedure definition by name. Throws an error if the procedure is not found. + */ + requireProcedure(name: string): ProcedureDef; + + /** + * Gets the unique fields of a model, including both singular and compound unique fields. + */ + getUniqueFields( + model: string, + ): Array<{ name: string; def: FieldDef } | { name: string; defs: Record }>; + + /** + * Gets the delegate discriminator field for a model, if defined via `@@delegate` attribute. Returns `undefined` if not available. + */ + getDelegateDiscriminator(model: string): string | undefined; +}; + +export class InvalidSchemaError extends Error { + constructor(message: string) { + super(message); + } +} + +type AccessorTarget = { schema: SchemaDef }; + +function _requireModel(schema: SchemaDef, name: string): ModelDef { + const model = schema.models[name]; + if (!model) throw new InvalidSchemaError(`Model "${name}" not found in schema`); + return model; +} + +function _getField(schema: SchemaDef, modelOrType: string, field: string): FieldDef | undefined { + const modelDef = schema.models?.[modelOrType]; + if (modelDef) { + return modelDef.fields[field]; + } + const typeDef = schema.typeDefs?.[modelOrType]; + if (typeDef) { + return typeDef.fields[field]; + } + return undefined; +} + +function _requireField(schema: SchemaDef, modelOrType: string, field: string): FieldDef { + const fieldDef = _getField(schema, modelOrType, field); + if (!fieldDef) throw new InvalidSchemaError(`Field "${modelOrType}.${field}" not found in schema`); + return fieldDef; +} + +function _requireModelField(schema: SchemaDef, model: string, field: string) { + const modelDef = _requireModel(schema, model); + const fieldDef = modelDef.fields[field]; + if (!fieldDef) throw new InvalidSchemaError(`Field "${model}.${field}" not found in schema`); + return fieldDef; +} + +const accessors: Accessors = { + get providerType() { + return (this as unknown as AccessorTarget).schema.provider.type; + }, + + getModel(this: { schema: SchemaDef }, name: string) { + return this.schema.models[name]; + }, + + requireModel(this: { schema: SchemaDef }, name: string) { + return _requireModel(this.schema, name); + }, + + getField(this: { schema: SchemaDef }, modelOrType: string, field: string) { + return _getField(this.schema, modelOrType, field); + }, + + requireField(this: { schema: SchemaDef }, modelOrType: string, field: string) { + return _requireField(this.schema, modelOrType, field); + }, + + getEnum(this: { schema: SchemaDef }, name: string) { + return this.schema.enums?.[name]; + }, + + requireEnum(this: { schema: SchemaDef }, name: string) { + const enumDef = this.schema.enums?.[name]; + if (!enumDef) throw new InvalidSchemaError(`Enum "${name}" not found in schema`); + return enumDef; + }, + + getTypeDef(this: { schema: SchemaDef }, name: string) { + return this.schema.typeDefs?.[name]; + }, + + requireTypeDef(this: { schema: SchemaDef }, name: string) { + const typeDef = this.schema.typeDefs?.[name]; + if (!typeDef) throw new InvalidSchemaError(`TypeDef "${name}" not found in schema`); + return typeDef; + }, + + getProcedure(this: { schema: SchemaDef }, name: string) { + return this.schema.procedures?.[name]; + }, + + requireProcedure(this: { schema: SchemaDef }, name: string) { + const procedure = this.schema.procedures?.[name]; + if (!procedure) throw new InvalidSchemaError(`Procedure "${name}" not found in schema`); + return procedure; + }, + + getUniqueFields(this: { schema: SchemaDef }, model: string) { + const modelDef = _requireModel(this.schema, model); + const result: Array<{ name: string; def: FieldDef } | { name: string; defs: Record }> = []; + for (const [key, value] of Object.entries(modelDef.uniqueFields)) { + if (value === null || typeof value !== 'object') { + throw new InvalidSchemaError(`Invalid unique field definition for "${model}.${key}"`); + } + + if (typeof value.type === 'string') { + // singular unique field + result.push({ name: key, def: _requireModelField(this.schema, model, key) }); + } else { + // compound unique field + result.push({ + name: key, + defs: Object.fromEntries( + Object.keys(value).map((k) => [k, _requireModelField(this.schema, model, k)]), + ), + }); + } + } + return result; + }, + + getDelegateDiscriminator(this: { schema: SchemaDef }, model: string) { + const modelDef = _requireModel(this.schema, model); + const delegateAttr = modelDef.attributes?.find((attr) => attr.name === '@@delegate'); + if (!delegateAttr) { + return undefined; + } + const discriminator = delegateAttr.args?.find((arg) => arg.name === 'discriminator'); + if (!discriminator || !ExpressionUtils.isField(discriminator.value)) { + throw new InvalidSchemaError(`Discriminator field not defined for model "${model}"`); + } + return discriminator.value.field; + }, +}; + +export type SchemaAccessor = Schema & Accessors; + +export interface SchemaAccessorConstructor { + new (schema: Schema): SchemaAccessor; +} + +export const SchemaAccessor = function (this: any, schema: Schema) { + return new Proxy( + { schema }, + { + get(target, prop) { + const descriptor = Object.getOwnPropertyDescriptor(accessors, prop); + if (descriptor?.get) { + return descriptor.get.call(target); + } + if (prop in accessors) { + return (accessors as any)[prop].bind(target); + } + return (schema as any)[prop]; + }, + }, + ); +} as unknown as SchemaAccessorConstructor; diff --git a/packages/schema/src/index.ts b/packages/schema/src/index.ts index 6a171e592..bb954c094 100644 --- a/packages/schema/src/index.ts +++ b/packages/schema/src/index.ts @@ -1,3 +1,4 @@ +export { InvalidSchemaError, SchemaAccessor } from './accessor'; export type * from './expression'; export * from './expression-utils'; export type * from './schema'; diff --git a/packages/schema/src/schema.ts b/packages/schema/src/schema.ts index d98b86f01..e21b5e30e 100644 --- a/packages/schema/src/schema.ts +++ b/packages/schema/src/schema.ts @@ -18,18 +18,18 @@ export type SchemaDef = { authType?: GetModels | GetTypeDefs; }; +export type UniqueFieldsInfo = + // singular unique field + | Pick + // compound unique field + | Record>; + export type ModelDef = { name: string; baseModel?: string; fields: Record; attributes?: readonly AttributeApplication[]; - uniqueFields: Record< - string, - // singular unique field - | Pick - // compound unique field - | Record> - >; + uniqueFields: Record; idFields: readonly string[]; computedFields?: Record; isDelegate?: boolean; @@ -139,7 +139,7 @@ export type GetSubModels> = Schema['models'][Model]; -export type GetEnums = keyof Schema['enums']; +export type GetEnums = Extract; export type GetEnum> = Schema['enums'][Enum] extends EnumDef ? Schema['enums'][Enum]['values'] @@ -287,7 +287,7 @@ export type FieldHasDefault< Field extends GetModelFields, > = GetModelField['default'] extends object | number | string | boolean ? true - : GetModelField['updatedAt'] extends (true | UpdatedAtInfo) + : GetModelField['updatedAt'] extends true | UpdatedAtInfo ? true : GetModelField['relation'] extends { hasDefault: true } ? true diff --git a/packages/schema/test/accessor.test.ts b/packages/schema/test/accessor.test.ts new file mode 100644 index 000000000..dd15af2f6 --- /dev/null +++ b/packages/schema/test/accessor.test.ts @@ -0,0 +1,117 @@ +import { describe, expect, it } from 'vitest'; +import { InvalidSchemaError, SchemaAccessor } from '../src/accessor'; +import { schema } from './schema/schema'; + +describe('SchemaAccessor tests', () => { + const accessor = new SchemaAccessor(schema); + + it('proxies schema properties through', () => { + expect(accessor.provider).toEqual({ type: 'sqlite' }); + expect(accessor.models).toBe(schema.models); + expect(accessor.authType).toBe('User'); + }); + + it('returns providerType', () => { + expect(accessor.providerType).toBe('sqlite'); + }); + + it('getModel returns model if found', () => { + expect(accessor.getModel('User')).toBe(schema.models.User); + expect(accessor.getModel('Post')).toBe(schema.models.Post); + }); + + it('getModel returns undefined for unknown model', () => { + expect(accessor.getModel('Unknown')).toBeUndefined(); + }); + + it('requireModel returns model if found', () => { + expect(accessor.requireModel('User')).toBe(schema.models.User); + }); + + it('requireModel throws for unknown model', () => { + expect(() => accessor.requireModel('Unknown')).toThrow(InvalidSchemaError); + expect(() => accessor.requireModel('Unknown')).toThrow('Model "Unknown" not found in schema'); + }); + + it('getEnum returns enum if found', () => { + expect(accessor.getEnum('Role')).toBe(schema.enums.Role); + }); + + it('getEnum returns undefined for unknown enum', () => { + expect(accessor.getEnum('Unknown')).toBeUndefined(); + }); + + it('requireEnum returns enum if found', () => { + const enumDef = accessor.requireEnum('Role'); + expect(enumDef.name).toBe('Role'); + expect(enumDef.values).toEqual({ ADMIN: 'ADMIN', USER: 'USER' }); + }); + + it('requireEnum throws for unknown enum', () => { + expect(() => accessor.requireEnum('Unknown')).toThrow(InvalidSchemaError); + expect(() => accessor.requireEnum('Unknown')).toThrow('Enum "Unknown" not found in schema'); + }); + + it('getTypeDef returns typeDef if found', () => { + expect(accessor.getTypeDef('Address')).toBe(schema.typeDefs.Address); + }); + + it('getTypeDef returns undefined for unknown typeDef', () => { + expect(accessor.getTypeDef('Unknown')).toBeUndefined(); + }); + + it('requireTypeDef returns typeDef if found', () => { + const typeDef = accessor.requireTypeDef('Address'); + expect(typeDef.name).toBe('Address'); + expect(typeDef.fields.street).toMatchObject({ name: 'street', type: 'String' }); + expect(typeDef.fields.city).toMatchObject({ name: 'city', type: 'String' }); + expect(typeDef.fields.zip).toMatchObject({ name: 'zip', type: 'String', optional: true }); + }); + + it('requireTypeDef throws for unknown typeDef', () => { + expect(() => accessor.requireTypeDef('Unknown')).toThrow(InvalidSchemaError); + expect(() => accessor.requireTypeDef('Unknown')).toThrow('TypeDef "Unknown" not found in schema'); + }); + + it('getProcedure returns procedure if found', () => { + expect(accessor.getProcedure('getUserPosts')).toBe(schema.procedures.getUserPosts); + }); + + it('getProcedure returns undefined for unknown procedure', () => { + expect(accessor.getProcedure('unknown')).toBeUndefined(); + }); + + it('requireProcedure returns procedure if found', () => { + const proc = accessor.requireProcedure('getUserPosts'); + expect(proc.returnType).toBe('Post'); + expect(proc.returnArray).toBe(true); + expect(proc.params.userId).toMatchObject({ name: 'userId', type: 'String' }); + }); + + it('requireProcedure throws for unknown procedure', () => { + expect(() => accessor.requireProcedure('unknown')).toThrow(InvalidSchemaError); + expect(() => accessor.requireProcedure('unknown')).toThrow('Procedure "unknown" not found in schema'); + }); + + it('getUniqueFields returns singular unique fields for User', () => { + const fields = accessor.getUniqueFields('User'); + const names = fields.map((f) => f.name); + expect(names).toContain('id'); + expect(names).toContain('email'); + // each entry should be a singular field with a `def` + for (const f of fields) { + expect('def' in f).toBe(true); + } + }); + + it('getUniqueFields returns singular unique field for Post', () => { + const fields = accessor.getUniqueFields('Post'); + expect(fields).toHaveLength(1); + expect(fields[0]!.name).toBe('id'); + expect('def' in fields[0]!).toBe(true); + }); + + it('getUniqueFields throws for unknown model', () => { + expect(() => accessor.getUniqueFields('Unknown')).toThrow(InvalidSchemaError); + }); +}); diff --git a/packages/schema/test/schema/schema.ts b/packages/schema/test/schema/schema.ts new file mode 100644 index 000000000..846028a79 --- /dev/null +++ b/packages/schema/test/schema/schema.ts @@ -0,0 +1,135 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaDef, ExpressionUtils } from "@zenstackhq/schema"; +export class SchemaType implements SchemaDef { + provider = { + type: "sqlite" + } as const; + models = { + User: { + name: "User", + fields: { + id: { + name: "id", + type: "String", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], + default: ExpressionUtils.call("cuid") + }, + email: { + name: "email", + type: "String", + unique: true, + attributes: [{ name: "@unique" }] + }, + name: { + name: "name", + type: "String", + optional: true + }, + role: { + name: "role", + type: "Role" + }, + address: { + name: "address", + type: "Address", + optional: true, + attributes: [{ name: "@json" }] + }, + posts: { + name: "posts", + type: "Post", + array: true, + relation: { opposite: "owner" } + } + }, + idFields: ["id"], + uniqueFields: { + id: { type: "String" }, + email: { type: "String" } + } + }, + Post: { + name: "Post", + fields: { + id: { + name: "id", + type: "String", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], + default: ExpressionUtils.call("cuid") + }, + title: { + name: "title", + type: "String" + }, + owner: { + name: "owner", + type: "User", + optional: true, + attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array("String", [ExpressionUtils.field("ownerId")]) }, { name: "references", value: ExpressionUtils.array("String", [ExpressionUtils.field("id")]) }] }], + relation: { opposite: "posts", fields: ["ownerId"], references: ["id"] } + }, + ownerId: { + name: "ownerId", + type: "String", + optional: true, + foreignKeyFor: [ + "owner" + ] + } + }, + idFields: ["id"], + uniqueFields: { + id: { type: "String" } + } + } + } as const; + typeDefs = { + Address: { + name: "Address", + fields: { + street: { + name: "street", + type: "String" + }, + city: { + name: "city", + type: "String" + }, + zip: { + name: "zip", + type: "String", + optional: true + } + } + } + } as const; + enums = { + Role: { + name: "Role", + values: { + ADMIN: "ADMIN", + USER: "USER" + } + } + } as const; + authType = "User" as const; + procedures = { + getUserPosts: { + params: { + userId: { name: "userId", type: "String" } + }, + returnType: "Post", + returnArray: true + } + } as const; + plugins = {}; +} +export const schema = new SchemaType(); diff --git a/packages/schema/test/schema/schema.zmodel b/packages/schema/test/schema/schema.zmodel new file mode 100644 index 000000000..b611fa663 --- /dev/null +++ b/packages/schema/test/schema/schema.zmodel @@ -0,0 +1,32 @@ +datasource db { + provider = 'sqlite' +} + +type Address { + street String + city String + zip String? +} + +enum Role { + ADMIN + USER +} + +model User { + id String @id @default(cuid()) + email String @unique + name String? + role Role + address Address? @json + posts Post[] +} + +model Post { + id String @id @default(cuid()) + title String + owner User? @relation(fields: [ownerId], references: [id]) + ownerId String? +} + +procedure getUserPosts(userId: String): Post[] diff --git a/packages/schema/vitest.config.ts b/packages/schema/vitest.config.ts new file mode 100644 index 000000000..75a9f709c --- /dev/null +++ b/packages/schema/vitest.config.ts @@ -0,0 +1,4 @@ +import base from '@zenstackhq/vitest-config/base'; +import { defineConfig, mergeConfig } from 'vitest/config'; + +export default mergeConfig(base, defineConfig({})); diff --git a/packages/zod/package.json b/packages/zod/package.json index a146e20e9..a7488a784 100644 --- a/packages/zod/package.json +++ b/packages/zod/package.json @@ -5,9 +5,16 @@ "type": "module", "scripts": { "build": "tsc --noEmit && tsup-node", - "lint": "eslint src --ext ts" + "watch": "tsup-node --watch", + "lint": "eslint src --ext ts", + "test": "vitest run", + "pack": "pnpm pack", + "test:generate": "tsx ../../scripts/test-generate.ts ." }, - "keywords": [], + "keywords": [ + "zenstack", + "zod" + ], "files": [ "dist" ], @@ -26,12 +33,16 @@ } }, "dependencies": { - "@zenstackhq/orm": "workspace:*", + "@zenstackhq/common-helpers": "workspace:*", + "@zenstackhq/schema": "workspace:*", + "decimal.js": "catalog:", + "json-stable-stringify": "^1.3.0", "ts-pattern": "catalog:" }, "devDependencies": { "@zenstackhq/eslint-config": "workspace:*", "@zenstackhq/typescript-config": "workspace:*", + "@zenstackhq/vitest-config": "workspace:*", "zod": "^4.1.0" }, "peerDependencies": { diff --git a/packages/zod/src/error.ts b/packages/zod/src/error.ts new file mode 100644 index 000000000..f55dbbba1 --- /dev/null +++ b/packages/zod/src/error.ts @@ -0,0 +1,8 @@ +/** + * Error representing failures in Zod schema building. + */ +export class ZodSchemaError extends Error { + constructor(message: string) { + super(message); + } +} diff --git a/packages/zod/src/factory.ts b/packages/zod/src/factory.ts new file mode 100644 index 000000000..02e1ab031 --- /dev/null +++ b/packages/zod/src/factory.ts @@ -0,0 +1,243 @@ +import { + SchemaAccessor, + type BuiltinType, + type FieldDef, + type FieldIsArray, + type FieldIsRelation, + type GetEnum, + type GetEnums, + type GetModelFields, + type GetModelFieldType, + type GetModels, + type GetTypeDefFields, + type GetTypeDefFieldType, + type GetTypeDefs, + type ModelFieldIsOptional, + type SchemaDef, + type TypeDefFieldIsOptional, +} from '@zenstackhq/schema'; +import Decimal from 'decimal.js'; +import { match } from 'ts-pattern'; +import z from 'zod'; +import { + addBigIntValidation, + addCustomValidation, + addDecimalValidation, + addNumberValidation, + addStringValidation, +} from './utils'; + +export function createSchemaFactory(schema: Schema) { + return new SchemaFactory(schema); +} + +class SchemaFactory { + private readonly schema: SchemaAccessor; + + constructor(_schema: Schema) { + this.schema = new SchemaAccessor(_schema); + } + + makeModelSchema>( + model: Model, + ): z.ZodObject, z.core.$strict> { + const modelDef = this.schema.models[model]; + if (!modelDef) { + throw new Error(`Model "${model}" not found in schema`); + } + const fields: Record = {}; + + for (const [fieldName, fieldDef] of Object.entries(modelDef.fields)) { + if (fieldDef.relation) { + const relatedModelName = fieldDef.type; + const lazySchema: z.ZodType = z.lazy(() => this.makeModelSchema(relatedModelName as GetModels)); + // relation fields are always optional + fields[fieldName] = this.applyCardinality(lazySchema, fieldDef).optional(); + } else { + fields[fieldName] = this.makeScalarFieldSchema(fieldDef); + } + } + + const shape = z.strictObject(fields); + return addCustomValidation(shape, modelDef.attributes) as unknown as z.ZodObject< + GetModelFieldsShape, + z.core.$strict + >; + } + + private makeScalarFieldSchema(fieldDef: FieldDef): z.ZodType { + const { type, attributes } = fieldDef; + + // enum + const enumDef = this.schema.getEnum(type); + if (enumDef) { + return this.applyCardinality(this.makeEnumSchema(type as GetEnums), fieldDef); + } + + // typedef + const typedefDef = this.schema.getTypeDef(type); + if (typedefDef) { + return this.applyCardinality(this.makeTypeSchema(type as GetTypeDefs), fieldDef); + } + + const base = match(type as BuiltinType) + .with('String', () => addStringValidation(z.string(), attributes)) + .with('Int', () => addNumberValidation(z.number().int(), attributes)) + .with('Float', () => addNumberValidation(z.number(), attributes)) + .with('Boolean', () => z.boolean()) + .with('BigInt', () => addBigIntValidation(z.bigint(), attributes)) + .with('Decimal', () => + z.union([ + addNumberValidation(z.number(), attributes) as z.ZodNumber, + addDecimalValidation(z.string(), attributes, true) as z.ZodString, + addDecimalValidation(z.instanceof(Decimal), attributes, true), + ]), + ) + .with('DateTime', () => z.union([z.date(), z.iso.datetime()])) + .with('Bytes', () => z.instanceof(Uint8Array)) + .with('Json', () => this.makeJsonSchema()) + .with('Unsupported', () => z.unknown()) + .exhaustive(); + + return this.applyCardinality(base, fieldDef); + } + + private makeJsonSchema(): z.ZodType { + return z.union([ + z.string(), + z.number(), + z.boolean(), + z.null(), + z.array(z.lazy(() => this.makeJsonSchema())), + z.object({}).catchall(z.lazy(() => this.makeJsonSchema())), + ]); + } + + private applyCardinality(schema: z.ZodType, fieldDef: FieldDef): z.ZodType { + let result = schema; + if (fieldDef.array) { + result = result.array(); + } + if (fieldDef.optional) { + result = result.nullable().optional(); + } + return result; + } + + makeTypeSchema>( + type: Type, + ): z.ZodObject, z.core.$strict> { + const typeDef = this.schema.requireTypeDef(type); + const fields: Record = {}; + + for (const [fieldName, fieldDef] of Object.entries(typeDef.fields)) { + fields[fieldName] = this.makeScalarFieldSchema(fieldDef); + } + + const shape = z.strictObject(fields); + return addCustomValidation(shape, typeDef.attributes) as unknown as z.ZodObject< + GetTypeDefFieldsShape, + z.core.$strict + >; + } + + makeEnumSchema>( + _enum: Enum, + ): z.ZodEnum<{ [Key in keyof GetEnum]: GetEnum[Key] }> { + const enumDef = this.schema.requireEnum(_enum); + return z.enum(Object.keys(enumDef.values) as [string, ...string[]]) as unknown as z.ZodEnum<{ + [Key in keyof GetEnum]: GetEnum[Key]; + }>; + } +} + +type GetModelFieldsShape> = { + // scalar fields + [Field in GetModelFields as FieldIsRelation extends true + ? never + : Field]: ZodOptionalAndNullableIf< + MapModelFieldToZod, + ModelFieldIsOptional + >; +} & { + // relation fields, always optional + [Field in GetModelFields as FieldIsRelation extends true + ? Field + : never]: ZodNullableIf< + z.ZodOptional< + ZodArrayIf< + z.ZodObject< + GetModelFieldsShape< + Schema, + GetModelFieldType extends GetModels + ? GetModelFieldType + : never + >, + z.core.$strict + >, + FieldIsArray + > + >, + ModelFieldIsOptional + >; +}; + +type GetTypeDefFieldsShape> = { + [Field in GetTypeDefFields]: ZodOptionalAndNullableIf< + MapTypeDefFieldToZod, + TypeDefFieldIsOptional + >; +}; + +type FieldTypeZodMap = { + String: z.ZodString; + Int: z.ZodNumber; + BigInt: z.ZodBigInt; + Float: z.ZodNumber; + Decimal: z.ZodType; + Boolean: z.ZodBoolean; + DateTime: z.ZodType; + Bytes: z.ZodType; + Json: JsonZodType; +}; + +type MapModelFieldToZod< + Schema extends SchemaDef, + Model extends GetModels, + Field extends GetModelFields, + FieldType = GetModelFieldType, +> = MapFieldTypeToZod; + +type MapTypeDefFieldToZod< + Schema extends SchemaDef, + Type extends GetTypeDefs, + Field extends GetTypeDefFields, + FieldType = GetTypeDefFieldType, +> = MapFieldTypeToZod; + +type MapFieldTypeToZod = FieldType extends keyof FieldTypeZodMap + ? FieldTypeZodMap[FieldType] + : FieldType extends GetEnums + ? EnumZodType + : FieldType extends GetTypeDefs + ? z.ZodObject, z.core.$strict> + : z.ZodUnknown; + +type JsonZodType = + | z.ZodObject, z.core.$loose> + | z.ZodArray + | z.ZodString + | z.ZodNumber + | z.ZodBoolean + | z.ZodNull; + +type EnumZodType> = z.ZodEnum<{ + [Key in keyof GetEnum]: GetEnum[Key]; +}>; + +type ZodOptionalAndNullableIf = Condition extends true + ? z.ZodOptional> + : T; + +type ZodNullableIf = Condition extends true ? z.ZodNullable : T; +type ZodArrayIf = Condition extends true ? z.ZodArray : T; diff --git a/packages/zod/src/index.ts b/packages/zod/src/index.ts index 8211af368..905551618 100644 --- a/packages/zod/src/index.ts +++ b/packages/zod/src/index.ts @@ -1,33 +1,2 @@ -import type { FieldDef, GetModels, SchemaDef } from '@zenstackhq/orm/schema'; -import { match, P } from 'ts-pattern'; -import { z, ZodType } from 'zod'; -import type { SelectSchema } from './types'; - -export function makeSelectSchema>( - schema: Schema, - model: Model, -) { - return z.strictObject(mapFields(schema, model)) as SelectSchema; -} - -function mapFields(schema: Schema, model: GetModels): any { - const modelDef = schema.models[model]; - if (!modelDef) { - throw new Error(`Model ${model} not found in schema`); - } - const scalarFields = Object.entries(modelDef.fields).filter(([_, fieldDef]) => !fieldDef.relation); - const result: Record = {}; - for (const [field, fieldDef] of scalarFields) { - result[field] = makeScalarSchema(fieldDef); - } - return result; -} - -function makeScalarSchema(fieldDef: FieldDef): ZodType { - return match(fieldDef.type) - .with('String', () => z.string()) - .with(P.union('Int', 'BigInt', 'Float', 'Decimal'), () => z.number()) - .with('Boolean', () => z.boolean()) - .with('DateTime', () => z.string().datetime()) - .otherwise(() => z.unknown()); -} +export { createSchemaFactory as createModelSchemaFactory } from './factory'; +export * as ZodUtils from './utils'; diff --git a/packages/zod/src/types.ts b/packages/zod/src/types.ts deleted file mode 100644 index 249c6de9e..000000000 --- a/packages/zod/src/types.ts +++ /dev/null @@ -1,27 +0,0 @@ -import type { FieldType, GetModels, ScalarFields, SchemaDef } from '@zenstackhq/orm/schema'; -import type { ZodBoolean, ZodNumber, ZodObject, ZodString, ZodUnknown } from 'zod'; - -export type SelectSchema> = ZodObject<{ - [Key in ScalarFields]: MapScalarType; -}>; - -type MapScalarType< - Schema extends SchemaDef, - Model extends GetModels, - Field extends ScalarFields, - Type = FieldType, -> = Type extends 'String' - ? ZodString - : Type extends 'Int' - ? ZodNumber - : Type extends 'BigInt' - ? ZodNumber - : Type extends 'Float' - ? ZodNumber - : Type extends 'Decimal' - ? ZodNumber - : Type extends 'DateTime' - ? ZodString - : Type extends 'Boolean' - ? ZodBoolean - : ZodUnknown; diff --git a/packages/orm/src/client/crud/validator/utils.ts b/packages/zod/src/utils.ts similarity index 96% rename from packages/orm/src/client/crud/validator/utils.ts rename to packages/zod/src/utils.ts index c9909a702..b670c7316 100644 --- a/packages/orm/src/client/crud/validator/utils.ts +++ b/packages/zod/src/utils.ts @@ -1,19 +1,18 @@ import { invariant } from '@zenstackhq/common-helpers'; -import type { - AttributeApplication, - BinaryExpression, - CallExpression, - Expression, - FieldExpression, - MemberExpression, - UnaryExpression, +import { + ExpressionUtils, + type AttributeApplication, + type BinaryExpression, + type CallExpression, + type Expression, + type FieldExpression, + type MemberExpression, + type UnaryExpression, } from '@zenstackhq/schema'; import Decimal from 'decimal.js'; import { match, P } from 'ts-pattern'; import { z } from 'zod'; -import { ZodIssueCode } from 'zod/v3'; -import { ExpressionUtils } from '../../../schema'; -import { createNotSupportedError } from '../../errors'; +import { ZodSchemaError } from './error'; function getArgValue(expr: Expression | undefined): T | undefined { if (!expr || !ExpressionUtils.isLiteral(expr)) { @@ -167,7 +166,7 @@ export function addDecimalValidation( new Decimal(v); } catch (err) { ctx.addIssue({ - code: z.ZodIssueCode.custom, + code: 'custom', message: `Invalid decimal: ${err}`, }); } @@ -184,7 +183,7 @@ export function addDecimalValidation( error?.issues.forEach((issue) => { if (op === 'gt' || op === 'gte') { ctx.addIssue({ - code: ZodIssueCode.too_small, + code: 'too_small', origin: 'number', minimum: value, type: 'decimal', @@ -193,7 +192,7 @@ export function addDecimalValidation( }); } else { ctx.addIssue({ - code: ZodIssueCode.too_big, + code: 'too_big', origin: 'number', maximum: value, type: 'decimal', @@ -467,7 +466,7 @@ function evalCall(data: any, expr: CallExpression) { return fieldArg.length === 0; }) .otherwise(() => { - throw createNotSupportedError(`Unsupported function "${expr.function}"`); + throw new ZodSchemaError(`Unsupported function "${expr.function}"`); }) ); } diff --git a/packages/zod/test/factory.test.ts b/packages/zod/test/factory.test.ts new file mode 100644 index 000000000..7283e1d91 --- /dev/null +++ b/packages/zod/test/factory.test.ts @@ -0,0 +1,581 @@ +import Decimal from 'decimal.js'; +import { describe, expect, expectTypeOf, it } from 'vitest'; +import { createModelSchemaFactory } from '../src/index'; +import { schema } from './schema/schema'; +import z from 'zod'; + +const factory = createModelSchemaFactory(schema); + +// A fully valid User object (without relations) +const validUser = { + id: 'user123', + email: 'test@example.com', + username: 'johndoe', + website: null, + code: 'USR001', + age: 25, + score: 50.0, + bigNum: BigInt(100), + balance: 10.0, + active: true, + birthdate: null, + avatar: null, + metadata: null, + status: 'ACTIVE', + address: null, +}; + +// A fully valid Post object (without relations) +const validPost = { + id: 'post123', + title: 'My First Post', + published: true, + authorId: null, +}; + +describe('SchemaFactory - makeModelSchema', () => { + describe('scalar field types', () => { + it('infers correct field types for User', () => { + const _userSchema = factory.makeModelSchema('User'); + type User = z.infer; + + // required string fields + expectTypeOf().toEqualTypeOf(); + expectTypeOf().toEqualTypeOf(); + expectTypeOf().toEqualTypeOf(); + expectTypeOf().toEqualTypeOf(); + // optional string field (nullable + optional) + expectTypeOf().toEqualTypeOf(); + + // number fields (Int and Float both map to ZodNumber) + expectTypeOf().toEqualTypeOf(); + expectTypeOf().toEqualTypeOf(); + + // bigint + expectTypeOf().toEqualTypeOf(); + + // Decimal maps to ZodCustom + expectTypeOf().toEqualTypeOf(); + + // boolean + expectTypeOf().toEqualTypeOf(); + + // DateTime + expectTypeOf().toEqualTypeOf(); + + // optional Bytes + expectTypeOf().toEqualTypeOf(); + + // optional Json + expectTypeOf().toHaveProperty('metadata'); + expectTypeOf().toEqualTypeOf< + string | number | boolean | null | Record | unknown[] | undefined + >(); + + // required enum + expectTypeOf().toEqualTypeOf<'ACTIVE' | 'INACTIVE' | 'PENDING'>(); + + // optional typedef (Address): { street, city, zip? } | null | undefined + type Address = Exclude; + expectTypeOf().toEqualTypeOf(); + expectTypeOf().toEqualTypeOf(); + expectTypeOf().toEqualTypeOf(); + expectTypeOf().toEqualTypeOf
(); + + // relation field present + expectTypeOf().toHaveProperty('posts'); + const _postSchema = factory.makeModelSchema('Post'); + type Post = z.infer; + expectTypeOf().toEqualTypeOf(); + }); + + it('infers correct field types for Post', () => { + const _postSchema = factory.makeModelSchema('Post'); + type Post = z.infer; + + // required string fields + expectTypeOf().toEqualTypeOf(); + expectTypeOf().toEqualTypeOf(); + + // required boolean + expectTypeOf().toEqualTypeOf(); + + // optional scalar (foreign key) + expectTypeOf().toEqualTypeOf(); + + // optional relation field present in type + expectTypeOf().toHaveProperty('author'); + const _userSchema = factory.makeModelSchema('User'); + type User = z.infer; + expectTypeOf().toEqualTypeOf(); + }); + + it('accepts a fully valid User', () => { + const userSchema = factory.makeModelSchema('User'); + expect(userSchema.safeParse(validUser).success).toBe(true); + }); + + it('accepts a fully valid Post', () => { + const postSchema = factory.makeModelSchema('Post'); + expect(postSchema.safeParse(validPost).success).toBe(true); + }); + + it('rejects extra fields (strict object)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, unknownField: 'value' }); + expect(result.success).toBe(false); + }); + + it('accepts DateTime as a Date object', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, birthdate: new Date() }); + expect(result.success).toBe(true); + }); + + it('accepts DateTime as an ISO datetime string', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ + ...validUser, + birthdate: '2024-01-15T10:30:00.000Z', + }); + expect(result.success).toBe(true); + }); + + it('accepts Bytes as Uint8Array', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ + ...validUser, + avatar: new Uint8Array([1, 2, 3]), + }); + expect(result.success).toBe(true); + }); + + it('accepts BigInt values', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, bigNum: BigInt(999) }); + expect(result.success).toBe(true); + }); + + it('accepts Decimal as a number', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, balance: 42.5 }); + expect(result.success).toBe(true); + }); + + it('accepts Decimal as a numeric string', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, balance: '42.5' }); + expect(result.success).toBe(true); + }); + + it('accepts Decimal as a Decimal instance', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, balance: new Decimal('42.5') }); + expect(result.success).toBe(true); + }); + + it('accepts Json values', () => { + const userSchema = factory.makeModelSchema('User'); + expect(userSchema.safeParse({ ...validUser, metadata: { key: 'value' } }).success).toBe(true); + expect(userSchema.safeParse({ ...validUser, metadata: [1, 2, 3] }).success).toBe(true); + expect(userSchema.safeParse({ ...validUser, metadata: 42 }).success).toBe(true); + }); + + it('rejects invalid Json values', () => { + const userSchema = factory.makeModelSchema('User'); + // BigInt is not a JSON primitive + expect(userSchema.safeParse({ ...validUser, metadata: BigInt(1) }).success).toBe(false); + // Symbol is not a JSON value + expect(userSchema.safeParse({ ...validUser, metadata: Symbol('s') }).success).toBe(false); + // Functions are not JSON values + expect(userSchema.safeParse({ ...validUser, metadata: () => {} }).success).toBe(false); + // Nested non-JSON values are also rejected + expect(userSchema.safeParse({ ...validUser, metadata: { key: BigInt(1) } }).success).toBe(false); + expect(userSchema.safeParse({ ...validUser, metadata: [BigInt(1)] }).success).toBe(false); + }); + }); + + describe('string validation attributes', () => { + it('rejects invalid email for @email field', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, email: 'not-an-email' }); + expect(result.success).toBe(false); + }); + + it('accepts valid email for @email field', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, email: 'valid@domain.com' }); + expect(result.success).toBe(true); + }); + + it('rejects username too short for @length(3, 50)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, username: 'ab' }); + expect(result.success).toBe(false); + }); + + it('rejects username too long for @length(3, 50)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, username: 'a'.repeat(51) }); + expect(result.success).toBe(false); + }); + + it('accepts username within @length bounds', () => { + const userSchema = factory.makeModelSchema('User'); + expect(userSchema.safeParse({ ...validUser, username: 'abc' }).success).toBe(true); + expect(userSchema.safeParse({ ...validUser, username: 'a'.repeat(50) }).success).toBe(true); + }); + + it('rejects invalid URL for @url field', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, website: 'not-a-url' }); + expect(result.success).toBe(false); + }); + + it('accepts valid URL for @url field', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, website: 'https://example.com' }); + expect(result.success).toBe(true); + }); + + it('accepts null for optional @url field', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, website: null }); + expect(result.success).toBe(true); + }); + + it('rejects code that does not start with "USR" for @startsWith', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, code: 'ABC001' }); + expect(result.success).toBe(false); + }); + + it('accepts code starting with "USR" for @startsWith', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, code: 'USR_ANYTHING' }); + expect(result.success).toBe(true); + }); + }); + + describe('number validation attributes', () => { + it('rejects age = 0 for @gt(0)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, age: 0 }); + expect(result.success).toBe(false); + }); + + it('rejects age = 151 for @lte(150)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, age: 151 }); + expect(result.success).toBe(false); + }); + + it('accepts age within @gt(0) and @lte(150) bounds', () => { + const userSchema = factory.makeModelSchema('User'); + // Note: @@validate(age >= 18) also applies, so the minimum valid age is 18 + expect(userSchema.safeParse({ ...validUser, age: 18 }).success).toBe(true); + expect(userSchema.safeParse({ ...validUser, age: 150 }).success).toBe(true); + }); + + it('rejects score < 0 for @gte(0)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, score: -0.1 }); + expect(result.success).toBe(false); + }); + + it('rejects score = 100 for @lt(100)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, score: 100.0 }); + expect(result.success).toBe(false); + }); + + it('accepts score within @gte(0) and @lt(100) bounds', () => { + const userSchema = factory.makeModelSchema('User'); + expect(userSchema.safeParse({ ...validUser, score: 0 }).success).toBe(true); + expect(userSchema.safeParse({ ...validUser, score: 99.9 }).success).toBe(true); + }); + }); + + describe('bigint validation attributes', () => { + it('rejects bigNum < 0 for @gte(0)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, bigNum: BigInt(-1) }); + expect(result.success).toBe(false); + }); + + it('accepts bigNum = 0 for @gte(0)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, bigNum: BigInt(0) }); + expect(result.success).toBe(true); + }); + }); + + describe('decimal validation attributes', () => { + it('rejects balance = 0 (number) for @gt(0)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, balance: 0 }); + expect(result.success).toBe(false); + }); + + it('rejects balance = "0.0" (string) for @gt(0)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, balance: '0.0' }); + expect(result.success).toBe(false); + }); + + it('rejects balance = Decimal("0") for @gt(0)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, balance: new Decimal('0') }); + expect(result.success).toBe(false); + }); + + it('accepts balance = 0.01 (number) for @gt(0)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, balance: 0.01 }); + expect(result.success).toBe(true); + }); + }); + + describe('enum fields', () => { + it('accepts valid enum values', () => { + const userSchema = factory.makeModelSchema('User'); + expect(userSchema.safeParse({ ...validUser, status: 'ACTIVE' }).success).toBe(true); + expect(userSchema.safeParse({ ...validUser, status: 'INACTIVE' }).success).toBe(true); + expect(userSchema.safeParse({ ...validUser, status: 'PENDING' }).success).toBe(true); + }); + + it('rejects invalid enum value', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, status: 'ADMIN' }); + expect(result.success).toBe(false); + }); + }); + + describe('typedef (embedded type) fields', () => { + it('accepts null for optional typedef field', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, address: null }); + expect(result.success).toBe(true); + }); + + it('accepts valid Address object', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ + ...validUser, + address: { street: '123 Main St', city: 'Springfield', zip: null }, + }); + expect(result.success).toBe(true); + }); + + it('accepts Address with optional zip present', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ + ...validUser, + address: { street: '123 Main St', city: 'Springfield', zip: '12345' }, + }); + expect(result.success).toBe(true); + }); + + it('rejects Address with extra fields (strict object)', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ + ...validUser, + address: { street: '123 Main St', city: 'Springfield', zip: null, extra: 'field' }, + }); + expect(result.success).toBe(false); + }); + + it('rejects Address missing required fields', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ + ...validUser, + address: { street: '123 Main St' }, + }); + expect(result.success).toBe(false); + }); + }); + + describe('@@validate custom validation', () => { + it('fails when @@validate condition is false (age < 18 passes field but fails model validation)', () => { + const userSchema = factory.makeModelSchema('User'); + // age: 16 passes @gt(0) and @lte(150) but fails @@validate(age >= 18) + const result = userSchema.safeParse({ ...validUser, age: 16 }); + expect(result.success).toBe(false); + }); + + it('@@validate error contains the configured message', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, age: 16 }); + expect(result.success).toBe(false); + if (!result.success) { + const messages = result.error.issues.map((i) => i.message); + expect(messages).toContain('Must be adult'); + } + }); + + it('@@validate error uses the configured path', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, age: 16 }); + expect(result.success).toBe(false); + if (!result.success) { + const paths = result.error.issues.map((i) => i.path); + expect(paths).toContainEqual(['age']); + } + }); + + it('passes when @@validate condition is satisfied', () => { + const userSchema = factory.makeModelSchema('User'); + const result = userSchema.safeParse({ ...validUser, age: 18 }); + expect(result.success).toBe(true); + }); + }); + + describe('error handling', () => { + it('throws when model is not found', () => { + expect(() => factory.makeModelSchema('Unknown' as any)).toThrow('Model "Unknown" not found in schema'); + }); + }); +}); + +describe('SchemaFactory - makeTypeSchema', () => { + it('generates schema for Address typedef', () => { + const addressSchema = factory.makeTypeSchema('Address'); + expect(addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: null }).success).toBe(true); + }); + + it('rejects Address with missing required field', () => { + const addressSchema = factory.makeTypeSchema('Address'); + const result = addressSchema.safeParse({ street: '123 Main' }); + expect(result.success).toBe(false); + }); + + it('rejects Address with extra fields (strict object)', () => { + const addressSchema = factory.makeTypeSchema('Address'); + const result = addressSchema.safeParse({ + street: '123 Main', + city: 'Springfield', + zip: null, + extra: 'field', + }); + expect(result.success).toBe(false); + }); + + it('accepts Address with optional zip as null', () => { + const addressSchema = factory.makeTypeSchema('Address'); + expect(addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: null }).success).toBe(true); + }); + + it('accepts Address with optional zip as a string', () => { + const addressSchema = factory.makeTypeSchema('Address'); + expect(addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: '12345' }).success).toBe(true); + }); + + describe('extra validations', () => { + it('passes when zip is null', () => { + const addressSchema = factory.makeTypeSchema('Address'); + expect(addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: null }).success).toBe(true); + }); + + it('passes when zip is omitted', () => { + const addressSchema = factory.makeTypeSchema('Address'); + expect(addressSchema.safeParse({ street: '123 Main', city: 'Springfield' }).success).toBe(true); + }); + + it('passes when zip is exactly 5 characters', () => { + const addressSchema = factory.makeTypeSchema('Address'); + expect(addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: '90210' }).success).toBe( + true, + ); + }); + + it('fails when zip is fewer than 5 characters', () => { + const addressSchema = factory.makeTypeSchema('Address'); + const result = addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: '123' }); + expect(result.success).toBe(false); + }); + + it('fails when zip is more than 5 characters', () => { + const addressSchema = factory.makeTypeSchema('Address'); + const result = addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: '123456' }); + expect(result.success).toBe(false); + }); + + it('error message matches the configured message', () => { + const addressSchema = factory.makeTypeSchema('Address'); + const result = addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: '123' }); + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error.issues.map((i) => i.message)).toContain('Zip code must be exactly 5 characters'); + } + }); + + it('error path points to the zip field', () => { + const addressSchema = factory.makeTypeSchema('Address'); + const result = addressSchema.safeParse({ street: '123 Main', city: 'Springfield', zip: '123' }); + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error.issues.map((i) => i.path)).toContainEqual(['zip']); + } + }); + + it('fails when city is too short', () => { + const addressSchema = factory.makeTypeSchema('Address'); + const result = addressSchema.safeParse({ street: '123 Main', city: '', zip: '12345' }); + expect(result.success).toBe(false); + }); + + it('also validates when Address is embedded in User', () => { + const userSchema = factory.makeModelSchema('User'); + const validUser = { + id: 'u1', + email: 'a@b.com', + username: 'alice', + website: null, + code: 'USR01', + age: 20, + score: 50, + bigNum: BigInt(0), + balance: 1, + active: true, + birthdate: null, + avatar: null, + metadata: null, + status: 'ACTIVE', + address: { street: '123 Main', city: 'Springfield', zip: '90210' }, + }; + expect(userSchema.safeParse(validUser).success).toBe(true); + expect( + userSchema.safeParse({ ...validUser, address: { street: '123 Main', city: 'Springfield', zip: '123' } }) + .success, + ).toBe(false); + }); + }); +}); + +describe('SchemaFactory - makeEnumSchema', () => { + it('accepts all valid enum values', () => { + const statusSchema = factory.makeEnumSchema('Status'); + expect(statusSchema.safeParse('ACTIVE').success).toBe(true); + expect(statusSchema.safeParse('INACTIVE').success).toBe(true); + expect(statusSchema.safeParse('PENDING').success).toBe(true); + }); + + it('rejects values not in the enum', () => { + const statusSchema = factory.makeEnumSchema('Status'); + expect(statusSchema.safeParse('ADMIN').success).toBe(false); + expect(statusSchema.safeParse('active').success).toBe(false); + expect(statusSchema.safeParse('').success).toBe(false); + expect(statusSchema.safeParse(null).success).toBe(false); + expect(statusSchema.safeParse(42).success).toBe(false); + }); + + it('infers enum value union type', () => { + const _statusSchema = factory.makeEnumSchema('Status'); + type Status = z.infer; + expectTypeOf().toEqualTypeOf<'ACTIVE' | 'INACTIVE' | 'PENDING'>(); + }); + + it('throws when enum is not found', () => { + expect(() => factory.makeEnumSchema('Unknown' as any)).toThrow(); + }); +}); diff --git a/packages/zod/test/schema/schema.ts b/packages/zod/test/schema/schema.ts new file mode 100644 index 000000000..ae8c3be32 --- /dev/null +++ b/packages/zod/test/schema/schema.ts @@ -0,0 +1,186 @@ +////////////////////////////////////////////////////////////////////////////////////////////// +// DO NOT MODIFY THIS FILE // +// This file is automatically generated by ZenStack CLI and should not be manually updated. // +////////////////////////////////////////////////////////////////////////////////////////////// + +/* eslint-disable */ + +import { type SchemaDef, ExpressionUtils } from "@zenstackhq/schema"; +export class SchemaType implements SchemaDef { + provider = { + type: "sqlite" + } as const; + models = { + User: { + name: "User", + fields: { + id: { + name: "id", + type: "String", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], + default: ExpressionUtils.call("cuid") + }, + email: { + name: "email", + type: "String", + attributes: [{ name: "@email" }] + }, + username: { + name: "username", + type: "String", + attributes: [{ name: "@length", args: [{ name: "min", value: ExpressionUtils.literal(3) }, { name: "max", value: ExpressionUtils.literal(50) }] }] + }, + website: { + name: "website", + type: "String", + optional: true, + attributes: [{ name: "@url" }] + }, + code: { + name: "code", + type: "String", + attributes: [{ name: "@startsWith", args: [{ name: "text", value: ExpressionUtils.literal("USR") }] }] + }, + age: { + name: "age", + type: "Int", + attributes: [{ name: "@gt", args: [{ name: "value", value: ExpressionUtils.literal(0) }] }, { name: "@lte", args: [{ name: "value", value: ExpressionUtils.literal(150) }] }] + }, + score: { + name: "score", + type: "Float", + attributes: [{ name: "@gte", args: [{ name: "value", value: ExpressionUtils.literal(0.0) }] }, { name: "@lt", args: [{ name: "value", value: ExpressionUtils.literal(100.0) }] }] + }, + bigNum: { + name: "bigNum", + type: "BigInt", + attributes: [{ name: "@gte", args: [{ name: "value", value: ExpressionUtils.literal(0) }] }] + }, + balance: { + name: "balance", + type: "Decimal", + attributes: [{ name: "@gt", args: [{ name: "value", value: ExpressionUtils.literal(0) }] }] + }, + active: { + name: "active", + type: "Boolean" + }, + birthdate: { + name: "birthdate", + type: "DateTime", + optional: true + }, + avatar: { + name: "avatar", + type: "Bytes", + optional: true + }, + metadata: { + name: "metadata", + type: "Json", + optional: true + }, + status: { + name: "status", + type: "Status" + }, + address: { + name: "address", + type: "Address", + optional: true, + attributes: [{ name: "@json" }] + }, + posts: { + name: "posts", + type: "Post", + array: true, + relation: { opposite: "author" } + } + }, + attributes: [ + { name: "@@validate", args: [{ name: "value", value: ExpressionUtils.binary(ExpressionUtils.field("age"), ">=", ExpressionUtils.literal(18)) }, { name: "message", value: ExpressionUtils.literal("Must be adult") }, { name: "path", value: ExpressionUtils.array("String", [ExpressionUtils.literal("age")]) }] } + ], + idFields: ["id"], + uniqueFields: { + id: { type: "String" } + } + }, + Post: { + name: "Post", + fields: { + id: { + name: "id", + type: "String", + id: true, + attributes: [{ name: "@id" }, { name: "@default", args: [{ name: "value", value: ExpressionUtils.call("cuid") }] }], + default: ExpressionUtils.call("cuid") + }, + title: { + name: "title", + type: "String" + }, + published: { + name: "published", + type: "Boolean" + }, + author: { + name: "author", + type: "User", + optional: true, + attributes: [{ name: "@relation", args: [{ name: "fields", value: ExpressionUtils.array("String", [ExpressionUtils.field("authorId")]) }, { name: "references", value: ExpressionUtils.array("String", [ExpressionUtils.field("id")]) }] }], + relation: { opposite: "posts", fields: ["authorId"], references: ["id"] } + }, + authorId: { + name: "authorId", + type: "String", + optional: true, + foreignKeyFor: [ + "author" + ] + } + }, + idFields: ["id"], + uniqueFields: { + id: { type: "String" } + } + } + } as const; + typeDefs = { + Address: { + name: "Address", + fields: { + street: { + name: "street", + type: "String" + }, + city: { + name: "city", + type: "String", + attributes: [{ name: "@length", args: [{ name: "min", value: ExpressionUtils.literal(2) }] }] + }, + zip: { + name: "zip", + type: "String", + optional: true + } + }, + attributes: [ + { name: "@@validate", args: [{ name: "value", value: ExpressionUtils.binary(ExpressionUtils.binary(ExpressionUtils.field("zip"), "==", ExpressionUtils._null()), "||", ExpressionUtils.binary(ExpressionUtils.call("length", [ExpressionUtils.field("zip")]), "==", ExpressionUtils.literal(5))) }, { name: "message", value: ExpressionUtils.literal("Zip code must be exactly 5 characters") }, { name: "path", value: ExpressionUtils.array("String", [ExpressionUtils.literal("zip")]) }] } + ] + } + } as const; + enums = { + Status: { + name: "Status", + values: { + ACTIVE: "ACTIVE", + INACTIVE: "INACTIVE", + PENDING: "PENDING" + } + } + } as const; + authType = "User" as const; + plugins = {}; +} +export const schema = new SchemaType(); diff --git a/packages/zod/test/schema/schema.zmodel b/packages/zod/test/schema/schema.zmodel new file mode 100644 index 000000000..fd651c927 --- /dev/null +++ b/packages/zod/test/schema/schema.zmodel @@ -0,0 +1,46 @@ +datasource db { + provider = 'sqlite' +} + +enum Status { + ACTIVE + INACTIVE + PENDING +} + +type Address { + street String + city String @length(2) + zip String? + + @@validate(zip == null || length(zip) == 5, "Zip code must be exactly 5 characters", ["zip"]) +} + +model User { + id String @id @default(cuid()) + email String @email + username String @length(3, 50) + website String? @url + code String @startsWith("USR") + age Int @gt(0) @lte(150) + score Float @gte(0.0) @lt(100.0) + bigNum BigInt @gte(0) + balance Decimal @gt(0) + active Boolean + birthdate DateTime? + avatar Bytes? + metadata Json? + status Status + address Address? @json + posts Post[] + + @@validate(age >= 18, "Must be adult", ["age"]) +} + +model Post { + id String @id @default(cuid()) + title String + published Boolean + author User? @relation(fields: [authorId], references: [id]) + authorId String? +} diff --git a/packages/zod/vitest.config.ts b/packages/zod/vitest.config.ts new file mode 100644 index 000000000..75a9f709c --- /dev/null +++ b/packages/zod/vitest.config.ts @@ -0,0 +1,4 @@ +import base from '@zenstackhq/vitest-config/base'; +import { defineConfig, mergeConfig } from 'vitest/config'; + +export default mergeConfig(base, defineConfig({})); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 5828dbe57..b484975aa 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -525,6 +525,9 @@ importers: '@zenstackhq/schema': specifier: workspace:* version: link:../schema + '@zenstackhq/zod': + specifier: workspace:* + version: link:../zod better-sqlite3: specifier: 'catalog:' version: 12.5.0 @@ -642,6 +645,9 @@ importers: '@zenstackhq/typescript-config': specifier: workspace:* version: link:../config/typescript-config + '@zenstackhq/vitest-config': + specifier: workspace:* + version: link:../config/vitest-config packages/sdk: dependencies: @@ -828,9 +834,18 @@ importers: packages/zod: dependencies: - '@zenstackhq/orm': + '@zenstackhq/common-helpers': specifier: workspace:* - version: link:../orm + version: link:../common-helpers + '@zenstackhq/schema': + specifier: workspace:* + version: link:../schema + decimal.js: + specifier: 'catalog:' + version: 10.6.0 + json-stable-stringify: + specifier: ^1.3.0 + version: 1.3.0 ts-pattern: specifier: 'catalog:' version: 5.7.1 @@ -841,6 +856,9 @@ importers: '@zenstackhq/typescript-config': specifier: workspace:* version: link:../config/typescript-config + '@zenstackhq/vitest-config': + specifier: workspace:* + version: link:../config/vitest-config zod: specifier: ^4.1.0 version: 4.1.12 diff --git a/scripts/test-generate.ts b/scripts/test-generate.ts index e799ecf45..0af24290e 100644 --- a/scripts/test-generate.ts +++ b/scripts/test-generate.ts @@ -21,9 +21,12 @@ async function main() { async function generate(schemaPath: string, options: string[]) { const cliPath = path.join(_dirname, '../packages/cli/dist/index.js'); const RUNTIME = process.env.RUNTIME ?? 'node'; - execSync(`${RUNTIME} ${cliPath} generate --schema ${schemaPath} ${options.join(' ')}`, { - cwd: path.dirname(schemaPath), - }); + execSync( + `${RUNTIME} ${cliPath} generate --schema ${schemaPath} ${options.join(' ')} --generate-models=false --generate-input=false`, + { + cwd: path.dirname(schemaPath), + }, + ); } main(); diff --git a/tests/e2e/orm/client-api/create-many.test.ts b/tests/e2e/orm/client-api/create-many.test.ts index c55f05ac4..af944f4f9 100644 --- a/tests/e2e/orm/client-api/create-many.test.ts +++ b/tests/e2e/orm/client-api/create-many.test.ts @@ -15,11 +15,6 @@ describe('Client createMany tests', () => { }); it('works with toplevel createMany', async () => { - // empty - await expect(client.user.createMany()).resolves.toMatchObject({ - count: 0, - }); - // single await expect( client.user.createMany({ diff --git a/tests/e2e/orm/client-api/zod.test-d.ts b/tests/e2e/orm/client-api/zod.test-d.ts new file mode 100644 index 000000000..148c47901 --- /dev/null +++ b/tests/e2e/orm/client-api/zod.test-d.ts @@ -0,0 +1,366 @@ +import { definePlugin, type ClientContract, type ClientOptions } from '@zenstackhq/orm'; +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expectTypeOf, it } from 'vitest'; +import z from 'zod'; +import { schema } from '../schemas/basic'; + +declare const client: ClientContract; + +describe('Zod schema typing tests', () => { + it('makeFindManySchema returns a typed schema', () => { + const s = client.$zod.makeFindManySchema('User'); + type Args = z.infer; + // all find args are optional + expectTypeOf>().toHaveProperty('where'); + expectTypeOf>().toHaveProperty('take'); + expectTypeOf>().toHaveProperty('skip'); + expectTypeOf>().toHaveProperty('orderBy'); + expectTypeOf>().toHaveProperty('select'); + expectTypeOf>().toHaveProperty('include'); + expectTypeOf>().toHaveProperty('cursor'); + }); + + it('makeFindUniqueSchema returns a typed schema', () => { + const s = client.$zod.makeFindUniqueSchema('User'); + type Args = z.infer; + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('select'); + expectTypeOf().toHaveProperty('include'); + // where has id and email (unique fields) + expectTypeOf().toHaveProperty('id'); + expectTypeOf().toHaveProperty('email'); + }); + + it('makeFindFirstSchema returns a typed schema', () => { + const s = client.$zod.makeFindFirstSchema('User'); + type Args = z.infer; + expectTypeOf>().toHaveProperty('where'); + expectTypeOf>().toHaveProperty('take'); + }); + + it('makeExistsSchema returns a typed schema', () => { + const s = client.$zod.makeExistsSchema('User'); + type Args = NonNullable>; + expectTypeOf().toHaveProperty('where'); + expectTypeOf().not.toHaveProperty('select'); + expectTypeOf().not.toHaveProperty('include'); + }); + + it('makeCreateSchema returns a typed schema', () => { + const s = client.$zod.makeCreateSchema('User'); + type Args = z.infer; + // data is required + expectTypeOf().toHaveProperty('data'); + // select / include / omit are optional + expectTypeOf().toHaveProperty('select'); + expectTypeOf().toHaveProperty('include'); + expectTypeOf().toHaveProperty('omit'); + // data has required field email and optional field name + expectTypeOf().toHaveProperty('email'); + expectTypeOf().toHaveProperty('name'); + }); + + it('makeCreateManySchema returns a typed schema', () => { + const s = client.$zod.makeCreateManySchema('User'); + type Args = z.infer; + // data is required + expectTypeOf().toHaveProperty('data'); + // skipDuplicates is optional + expectTypeOf().toHaveProperty('skipDuplicates'); + // no select / include on createMany + expectTypeOf().not.toHaveProperty('select'); + expectTypeOf().not.toHaveProperty('include'); + }); + + it('makeCreateManyAndReturnSchema returns a typed schema', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + type Args = NonNullable>; + // data and skipDuplicates from createMany payload + expectTypeOf().toHaveProperty('data'); + expectTypeOf().toHaveProperty('skipDuplicates'); + // select and omit are supported; include is not + expectTypeOf().toHaveProperty('select'); + expectTypeOf().toHaveProperty('omit'); + expectTypeOf().not.toHaveProperty('include'); + }); + + it('makeUpdateSchema returns a typed schema', () => { + const s = client.$zod.makeUpdateSchema('User'); + type Args = z.infer; + // where (unique) and data are required + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('data'); + // select / include / omit are present + expectTypeOf().toHaveProperty('select'); + expectTypeOf().toHaveProperty('include'); + expectTypeOf().toHaveProperty('omit'); + // where is limited to unique fields (id and email) + expectTypeOf().toHaveProperty('id'); + expectTypeOf().toHaveProperty('email'); + // data has updatable fields + expectTypeOf().toHaveProperty('name'); + expectTypeOf().toHaveProperty('role'); + }); + + it('makeUpdateManySchema returns a typed schema', () => { + const s = client.$zod.makeUpdateManySchema('User'); + type Args = z.infer; + // data is required; where and limit are optional + expectTypeOf().toHaveProperty('data'); + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('limit'); + // no select / include on updateMany + expectTypeOf().not.toHaveProperty('select'); + expectTypeOf().not.toHaveProperty('include'); + }); + + it('makeUpdateManyAndReturnSchema returns a typed schema', () => { + const s = client.$zod.makeUpdateManyAndReturnSchema('User'); + type Args = z.infer; + // data is required; where and limit are optional + expectTypeOf().toHaveProperty('data'); + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('limit'); + // select and omit are present; include is not + expectTypeOf().toHaveProperty('select'); + expectTypeOf().toHaveProperty('omit'); + expectTypeOf().not.toHaveProperty('include'); + }); + + it('makeUpsertSchema returns a typed schema', () => { + const s = client.$zod.makeUpsertSchema('User'); + type Args = z.infer; + // where (unique), create, and update are all required + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('create'); + expectTypeOf().toHaveProperty('update'); + // select / include / omit are present + expectTypeOf().toHaveProperty('select'); + expectTypeOf().toHaveProperty('include'); + expectTypeOf().toHaveProperty('omit'); + // create has the required email field; update has optional fields + expectTypeOf().toHaveProperty('email'); + expectTypeOf().toHaveProperty('name'); + }); + + it('makeDeleteSchema returns a typed schema', () => { + const s = client.$zod.makeDeleteSchema('User'); + type Args = z.infer; + // where (unique) is required; no data field + expectTypeOf().toHaveProperty('where'); + expectTypeOf().not.toHaveProperty('data'); + // select / include / omit are present + expectTypeOf().toHaveProperty('select'); + expectTypeOf().toHaveProperty('include'); + expectTypeOf().toHaveProperty('omit'); + // where is limited to unique fields (id and email) + expectTypeOf().toHaveProperty('id'); + expectTypeOf().toHaveProperty('email'); + }); + + it('makeDeleteManySchema returns a typed schema', () => { + const s = client.$zod.makeDeleteManySchema('User'); + type Args = NonNullable>; + // where and limit are optional; no data field + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('limit'); + expectTypeOf().not.toHaveProperty('data'); + // no select / include on deleteMany + expectTypeOf().not.toHaveProperty('select'); + expectTypeOf().not.toHaveProperty('include'); + }); + + it('makeCountSchema returns a typed schema', () => { + const s = client.$zod.makeCountSchema('User'); + type Args = NonNullable>; + // where, select, skip, take, orderBy are present + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('select'); + expectTypeOf().toHaveProperty('skip'); + expectTypeOf().toHaveProperty('take'); + expectTypeOf().toHaveProperty('orderBy'); + // no data, include, omit + expectTypeOf().not.toHaveProperty('data'); + expectTypeOf().not.toHaveProperty('include'); + }); + + it('makeAggregateSchema returns a typed schema', () => { + const s = client.$zod.makeAggregateSchema('User'); + type Args = NonNullable>; + // standard query args + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('skip'); + expectTypeOf().toHaveProperty('take'); + expectTypeOf().toHaveProperty('orderBy'); + // aggregation operators + expectTypeOf().toHaveProperty('_count'); + expectTypeOf().toHaveProperty('_avg'); + expectTypeOf().toHaveProperty('_sum'); + expectTypeOf().toHaveProperty('_min'); + expectTypeOf().toHaveProperty('_max'); + }); + + it('makeGroupBySchema returns a typed schema', () => { + const s = client.$zod.makeGroupBySchema('User'); + type Args = z.infer; + // by is required; where, orderBy, having, skip, take, aggregations are optional + expectTypeOf().toHaveProperty('by'); + expectTypeOf().toHaveProperty('where'); + expectTypeOf().toHaveProperty('orderBy'); + expectTypeOf().toHaveProperty('having'); + expectTypeOf().toHaveProperty('skip'); + expectTypeOf().toHaveProperty('take'); + expectTypeOf().toHaveProperty('_count'); + }); +}); + +describe('Zod schema with slicing - typing', () => { + it('model exclusion removes relation field from include type', async () => { + type ExcludePostOptions = ClientOptions & { + slicing: { excludedModels: readonly ['Post'] }; + }; + const slicingClient = await createTestClient(schema, { + slicing: { excludedModels: ['Post'] as const }, + }); + const s = slicingClient.$zod.makeFindManySchema('User'); + type Include = NonNullable>['include']>; + // 'posts' relation is excluded → not in include type + expectTypeOf().not.toHaveProperty('posts'); + // 'profile' is not excluded → still in include type + expectTypeOf().toHaveProperty('profile'); + }); + + it('includedModels restricts relation fields in include type', async () => { + type IncludeUserProfileOptions = ClientOptions & { + slicing: { includedModels: readonly ['User', 'Profile'] }; + }; + const slicingClient = await createTestClient(schema, { + slicing: { includedModels: ['User', 'Profile'] as const }, + }); + const s = slicingClient.$zod.makeFindManySchema('User'); + type Include = NonNullable>['include']>; + // 'profile' points to Profile which is included + expectTypeOf().toHaveProperty('profile'); + // 'posts' points to Post which is NOT in includedModels → not in include type + expectTypeOf().not.toHaveProperty('posts'); + }); + + it('includedFilterKinds: Equality removes Range operators from number filter type', async () => { + type EqualityOnlyOptions = ClientOptions & { + slicing: { + models: { + user: { fields: { $all: { includedFilterKinds: readonly ['Equality'] } } }; + }; + }; + }; + const slicingClient = await createTestClient(schema, { + slicing: { + models: { user: { fields: { $all: { includedFilterKinds: ['Equality'] as const } } } }, + }, + }); + const s = slicingClient.$zod.makeFindManySchema('User'); + type Where = NonNullable>['where']>; + // Range operators are excluded → type error + // @ts-expect-error 'gt' is not a valid operator when only Equality is included + const _rangeInvalid: Where = { age: { gt: 25 } }; + void _rangeInvalid; + // Equality operators are still valid + const _equalityValid: Where = { age: { equals: 25 } }; + void _equalityValid; + }); + + it('includedFilterKinds: Equality removes Like operators from string filter type', async () => { + type EqualityOnlyOptions = ClientOptions & { + slicing: { + models: { + user: { fields: { $all: { includedFilterKinds: readonly ['Equality'] } } }; + }; + }; + }; + const slicingClient = await createTestClient(schema, { + slicing: { + models: { user: { fields: { $all: { includedFilterKinds: ['Equality'] as const } } } }, + }, + }); + const s = slicingClient.$zod.makeFindManySchema('User'); + type Where = NonNullable>['where']>; + // Like operators are excluded → type error + // @ts-expect-error 'contains' is not a valid operator when only Equality is included + const _likeInvalid: Where = { email: { contains: 'test' } }; + void _likeInvalid; + // Equality operators are still valid + const _equalityValid: Where = { email: { equals: 'test@example.com' } }; + void _equalityValid; + }); + + it('excludedFilterKinds: Range removes range operators while keeping equality and like', async () => { + type ExcludeRangeOptions = ClientOptions & { + slicing: { + models: { + user: { fields: { $all: { excludedFilterKinds: readonly ['Range'] } } }; + }; + }; + }; + const slicingClient = await createTestClient(schema, { + slicing: { + models: { user: { fields: { $all: { excludedFilterKinds: ['Range'] as const } } } }, + }, + }); + const s = slicingClient.$zod.makeFindManySchema('User'); + type Where = NonNullable>['where']>; + // Range operators are excluded → type error + // @ts-expect-error 'gt' is not a valid operator when Range is excluded + const _rangeInvalid: Where = { age: { gt: 25 } }; + void _rangeInvalid; + // Equality operators are still valid + const _equalityValid: Where = { age: { equals: 25 } }; + void _equalityValid; + // Like operators on string fields are still valid + const _likeValid: Where = { email: { contains: 'test' } }; + void _likeValid; + }); +}); + +describe('Zod schema with plugins - query args extension typing', () => { + const cachePlugin = definePlugin({ + id: 'cache', + queryArgs: { + $read: z.object({ cache: z.object({ ttl: z.number().optional() }).optional() }), + $create: z.object({ cache: z.object({ bust: z.boolean().optional() }).optional() }), + }, + }); + + it('find schema includes extended read args in type', () => { + const extClient = client.$use(cachePlugin); + const s = extClient.$zod.makeFindManySchema('User'); + type Args = NonNullable>; + expectTypeOf().toHaveProperty('cache'); + expectTypeOf>().toHaveProperty('ttl'); + }); + + it('create schema includes create-specific extended args in type', () => { + const extClient = client.$use(cachePlugin); + const s = extClient.$zod.makeCreateSchema('User'); + type Args = z.infer; + expectTypeOf().toHaveProperty('cache'); + // @ts-expect-error 'ttl' belongs to $read args, not $create args + const _invalid: Args = { data: { email: 'u@test.com' }, cache: { ttl: 1000 } }; + void _invalid; + }); + + it('$all extended args appear in all schema types', () => { + const sourcePlugin = definePlugin({ + id: 'source', + queryArgs: { + $all: z.object({ source: z.string().optional() }), + }, + }); + const extClient = client.$use(sourcePlugin); + const findSchema = extClient.$zod.makeFindManySchema('User'); + type FindArgs = NonNullable>; + expectTypeOf().toHaveProperty('source'); + const createSchema = extClient.$zod.makeCreateSchema('User'); + type CreateArgs = z.infer; + expectTypeOf().toHaveProperty('source'); + }); +}); diff --git a/tests/e2e/orm/client-api/zod.test.ts b/tests/e2e/orm/client-api/zod.test.ts new file mode 100644 index 000000000..377b68743 --- /dev/null +++ b/tests/e2e/orm/client-api/zod.test.ts @@ -0,0 +1,1063 @@ +import { createQuerySchemaFactory, definePlugin, type ClientContract } from '@zenstackhq/orm'; +import { createTestClient, getTestDbProvider } from '@zenstackhq/testtools'; +import { afterEach, beforeEach, describe as _describe, expect, it } from 'vitest'; +import z from 'zod'; +import { schema } from '../schemas/basic'; + +// only run for sqlite because schemas are provider independent +const describe = getTestDbProvider() === 'sqlite' ? _describe : _describe.skip; + +describe('Zod schema factory test', () => { + if (getTestDbProvider() !== 'sqlite') { + return; + } + + let client: ClientContract; + + beforeEach(async () => { + client = await createTestClient(schema); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + describe('CRUD schemas tests', () => { + // #region Find + + describe('makeFindManySchema', () => { + it('accepts undefined (all args optional)', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse(undefined).success).toBe(true); + }); + + it('accepts valid where clause', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ where: { email: { equals: 'u@test.com' } } }).success).toBe(true); + }); + + it('accepts string filter operators', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ where: { email: { contains: 'test', startsWith: 'u' } } }).success).toBe(true); + }); + + it('accepts number filter operators', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ where: { age: { gt: 18, lte: 65 } } }).success).toBe(true); + }); + + it('accepts enum filter', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ where: { role: { in: ['USER', 'ADMIN'] } } }).success).toBe(true); + }); + + it('accepts logical combinators (AND/OR/NOT)', () => { + const s = client.$zod.makeFindManySchema('User'); + expect( + s.safeParse({ + where: { + AND: [{ email: { contains: 'test' } }, { role: 'USER' }], + }, + }).success, + ).toBe(true); + expect(s.safeParse({ where: { NOT: { email: { equals: 'admin@test.com' } } } }).success).toBe(true); + }); + + it('accepts relation filter', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ where: { posts: { some: { published: true } } } }).success).toBe(true); + }); + + it('accepts pagination args (take, skip)', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ take: 10, skip: 20 }).success).toBe(true); + }); + + it('accepts orderBy', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ orderBy: { email: 'asc' } }).success).toBe(true); + expect(s.safeParse({ orderBy: [{ email: 'asc' }, { name: 'desc' }] }).success).toBe(true); + }); + + it('accepts select', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ select: { id: true, email: true } }).success).toBe(true); + }); + + it('accepts include', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ include: { posts: true, profile: true } }).success).toBe(true); + }); + + it('rejects select and include together', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ select: { id: true }, include: { posts: true } }).success).toBe(false); + }); + + it('rejects select and omit together', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ select: { id: true }, omit: { name: true } }).success).toBe(false); + }); + + it('rejects unknown where field', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ where: { notAField: 'val' } }).success).toBe(false); + }); + + it('rejects invalid enum value in where', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ where: { role: 'SUPERUSER' } }).success).toBe(false); + }); + + it('rejects non-integer take', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ take: 1.5 }).success).toBe(false); + }); + + it('accepts negative take (cursor-based pagination)', () => { + const s = client.$zod.makeFindManySchema('User'); + // negative take is valid: means "take the last N results" + expect(s.safeParse({ take: -1 }).success).toBe(true); + }); + }); + + describe('makeFindUniqueSchema', () => { + it('accepts where with unique id field', () => { + const s = client.$zod.makeFindUniqueSchema('User'); + expect(s.safeParse({ where: { id: 'u1' } }).success).toBe(true); + }); + + it('accepts where with unique email field', () => { + const s = client.$zod.makeFindUniqueSchema('User'); + expect(s.safeParse({ where: { email: 'u@test.com' } }).success).toBe(true); + }); + + it('accepts optional select/include', () => { + const s = client.$zod.makeFindUniqueSchema('User'); + expect(s.safeParse({ where: { id: 'u1' }, include: { posts: true } }).success).toBe(true); + }); + + it('rejects empty where', () => { + const s = client.$zod.makeFindUniqueSchema('User'); + expect(s.safeParse({ where: {} }).success).toBe(false); + }); + + it('rejects non-unique where field', () => { + const s = client.$zod.makeFindUniqueSchema('User'); + // name is not a unique field + expect(s.safeParse({ where: { name: 'Alice' } }).success).toBe(false); + }); + + it('rejects missing where', () => { + const s = client.$zod.makeFindUniqueSchema('User'); + expect(s.safeParse({}).success).toBe(false); + }); + + it('rejects select and include together', () => { + const s = client.$zod.makeFindUniqueSchema('User'); + expect( + s.safeParse({ where: { id: 'u1' }, select: { id: true }, include: { posts: true } }).success, + ).toBe(false); + }); + }); + + describe('makeFindFirstSchema', () => { + it('accepts undefined (all args optional)', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse(undefined).success).toBe(true); + }); + + it('accepts valid where clause', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ where: { email: { contains: 'test' } } }).success).toBe(true); + }); + + it('accepts non-unique where field (unlike findUnique)', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ where: { name: 'Alice' } }).success).toBe(true); + }); + + it('accepts pagination args (take, skip)', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ take: 1, skip: 5 }).success).toBe(true); + }); + + it('accepts orderBy', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ orderBy: { email: 'asc' } }).success).toBe(true); + }); + + it('accepts select', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ select: { id: true, email: true } }).success).toBe(true); + }); + + it('accepts include', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ include: { posts: true } }).success).toBe(true); + }); + + it('rejects select and include together', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ select: { id: true }, include: { posts: true } }).success).toBe(false); + }); + + it('rejects unknown where field', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ where: { notAField: 'val' } }).success).toBe(false); + }); + + it('rejects invalid enum value in where', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ where: { role: 'SUPERUSER' } }).success).toBe(false); + }); + + it('rejects non-integer take', () => { + const s = client.$zod.makeFindFirstSchema('User'); + expect(s.safeParse({ take: 1.5 }).success).toBe(false); + }); + }); + + describe('makeExistsSchema', () => { + it('accepts undefined (all optional)', () => { + const s = client.$zod.makeExistsSchema('User'); + expect(s.safeParse(undefined).success).toBe(true); + }); + + it('accepts empty object (where is optional)', () => { + const s = client.$zod.makeExistsSchema('User'); + expect(s.safeParse({}).success).toBe(true); + }); + + it('accepts where clause', () => { + const s = client.$zod.makeExistsSchema('User'); + expect(s.safeParse({ where: { role: 'USER' } }).success).toBe(true); + }); + + it('rejects unknown where field', () => { + const s = client.$zod.makeExistsSchema('User'); + expect(s.safeParse({ where: { notAField: 'val' } }).success).toBe(false); + }); + }); + + // #endregion + + // #region Create + + describe('makeCreateSchema', () => { + it('accepts minimal valid create input (required fields only)', () => { + const s = client.$zod.makeCreateSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com' } }).success).toBe(true); + }); + + it('accepts full create input with optional fields', () => { + const s = client.$zod.makeCreateSchema('User'); + expect( + s.safeParse({ + data: { email: 'u@test.com', name: 'Alice', age: 30, role: 'ADMIN' }, + }).success, + ).toBe(true); + }); + + it('accepts nested relation in create (nested create)', () => { + const s = client.$zod.makeCreateSchema('User'); + expect( + s.safeParse({ + data: { + email: 'u@test.com', + posts: { create: { title: 'Hello' } }, + }, + }).success, + ).toBe(true); + }); + + it('accepts select/include in create args', () => { + const s = client.$zod.makeCreateSchema('User'); + expect( + s.safeParse({ + data: { email: 'u@test.com' }, + select: { id: true, email: true }, + }).success, + ).toBe(true); + }); + + it('rejects missing required field (email)', () => { + const s = client.$zod.makeCreateSchema('User'); + expect(s.safeParse({ data: {} }).success).toBe(false); + }); + + it('rejects invalid enum value', () => { + const s = client.$zod.makeCreateSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com', role: 'SUPERUSER' } }).success).toBe(false); + }); + + it('rejects unknown field in data', () => { + const s = client.$zod.makeCreateSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com', notAField: 'val' } }).success).toBe(false); + }); + + it('rejects missing data wrapper', () => { + const s = client.$zod.makeCreateSchema('User'); + expect(s.safeParse({ email: 'u@test.com' }).success).toBe(false); + }); + + it('rejects select and include together', () => { + const s = client.$zod.makeCreateSchema('User'); + expect( + s.safeParse({ data: { email: 'u@test.com' }, select: { id: true }, include: { posts: true } }) + .success, + ).toBe(false); + }); + }); + + describe('makeCreateManySchema', () => { + it('accepts single record as data', () => { + const s = client.$zod.makeCreateManySchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com' } }).success).toBe(true); + }); + + it('accepts array of records as data', () => { + const s = client.$zod.makeCreateManySchema('User'); + expect(s.safeParse({ data: [{ email: 'a@test.com' }, { email: 'b@test.com' }] }).success).toBe(true); + }); + + it('accepts skipDuplicates flag', () => { + const s = client.$zod.makeCreateManySchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com' }, skipDuplicates: true }).success).toBe(true); + }); + + it('rejects missing data', () => { + const s = client.$zod.makeCreateManySchema('User'); + expect(s.safeParse({}).success).toBe(false); + }); + + it('rejects unknown field in data', () => { + const s = client.$zod.makeCreateManySchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com', notAField: 'val' } }).success).toBe(false); + }); + + it('rejects invalid enum value in data', () => { + const s = client.$zod.makeCreateManySchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com', role: 'SUPERUSER' } }).success).toBe(false); + }); + }); + + describe('makeCreateManyAndReturnSchema', () => { + it('accepts undefined (whole schema is optional)', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + expect(s.safeParse(undefined).success).toBe(true); + }); + + it('accepts single record as data', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com' } }).success).toBe(true); + }); + + it('accepts array of records as data', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + expect(s.safeParse({ data: [{ email: 'a@test.com' }, { email: 'b@test.com' }] }).success).toBe(true); + }); + + it('accepts skipDuplicates flag', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com' }, skipDuplicates: true }).success).toBe(true); + }); + + it('accepts select', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com' }, select: { id: true } }).success).toBe(true); + }); + + it('accepts omit', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com' }, omit: { name: true } }).success).toBe(true); + }); + + it('rejects select and omit together', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + expect( + s.safeParse({ data: { email: 'u@test.com' }, select: { id: true }, omit: { name: true } }).success, + ).toBe(false); + }); + + it('rejects unknown field in data', () => { + const s = client.$zod.makeCreateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com', notAField: 'val' } }).success).toBe(false); + }); + }); + + // #endregion + + // #region Update + + describe('makeUpdateSchema', () => { + it('accepts valid update args', () => { + const s = client.$zod.makeUpdateSchema('User'); + expect(s.safeParse({ where: { id: 'u1' }, data: { name: 'Alice' } }).success).toBe(true); + }); + + it('accepts update with enum field', () => { + const s = client.$zod.makeUpdateSchema('User'); + expect(s.safeParse({ where: { id: 'u1' }, data: { role: 'ADMIN' } }).success).toBe(true); + }); + + it('accepts update with nested relation', () => { + const s = client.$zod.makeUpdateSchema('User'); + expect( + s.safeParse({ + where: { id: 'u1' }, + data: { posts: { create: { title: 'New Post' } } }, + }).success, + ).toBe(true); + }); + + it('rejects missing where', () => { + const s = client.$zod.makeUpdateSchema('User'); + expect(s.safeParse({ data: { name: 'Alice' } }).success).toBe(false); + }); + + it('rejects missing data', () => { + const s = client.$zod.makeUpdateSchema('User'); + expect(s.safeParse({ where: { id: 'u1' } }).success).toBe(false); + }); + + it('rejects non-unique where', () => { + const s = client.$zod.makeUpdateSchema('User'); + // name is not a unique field + expect(s.safeParse({ where: { name: 'Alice' }, data: { name: 'Bob' } }).success).toBe(false); + }); + + it('rejects invalid enum value in data', () => { + const s = client.$zod.makeUpdateSchema('User'); + expect(s.safeParse({ where: { id: 'u1' }, data: { role: 'INVALID' } }).success).toBe(false); + }); + }); + + describe('makeUpdateManySchema', () => { + it('accepts valid updateMany (where is optional)', () => { + const s = client.$zod.makeUpdateManySchema('User'); + expect(s.safeParse({ data: { name: 'Updated' } }).success).toBe(true); + }); + + it('accepts updateMany with non-unique where', () => { + const s = client.$zod.makeUpdateManySchema('User'); + expect(s.safeParse({ where: { role: 'USER' }, data: { name: 'Updated' } }).success).toBe(true); + }); + + it('rejects missing data', () => { + const s = client.$zod.makeUpdateManySchema('User'); + expect(s.safeParse({ where: { role: 'USER' } }).success).toBe(false); + }); + }); + + describe('makeUpdateManyAndReturnSchema', () => { + it('accepts minimal valid args (data required)', () => { + const s = client.$zod.makeUpdateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { name: 'Updated' } }).success).toBe(true); + }); + + it('accepts where clause', () => { + const s = client.$zod.makeUpdateManyAndReturnSchema('User'); + expect(s.safeParse({ where: { role: 'USER' }, data: { name: 'Updated' } }).success).toBe(true); + }); + + it('accepts select', () => { + const s = client.$zod.makeUpdateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { name: 'Updated' }, select: { id: true } }).success).toBe(true); + }); + + it('accepts omit', () => { + const s = client.$zod.makeUpdateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { name: 'Updated' }, omit: { name: true } }).success).toBe(true); + }); + + it('rejects select and omit together', () => { + const s = client.$zod.makeUpdateManyAndReturnSchema('User'); + expect( + s.safeParse({ data: { name: 'Updated' }, select: { id: true }, omit: { name: true } }).success, + ).toBe(false); + }); + + it('rejects missing data', () => { + const s = client.$zod.makeUpdateManyAndReturnSchema('User'); + expect(s.safeParse({ where: { role: 'USER' } }).success).toBe(false); + }); + + it('rejects invalid enum in data', () => { + const s = client.$zod.makeUpdateManyAndReturnSchema('User'); + expect(s.safeParse({ data: { role: 'INVALID' } }).success).toBe(false); + }); + }); + + describe('makeUpsertSchema', () => { + it('accepts valid upsert args', () => { + const s = client.$zod.makeUpsertSchema('User'); + expect( + s.safeParse({ + where: { id: 'u1' }, + create: { email: 'u@test.com' }, + update: { name: 'Alice' }, + }).success, + ).toBe(true); + }); + + it('rejects missing create', () => { + const s = client.$zod.makeUpsertSchema('User'); + expect(s.safeParse({ where: { id: 'u1' }, update: { name: 'Alice' } }).success).toBe(false); + }); + + it('rejects missing update', () => { + const s = client.$zod.makeUpsertSchema('User'); + expect(s.safeParse({ where: { id: 'u1' }, create: { email: 'u@test.com' } }).success).toBe(false); + }); + + it('rejects missing where', () => { + const s = client.$zod.makeUpsertSchema('User'); + expect(s.safeParse({ create: { email: 'u@test.com' }, update: { name: 'Alice' } }).success).toBe(false); + }); + + it('rejects invalid enum in create', () => { + const s = client.$zod.makeUpsertSchema('User'); + expect( + s.safeParse({ + where: { id: 'u1' }, + create: { email: 'u@test.com', role: 'BAD' }, + update: {}, + }).success, + ).toBe(false); + }); + }); + + // #endregion + + // #region Delete + + describe('makeDeleteSchema', () => { + it('accepts valid delete args with unique where', () => { + const s = client.$zod.makeDeleteSchema('User'); + expect(s.safeParse({ where: { id: 'u1' } }).success).toBe(true); + }); + + it('accepts unique email in where', () => { + const s = client.$zod.makeDeleteSchema('User'); + expect(s.safeParse({ where: { email: 'u@test.com' } }).success).toBe(true); + }); + + it('rejects missing where', () => { + const s = client.$zod.makeDeleteSchema('User'); + expect(s.safeParse({}).success).toBe(false); + }); + + it('rejects empty where', () => { + const s = client.$zod.makeDeleteSchema('User'); + expect(s.safeParse({ where: {} }).success).toBe(false); + }); + + it('rejects non-unique where field', () => { + const s = client.$zod.makeDeleteSchema('User'); + expect(s.safeParse({ where: { name: 'Alice' } }).success).toBe(false); + }); + }); + + describe('makeDeleteManySchema', () => { + it('accepts undefined (where optional, deletes all)', () => { + const s = client.$zod.makeDeleteManySchema('User'); + expect(s.safeParse(undefined).success).toBe(true); + }); + + it('accepts non-unique where', () => { + const s = client.$zod.makeDeleteManySchema('User'); + expect(s.safeParse({ where: { role: 'USER' } }).success).toBe(true); + }); + + it('rejects unknown where field', () => { + const s = client.$zod.makeDeleteManySchema('User'); + expect(s.safeParse({ where: { notAField: 'val' } }).success).toBe(false); + }); + }); + + // #endregion + + // #region Aggregation + + describe('makeCountSchema', () => { + it('accepts undefined (all optional)', () => { + const s = client.$zod.makeCountSchema('User'); + expect(s.safeParse(undefined).success).toBe(true); + }); + + it('accepts where clause', () => { + const s = client.$zod.makeCountSchema('User'); + expect(s.safeParse({ where: { role: 'USER' } }).success).toBe(true); + }); + + it('accepts select: true (count all)', () => { + const s = client.$zod.makeCountSchema('User'); + expect(s.safeParse({ select: true }).success).toBe(true); + }); + + it('accepts select with field names', () => { + const s = client.$zod.makeCountSchema('User'); + expect(s.safeParse({ select: { id: true, email: true } }).success).toBe(true); + }); + + it('accepts take and skip', () => { + const s = client.$zod.makeCountSchema('User'); + expect(s.safeParse({ take: 10, skip: 5 }).success).toBe(true); + }); + + it('accepts orderBy', () => { + const s = client.$zod.makeCountSchema('User'); + expect(s.safeParse({ orderBy: { email: 'asc' } }).success).toBe(true); + }); + + it('rejects unknown where field', () => { + const s = client.$zod.makeCountSchema('User'); + expect(s.safeParse({ where: { notAField: 'val' } }).success).toBe(false); + }); + }); + + describe('makeAggregateSchema', () => { + it('accepts undefined (all optional)', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse(undefined).success).toBe(true); + }); + + it('accepts where clause', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ where: { role: 'USER' } }).success).toBe(true); + }); + + it('accepts _count: true', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ _count: true }).success).toBe(true); + }); + + it('accepts _count with field selection', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ _count: { id: true, email: true } }).success).toBe(true); + }); + + it('accepts _avg on numeric fields', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ _avg: { age: true } }).success).toBe(true); + }); + + it('accepts _sum on numeric fields', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ _sum: { age: true } }).success).toBe(true); + }); + + it('accepts _min and _max on non-array non-relation fields', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ _min: { age: true, email: true }, _max: { age: true } }).success).toBe(true); + }); + + it('accepts take and skip', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ take: 10, skip: 5 }).success).toBe(true); + }); + + it('accepts orderBy', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ orderBy: { age: 'asc' } }).success).toBe(true); + }); + + it('rejects unknown where field', () => { + const s = client.$zod.makeAggregateSchema('User'); + expect(s.safeParse({ where: { notAField: 'val' } }).success).toBe(false); + }); + }); + + describe('makeGroupBySchema', () => { + it('accepts single field in by', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'role' }).success).toBe(true); + }); + + it('accepts multiple fields in by as array', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: ['role', 'name'] }).success).toBe(true); + }); + + it('rejects missing by', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({}).success).toBe(false); + }); + + it('rejects relation field in by', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'posts' }).success).toBe(false); + }); + + it('accepts where clause', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'role', where: { role: 'USER' } }).success).toBe(true); + }); + + it('accepts _count aggregation', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'role', _count: true }).success).toBe(true); + }); + + it('accepts _avg on numeric fields', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'role', _avg: { age: true } }).success).toBe(true); + }); + + it('accepts orderBy matching the by field', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'role', orderBy: { role: 'asc' } }).success).toBe(true); + }); + + it('rejects orderBy with a field not in by (without aggregation)', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'role', orderBy: { name: 'asc' } }).success).toBe(false); + }); + + it('accepts take and skip', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'role', take: 10, skip: 0 }).success).toBe(true); + }); + + it('rejects unknown field in by', () => { + const s = client.$zod.makeGroupBySchema('User'); + expect(s.safeParse({ by: 'notAField' }).success).toBe(false); + }); + }); + + // #endregion + }); + + // #region Slicing + + describe('slicing - model exclusion', () => { + it('excluded model relation is rejected in select', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { excludedModels: ['Post'] }, + }); + try { + const s = slicingClient.$zod.makeFindManySchema('User'); + // 'posts' relation is excluded → rejected by strict schema + expect(s.safeParse({ select: { posts: true } }).success).toBe(false); + // scalar fields are unaffected + expect(s.safeParse({ select: { id: true, email: true } }).success).toBe(true); + } finally { + await slicingClient.$disconnect(); + } + }); + + it('excluded model relation is rejected in include', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { excludedModels: ['Post'] }, + }); + try { + const s = slicingClient.$zod.makeFindManySchema('User'); + expect(s.safeParse({ include: { posts: true } }).success).toBe(false); + // 'profile' is not excluded, still allowed + expect(s.safeParse({ include: { profile: true } }).success).toBe(true); + } finally { + await slicingClient.$disconnect(); + } + }); + + it('excluded model relation is rejected in create data', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { excludedModels: ['Post'] }, + }); + try { + const s = slicingClient.$zod.makeCreateSchema('User'); + expect( + s.safeParse({ + data: { email: 'u@test.com', posts: { create: { title: 'Hello' } } }, + }).success, + ).toBe(false); + // without the excluded relation, create still works + expect(s.safeParse({ data: { email: 'u@test.com' } }).success).toBe(true); + } finally { + await slicingClient.$disconnect(); + } + }); + + it('excluded model relation is rejected in update data', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { excludedModels: ['Post'] }, + }); + try { + const s = slicingClient.$zod.makeUpdateSchema('User'); + expect( + s.safeParse({ + where: { id: 'u1' }, + data: { posts: { create: { title: 'New Post' } } }, + }).success, + ).toBe(false); + } finally { + await slicingClient.$disconnect(); + } + }); + + it('includedModels restricts relations to allowed models only', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { includedModels: ['User', 'Profile'] }, + }); + try { + const s = slicingClient.$zod.makeFindManySchema('User'); + // Post is not included → posts relation rejected + expect(s.safeParse({ include: { posts: true } }).success).toBe(false); + // Profile is included → profile relation accepted + expect(s.safeParse({ include: { profile: true } }).success).toBe(true); + } finally { + await slicingClient.$disconnect(); + } + }); + }); + + describe('slicing - filter kinds', () => { + it('includedFilterKinds restricts to equality operators only', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { + models: { + user: { fields: { $all: { includedFilterKinds: ['Equality'] as const } } }, + }, + }, + }); + try { + const s = slicingClient.$zod.makeFindManySchema('User'); + // equality operators accepted + expect(s.safeParse({ where: { age: { equals: 25 } } }).success).toBe(true); + expect(s.safeParse({ where: { email: { in: ['a@b.com'] } } }).success).toBe(true); + // direct value still accepted (equality) + expect(s.safeParse({ where: { age: 25 } }).success).toBe(true); + // range operators rejected + expect(s.safeParse({ where: { age: { gt: 18 } } }).success).toBe(false); + expect(s.safeParse({ where: { age: { lte: 65 } } }).success).toBe(false); + expect(s.safeParse({ where: { age: { between: [10, 50] } } }).success).toBe(false); + // like operators rejected + expect(s.safeParse({ where: { email: { contains: 'test' } } }).success).toBe(false); + expect(s.safeParse({ where: { email: { startsWith: 'u' } } }).success).toBe(false); + } finally { + await slicingClient.$disconnect(); + } + }); + + it('excludedFilterKinds removes specified operators while keeping others', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { + models: { + user: { fields: { $all: { excludedFilterKinds: ['Range'] as const } } }, + }, + }, + }); + try { + const s = slicingClient.$zod.makeFindManySchema('User'); + // equality operators still work + expect(s.safeParse({ where: { age: { equals: 25 } } }).success).toBe(true); + // like operators still work for string fields + expect(s.safeParse({ where: { email: { contains: 'test' } } }).success).toBe(true); + // direct value still works + expect(s.safeParse({ where: { age: 25 } }).success).toBe(true); + // range operators rejected + expect(s.safeParse({ where: { age: { gt: 18 } } }).success).toBe(false); + expect(s.safeParse({ where: { age: { lte: 65 } } }).success).toBe(false); + expect(s.safeParse({ where: { age: { between: [10, 50] } } }).success).toBe(false); + } finally { + await slicingClient.$disconnect(); + } + }); + + it('field-level filter overrides model-level $all', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { + models: { + user: { + fields: { + $all: { includedFilterKinds: ['Equality'] as const }, + // 'name' additionally allows Like operators + name: { includedFilterKinds: ['Equality', 'Like'] as const }, + }, + }, + }, + }, + }); + try { + const s = slicingClient.$zod.makeFindManySchema('User'); + // 'name' has field-level override: allows Like + expect(s.safeParse({ where: { name: { contains: 'Alice' } } }).success).toBe(true); + expect(s.safeParse({ where: { name: { startsWith: 'A' } } }).success).toBe(true); + // 'email' falls back to $all: Equality only + expect(s.safeParse({ where: { email: { contains: 'test' } } }).success).toBe(false); + expect(s.safeParse({ where: { email: { equals: 'a@b.com' } } }).success).toBe(true); + } finally { + await slicingClient.$disconnect(); + } + }); + + it('$all models fallback applies filter restrictions across all models', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { + models: { + $all: { fields: { $all: { includedFilterKinds: ['Equality'] as const } } }, + }, + }, + }); + try { + const userSchema = slicingClient.$zod.makeFindManySchema('User'); + const postSchema = slicingClient.$zod.makeFindManySchema('Post'); + // equality works for both models + expect(userSchema.safeParse({ where: { email: { equals: 'u@test.com' } } }).success).toBe(true); + expect(postSchema.safeParse({ where: { published: { equals: true } } }).success).toBe(true); + // range/like rejected for both models + expect(userSchema.safeParse({ where: { age: { gt: 18 } } }).success).toBe(false); + expect(postSchema.safeParse({ where: { title: { contains: 'hello' } } }).success).toBe(false); + } finally { + await slicingClient.$disconnect(); + } + }); + + it('Relation filter kind exclusion rejects relation-style filters on a field', async () => { + const slicingClient = await createTestClient(schema, { + slicing: { + models: { + user: { + fields: { + posts: { excludedFilterKinds: ['Relation'] as const }, + }, + }, + }, + }, + }); + try { + const s = slicingClient.$zod.makeFindManySchema('User'); + // relation-style filters on 'posts' are excluded + expect(s.safeParse({ where: { posts: { some: { published: true } } } }).success).toBe(false); + expect(s.safeParse({ where: { posts: { every: { published: true } } } }).success).toBe(false); + // scalar fields on the same model are unaffected + expect(s.safeParse({ where: { email: { equals: 'u@test.com' } } }).success).toBe(true); + } finally { + await slicingClient.$disconnect(); + } + }); + }); + + // #endregion + + // #region Plugin query args + + describe('plugin - query args extension', () => { + const cachePlugin = definePlugin({ + id: 'cache', + queryArgs: { + $read: z.object({ cache: z.strictObject({ ttl: z.number().min(0).optional() }).optional() }), + $create: z.object({ cache: z.strictObject({ bust: z.boolean().optional() }).optional() }), + }, + }); + + it('extended read args are accepted by find schema', () => { + const extClient = client.$use(cachePlugin); + const s = extClient.$zod.makeFindManySchema('User'); + expect(s.safeParse({ cache: { ttl: 1000 } }).success).toBe(true); + }); + + it('extended read args are validated (min constraint)', () => { + const extClient = client.$use(cachePlugin); + const s = extClient.$zod.makeFindManySchema('User'); + expect(s.safeParse({ cache: { ttl: -1 } }).success).toBe(false); + }); + + it('strict validation rejects unknown plugin arg keys', () => { + const extClient = client.$use(cachePlugin); + const s = extClient.$zod.makeFindManySchema('User'); + expect(s.safeParse({ cache: { ttl: 100, unknown: true } }).success).toBe(false); + }); + + it('create-specific extended args are accepted by create schema', () => { + const extClient = client.$use(cachePlugin); + const s = extClient.$zod.makeCreateSchema('User'); + expect(s.safeParse({ data: { email: 'u@test.com' }, cache: { bust: true } }).success).toBe(true); + }); + + it('base client schema rejects extended args', () => { + const s = client.$zod.makeFindManySchema('User'); + expect(s.safeParse({ cache: { ttl: 1000 } }).success).toBe(false); + }); + + it('$all extended args appear in all operation schemas', () => { + const sourcePlugin = definePlugin({ + id: 'source', + queryArgs: { + $all: z.object({ source: z.string().optional() }), + }, + }); + const extClient = client.$use(sourcePlugin); + expect(extClient.$zod.makeFindManySchema('User').safeParse({ source: 'web' }).success).toBe(true); + expect( + extClient.$zod.makeCreateSchema('User').safeParse({ data: { email: 'u@test.com' }, source: 'web' }) + .success, + ).toBe(true); + }); + }); + + // #endregion + + // #region ZodSchemaFactory standalone constructor + + describe('create factory functions tests', () => { + it('can be constructed directly from client', async () => { + try { + const client = await createTestClient(schema); + const factory = createQuerySchemaFactory(client); + const s = factory.makeFindManySchema('User'); + expect(s.safeParse({ where: { email: 'u@test.com' } }).success).toBe(true); + expect(s.safeParse({ where: { notAField: 'val' } }).success).toBe(false); + } finally { + await client.$disconnect(); + } + }); + + it('can be constructed directly from schema and options and produces equivalent schemas', () => { + const factory = createQuerySchemaFactory(schema); + const s = factory.makeFindManySchema('User'); + expect(s.safeParse({ where: { email: 'u@test.com' } }).success).toBe(true); + expect(s.safeParse({ where: { notAField: 'val' } }).success).toBe(false); + }); + }); + + // #endregion + + // #region makeProcedureParamSchema + + describe('makeProcedureParamSchema', () => { + it('works with scalar types', () => { + const s = client.$zod.makeProcedureParamSchema({ type: 'String' }); + expect(s.safeParse('hello').success).toBe(true); + expect(s.safeParse(42).success).toBe(false); + }); + + it('works with array types', () => { + const s = client.$zod.makeProcedureParamSchema({ type: 'String', array: true }); + expect(s.safeParse(['a', 'b', 'c']).success).toBe(true); + expect(s.safeParse('a').success).toBe(false); + expect(s.safeParse([1, 2, 3]).success).toBe(false); + }); + + it('works with optional types', () => { + const s = client.$zod.makeProcedureParamSchema({ type: 'String', optional: true }); + expect(s.safeParse('hello').success).toBe(true); + expect(s.safeParse(undefined).success).toBe(true); + expect(s.safeParse(42).success).toBe(false); + }); + + it('works with array and optional types combined', () => { + const s = client.$zod.makeProcedureParamSchema({ type: 'Int', array: true, optional: true }); + expect(s.safeParse([1, 2, 3]).success).toBe(true); + expect(s.safeParse(undefined).success).toBe(true); + expect(s.safeParse(1).success).toBe(false); + }); + + it('throws for unsupported type', () => { + expect(() => client.$zod.makeProcedureParamSchema({ type: 'NotAType' })).toThrow(); + }); + }); + + // #endregion +}); diff --git a/tests/e2e/vitest.config.ts b/tests/e2e/vitest.config.ts index 10606ce6c..f8b260ff4 100644 --- a/tests/e2e/vitest.config.ts +++ b/tests/e2e/vitest.config.ts @@ -6,6 +6,9 @@ export default mergeConfig( defineConfig({ test: { setupFiles: ['@zenstackhq/testtools'], + typecheck: { + enabled: true, + }, }, }), );