diff --git a/packages/cli/package.json b/packages/cli/package.json index 2196fa74d..7ab08561f 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -36,7 +36,9 @@ "./package.json": "./package.json" }, "dependencies": { + "@dotenvx/dotenvx": "^1.51.0", "@zenstackhq/common-helpers": "workspace:*", + "@zenstackhq/schema": "workspace:*", "@zenstackhq/language": "workspace:*", "@zenstackhq/orm": "workspace:*", "@zenstackhq/sdk": "workspace:*", diff --git a/packages/cli/src/actions/action-utils.ts b/packages/cli/src/actions/action-utils.ts index d2e0ca2e9..86d55baa6 100644 --- a/packages/cli/src/actions/action-utils.ts +++ b/packages/cli/src/actions/action-utils.ts @@ -1,5 +1,5 @@ -import { loadDocument } from '@zenstackhq/language'; -import { isDataSource } from '@zenstackhq/language/ast'; +import { type ZModelServices, loadDocument } from '@zenstackhq/language'; +import { type Model, isDataSource } from '@zenstackhq/language/ast'; import { PrismaSchemaGenerator } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'node:fs'; @@ -41,8 +41,22 @@ export function getSchemaFile(file?: string) { } } -export async function loadSchemaDocument(schemaFile: string) { - const loadResult = await loadDocument(schemaFile); +export async function loadSchemaDocument( + schemaFile: string, + opts?: { keepImports?: boolean; returnServices?: false }, +): Promise; +export async function loadSchemaDocument( + schemaFile: string, + opts: { returnServices: true; keepImports?: boolean }, +): Promise<{ model: Model; services: ZModelServices }>; +export async function loadSchemaDocument( + schemaFile: string, + opts: { returnServices?: boolean; keepImports?: boolean } = {}, +) { + const returnServices = opts.returnServices || false; + const keepImports = opts.keepImports || false; + + const loadResult = await loadDocument(schemaFile, [], keepImports); if (!loadResult.success) { loadResult.errors.forEach((err) => { console.error(colors.red(err)); @@ -52,6 +66,9 @@ export async function loadSchemaDocument(schemaFile: string) { loadResult.warnings.forEach((warn) => { console.warn(colors.yellow(warn)); }); + + if (returnServices) return { model: loadResult.model, services: loadResult.services }; + return loadResult.model; } diff --git a/packages/cli/src/actions/db.ts b/packages/cli/src/actions/db.ts index 3d0108374..702268f0b 100644 --- a/packages/cli/src/actions/db.ts +++ b/packages/cli/src/actions/db.ts @@ -1,25 +1,54 @@ +import { config } from '@dotenvx/dotenvx'; +import { formatDocument, ZModelCodeGenerator } from '@zenstackhq/language'; +import { DataModel, Enum, type Model } from '@zenstackhq/language/ast'; +import colors from 'colors'; import fs from 'node:fs'; +import path from 'node:path'; +import ora from 'ora'; import { execPrisma } from '../utils/exec-utils'; -import { generateTempPrismaSchema, getSchemaFile, handleSubProcessError, requireDataSourceUrl } from './action-utils'; +import { + generateTempPrismaSchema, + getSchemaFile, + handleSubProcessError, + loadSchemaDocument, + requireDataSourceUrl, +} from './action-utils'; +import { syncEnums, syncRelation, syncTable, type Relation } from './pull'; +import { providers } from './pull/provider'; +import { getDatasource, getDbName, getRelationFieldsKey, getRelationFkName } from './pull/utils'; +import type { DataSourceProviderType } from '@zenstackhq/schema'; -type Options = { +type PushOptions = { schema?: string; acceptDataLoss?: boolean; forceReset?: boolean; }; +export type PullOptions = { + schema?: string; + out?: string; + modelCasing: 'pascal' | 'camel' | 'snake' | 'kebab' | 'none'; + fieldCasing: 'pascal' | 'camel' | 'snake' | 'kebab' | 'none'; + alwaysMap: boolean; + quote: 'single' | 'double'; + indent: number; +}; + /** * CLI action for db related commands */ -export async function run(command: string, options: Options) { +export async function run(command: string, options: any) { switch (command) { case 'push': await runPush(options); break; + case 'pull': + await runPull(options); + break; } } -async function runPush(options: Options) { +async function runPush(options: PushOptions) { const schemaFile = getSchemaFile(options.schema); // validate datasource url exists @@ -49,3 +78,347 @@ async function runPush(options: Options) { } } } + +async function runPull(options: PullOptions) { + const spinner = ora(); + try { + const schemaFile = getSchemaFile(options.schema); + const { model, services } = await loadSchemaDocument(schemaFile, { returnServices: true, keepImports: true }); + config({ + ignore: ['MISSING_ENV_FILE'], + }); + const SUPPORTED_PROVIDERS = Object.keys(providers) as DataSourceProviderType[]; + const datasource = getDatasource(model); + if (!datasource) { + throw new Error('No datasource found in the schema.'); + } + + if (!SUPPORTED_PROVIDERS.includes(datasource.provider)) { + throw new Error(`Unsupported datasource provider: ${datasource.provider}`); + } + + const provider = providers[datasource.provider]; + + if (!provider) { + throw new Error(`No introspection provider found for: ${datasource.provider}`); + } + + spinner.start('Introspecting database...'); + const { enums: allEnums, tables: allTables } = await provider.introspect(datasource.url); + spinner.succeed('Database introspected'); + + const enums = provider.isSupportedFeature('Schema') + ? allEnums.filter((e) => datasource.allSchemas.includes(e.schema_name)) + : allEnums; + const tables = provider.isSupportedFeature('Schema') + ? allTables.filter((t) => datasource.allSchemas.includes(t.schema)) + : allTables; + + console.log(colors.blue('Syncing schema...')); + + const newModel: Model = { + $type: 'Model', + $container: undefined, + $containerProperty: undefined, + $containerIndex: undefined, + declarations: [...model.declarations.filter((d) => ['DataSource'].includes(d.$type))], + imports: [], + }; + syncEnums({ + dbEnums: enums, + model: newModel, + services, + options, + defaultSchema: datasource.defaultSchema, + oldModel: model, + provider, + }); + + const resolvedRelations: Relation[] = []; + for (const table of tables) { + const relations = syncTable({ + table, + model: newModel, + provider, + services, + options, + defaultSchema: datasource.defaultSchema, + oldModel: model, + }); + resolvedRelations.push(...relations); + } + // sync relation fields + for (const relation of resolvedRelations) { + const similarRelations = resolvedRelations.filter((rr) => { + return ( + rr !== relation && + ((rr.schema === relation.schema && + rr.table === relation.table && + rr.references.schema === relation.references.schema && + rr.references.table === relation.references.table) || + (rr.schema === relation.references.schema && + rr.column === relation.references.column && + rr.references.schema === relation.schema && + rr.references.table === relation.table)) + ); + }).length; + const selfRelation = + relation.references.schema === relation.schema && relation.references.table === relation.table; + syncRelation({ + model: newModel, + relation, + services, + options, + selfRelation, + similarRelations: similarRelations, + }); + } + + console.log(colors.blue('Schema synced')); + + const cwd = new URL(`file://${process.cwd()}`).pathname; + const docs = services.shared.workspace.LangiumDocuments.all + .filter(({ uri }) => uri.path.toLowerCase().startsWith(cwd.toLowerCase())) + .toArray(); + const docsSet = new Set(docs.map((d) => d.uri.toString())); + + console.log(colors.bold('\nApplying changes to ZModel...')); + + const deletedModels: string[] = []; + const deletedEnums: string[] = []; + const addedFields: string[] = []; + const deletedAttributes: string[] = []; + const deletedFields: string[] = []; + + //Delete models + services.shared.workspace.IndexManager.allElements('DataModel', docsSet) + .filter( + (declaration) => + !newModel.declarations.find((d) => getDbName(d) === getDbName(declaration.node as any)), + ) + .forEach((decl) => { + const model = decl.node!.$container as Model; + const index = model.declarations.findIndex((d) => d === decl.node); + model.declarations.splice(index, 1); + deletedModels.push(colors.red(`- Model ${decl.name} deleted`)); + }); + + // Delete Enums + if (provider.isSupportedFeature('NativeEnum')) + services.shared.workspace.IndexManager.allElements('Enum', docsSet) + .filter( + (declaration) => + !newModel.declarations.find((d) => getDbName(d) === getDbName(declaration.node as any)), + ) + .forEach((decl) => { + const model = decl.node!.$container as Model; + const index = model.declarations.findIndex((d) => d === decl.node); + model.declarations.splice(index, 1); + deletedEnums.push(colors.red(`- Enum ${decl.name} deleted`)); + }); + // + newModel.declarations + .filter((d) => [DataModel, Enum].includes(d.$type)) + .forEach((_declaration) => { + const newDataModel = _declaration as DataModel | Enum; + const declarations = services.shared.workspace.IndexManager.allElements( + newDataModel.$type, + docsSet, + ).toArray(); + const originalDataModel = declarations.find((d) => getDbName(d.node as any) === getDbName(newDataModel)) + ?.node as DataModel | Enum | undefined; + if (!originalDataModel) { + model.declarations.push(newDataModel); + (newDataModel as any).$container = model; + newDataModel.fields.forEach((f) => { + if (f.$type === 'DataField' && f.type.reference?.ref) { + const ref = declarations.find( + (d) => getDbName(d.node as any) === getDbName(f.type.reference!.ref as any), + )?.node; + if (ref) (f.type.reference.ref as any) = ref; + } + }); + return; + } + + newDataModel.fields.forEach((f) => { + // Prioritized matching: exact db name > relation fields key > relation FK name > type reference + let originalFields = originalDataModel.fields.filter((d) => getDbName(d) === getDbName(f)); + + if (originalFields.length === 0) { + // Try matching by relation fields key (the `fields` attribute in @relation) + // This matches relation fields by their FK field references + const newFieldsKey = getRelationFieldsKey(f as any); + if (newFieldsKey) { + originalFields = originalDataModel.fields.filter( + (d) => getRelationFieldsKey(d as any) === newFieldsKey, + ); + } + } + + if (originalFields.length === 0) { + // Try matching by relation FK name (the `map` attribute in @relation) + originalFields = originalDataModel.fields.filter( + (d) => + getRelationFkName(d as any) === getRelationFkName(f as any) && + !!getRelationFkName(d as any) && + !!getRelationFkName(f as any), + ); + } + + if (originalFields.length === 0) { + // Try matching by type reference + originalFields = originalDataModel.fields.filter( + (d) => + f.$type === 'DataField' && + d.$type === 'DataField' && + f.type.reference?.ref && + d.type.reference?.ref && + getDbName(f.type.reference.ref) === getDbName(d.type.reference.ref), + ); + } + + if (originalFields.length > 1) { + // If this is a back-reference relation field (no `fields` attribute), + // silently skip when there are multiple potential matches + const isBackReferenceField = !getRelationFieldsKey(f as any); + if (!isBackReferenceField) { + console.warn( + colors.yellow( + `Found more original fields, need to tweak the search algorithm. ${originalDataModel.name}->[${originalFields.map((of) => of.name).join(', ')}](${f.name})`, + ), + ); + } + return; + } + const originalField = originalFields.at(0); + Object.freeze(originalField); + if (!originalField) { + addedFields.push(colors.green(`+ Field ${f.name} added to ${originalDataModel.name}`)); + (f as any).$container = originalDataModel; + originalDataModel.fields.push(f as any); + if (f.$type === 'DataField' && f.type.reference?.ref) { + const ref = declarations.find( + (d) => getDbName(d.node as any) === getDbName(f.type.reference!.ref as any), + )?.node as DataModel | undefined; + if (ref) { + (f.type.reference.$refText as any) = ref.name; + (f.type.reference.ref as any) = ref; + } + } + return; + } + + originalField.attributes + .filter( + (attr) => + !f.attributes.find((d) => d.decl.$refText === attr.decl.$refText) && + !['@map', '@@map', '@default', '@updatedAt'].includes(attr.decl.$refText), + ) + .forEach((attr) => { + const field = attr.$container; + const index = field.attributes.findIndex((d) => d === attr); + field.attributes.splice(index, 1); + deletedAttributes.push( + colors.yellow(`- Attribute ${attr.decl.$refText} deleted from field: ${field.name}`), + ); + }); + }); + originalDataModel.fields + .filter((f) => { + // Prioritized matching: exact db name > relation fields key > relation FK name > type reference + const matchByDbName = newDataModel.fields.find((d) => getDbName(d) === getDbName(f)); + if (matchByDbName) return false; + + // Try matching by relation fields key (the `fields` attribute in @relation) + const originalFieldsKey = getRelationFieldsKey(f as any); + if (originalFieldsKey) { + const matchByFieldsKey = newDataModel.fields.find( + (d) => getRelationFieldsKey(d as any) === originalFieldsKey, + ); + if (matchByFieldsKey) return false; + } + + const matchByFkName = newDataModel.fields.find( + (d) => + getRelationFkName(d as any) === getRelationFkName(f as any) && + !!getRelationFkName(d as any) && + !!getRelationFkName(f as any), + ); + if (matchByFkName) return false; + + const matchByTypeRef = newDataModel.fields.find( + (d) => + f.$type === 'DataField' && + d.$type === 'DataField' && + f.type.reference?.ref && + d.type.reference?.ref && + getDbName(f.type.reference.ref) === getDbName(d.type.reference.ref), + ); + return !matchByTypeRef; + }) + .forEach((f) => { + const _model = f.$container; + const index = _model.fields.findIndex((d) => d === f); + _model.fields.splice(index, 1); + deletedFields.push(colors.red(`- Field ${f.name} deleted from ${_model.name}`)); + }); + }); + + if (deletedModels.length > 0) { + console.log(colors.bold('\nDeleted Models:')); + deletedModels.forEach((msg) => console.log(msg)); + } + + if (deletedEnums.length > 0) { + console.log(colors.bold('\nDeleted Enums:')); + deletedEnums.forEach((msg) => console.log(msg)); + } + + if (addedFields.length > 0) { + console.log(colors.bold('\nAdded Fields:')); + addedFields.forEach((msg) => console.log(msg)); + } + + if (deletedAttributes.length > 0) { + console.log(colors.bold('\nDeleted Attributes:')); + deletedAttributes.forEach((msg) => console.log(msg)); + } + + if (deletedFields.length > 0) { + console.log(colors.bold('\nDeleted Fields:')); + deletedFields.forEach((msg) => console.log(msg)); + } + + if (options.out && fs.existsSync(options.out) && !fs.lstatSync(options.out).isFile()) { + throw new Error(`Output path ${options.out} exists but is not a file`); + } + + const generator = new ZModelCodeGenerator({ + quote: options.quote, + indent: options.indent, + }); + + if (options.out) { + const zmodelSchema = await formatDocument(generator.generate(newModel)); + + console.log(colors.blue(`Writing to ${options.out}`)); + + const outPath = options.out ? path.resolve(options.out) : schemaFile; + + fs.writeFileSync(outPath, zmodelSchema); + } else { + for (const { uri, parseResult: { value: model } } of docs) { + const zmodelSchema = await formatDocument(generator.generate(model)); + console.log(colors.blue(`Writing to ${uri.path}`)); + fs.writeFileSync(uri.fsPath, zmodelSchema); + } + } + + console.log(colors.green.bold('\nPull completed successfully!')); + } catch (error) { + spinner.fail('Pull failed'); + console.error(error); + throw error; + } +} \ No newline at end of file diff --git a/packages/cli/src/actions/pull/index.ts b/packages/cli/src/actions/pull/index.ts new file mode 100644 index 000000000..e68513961 --- /dev/null +++ b/packages/cli/src/actions/pull/index.ts @@ -0,0 +1,568 @@ +import type { ZModelServices } from '@zenstackhq/language'; +import colors from 'colors'; +import { + isEnum, + type DataField, + type DataModel, + type Enum, + type Model, +} from '@zenstackhq/language/ast'; +import { + DataFieldAttributeFactory, + DataFieldFactory, + DataModelFactory, + EnumFactory, +} from '@zenstackhq/language/factory'; +import type { PullOptions } from '../db'; +import type { Cascade, IntrospectedEnum, IntrospectedTable, IntrospectionProvider } from './provider'; +import { getAttributeRef, getDbName, getEnumRef } from './utils'; + +export function syncEnums({ + dbEnums, + model, + oldModel, + provider, + options, + services, + defaultSchema, +}: { + dbEnums: IntrospectedEnum[]; + model: Model; + oldModel: Model; + provider: IntrospectionProvider; + services: ZModelServices; + options: PullOptions; + defaultSchema: string; +}) { + if (provider.isSupportedFeature('NativeEnum')) { + for (const dbEnum of dbEnums) { + const { modified, name } = resolveNameCasing(options.modelCasing, dbEnum.enum_type); + if (modified) console.log(colors.gray(`Mapping enum ${dbEnum.enum_type} to ${name}`)); + const factory = new EnumFactory().setName(name); + if (modified || options.alwaysMap) + factory.addAttribute((builder) => + builder + .setDecl(getAttributeRef('@@map', services)) + .addArg((argBuilder) => argBuilder.StringLiteral.setValue(dbEnum.enum_type)), + ); + + dbEnum.values.forEach((v) => { + const { name, modified } = resolveNameCasing(options.fieldCasing, v); + factory.addField((builder) => { + builder.setName(name); + if (modified || options.alwaysMap) + builder.addAttribute((builder) => + builder + .setDecl(getAttributeRef('@map', services)) + .addArg((argBuilder) => argBuilder.StringLiteral.setValue(v)), + ); + + return builder; + }); + }); + + if (dbEnum.schema_name && dbEnum.schema_name !== '' && dbEnum.schema_name !== defaultSchema) { + factory.addAttribute((b) => + b + .setDecl(getAttributeRef('@@schema', services)) + .addArg((a) => a.StringLiteral.setValue(dbEnum.schema_name)), + ); + } + + model.declarations.push(factory.get({ $container: model })); + } + } else { + oldModel.declarations + .filter((d) => isEnum(d)) + .forEach((d) => { + const factory = new EnumFactory().setName(d.name); + // Copy enum-level comments + if (d.comments?.length) { + factory.update({ comments: [...d.comments] }); + } + // Copy enum-level attributes (@@map, @@schema, etc.) + if (d.attributes?.length) { + factory.update({ attributes: [...d.attributes] }); + } + // Copy fields with their attributes and comments + d.fields.forEach((v) => { + factory.addField((builder) => { + builder.setName(v.name); + // Copy field-level comments + if (v.comments?.length) { + v.comments.forEach((c) => builder.addComment(c)); + } + // Copy field-level attributes (@map, etc.) + if (v.attributes?.length) { + builder.update({ attributes: [...v.attributes] }); + } + return builder; + }); + }); + model.declarations.push(factory.get({ $container: model })); + }); + } +} + +function resolveNameCasing(casing: 'pascal' | 'camel' | 'snake' | 'kebab' | 'none', originalName: string) { + let name = originalName; + const fieldPrefix = /[0-9]/g.test(name.charAt(0)) ? '_' : ''; + + switch (casing) { + case 'pascal': + name = toPascalCase(originalName); + break; + case 'camel': + name = toCamelCase(originalName); + break; + case 'snake': + name = toSnakeCase(originalName); + break; + case 'kebab': + name = toKebabCase(originalName); + break; + } + + return { + modified: name !== originalName || fieldPrefix !== '', + name: `${fieldPrefix}${name}`, + }; +} + +function toPascalCase(str: string): string { + return str.replace(/[_\- ]+(\w)/g, (_, c) => c.toUpperCase()).replace(/^\w/, (c) => c.toUpperCase()); +} + +function toCamelCase(str: string): string { + return str.replace(/[_\- ]+(\w)/g, (_, c) => c.toUpperCase()).replace(/^\w/, (c) => c.toLowerCase()); +} + +function toSnakeCase(str: string): string { + return str + .replace(/[- ]+/g, '_') + .replace(/([a-z0-9])([A-Z])/g, '$1_$2') + .toLowerCase(); +} + +function toKebabCase(str: string): string { + return str + .replace(/[_ ]+/g, '-') + .replace(/([a-z0-9])([A-Z])/g, '$1-$2') + .toLowerCase(); +} + +export type Relation = { + schema: string; + table: string; + column: string; + type: 'one' | 'many'; + fk_name: string; + foreign_key_on_update: Cascade; + foreign_key_on_delete: Cascade; + nullable: boolean; + references: { + schema: string | null; + table: string | null; + column: string | null; + type: 'one' | 'many'; + }; +}; + +export function syncTable({ + model, + provider, + table, + services, + options, + defaultSchema, +}: { + table: IntrospectedTable; + model: Model; + oldModel: Model; + provider: IntrospectionProvider; + services: ZModelServices; + options: PullOptions; + defaultSchema: string; +}) { + const idAttribute = getAttributeRef('@id', services); + const modelIdAttribute = getAttributeRef('@@id', services); + const uniqueAttribute = getAttributeRef('@unique', services); + const modelUniqueAttribute = getAttributeRef('@@unique', services); + const relationAttribute = getAttributeRef('@relation', services); + const fieldMapAttribute = getAttributeRef('@map', services); + const tableMapAttribute = getAttributeRef('@@map', services); + const modelindexAttribute = getAttributeRef('@@index', services); + + if ( + !idAttribute || + !uniqueAttribute || + !relationAttribute || + !fieldMapAttribute || + !tableMapAttribute || + !modelIdAttribute || + !modelUniqueAttribute || + !modelindexAttribute + ) { + throw new Error('Cannot find required attributes in the model.'); + } + + const relations: Relation[] = []; + const { name, modified } = resolveNameCasing(options.modelCasing, table.name); + const multiPk = table.columns.filter((c) => c.pk).length > 1; + + const modelFactory = new DataModelFactory().setName(name).setIsView(table.type === 'view'); + modelFactory.setContainer(model); + + if (modified || options.alwaysMap) { + modelFactory.addAttribute((builder) => + builder.setDecl(tableMapAttribute).addArg((argBuilder) => argBuilder.StringLiteral.setValue(table.name)), + ); + } + table.columns.forEach((column) => { + if (column.foreign_key_table) { + relations.push({ + schema: table.schema, + table: table.name, + column: column.name, + type: 'one', + fk_name: column.foreign_key_name!, + foreign_key_on_delete: column.foreign_key_on_delete, + foreign_key_on_update: column.foreign_key_on_update, + nullable: column.nullable, + references: { + schema: column.foreign_key_schema, + table: column.foreign_key_table, + column: column.foreign_key_column, + type: column.unique ? 'one' : 'many', + }, + }); + } + + const { name, modified } = resolveNameCasing(options.fieldCasing, column.name); + + const builtinType = provider.getBuiltinType(column.datatype); + + modelFactory.addField((builder) => { + builder.setName(name); + builder.setType((typeBuilder) => { + typeBuilder.setArray(builtinType.isArray); + typeBuilder.setOptional(column.nullable); + + if (column.options.length > 0) { + const ref = model.declarations.find((d) => isEnum(d) && getDbName(d) === column.datatype) as + | Enum + | undefined; + + if (!ref) { + throw new Error(`Enum ${column.datatype} not found`); + } + typeBuilder.setReference(ref); + } else { + if (builtinType.type !== 'Unsupported') { + typeBuilder.setType(builtinType.type); + } else { + typeBuilder.setUnsupported((unsupportedBuilder) => + unsupportedBuilder.setValue((lt) => lt.StringLiteral.setValue(column.datatype)), + ); + } + } + + return typeBuilder; + }); + + if (column.pk && !multiPk) { + builder.addAttribute((b) => b.setDecl(idAttribute)); + } + + // Add field-type-based attributes (e.g., @updatedAt for DateTime fields, @db.* attributes) + const fieldAttrs = provider.getFieldAttributes({ + fieldName: column.name, + fieldType: builtinType.type, + datatype: column.datatype, + length: column.length, + precision: column.precision, + services, + }); + fieldAttrs.forEach(builder.addAttribute.bind(builder)); + + if (column.default) { + const defaultExprBuilder = provider.getDefaultValue({ + fieldType: builtinType.type, + defaultValue: column.default, + services, + enums: model.declarations.filter((d) => d.$type === 'Enum') as Enum[], + }); + if (defaultExprBuilder) { + const defaultAttr = new DataFieldAttributeFactory() + .setDecl(getAttributeRef('@default', services)) + .addArg(defaultExprBuilder); + builder.addAttribute(defaultAttr); + } + } + + if (column.unique && !column.pk) { + builder.addAttribute((b) => { + b.setDecl(uniqueAttribute); + // Only add map if the unique constraint name differs from default patterns + // Default patterns: TableName_columnName_key (Prisma) or just columnName (MySQL) + const isDefaultName = !column.unique_name + || column.unique_name === `${table.name}_${column.name}_key` + || column.unique_name === column.name; + if (!isDefaultName) { + b.addArg((ab) => ab.StringLiteral.setValue(column.unique_name!), 'map'); + } + + return b; + }); + } + if (modified || options.alwaysMap) { + builder.addAttribute((ab) => + ab.setDecl(fieldMapAttribute).addArg((ab) => ab.StringLiteral.setValue(column.name)), + ); + } + + return builder; + }); + }); + + const pkColumns = table.columns.filter((c) => c.pk).map((c) => c.name); + if (multiPk) { + modelFactory.addAttribute((builder) => + builder.setDecl(modelIdAttribute).addArg((argBuilder) => { + const arrayExpr = argBuilder.ArrayExpr; + pkColumns.forEach((c) => { + const ref = modelFactory.node.fields.find((f) => getDbName(f) === c); + if (!ref) { + throw new Error(`Field ${c} not found`); + } + arrayExpr.addItem((itemBuilder) => itemBuilder.ReferenceExpr.setTarget(ref)); + }); + return arrayExpr; + }), + ); + } + + const hasUniqueConstraint = + table.columns.some((c) => c.unique || c.pk) || + table.indexes.some((i) => i.unique); + if (!hasUniqueConstraint) { + modelFactory.addAttribute((a) => a.setDecl(getAttributeRef('@@ignore', services))); + modelFactory.comments.push( + '/// The underlying table does not contain a valid unique identifier and can therefore currently not be handled by Zenstack Client.', + ); + } + + // Sort indexes: unique indexes first, then other indexes + const sortedIndexes = table.indexes.reverse().sort((a, b) => { + if (a.unique && !b.unique) return -1; + if (!a.unique && b.unique) return 1; + return 0; + }); + + sortedIndexes.forEach((index) => { + if (index.predicate) { + //These constraints are not supported by Zenstack, because Zenstack currently does not fully support check constraints. Read more: https://pris.ly/d/check-constraints + console.warn( + colors.yellow( + `These constraints are not supported by Zenstack. Read more: https://pris.ly/d/check-constraints\n- Model: "${table.name}", constraint: "${index.name}"`, + ), + ); + return; + } + if (index.columns.find((c) => c.expression)) { + console.warn( + colors.yellow( + `These constraints are not supported by Zenstack. Read more: https://pris.ly/d/check-constraints\n- Model: "${table.name}", constraint: "${index.name}"`, + ), + ); + return; + } + + // Skip PRIMARY key index (handled via @id or @@id) + if (index.primary) { + return; + } + + // Skip single-column indexes that are already handled by @id or @unique on the field + if (index.columns.length === 1 && (index.columns.find((c) => pkColumns.includes(c.name)) || index.unique)) { + return; + } + + modelFactory.addAttribute((builder) => + { + const attr = builder + .setDecl(index.unique ? modelUniqueAttribute : modelindexAttribute) + .addArg((argBuilder) => { + const arrayExpr = argBuilder.ArrayExpr; + index.columns.forEach((c) => { + const ref = modelFactory.node.fields.find((f) => getDbName(f) === c.name); + if (!ref) { + throw new Error(`Column ${c.name} not found in model ${table.name}`); + } + arrayExpr.addItem((itemBuilder) => { + const refExpr = itemBuilder.ReferenceExpr.setTarget(ref); + if (c.order && c.order !== 'ASC') + refExpr.addArg((ab) => ab.StringLiteral.setValue('DESC'), 'sort'); + + return refExpr; + }); + }); + return arrayExpr; + }); + + const suffix = index.unique ? '_key' : '_idx'; + + if(index.name !== `${table.name}_${index.columns.map(c => c.name).join('_')}${suffix}`){ + attr.addArg((argBuilder) => argBuilder.StringLiteral.setValue(index.name), 'map'); + } + + return attr + } + + ); + }); + if (table.schema && table.schema !== '' && table.schema !== defaultSchema) { + modelFactory.addAttribute((b) => + b.setDecl(getAttributeRef('@@schema', services)).addArg((a) => a.StringLiteral.setValue(table.schema)), + ); + } + + model.declarations.push(modelFactory.node); + return relations; +} + +export function syncRelation({ + model, + relation, + services, + options, + selfRelation, + similarRelations, +}: { + model: Model; + relation: Relation; + services: ZModelServices; + options: PullOptions; + //self included + similarRelations: number; + selfRelation: boolean; +}) { + const idAttribute = getAttributeRef('@id', services); + const uniqueAttribute = getAttributeRef('@unique', services); + const relationAttribute = getAttributeRef('@relation', services); + const fieldMapAttribute = getAttributeRef('@map', services); + const tableMapAttribute = getAttributeRef('@@map', services); + + const includeRelationName = selfRelation || similarRelations > 0; + + if (!idAttribute || !uniqueAttribute || !relationAttribute || !fieldMapAttribute || !tableMapAttribute) { + throw new Error('Cannot find required attributes in the model.'); + } + + const sourceModel = model.declarations.find((d) => d.$type === 'DataModel' && getDbName(d) === relation.table) as + | DataModel + | undefined; + if (!sourceModel) return; + + const sourceFieldId = sourceModel.fields.findIndex((f) => getDbName(f) === relation.column); + const sourceField = sourceModel.fields[sourceFieldId] as DataField | undefined; + if (!sourceField) return; + + const targetModel = model.declarations.find( + (d) => d.$type === 'DataModel' && getDbName(d) === relation.references.table, + ) as DataModel | undefined; + if (!targetModel) return; + + const targetField = targetModel.fields.find((f) => getDbName(f) === relation.references.column); + if (!targetField) return; + + const fieldPrefix = /[0-9]/g.test(sourceModel.name.charAt(0)) ? '_' : ''; + + const relationName = `${relation.table}${similarRelations > 0 ? `_${relation.column}` : ''}To${relation.references.table}`; + + const sourceNameFromReference = sourceField.name.toLowerCase().endsWith('id') ? `${resolveNameCasing("camel", sourceField.name.slice(0, -2)).name}${relation.type === 'many'? 's' : ''}` : undefined; + + const sourceFieldFromReference = sourceModel.fields.find((f) => f.name === sourceNameFromReference); + + let { name: sourceFieldName } = resolveNameCasing( + options.fieldCasing, + similarRelations > 0 + ? `${fieldPrefix}${sourceModel.name.charAt(0).toLowerCase()}${sourceModel.name.slice(1)}_${relation.column}` + : `${(!sourceFieldFromReference? sourceNameFromReference : undefined) || resolveNameCasing("camel", targetModel.name).name}${relation.type === 'many'? 's' : ''}`, + ); + + if (sourceModel.fields.find((f) => f.name === sourceFieldName)) { + sourceFieldName = `${sourceFieldName}To${targetModel.name.charAt(0).toLowerCase()}${targetModel.name.slice(1)}_${relation.references.column}`; + } + + const sourceFieldFactory = new DataFieldFactory() + .setContainer(sourceModel) + .setName(sourceFieldName) + .setType((tb) => + tb + .setOptional(relation.nullable) + .setArray(relation.type === 'many') + .setReference(targetModel), + ); + sourceFieldFactory.addAttribute((ab) => { + ab.setDecl(relationAttribute); + if (includeRelationName) ab.addArg((ab) => ab.StringLiteral.setValue(relationName)); + ab.addArg((ab) => ab.ArrayExpr.addItem((aeb) => aeb.ReferenceExpr.setTarget(sourceField)), 'fields').addArg( + (ab) => ab.ArrayExpr.addItem((aeb) => aeb.ReferenceExpr.setTarget(targetField)), + 'references', + ); + + // Prisma defaults: onDelete is SetNull for optional, Restrict for mandatory + const onDeleteDefault = relation.nullable ? 'SET NULL' : 'RESTRICT'; + if (relation.foreign_key_on_delete && relation.foreign_key_on_delete !== onDeleteDefault) { + const enumRef = getEnumRef('ReferentialAction', services); + if (!enumRef) throw new Error('ReferentialAction enum not found'); + const enumFieldRef = enumRef.fields.find( + (f) => f.name.toLowerCase() === relation.foreign_key_on_delete!.replace(/ /g, '').toLowerCase(), + ); + if (!enumFieldRef) throw new Error(`ReferentialAction ${relation.foreign_key_on_delete} not found`); + ab.addArg((a) => a.ReferenceExpr.setTarget(enumFieldRef), 'onDelete'); + } + + // Prisma default: onUpdate is Cascade + if (relation.foreign_key_on_update && relation.foreign_key_on_update !== 'CASCADE') { + const enumRef = getEnumRef('ReferentialAction', services); + if (!enumRef) throw new Error('ReferentialAction enum not found'); + const enumFieldRef = enumRef.fields.find( + (f) => f.name.toLowerCase() === relation.foreign_key_on_update!.replace(/ /g, '').toLowerCase(), + ); + if (!enumFieldRef) throw new Error(`ReferentialAction ${relation.foreign_key_on_update} not found`); + ab.addArg((a) => a.ReferenceExpr.setTarget(enumFieldRef), 'onUpdate'); + } + + if (relation.fk_name && relation.fk_name !== `${relation.table}_${relation.column}_fkey`) ab.addArg((ab) => ab.StringLiteral.setValue(relation.fk_name), 'map'); + + return ab; + }); + + sourceModel.fields.splice(sourceFieldId, 0, sourceFieldFactory.node); // Remove the original scalar foreign key field + + const oppositeFieldPrefix = /[0-9]/g.test(targetModel.name.charAt(0)) ? '_' : ''; + const { name: oppositeFieldName } = resolveNameCasing( + options.fieldCasing, + similarRelations > 0 + ? `${oppositeFieldPrefix}${sourceModel.name.charAt(0).toLowerCase()}${sourceModel.name.slice(1)}_${relation.column}` + : `${resolveNameCasing("camel", sourceModel.name).name}${relation.references.type === 'many'? 's' : ''}`, + ); + + const targetFieldFactory = new DataFieldFactory() + .setContainer(targetModel) + .setName(oppositeFieldName) + .setType((tb) => + tb + .setOptional(relation.references.type === 'one') + .setArray(relation.references.type === 'many') + .setReference(sourceModel), + ); + if (includeRelationName) + targetFieldFactory.addAttribute((ab) => + ab.setDecl(relationAttribute).addArg((ab) => ab.StringLiteral.setValue(relationName)), + ); + + targetModel.fields.push(targetFieldFactory.node); +} diff --git a/packages/cli/src/actions/pull/provider/index.ts b/packages/cli/src/actions/pull/provider/index.ts new file mode 100644 index 000000000..7c93746d4 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/index.ts @@ -0,0 +1,13 @@ +import type { DataSourceProviderType } from '@zenstackhq/schema'; +export * from './provider'; + +import { mysql } from './mysql'; +import { postgresql } from './postgresql'; +import type { IntrospectionProvider } from './provider'; +import { sqlite } from './sqlite'; + +export const providers: Record = { + mysql, + postgresql, + sqlite, +}; diff --git a/packages/cli/src/actions/pull/provider/mysql.ts b/packages/cli/src/actions/pull/provider/mysql.ts new file mode 100644 index 000000000..cb104eb1e --- /dev/null +++ b/packages/cli/src/actions/pull/provider/mysql.ts @@ -0,0 +1,525 @@ +import type { Attribute, BuiltinType } from '@zenstackhq/language/ast'; +import { DataFieldAttributeFactory } from '@zenstackhq/language/factory'; +import { getAttributeRef, getDbName, getFunctionRef } from '../utils'; +import type { IntrospectedEnum, IntrospectedSchema, IntrospectedTable, IntrospectionProvider } from './provider'; + +// Note: We dynamically import mysql2 inside the async function to avoid +// requiring it at module load time for environments that don't use MySQL. + +export const mysql: IntrospectionProvider = { + isSupportedFeature(feature) { + switch (feature) { + case 'NativeEnum': + // MySQL enums are defined inline in column definitions, not as separate types. + // They can't be shared across tables like PostgreSQL enums. + // Return false to preserve existing enums from the schema. + return false; + case 'Schema': + default: + return false; + } + }, + getBuiltinType(type) { + const t = (type || '').toLowerCase().trim(); + + // MySQL doesn't have native array types + const isArray = false; + + switch (t) { + // integers + case 'tinyint': + case 'smallint': + case 'mediumint': + case 'int': + case 'integer': + return { type: 'Int', isArray }; + case 'bigint': + return { type: 'BigInt', isArray }; + + // decimals and floats + case 'decimal': + case 'numeric': + return { type: 'Decimal', isArray }; + case 'float': + case 'double': + case 'real': + return { type: 'Float', isArray }; + + // boolean (MySQL uses TINYINT(1) for boolean) + case 'boolean': + case 'bool': + return { type: 'Boolean', isArray }; + + // strings + case 'char': + case 'varchar': + case 'tinytext': + case 'text': + case 'mediumtext': + case 'longtext': + return { type: 'String', isArray }; + + // dates/times + case 'date': + case 'time': + case 'datetime': + case 'timestamp': + case 'year': + return { type: 'DateTime', isArray }; + + // binary + case 'binary': + case 'varbinary': + case 'tinyblob': + case 'blob': + case 'mediumblob': + case 'longblob': + return { type: 'Bytes', isArray }; + + // json + case 'json': + return { type: 'Json', isArray }; + + default: + // Handle ENUM type - MySQL returns enum values like "enum('val1','val2')" + if (t.startsWith('enum(')) { + return { type: 'String', isArray }; + } + // Handle SET type + if (t.startsWith('set(')) { + return { type: 'String', isArray }; + } + return { type: 'Unsupported' as const, isArray }; + } + }, + getDefaultDatabaseType(type: BuiltinType) { + switch (type) { + case 'String': + return { type: 'varchar', precision: 191 }; + case 'Boolean': + // Boolean maps to 'boolean' (our synthetic type from tinyint(1)) + // No precision needed since we handle the mapping in the query + return { type: 'boolean' }; + case 'Int': + return { type: 'int' }; + case 'BigInt': + return { type: 'bigint' }; + case 'Float': + return { type: 'double' }; + case 'Decimal': + return { type: 'decimal', precision: 65 }; + case 'DateTime': + return { type: 'datetime', precision: 3 }; + case 'Json': + return { type: 'json' }; + case 'Bytes': + return { type: 'longblob' }; + } + }, + async introspect(connectionString: string): Promise { + const mysql = await import('mysql2/promise'); + const connection = await mysql.createConnection(connectionString); + + try { + // Extract database name from connection string + const url = new URL(connectionString); + const databaseName = url.pathname.replace('/', ''); + + if (!databaseName) { + throw new Error('Database name not found in connection string'); + } + + // Introspect tables + const [tableRows] = (await connection.execute(getTableIntrospectionQuery(databaseName))) as [ + IntrospectedTable[], + unknown, + ]; + const tables: IntrospectedTable[] = []; + + for (const row of tableRows) { + const columns = typeof row.columns === 'string' ? JSON.parse(row.columns) : row.columns; + const indexes = typeof row.indexes === 'string' ? JSON.parse(row.indexes) : row.indexes; + + // Sort columns by ordinal_position to preserve database column order + const sortedColumns = (columns || []).sort( + (a: { ordinal_position?: number }, b: { ordinal_position?: number }) => + (a.ordinal_position ?? 0) - (b.ordinal_position ?? 0) + ); + + // Filter out auto-generated FK indexes (MySQL creates these automatically) + // Pattern: {Table}_{column}_fkey for single-column FK indexes + const filteredIndexes = (indexes || []).filter( + (idx: { name: string; columns: { name: string }[] }) => + !(idx.columns.length === 1 && idx.name === `${row.name}_${idx.columns[0]?.name}_fkey`) + ); + + tables.push({ + schema: '', // MySQL doesn't support multi-schema + name: row.name, + type: row.type as 'table' | 'view', + definition: row.definition, + columns: sortedColumns, + indexes: filteredIndexes, + }); + } + + // Introspect enums (MySQL stores enum values in column definitions) + const [enumRows] = (await connection.execute(getEnumIntrospectionQuery(databaseName))) as [ + { table_name: string; column_name: string; column_type: string }[], + unknown, + ]; + + const enums: IntrospectedEnum[] = enumRows.map((row) => { + // Parse enum values from column_type like "enum('val1','val2','val3')" + const values = parseEnumValues(row.column_type); + return { + schema_name: '', // MySQL doesn't support multi-schema + // Create a unique enum type name based on table and column + enum_type: `${row.table_name}_${row.column_name}`, + values, + }; + }); + + return { tables, enums }; + } finally { + await connection.end(); + } + }, + getDefaultValue({ defaultValue, fieldType, services, enums }) { + const val = defaultValue.trim(); + + // Handle NULL early + if (val.toUpperCase() === 'NULL') { + return null; + } + + switch (fieldType) { + case 'DateTime': + if (/^CURRENT_TIMESTAMP(\(\d*\))?$/i.test(val) || val.toLowerCase() === 'current_timestamp()' || val.toLowerCase() === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + // Fallback to string literal for other DateTime defaults + return (ab) => ab.StringLiteral.setValue(val); + + case 'Int': + case 'BigInt': + if (val.toLowerCase() === 'auto_increment') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val); + } + break; + + case 'Float': + if (/^-?\d+\.\d+$/.test(val)) { + const numVal = parseFloat(val); + return (ab) => ab.NumberLiteral.setValue(numVal === Math.floor(numVal) ? numVal.toFixed(1) : String(numVal)); + } + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val + '.0'); + } + break; + + case 'Decimal': + if (/^-?\d+\.\d+$/.test(val)) { + const numVal = parseFloat(val); + if (numVal === Math.floor(numVal)) { + return (ab) => ab.NumberLiteral.setValue(numVal.toFixed(2)); + } + return (ab) => ab.NumberLiteral.setValue(String(numVal)); + } + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val + '.00'); + } + break; + + case 'Boolean': + if (val === 'true' || val === '1' || val === "b'1'") { + return (ab) => ab.BooleanLiteral.setValue(true); + } + if (val === 'false' || val === '0' || val === "b'0'") { + return (ab) => ab.BooleanLiteral.setValue(false); + } + break; + + case 'String': + if (val.startsWith("'") && val.endsWith("'")) { + const strippedValue = val.slice(1, -1).replace(/''/g, "'"); + const enumDef = enums.find((e) => e.fields.find((v) => getDbName(v) === strippedValue)); + if (enumDef) { + const enumField = enumDef.fields.find((v) => getDbName(v) === strippedValue); + if (enumField) { + return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + } + return (ab) => ab.StringLiteral.setValue(strippedValue); + } + if (val.toLowerCase() === 'uuid()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('uuid', services)); + } + if (/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(val)) { + return (ab) => ab.StringLiteral.setValue(val); + } + break; + } + + // Fallback handlers for values that don't match field type-specific patterns + if (/^CURRENT_TIMESTAMP(\(\d*\))?$/i.test(val) || val.toLowerCase() === 'current_timestamp()' || val.toLowerCase() === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + + if (val.toLowerCase() === 'auto_increment') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + + if (val === 'true' || val === "b'1'") { + return (ab) => ab.BooleanLiteral.setValue(true); + } + if (val === 'false' || val === "b'0'") { + return (ab) => ab.BooleanLiteral.setValue(false); + } + + if (/^-?\d+\.\d+$/.test(val) || /^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val); + } + + if (val.startsWith("'") && val.endsWith("'")) { + const strippedValue = val.slice(1, -1).replace(/''/g, "'"); + const enumDef = enums.find((e) => e.fields.find((v) => getDbName(v) === strippedValue)); + if (enumDef) { + const enumField = enumDef.fields.find((v) => getDbName(v) === strippedValue); + if (enumField) { + return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + } + return (ab) => ab.StringLiteral.setValue(strippedValue); + } + + // Handle function calls (e.g., uuid(), now()) + if (val.includes('(') && val.includes(')')) { + if (val.toLowerCase() === 'uuid()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('uuid', services)); + } + return (ab) => + ab.InvocationExpr.setFunction(getFunctionRef('dbgenerated', services)).addArg((a) => + a.setValue((v) => v.StringLiteral.setValue(val)), + ); + } + + // Handle unquoted string values + if (/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(val)) { + return (ab) => ab.StringLiteral.setValue(val); + } + + // For any other unhandled cases, use dbgenerated + return (ab) => + ab.InvocationExpr.setFunction(getFunctionRef('dbgenerated', services)).addArg((a) => + a.setValue((v) => v.StringLiteral.setValue(val)), + ); + }, + + getFieldAttributes({ fieldName, fieldType, datatype, length, precision, services }) { + const factories: DataFieldAttributeFactory[] = []; + + // Add @updatedAt for DateTime fields named updatedAt or updated_at + if (fieldType === 'DateTime' && (fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')) { + factories.push(new DataFieldAttributeFactory().setDecl(getAttributeRef('@updatedAt', services))); + } + + // Add @db.* attribute if the datatype differs from the default + const dbAttr = services.shared.workspace.IndexManager.allElements('Attribute').find( + (d) => d.name.toLowerCase() === `@db.${datatype.toLowerCase()}`, + )?.node as Attribute | undefined; + + const defaultDatabaseType = this.getDefaultDatabaseType(fieldType as BuiltinType); + + if ( + dbAttr && + defaultDatabaseType && + (defaultDatabaseType.type !== datatype || + (defaultDatabaseType.precision && + defaultDatabaseType.precision !== (length || precision))) + ) { + const dbAttrFactory = new DataFieldAttributeFactory().setDecl(dbAttr); + if (length || precision) { + dbAttrFactory.addArg((a) => a.NumberLiteral.setValue(length! || precision!)); + } + factories.push(dbAttrFactory); + } + + return factories; + }, +}; + +function getTableIntrospectionQuery(databaseName: string) { + // Note: We use subqueries with ORDER BY before JSON_ARRAYAGG to ensure ordering + // since MySQL < 8.0.21 doesn't support ORDER BY inside JSON_ARRAYAGG + // MySQL doesn't support multi-schema, so we don't include schema in the result + return ` +SELECT + t.TABLE_NAME AS \`name\`, + CASE t.TABLE_TYPE + WHEN 'BASE TABLE' THEN 'table' + WHEN 'VIEW' THEN 'view' + ELSE NULL + END AS \`type\`, + CASE + WHEN t.TABLE_TYPE = 'VIEW' THEN v.VIEW_DEFINITION + ELSE NULL + END AS \`definition\`, + ( + SELECT JSON_ARRAYAGG(col_json) + FROM ( + SELECT JSON_OBJECT( + 'ordinal_position', c.ORDINAL_POSITION, + 'name', c.COLUMN_NAME, + 'datatype', CASE + WHEN c.DATA_TYPE = 'tinyint' AND c.COLUMN_TYPE = 'tinyint(1)' THEN 'boolean' + ELSE c.DATA_TYPE + END, + 'datatype_schema', '', + 'length', c.CHARACTER_MAXIMUM_LENGTH, + 'precision', COALESCE(c.NUMERIC_PRECISION, c.DATETIME_PRECISION), + 'nullable', c.IS_NULLABLE = 'YES', + 'default', CASE + WHEN c.EXTRA LIKE '%auto_increment%' THEN 'auto_increment' + ELSE c.COLUMN_DEFAULT + END, + 'pk', c.COLUMN_KEY = 'PRI', + 'unique', c.COLUMN_KEY = 'UNI', + 'unique_name', CASE WHEN c.COLUMN_KEY = 'UNI' THEN c.COLUMN_NAME ELSE NULL END, + 'computed', c.GENERATION_EXPRESSION IS NOT NULL AND c.GENERATION_EXPRESSION != '', + 'options', JSON_ARRAY(), + 'foreign_key_schema', NULL, + 'foreign_key_table', kcu_fk.REFERENCED_TABLE_NAME, + 'foreign_key_column', kcu_fk.REFERENCED_COLUMN_NAME, + 'foreign_key_name', kcu_fk.CONSTRAINT_NAME, + 'foreign_key_on_update', rc.UPDATE_RULE, + 'foreign_key_on_delete', rc.DELETE_RULE + ) AS col_json + FROM INFORMATION_SCHEMA.COLUMNS c + LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu_fk + ON c.TABLE_SCHEMA = kcu_fk.TABLE_SCHEMA + AND c.TABLE_NAME = kcu_fk.TABLE_NAME + AND c.COLUMN_NAME = kcu_fk.COLUMN_NAME + AND kcu_fk.REFERENCED_TABLE_NAME IS NOT NULL + LEFT JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc + ON kcu_fk.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA + AND kcu_fk.CONSTRAINT_NAME = rc.CONSTRAINT_NAME + WHERE c.TABLE_SCHEMA = t.TABLE_SCHEMA + AND c.TABLE_NAME = t.TABLE_NAME + ORDER BY c.ORDINAL_POSITION + ) AS cols_ordered + ) AS \`columns\`, + ( + SELECT JSON_ARRAYAGG(idx_json) + FROM ( + SELECT JSON_OBJECT( + 'name', s.INDEX_NAME, + 'method', s.INDEX_TYPE, + 'unique', s.NON_UNIQUE = 0, + 'primary', s.INDEX_NAME = 'PRIMARY', + 'valid', TRUE, + 'ready', TRUE, + 'partial', FALSE, + 'predicate', NULL, + 'columns', ( + SELECT JSON_ARRAYAGG(idx_col_json) + FROM ( + SELECT JSON_OBJECT( + 'name', s2.COLUMN_NAME, + 'expression', NULL, + 'order', CASE s2.COLLATION WHEN 'A' THEN 'ASC' WHEN 'D' THEN 'DESC' ELSE NULL END, + 'nulls', NULL + ) AS idx_col_json + FROM INFORMATION_SCHEMA.STATISTICS s2 + WHERE s2.TABLE_SCHEMA = s.TABLE_SCHEMA + AND s2.TABLE_NAME = s.TABLE_NAME + AND s2.INDEX_NAME = s.INDEX_NAME + ORDER BY s2.SEQ_IN_INDEX + ) AS idx_cols_ordered + ) + ) AS idx_json + FROM ( + SELECT DISTINCT INDEX_NAME, INDEX_TYPE, NON_UNIQUE, TABLE_SCHEMA, TABLE_NAME + FROM INFORMATION_SCHEMA.STATISTICS + WHERE TABLE_SCHEMA = t.TABLE_SCHEMA AND TABLE_NAME = t.TABLE_NAME + ) s + ) AS idxs_ordered + ) AS \`indexes\` +FROM INFORMATION_SCHEMA.TABLES t +LEFT JOIN INFORMATION_SCHEMA.VIEWS v + ON t.TABLE_SCHEMA = v.TABLE_SCHEMA AND t.TABLE_NAME = v.TABLE_NAME +WHERE t.TABLE_SCHEMA = '${databaseName}' + AND t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') + AND t.TABLE_NAME NOT LIKE '_prisma_migrations' +ORDER BY t.TABLE_NAME; +`; +} + +function getEnumIntrospectionQuery(databaseName: string) { + return ` +SELECT + c.TABLE_NAME AS table_name, + c.COLUMN_NAME AS column_name, + c.COLUMN_TYPE AS column_type +FROM INFORMATION_SCHEMA.COLUMNS c +WHERE c.TABLE_SCHEMA = '${databaseName}' + AND c.DATA_TYPE = 'enum' +ORDER BY c.TABLE_NAME, c.COLUMN_NAME; +`; +} + +/** + * Parse enum values from MySQL COLUMN_TYPE string like "enum('val1','val2','val3')" + */ +function parseEnumValues(columnType: string): string[] { + // Match the content inside enum(...) + const match = columnType.match(/^enum\((.+)\)$/i); + if (!match || !match[1]) return []; + + const valuesString = match[1]; + const values: string[] = []; + + // Parse quoted values, handling escaped quotes + let current = ''; + let inQuote = false; + let i = 0; + + while (i < valuesString.length) { + const char = valuesString[i]; + + if (char === "'" && !inQuote) { + inQuote = true; + i++; + continue; + } + + if (char === "'" && inQuote) { + // Check for escaped quote ('') + if (valuesString[i + 1] === "'") { + current += "'"; + i += 2; + continue; + } + // End of value + values.push(current); + current = ''; + inQuote = false; + i++; + // Skip comma and any whitespace + while (i < valuesString.length && (valuesString[i] === ',' || valuesString[i] === ' ')) { + i++; + } + continue; + } + + if (inQuote) { + current += char; + } + i++; + } + + return values; +} diff --git a/packages/cli/src/actions/pull/provider/postgresql.ts b/packages/cli/src/actions/pull/provider/postgresql.ts new file mode 100644 index 000000000..08a041b56 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/postgresql.ts @@ -0,0 +1,475 @@ +import type { Attribute, BuiltinType, Enum, Expression } from '@zenstackhq/language/ast'; +import { AstFactory, DataFieldAttributeFactory, ExpressionBuilder } from '@zenstackhq/language/factory'; +import { Client } from 'pg'; +import { getAttributeRef, getDbName, getFunctionRef } from '../utils'; +import type { IntrospectedEnum, IntrospectedSchema, IntrospectedTable, IntrospectionProvider } from './provider'; +import type { ZModelServices } from '@zenstackhq/language'; + +export const postgresql: IntrospectionProvider = { + isSupportedFeature(feature) { + switch (feature) { + case 'Schema': + return true; + default: + return false; + } + }, + getBuiltinType(type) { + const t = (type || '').toLowerCase(); + + const isArray = t.startsWith('_'); + + switch (t.replace(/^_/, '')) { + // integers + case 'int2': + case 'smallint': + case 'int4': + case 'integer': + return { type: 'Int', isArray }; + case 'int8': + case 'bigint': + return { type: 'BigInt', isArray }; + + // decimals and floats + case 'numeric': + case 'decimal': + return { type: 'Decimal', isArray }; + case 'float4': + case 'real': + case 'float8': + case 'double precision': + return { type: 'Float', isArray }; + + // boolean + case 'bool': + case 'boolean': + return { type: 'Boolean', isArray }; + + // strings + case 'text': + case 'varchar': + case 'bpchar': + case 'character varying': + case 'character': + return { type: 'String', isArray }; + + // uuid + case 'uuid': + return { type: 'String', isArray }; + + // dates/times + case 'date': + case 'time': + case 'timestamp': + case 'timestamptz': + return { type: 'DateTime', isArray }; + + // binary + case 'bytea': + return { type: 'Bytes', isArray }; + + // json + case 'json': + case 'jsonb': + return { type: 'Json', isArray }; + default: + return { type: 'Unsupported' as const, isArray }; + } + }, + async introspect(connectionString: string): Promise { + const client = new Client({ connectionString }); + await client.connect(); + + const { rows: tables } = await client.query(tableIntrospectionQuery); + const { rows: enums } = await client.query(enumIntrospectionQuery); + + return { + enums, + tables, + }; + }, + getDefaultDatabaseType(type: BuiltinType) { + switch (type) { + case 'String': + return { type: 'text' }; + case 'Boolean': + return { type: 'boolean' }; + case 'Int': + return { type: 'integer' }; + case 'BigInt': + return { type: 'bigint' }; + case 'Float': + return { type: 'double precision' }; + case 'Decimal': + return { type: 'decimal' }; + case 'DateTime': + return { type: 'timestamp', precision: 3 }; + case 'Json': + return { type: 'jsonb' }; + case 'Bytes': + return { type: 'bytea' }; + } + }, + getDefaultValue({ defaultValue, fieldType, services, enums }) { + const val = defaultValue.trim(); + + switch (fieldType) { + case 'DateTime': + if (val === 'CURRENT_TIMESTAMP' || val === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + + // Fallback to string literal for other DateTime defaults + return (ab) => ab.StringLiteral.setValue(val); + + case 'Int': + case 'BigInt': + if (val.startsWith('nextval(')) { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val); + } + break; + + case 'Float': + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + + if (/^-?\d+\.\d+$/.test(val)) { + const numVal = parseFloat(val); + return (ab) => ab.NumberLiteral.setValue(numVal === Math.floor(numVal) ? numVal.toFixed(1) : String(numVal)); + } + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val + '.0'); + } + break; + + case 'Decimal': + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + + if (/^-?\d+\.\d+$/.test(val)) { + const numVal = parseFloat(val); + if (numVal === Math.floor(numVal)) { + return (ab) => ab.NumberLiteral.setValue(numVal.toFixed(2)); + } + return (ab) => ab.NumberLiteral.setValue(String(numVal)); + } + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val + '.00'); + } + break; + + case 'Boolean': + if (val === 'true') { + return (ab) => ab.BooleanLiteral.setValue(true); + } + if (val === 'false') { + return (ab) => ab.BooleanLiteral.setValue(false); + } + break; + + case 'String': + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + + if (val.startsWith("'") && val.endsWith("'")) { + return (ab) => ab.StringLiteral.setValue(val.slice(1, -1).replace(/''/g, "'")); + } + break; + } + + if (val.includes('::')) { + return typeCastingConvert({defaultValue,enums,val,services}); + } + + // Fallback handlers for values that don't match field type-specific patterns + if (val === 'CURRENT_TIMESTAMP' || val === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + + if (val.startsWith('nextval(')) { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + + if (val.includes('(') && val.includes(')')) { + return (ab) => + ab.InvocationExpr.setFunction(getFunctionRef('dbgenerated', services)).addArg((a) => + a.setValue((v) => v.StringLiteral.setValue(val)), + ); + } + + if (val === 'true' || val === 'false') { + return (ab) => ab.BooleanLiteral.setValue(val === 'true'); + } + + if (/^-?\d+\.\d+$/.test(val) || /^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val); + } + + if (val.startsWith("'") && val.endsWith("'")) { + return (ab) => ab.StringLiteral.setValue(val.slice(1, -1).replace(/''/g, "'")); + } + + return null; + }, + + getFieldAttributes({ fieldName, fieldType, datatype, length, precision, services }) { + const factories: DataFieldAttributeFactory[] = []; + + // Add @updatedAt for DateTime fields named updatedAt or updated_at + if (fieldType === 'DateTime' && (fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')) { + factories.push(new DataFieldAttributeFactory().setDecl(getAttributeRef('@updatedAt', services))); + } + + // Add @db.* attribute if the datatype differs from the default + const dbAttr = services.shared.workspace.IndexManager.allElements('Attribute').find( + (d) => d.name.toLowerCase() === `@db.${datatype.toLowerCase()}`, + )?.node as Attribute | undefined; + + const defaultDatabaseType = this.getDefaultDatabaseType(fieldType as BuiltinType); + + if ( + dbAttr && + defaultDatabaseType && + (defaultDatabaseType.type !== datatype || + (defaultDatabaseType.precision && + defaultDatabaseType.precision !== (length || precision))) + ) { + const dbAttrFactory = new DataFieldAttributeFactory().setDecl(dbAttr); + if (length || precision) { + dbAttrFactory.addArg((a) => a.NumberLiteral.setValue(length! || precision!)); + } + factories.push(dbAttrFactory); + } + + return factories; + }, +}; + +const enumIntrospectionQuery = ` +SELECT + n.nspname AS schema_name, + t.typname AS enum_type, + coalesce(json_agg(e.enumlabel ORDER BY e.enumsortorder), '[]') AS values +FROM pg_type t +JOIN pg_enum e ON t.oid = e.enumtypid +JOIN pg_namespace n ON n.oid = t.typnamespace +GROUP BY schema_name, enum_type +ORDER BY schema_name, enum_type;`; + +const tableIntrospectionQuery = ` +SELECT + "ns"."nspname" AS "schema", + "cls"."relname" AS "name", + CASE "cls"."relkind" + WHEN 'r' THEN 'table' + WHEN 'v' THEN 'view' + ELSE NULL + END AS "type", + CASE + WHEN "cls"."relkind" = 'v' THEN pg_get_viewdef("cls"."oid", true) + ELSE NULL + END AS "definition", + ( + SELECT coalesce(json_agg(agg), '[]') + FROM ( + SELECT + "att"."attname" AS "name", + "typ"."typname" AS "datatype", + "tns"."nspname" AS "datatype_schema", + "c"."character_maximum_length" AS "length", + COALESCE("c"."numeric_precision", "c"."datetime_precision") AS "precision", + "fk_ns"."nspname" AS "foreign_key_schema", + "fk_cls"."relname" AS "foreign_key_table", + "fk_att"."attname" AS "foreign_key_column", + "fk_con"."conname" AS "foreign_key_name", + CASE "fk_con"."confupdtype" + WHEN 'a' THEN 'NO ACTION' + WHEN 'r' THEN 'RESTRICT' + WHEN 'c' THEN 'CASCADE' + WHEN 'n' THEN 'SET NULL' + WHEN 'd' THEN 'SET DEFAULT' + ELSE NULL + END AS "foreign_key_on_update", + CASE "fk_con"."confdeltype" + WHEN 'a' THEN 'NO ACTION' + WHEN 'r' THEN 'RESTRICT' + WHEN 'c' THEN 'CASCADE' + WHEN 'n' THEN 'SET NULL' + WHEN 'd' THEN 'SET DEFAULT' + ELSE NULL + END AS "foreign_key_on_delete", + "pk_con"."conkey" IS NOT NULL AS "pk", + ( + EXISTS ( + SELECT 1 + FROM "pg_catalog"."pg_constraint" AS "u_con" + WHERE "u_con"."contype" = 'u' + AND "u_con"."conrelid" = "cls"."oid" + AND array_length("u_con"."conkey", 1) = 1 + AND "att"."attnum" = ANY ("u_con"."conkey") + ) + OR EXISTS ( + SELECT 1 + FROM "pg_catalog"."pg_index" AS "u_idx" + WHERE "u_idx"."indrelid" = "cls"."oid" + AND "u_idx"."indisunique" = TRUE + AND "u_idx"."indnkeyatts" = 1 + AND "att"."attnum" = ANY ("u_idx"."indkey"::int2[]) + ) + ) AS "unique", + ( + SELECT COALESCE( + ( + SELECT "u_con"."conname" + FROM "pg_catalog"."pg_constraint" AS "u_con" + WHERE "u_con"."contype" = 'u' + AND "u_con"."conrelid" = "cls"."oid" + AND array_length("u_con"."conkey", 1) = 1 + AND "att"."attnum" = ANY ("u_con"."conkey") + LIMIT 1 + ), + ( + SELECT "u_idx_cls"."relname" + FROM "pg_catalog"."pg_index" AS "u_idx" + JOIN "pg_catalog"."pg_class" AS "u_idx_cls" ON "u_idx"."indexrelid" = "u_idx_cls"."oid" + WHERE "u_idx"."indrelid" = "cls"."oid" + AND "u_idx"."indisunique" = TRUE + AND "u_idx"."indnkeyatts" = 1 + AND "att"."attnum" = ANY ("u_idx"."indkey"::int2[]) + LIMIT 1 + ) + ) + ) AS "unique_name", + "att"."attgenerated" != '' AS "computed", + pg_get_expr("def"."adbin", "def"."adrelid") AS "default", + "att"."attnotnull" != TRUE AS "nullable", + coalesce( + ( + SELECT json_agg("enm"."enumlabel") AS "o" + FROM "pg_catalog"."pg_enum" AS "enm" + WHERE "enm"."enumtypid" = "typ"."oid" + ), + '[]' + ) AS "options" + + FROM "pg_catalog"."pg_attribute" AS "att" + + INNER JOIN "pg_catalog"."pg_type" AS "typ" ON "typ"."oid" = "att"."atttypid" + + INNER JOIN "pg_catalog"."pg_namespace" AS "tns" ON "tns"."oid" = "typ"."typnamespace" + + LEFT JOIN "information_schema"."columns" AS "c" ON "c"."table_schema" = "ns"."nspname" + AND "c"."table_name" = "cls"."relname" + AND "c"."column_name" = "att"."attname" + LEFT JOIN "pg_catalog"."pg_constraint" AS "pk_con" ON "pk_con"."contype" = 'p' + + AND "pk_con"."conrelid" = "cls"."oid" + AND "att"."attnum" = ANY ("pk_con"."conkey") + LEFT JOIN "pg_catalog"."pg_constraint" AS "fk_con" ON "fk_con"."contype" = 'f' + AND "fk_con"."conrelid" = "cls"."oid" + AND "att"."attnum" = ANY ("fk_con"."conkey") + LEFT JOIN "pg_catalog"."pg_class" AS "fk_cls" ON "fk_cls"."oid" = "fk_con"."confrelid" + LEFT JOIN "pg_catalog"."pg_namespace" AS "fk_ns" ON "fk_ns"."oid" = "fk_cls"."relnamespace" + LEFT JOIN "pg_catalog"."pg_attribute" AS "fk_att" ON "fk_att"."attrelid" = "fk_cls"."oid" + AND "fk_att"."attnum" = ANY ("fk_con"."confkey") + LEFT JOIN "pg_catalog"."pg_attrdef" AS "def" ON "def"."adrelid" = "cls"."oid" AND "def"."adnum" = "att"."attnum" + WHERE + "att"."attrelid" = "cls"."oid" + AND "att"."attnum" >= 0 + AND "att"."attisdropped" != TRUE + ORDER BY "att"."attnum" + ) AS agg + ) AS "columns", + ( + SELECT coalesce(json_agg(agg), '[]') + FROM ( + SELECT + "idx_cls"."relname" AS "name", + "am"."amname" AS "method", + "idx"."indisunique" AS "unique", + "idx"."indisprimary" AS "primary", + "idx"."indisvalid" AS "valid", + "idx"."indisready" AS "ready", + ("idx"."indpred" IS NOT NULL) AS "partial", + pg_get_expr("idx"."indpred", "idx"."indrelid") AS "predicate", + ( + SELECT json_agg( + json_build_object( + 'name', COALESCE("att"."attname", pg_get_indexdef("idx"."indexrelid", "s"."i", true)), + 'expression', CASE WHEN "att"."attname" IS NULL THEN pg_get_indexdef("idx"."indexrelid", "s"."i", true) ELSE NULL END, + 'order', CASE ((( "idx"."indoption"::int2[] )["s"."i"] & 1)) WHEN 1 THEN 'DESC' ELSE 'ASC' END, + 'nulls', CASE (((( "idx"."indoption"::int2[] )["s"."i"] >> 1) & 1)) WHEN 1 THEN 'NULLS FIRST' ELSE 'NULLS LAST' END + ) + ORDER BY "s"."i" + ) + FROM generate_subscripts("idx"."indkey"::int2[], 1) AS "s"("i") + LEFT JOIN "pg_catalog"."pg_attribute" AS "att" + ON "att"."attrelid" = "cls"."oid" + AND "att"."attnum" = ("idx"."indkey"::int2[])["s"."i"] + ) AS "columns" + FROM "pg_catalog"."pg_index" AS "idx" + JOIN "pg_catalog"."pg_class" AS "idx_cls" ON "idx"."indexrelid" = "idx_cls"."oid" + JOIN "pg_catalog"."pg_am" AS "am" ON "idx_cls"."relam" = "am"."oid" + WHERE "idx"."indrelid" = "cls"."oid" + ORDER BY "idx_cls"."relname" + ) AS agg + ) AS "indexes" +FROM "pg_catalog"."pg_class" AS "cls" +INNER JOIN "pg_catalog"."pg_namespace" AS "ns" ON "cls"."relnamespace" = "ns"."oid" +WHERE + "ns"."nspname" !~ '^pg_' + AND "ns"."nspname" != 'information_schema' + AND "cls"."relkind" IN ('r', 'v') + AND "cls"."relname" !~ '^pg_' + AND "cls"."relname" !~ '_prisma_migrations' + ORDER BY "ns"."nspname", "cls"."relname" ASC; +`; + +function typeCastingConvert({defaultValue, enums, val, services}:{val: string, enums: Enum[], defaultValue:string, services:ZModelServices}): ((builder: ExpressionBuilder) => AstFactory) | null { + const [value, type] = val + .replace(/'/g, '') + .split('::') + .map((s) => s.trim()) as [string, string]; + switch (type) { + case 'character varying': + case 'uuid': + case 'json': + case 'jsonb': + case 'text': + if (value === 'NULL') return null; + return (ab) => ab.StringLiteral.setValue(value); + case 'real': + return (ab) => ab.NumberLiteral.setValue(value); + default: { + const enumDef = enums.find((e) => getDbName(e, true) === type); + if (!enumDef) { + return (ab) => + ab.InvocationExpr.setFunction(getFunctionRef('dbgenerated', services)).addArg((a) => + a.setValue((v) => v.StringLiteral.setValue(val)), + ); + } + const enumField = enumDef.fields.find((v) => getDbName(v) === value); + if (!enumField) { + throw new Error( + `Enum value ${value} not found in enum ${type} for default value ${defaultValue}`, + ); + } + return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + } +} \ No newline at end of file diff --git a/packages/cli/src/actions/pull/provider/provider.ts b/packages/cli/src/actions/pull/provider/provider.ts new file mode 100644 index 000000000..a3922b7a7 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/provider.ts @@ -0,0 +1,94 @@ +import type { ZModelServices } from '@zenstackhq/language'; +import type { BuiltinType, Enum, Expression } from '@zenstackhq/language/ast'; +import type { AstFactory, DataFieldAttributeFactory, ExpressionBuilder } from '@zenstackhq/language/factory'; + +export type Cascade = 'NO ACTION' | 'RESTRICT' | 'CASCADE' | 'SET NULL' | 'SET DEFAULT' | null; + +export interface IntrospectedTable { + schema: string; + name: string; + type: 'table' | 'view'; + definition: string | null; + columns: { + name: string; + datatype: string; + length: number | null; + precision: number | null; + datatype_schema: string; + foreign_key_schema: string | null; + foreign_key_table: string | null; + foreign_key_column: string | null; + foreign_key_name: string | null; + foreign_key_on_update: Cascade; + foreign_key_on_delete: Cascade; + pk: boolean; + computed: boolean; + nullable: boolean; + options: string[]; + unique: boolean; + unique_name: string | null; + default: string | null; + }[]; + indexes: { + name: string; + method: string | null; + unique: boolean; + primary: boolean; + valid: boolean; + ready: boolean; + partial: boolean; + predicate: string | null; + columns: { + name: string; + expression: string | null; + order: 'ASC' | 'DESC' | null; + nulls: string | null; + }[]; + }[]; +} + +export type IntrospectedEnum = { + schema_name: string; + enum_type: string; + values: string[]; +}; + +export type IntrospectedSchema = { + tables: IntrospectedTable[]; + enums: IntrospectedEnum[]; +}; + +export type DatabaseFeature = 'Schema' | 'NativeEnum'; + +export interface IntrospectionProvider { + introspect(connectionString: string): Promise; + getBuiltinType(type: string): { + type: BuiltinType | 'Unsupported'; + isArray: boolean; + }; + getDefaultDatabaseType(type: BuiltinType): { precision?: number; type: string } | undefined; + /** + * Get the expression builder callback for a field's @default attribute value. + * Returns null if no @default attribute should be added. + * The callback will be passed to DataFieldAttributeFactory.addArg(). + */ + getDefaultValue(args: { + fieldType: BuiltinType | 'Unsupported'; + defaultValue: string; + services: ZModelServices; + enums: Enum[]; + }): ((builder: ExpressionBuilder) => AstFactory) | null; + /** + * Get additional field attributes based on field type and name (e.g., @updatedAt for DateTime fields, @db.* attributes). + * This is separate from getDefaultValue to keep concerns separated. + */ + getFieldAttributes(args: { + fieldName: string; + fieldType: BuiltinType | 'Unsupported'; + datatype: string; + length: number | null; + precision: number | null; + services: ZModelServices; + }): DataFieldAttributeFactory[]; + isSupportedFeature(feature: DatabaseFeature): boolean; +} diff --git a/packages/cli/src/actions/pull/provider/sqlite.ts b/packages/cli/src/actions/pull/provider/sqlite.ts new file mode 100644 index 000000000..fcdbfbad7 --- /dev/null +++ b/packages/cli/src/actions/pull/provider/sqlite.ts @@ -0,0 +1,378 @@ +import type { Attribute, BuiltinType } from '@zenstackhq/language/ast'; +import { DataFieldAttributeFactory } from '@zenstackhq/language/factory'; +import { getAttributeRef, getDbName, getFunctionRef } from '../utils'; +import type { IntrospectedEnum, IntrospectedSchema, IntrospectedTable, IntrospectionProvider } from './provider'; + +// Note: We dynamically import better-sqlite3 inside the async function to avoid +// requiring it at module load time for environments that don't use SQLite. + +export const sqlite: IntrospectionProvider = { + isSupportedFeature(feature) { + switch (feature) { + case 'Schema': + // Multi-schema feature is not available for SQLite because it doesn't have + // the same concept of schemas as namespaces (unlike PostgreSQL, CockroachDB, SQL Server). + return false; + case 'NativeEnum': + // SQLite doesn't support native enum types + return false; + default: + return false; + } + }, + getBuiltinType(type) { + const t = (type || '').toLowerCase().trim(); + // SQLite has no array types + const isArray = false; + switch (t) { + case 'integer': + return { type: 'Int', isArray }; + case 'text': + return { type: 'String', isArray }; + case 'bigint': + return { type: 'BigInt', isArray }; + case 'blob': + return { type: 'Bytes', isArray }; + case 'real': + return { type: 'Float', isArray }; + case 'numeric': + case 'decimal': + return { type: 'Decimal', isArray }; + case 'datetime': + return { type: 'DateTime', isArray }; + case 'jsonb': + return { type: 'Json', isArray }; + case 'boolean': + return { type: 'Boolean', isArray }; + default: { + return { type: 'Unsupported' as const, isArray }; + } + } + }, + + getDefaultDatabaseType() { + return undefined; + }, + + async introspect(connectionString: string): Promise { + const SQLite = (await import('better-sqlite3')).default; + const db = new SQLite(connectionString, { readonly: true }); + + try { + const all = (sql: string): T[] => { + const stmt: any = db.prepare(sql); + return stmt.all() as T[]; + }; + + // List user tables and views (exclude internal sqlite_*) + const tablesRaw = all<{ name: string; type: 'table' | 'view'; definition: string | null }>( + "SELECT name, type, sql AS definition FROM sqlite_schema WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' ORDER BY name", + ); + + // Detect AUTOINCREMENT by parsing the CREATE TABLE statement + // The sqlite_sequence table only has entries after rows are inserted, + // so we need to check the actual table definition instead + const autoIncrementTables = new Set(); + for (const t of tablesRaw) { + if (t.type === 'table' && t.definition) { + // AUTOINCREMENT keyword appears in PRIMARY KEY definition + // e.g., PRIMARY KEY("id" AUTOINCREMENT) or PRIMARY KEY(id AUTOINCREMENT) + if (/\bAUTOINCREMENT\b/i.test(t.definition)) { + autoIncrementTables.add(t.name); + } + } + } + + const tables: IntrospectedTable[] = []; + + for (const t of tablesRaw) { + const tableName = t.name; + const schema = ''; + + // Check if this table has autoincrement (via sqlite_sequence) + const hasAutoIncrement = autoIncrementTables.has(tableName); + + // Columns with extended info; filter out hidden=1 (internal/rowid), mark computed if hidden=2 (generated) + const columnsInfo = all<{ + cid: number; + name: string; + type: string; + notnull: number; + dflt_value: string | null; + pk: number; + hidden?: number; + }>(`PRAGMA table_xinfo('${tableName.replace(/'/g, "''")}')`); + + // Index list (used for both unique inference and index collection) + const tableNameEsc = tableName.replace(/'/g, "''"); + const idxList = all<{ + seq: number; + name: string; + unique: number; + origin: string; + partial: number; + }>(`PRAGMA index_list('${tableNameEsc}')`).filter((r) => !r.name.startsWith('sqlite_autoindex_')); + + // Unique columns detection via unique indexes with single column + const uniqueSingleColumn = new Set(); + const uniqueIndexRows = idxList.filter((r) => r.unique === 1 && r.partial !== 1); + for (const idx of uniqueIndexRows) { + const idxCols = all<{ name: string }>(`PRAGMA index_info('${idx.name.replace(/'/g, "''")}')`); + if (idxCols.length === 1 && idxCols[0]?.name) { + uniqueSingleColumn.add(idxCols[0].name); + } + } + + // Indexes details + const indexes: IntrospectedTable['indexes'] = idxList.map((idx) => { + const idxCols = all<{ name: string }>(`PRAGMA index_info('${idx.name.replace(/'/g, "''")}')`); + return { + name: idx.name, + method: null, // SQLite does not expose index method + unique: idx.unique === 1, + primary: false, // SQLite does not expose this directly; handled via pk in columns + valid: true, // SQLite does not expose index validity + ready: true, // SQLite does not expose index readiness + partial: idx.partial === 1, + predicate: idx.partial === 1 ? '[partial]' : null, // SQLite does not expose index predicate + columns: idxCols.map((col) => ({ + name: col.name, + expression: null, + order: null, + nulls: null, + })), + }; + }); + + // Foreign keys mapping by column name + const fkRows = all<{ + id: number; + seq: number; + table: string; + from: string; + to: string | null; + on_update: any; + on_delete: any; + }>(`PRAGMA foreign_key_list('${tableName.replace(/'/g, "''")}')`); + + // Extract FK constraint names from CREATE TABLE statement + // Pattern: CONSTRAINT "name" FOREIGN KEY("column") or CONSTRAINT name FOREIGN KEY(column) + const fkConstraintNames = new Map(); + if (t.definition) { + // Match: CONSTRAINT "name" FOREIGN KEY("col") or CONSTRAINT name FOREIGN KEY(col) + // Use [^"'`]+ for quoted names to capture full identifier including underscores and other chars + const fkRegex = /CONSTRAINT\s+(?:["'`]([^"'`]+)["'`]|(\w+))\s+FOREIGN\s+KEY\s*\(\s*(?:["'`]([^"'`]+)["'`]|(\w+))\s*\)/gi; + let match; + while ((match = fkRegex.exec(t.definition)) !== null) { + // match[1] = quoted constraint name, match[2] = unquoted constraint name + // match[3] = quoted column name, match[4] = unquoted column name + const constraintName = match[1] || match[2]; + const columnName = match[3] || match[4]; + if (constraintName && columnName) { + fkConstraintNames.set(columnName, constraintName); + } + } + } + + const fkByColumn = new Map< + string, + { + foreign_key_schema: string | null; + foreign_key_table: string | null; + foreign_key_column: string | null; + foreign_key_name: string | null; + foreign_key_on_update: IntrospectedTable['columns'][number]['foreign_key_on_update']; + foreign_key_on_delete: IntrospectedTable['columns'][number]['foreign_key_on_delete']; + } + >(); + + for (const fk of fkRows) { + fkByColumn.set(fk.from, { + foreign_key_schema: '', + foreign_key_table: fk.table || null, + foreign_key_column: fk.to || null, + foreign_key_name: fkConstraintNames.get(fk.from) ?? null, + foreign_key_on_update: (fk.on_update as any) ?? null, + foreign_key_on_delete: (fk.on_delete as any) ?? null, + }); + } + + const columns: IntrospectedTable['columns'] = []; + for (const c of columnsInfo) { + // hidden: 1 (hidden/internal) -> skip; 2 (generated) -> mark computed + const hidden = c.hidden ?? 0; + if (hidden === 1) continue; + + const fk = fkByColumn.get(c.name); + + // Determine default value - check for autoincrement + // AUTOINCREMENT in SQLite can only be on INTEGER PRIMARY KEY column + let defaultValue = c.dflt_value; + if (hasAutoIncrement && c.pk) { + defaultValue = 'autoincrement'; + } + + columns.push({ + name: c.name, + datatype: c.type || '', + length: null, + precision: null, + datatype_schema: schema, + foreign_key_schema: fk?.foreign_key_schema ?? null, + foreign_key_table: fk?.foreign_key_table ?? null, + foreign_key_column: fk?.foreign_key_column ?? null, + foreign_key_name: fk?.foreign_key_name ?? null, + foreign_key_on_update: fk?.foreign_key_on_update ?? null, + foreign_key_on_delete: fk?.foreign_key_on_delete ?? null, + pk: !!c.pk, + computed: hidden === 2, + nullable: c.notnull !== 1, + default: defaultValue, + options: [], + unique: uniqueSingleColumn.has(c.name), + unique_name: null, + }); + } + + tables.push({ schema, name: tableName, columns, type: t.type, definition: t.definition, indexes }); + } + + const enums: IntrospectedEnum[] = []; // SQLite doesn't support enums + + return { tables, enums }; + } finally { + db.close(); + } + }, + + getDefaultValue({ defaultValue, fieldType, services, enums }) { + const val = defaultValue.trim(); + + switch (fieldType) { + case 'DateTime': + if (val === 'CURRENT_TIMESTAMP' || val === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + // Fallback to string literal for other DateTime defaults + return (ab) => ab.StringLiteral.setValue(val); + + case 'Int': + case 'BigInt': + if (val === 'autoincrement') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val); + } + break; + + case 'Float': + if (/^-?\d+\.\d+$/.test(val)) { + const numVal = parseFloat(val); + return (ab) => ab.NumberLiteral.setValue(numVal === Math.floor(numVal) ? numVal.toFixed(1) : String(numVal)); + } + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val + '.0'); + } + break; + + case 'Decimal': + if (/^-?\d+\.\d+$/.test(val)) { + const numVal = parseFloat(val); + if (numVal === Math.floor(numVal)) { + return (ab) => ab.NumberLiteral.setValue(numVal.toFixed(2)); + } + return (ab) => ab.NumberLiteral.setValue(String(numVal)); + } + if (/^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val + '.00'); + } + break; + + case 'Boolean': + if (val === 'true' || val === '1') { + return (ab) => ab.BooleanLiteral.setValue(true); + } + if (val === 'false' || val === '0') { + return (ab) => ab.BooleanLiteral.setValue(false); + } + break; + + case 'String': + if (val.startsWith("'") && val.endsWith("'")) { + const strippedName = val.slice(1, -1); + const enumDef = enums.find((e) => e.fields.find((v) => getDbName(v) === strippedName)); + if (enumDef) { + const enumField = enumDef.fields.find((v) => getDbName(v) === strippedName); + if (enumField) return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + return (ab) => ab.StringLiteral.setValue(strippedName); + } + break; + } + + // Fallback handlers for values that don't match field type-specific patterns + if (val === 'CURRENT_TIMESTAMP' || val === 'now()') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('now', services)); + } + + if (val === 'autoincrement') { + return (ab) => ab.InvocationExpr.setFunction(getFunctionRef('autoincrement', services)); + } + + if (val === 'true' || val === 'false') { + return (ab) => ab.BooleanLiteral.setValue(val === 'true'); + } + + if (/^-?\d+\.\d+$/.test(val) || /^-?\d+$/.test(val)) { + return (ab) => ab.NumberLiteral.setValue(val); + } + + if (val.startsWith("'") && val.endsWith("'")) { + const strippedName = val.slice(1, -1); + const enumDef = enums.find((e) => e.fields.find((v) => getDbName(v) === strippedName)); + if (enumDef) { + const enumField = enumDef.fields.find((v) => getDbName(v) === strippedName); + if (enumField) return (ab) => ab.ReferenceExpr.setTarget(enumField); + } + return (ab) => ab.StringLiteral.setValue(strippedName); + } + + //TODO: add more default value factories if exists + throw new Error( + `This default value type currently is not supported. Please open an issue on github. Values: "${defaultValue}"`, + ); + }, + + getFieldAttributes({ fieldName, fieldType, datatype, length, precision, services }) { + const factories: DataFieldAttributeFactory[] = []; + + // Add @updatedAt for DateTime fields named updatedAt or updated_at + if (fieldType === 'DateTime' && (fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')) { + factories.push(new DataFieldAttributeFactory().setDecl(getAttributeRef('@updatedAt', services))); + } + + // Add @db.* attribute if the datatype differs from the default + const dbAttr = services.shared.workspace.IndexManager.allElements('Attribute').find( + (d) => d.name.toLowerCase() === `@db.${datatype.toLowerCase()}`, + )?.node as Attribute | undefined; + + const defaultDatabaseType = this.getDefaultDatabaseType(fieldType as BuiltinType); + + if ( + dbAttr && + defaultDatabaseType && + (defaultDatabaseType.type !== datatype || + (defaultDatabaseType.precision && + defaultDatabaseType.precision !== (length || precision))) + ) { + const dbAttrFactory = new DataFieldAttributeFactory().setDecl(dbAttr); + if (length || precision) { + dbAttrFactory.addArg((a) => a.NumberLiteral.setValue(length! || precision!)); + } + factories.push(dbAttrFactory); + } + + return factories; + }, +}; diff --git a/packages/cli/src/actions/pull/utils.ts b/packages/cli/src/actions/pull/utils.ts new file mode 100644 index 000000000..b46693afe --- /dev/null +++ b/packages/cli/src/actions/pull/utils.ts @@ -0,0 +1,169 @@ +import type { ZModelServices } from '@zenstackhq/language'; +import { + type AbstractDeclaration, + type DataField, + type DataModel, + type Enum, + type EnumField, + type FunctionDecl, + isInvocationExpr, + type Attribute, + type Model, + type ReferenceExpr, + type StringLiteral, +} from '@zenstackhq/language/ast'; +import { getLiteralArray, getStringLiteral } from '@zenstackhq/language/utils'; +import type { DataSourceProviderType } from '@zenstackhq/schema'; +import type { Reference } from 'langium'; + +export function getAttribute(model: Model, attrName: string) { + if (!model.$document) throw new Error('Model is not associated with a document.'); + + const references = model.$document.references as Reference[]; + return references.find((a) => a.ref?.$type === 'Attribute' && a.ref?.name === attrName)?.ref as + | Attribute + | undefined; +} + +export function getDatasource(model: Model) { + const datasource = model.declarations.find((d) => d.$type === 'DataSource'); + if (!datasource) { + throw new Error('No datasource declaration found in the schema.'); + } + + const urlField = datasource.fields.find((f) => f.name === 'url'); + + if (!urlField) throw new Error(`No url field found in the datasource declaration.`); + + let url = getStringLiteral(urlField.value); + + if (!url && isInvocationExpr(urlField.value)) { + const envName = getStringLiteral(urlField.value.args[0]?.value); + if (!envName) { + throw new Error('The url field must be a string literal or an env().'); + } + if (!process.env[envName]) { + throw new Error( + `Environment variable ${envName} is not set, please set it to the database connection string.`, + ); + } + url = process.env[envName]; + } + + if (!url) { + throw new Error('The url field must be a string literal or an env().'); + } + + if (url.startsWith('file:')) { + url = new URL(url, `file:${model.$document!.uri.path}`).pathname; + if (process.platform === 'win32' && url[0] === '/') url = url.slice(1); + } + + const defaultSchemaField = datasource.fields.find((f) => f.name === 'defaultSchema'); + const defaultSchema = (defaultSchemaField && getStringLiteral(defaultSchemaField.value)) || 'public'; + + const schemasField = datasource.fields.find((f) => f.name === 'schemas'); + const schemas = + (schemasField && + getLiteralArray(schemasField.value) + ?.filter((s) => s !== undefined)) as string[] || + []; + + return { + name: datasource.name, + provider: getStringLiteral( + datasource.fields.find((f) => f.name === 'provider')?.value, + ) as DataSourceProviderType, + url, + defaultSchema, + schemas, + allSchemas: [defaultSchema, ...schemas], + }; +} + +export function getDbName(decl: AbstractDeclaration | DataField | EnumField, includeSchema: boolean = false): string { + if (!('attributes' in decl)) return decl.name; + + const schemaAttr = decl.attributes.find((a) => a.decl.ref?.name === '@@schema'); + const schemaAttrValue = schemaAttr?.args[0]?.value; + let schema: string; + if (schemaAttrValue?.$type !== 'StringLiteral') schema = 'public'; + if (!schemaAttr) schema = 'public'; + else schema = (schemaAttr.args[0]?.value as any)?.value as string; + + const formatName = (name: string) => `${schema && includeSchema ? `${schema}.` : ''}${name}`; + + const nameAttr = decl.attributes.find((a) => a.decl.ref?.name === '@@map' || a.decl.ref?.name === '@map'); + if (!nameAttr) return formatName(decl.name); + const attrValue = nameAttr.args[0]?.value; + + if (attrValue?.$type !== 'StringLiteral') return formatName(decl.name); + + return formatName(attrValue.value); +} + +export function getRelationFkName(decl: DataField): string | undefined { + const relationAttr = decl?.attributes.find((a) => a.decl.ref?.name === '@relation'); + const schemaAttrValue = relationAttr?.args.find((a) => a.name === 'map')?.value as StringLiteral; + return schemaAttrValue?.value; +} + +/** + * Gets the FK field names from the @relation attribute's `fields` argument. + * Returns a sorted, comma-separated string of field names for comparison. + * e.g., @relation(fields: [userId], references: [id]) -> "userId" + * e.g., @relation(fields: [postId, tagId], references: [id, id]) -> "postId,tagId" + */ +export function getRelationFieldsKey(decl: DataField): string | undefined { + const relationAttr = decl?.attributes.find((a) => a.decl.ref?.name === '@relation'); + if (!relationAttr) return undefined; + + const fieldsArg = relationAttr.args.find((a) => a.name === 'fields')?.value; + if (!fieldsArg || fieldsArg.$type !== 'ArrayExpr') return undefined; + + const fieldNames = fieldsArg.items + .filter((item): item is ReferenceExpr => item.$type === 'ReferenceExpr') + .map((item) => item.target?.$refText || item.target?.ref?.name) + .filter((name): name is string => !!name) + .sort(); + + return fieldNames.length > 0 ? fieldNames.join(',') : undefined; +} + +export function getDbSchemaName(decl: DataModel | Enum): string { + const schemaAttr = decl.attributes.find((a) => a.decl.ref?.name === '@@schema'); + if (!schemaAttr) return 'public'; + const attrValue = schemaAttr.args[0]?.value; + + if (attrValue?.$type !== 'StringLiteral') return 'public'; + + return attrValue.value; +} + +export function getDeclarationRef( + type: T['$type'], + name: string, + services: ZModelServices, +) { + const node = services.shared.workspace.IndexManager.allElements(type).find( + (m) => m.node && getDbName(m.node as T) === name, + )?.node; + if (!node) throw new Error(`Declaration not found: ${name}`); + return node as T; +} + +export function getEnumRef(name: string, services: ZModelServices) { + return getDeclarationRef('Enum', name, services); +} + +export function getModelRef(name: string, services: ZModelServices) { + return getDeclarationRef('DataModel', name, services); +} + +export function getAttributeRef(name: string, services: ZModelServices) { + return getDeclarationRef('Attribute', name, services); +} + +export function getFunctionRef(name: string, services: ZModelServices) { + return getDeclarationRef('FunctionDecl', name, services); +} diff --git a/packages/cli/src/index.ts b/packages/cli/src/index.ts index 4efc86fd9..8d253cc3e 100644 --- a/packages/cli/src/index.ts +++ b/packages/cli/src/index.ts @@ -143,6 +143,31 @@ function createProgram() { .addOption(new Option('--force-reset', 'force a reset of the database before push')) .action((options) => dbAction('push', options)); + dbCommand + .command('pull') + .description('Introspect your database.') + .addOption(schemaOption) + .addOption(noVersionCheckOption) + .addOption(new Option('-o, --out ', 'add custom output path for the introspected schema')) + .addOption( + new Option('--model-casing ', 'set the casing of generated models').default( + 'none', + ), + ) + .addOption( + new Option('--field-casing ', 'set the casing of generated fields').default( + 'none', + ), + ) + .addOption( + new Option('--always-map', 'always add @map and @@map attributes to models and fields').default(false), + ) + .addOption( + new Option('--quote ', 'set the quote style of generated schema files').default('double'), + ) + .addOption(new Option('--indent ', 'set the indentation of the generated schema files').default(4)) + .action((options) => dbAction('pull', options)); + dbCommand .command('seed') .description('Seed the database') diff --git a/packages/cli/test/check.test.ts b/packages/cli/test/check.test.ts index 287bb6b80..60f80903e 100644 --- a/packages/cli/test/check.test.ts +++ b/packages/cli/test/check.test.ts @@ -83,17 +83,12 @@ describe('CLI validate command test', () => { it('should validate schema with syntax errors', () => { const modelWithSyntaxError = ` -datasource db { - provider = "sqlite" - url = "file:./dev.db" -} - model User { id String @id @default(cuid()) email String @unique // Missing closing brace - syntax error `; - const workDir = createProject(modelWithSyntaxError, false); + const workDir = createProject(modelWithSyntaxError); // Should throw an error due to syntax error expect(() => runCli('check', workDir)).toThrow(); diff --git a/packages/cli/test/db.test.ts b/packages/cli/test/db.test.ts index 636dcff8f..b17f92e5e 100644 --- a/packages/cli/test/db.test.ts +++ b/packages/cli/test/db.test.ts @@ -11,13 +11,13 @@ model User { describe('CLI db commands test', () => { it('should generate a database with db push', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); runCli('db push', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); }); it('should seed the database with db seed with seed script', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); const pkgJson = JSON.parse(fs.readFileSync(path.join(workDir, 'package.json'), 'utf8')); pkgJson.zenstack = { seed: 'node seed.js', @@ -36,7 +36,7 @@ fs.writeFileSync('seed.txt', 'success'); }); it('should seed the database after migrate reset', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); const pkgJson = JSON.parse(fs.readFileSync(path.join(workDir, 'package.json'), 'utf8')); pkgJson.zenstack = { seed: 'node seed.js', @@ -55,7 +55,7 @@ fs.writeFileSync('seed.txt', 'success'); }); it('should skip seeding the database without seed script', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); runCli('db seed', workDir); }); }); diff --git a/packages/cli/test/db/pull.test.ts b/packages/cli/test/db/pull.test.ts new file mode 100644 index 000000000..d8d677258 --- /dev/null +++ b/packages/cli/test/db/pull.test.ts @@ -0,0 +1,450 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { createFormattedProject, createProject, getDefaultPrelude, runCli } from '../utils'; +import { loadSchemaDocument } from '../../src/actions/action-utils'; +import { ZModelCodeGenerator, formatDocument } from '@zenstackhq/language'; +import { getTestDbProvider } from '@zenstackhq/testtools'; + +const getSchema = (workDir: string) => fs.readFileSync(path.join(workDir, 'zenstack/schema.zmodel')).toString(); +const generator = new ZModelCodeGenerator({ + quote: 'double', + indent: 4, +}); + +describe('DB pull - Common features (all providers)', () => { + describe('Pull from zero - restore complete schema from database', () => { + it('should restore basic schema with all supported types', async () => { + const workDir = await createFormattedProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + name String? + age Int @default(0) + balance Decimal @default(0.00) + isActive Boolean @default(true) + bigCounter BigInt @default(0) + score Float @default(0.0) + bio String? + avatar Bytes? + metadata Json? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt +}`, + ); + runCli('db push', workDir); + + // Store the schema after db push (this is what provider names will be) + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const { model } = await loadSchemaDocument(schemaFile, { returnServices: true }); + const expectedSchema = generator.generate(model); + + // Remove schema content to simulate restoration from zero + fs.writeFileSync(schemaFile, getDefaultPrelude()); + + // Pull should fully restore the schema + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(expectedSchema); + }); + + it('should restore schema with relations', async () => { + const workDir = await createFormattedProject( + `model Post { + id Int @id @default(autoincrement()) + title String + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + posts Post[] +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const { model } = await loadSchemaDocument(schemaFile, { returnServices: true }); + const expectedSchema = generator.generate(model); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(expectedSchema); + }); + + it('should restore schema with many-to-many relations', async () => { + const workDir = await createFormattedProject( + `model Post { + id Int @id @default(autoincrement()) + title String + postTags PostTag[] +} + +model PostTag { + post Post @relation(fields: [postId], references: [id], onDelete: Cascade) + postId Int + tag Tag @relation(fields: [tagId], references: [id], onDelete: Cascade) + tagId Int + + @@id([postId, tagId]) +} + +model Tag { + id Int @id @default(autoincrement()) + name String @unique + postTags PostTag[] +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const { model } = await loadSchemaDocument(schemaFile, { returnServices: true }); + const expectedSchema = generator.generate(model); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(expectedSchema); + }); + + it('should restore schema with indexes and unique constraints', async () => { + const workDir = await createFormattedProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + username String + firstName String + lastName String + role String + + @@unique([username, email]) + @@index([role]) + @@index([firstName, lastName]) + @@index([email, username, role]) +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const { model } = await loadSchemaDocument(schemaFile, { returnServices: true }); + const expectedSchema = generator.generate(model); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(expectedSchema); + }); + + it('should restore schema with composite primary keys', async () => { + const workDir = await createFormattedProject( + `model UserRole { + userId String + role String + grantedAt DateTime @default(now()) + + @@id([userId, role]) +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const { model } = await loadSchemaDocument(schemaFile, { returnServices: true }); + const expectedSchema = generator.generate(model); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(expectedSchema); + }); + + }); + + describe('Pull with existing schema - preserve schema features', () => { + it('should preserve field and table mappings', async () => { + const workDir = await createFormattedProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique @map("email_address") + firstName String @map("first_name") + lastName String @map("last_name") + + @@map("users") +}`, + ); + runCli('db push', workDir); + + const originalSchema = getSchema(workDir); + runCli('db pull --indent 4', workDir); + + expect(getSchema(workDir)).toEqual(originalSchema); + }); + + it('should not modify a comprehensive schema with all features', async () => { + const workDir = await createFormattedProject(`model User { + id Int @id @default(autoincrement()) + email String @unique @map("email_address") + name String? @default("Anonymous") + role Role @default(USER) + profile Profile? + shared_profile Profile? @relation("shared") + posts Post[] + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + jsonData Json? + balance Decimal @default(0.00) + isActive Boolean @default(true) + bigCounter BigInt @default(0) + bytes Bytes? + + @@index([role]) + @@map("users") +} + +model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId Int @unique + user_shared User @relation("shared", fields: [shared_userId], references: [id], onDelete: Cascade) + shared_userId Int @unique + bio String? + avatarUrl String? + + @@map("profiles") +} + +model Post { + id Int @id @default(autoincrement()) + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + title String + content String? + published Boolean @default(false) + tags PostTag[] + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + slug String + score Float @default(0.0) + metadata Json? + + @@unique([authorId, slug]) + @@index([authorId, published]) + @@map("posts") +} + +model Tag { + id Int @id @default(autoincrement()) + name String @unique + posts PostTag[] + createdAt DateTime @default(now()) + + @@index([name], name: "tag_name_idx") + @@map("tags") +} + +model PostTag { + post Post @relation(fields: [postId], references: [id], onDelete: Cascade) + postId Int + tag Tag @relation(fields: [tagId], references: [id], onDelete: Cascade) + tagId Int + assignedAt DateTime @default(now()) + note String? @default("initial") + + @@id([postId, tagId]) + @@map("post_tags") +} + +enum Role { + USER + ADMIN + MODERATOR +}`, + ); + runCli('db push', workDir); + + const originalSchema = getSchema(workDir); + runCli('db pull --indent 4', workDir); + expect(getSchema(workDir)).toEqual(originalSchema); + }); + + it('should preserve imports when pulling with multi-file schema', async () => { + const workDir = createProject('', { customPrelude: true }); + const schemaPath = path.join(workDir, 'zenstack/schema.zmodel'); + const modelsDir = path.join(workDir, 'zenstack/models'); + + fs.mkdirSync(modelsDir, { recursive: true }); + + // Create main schema with imports + const mainSchema = await formatDocument(`import "./models/user" +import "./models/post" + +${getDefaultPrelude()}`); + fs.writeFileSync(schemaPath, mainSchema); + + // Create user model + const userModel = await formatDocument(`import "./post" + +model User { + id Int @id @default(autoincrement()) + email String @unique + name String? + posts Post[] + createdAt DateTime @default(now()) +}`); + fs.writeFileSync(path.join(modelsDir, 'user.zmodel'), userModel); + + // Create post model + const postModel = await formatDocument(`import "./user" + +model Post { + id Int @id @default(autoincrement()) + title String + content String? + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + createdAt DateTime @default(now()) +}`); + fs.writeFileSync(path.join(modelsDir, 'post.zmodel'), postModel); + + runCli('db push', workDir); + + // Pull and verify imports are preserved + runCli('db pull --indent 4', workDir); + + const pulledMainSchema = fs.readFileSync(schemaPath).toString(); + const pulledUserSchema = fs.readFileSync(path.join(modelsDir, 'user.zmodel')).toString(); + const pulledPostSchema = fs.readFileSync(path.join(modelsDir, 'post.zmodel')).toString(); + + expect(pulledMainSchema).toEqual(mainSchema); + expect(pulledUserSchema).toEqual(userModel); + expect(pulledPostSchema).toEqual(postModel); + }); + }); +}); + +describe('DB pull - PostgreSQL specific features', () => { + it('should restore schema with multiple database schemas', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + const workDir = await createFormattedProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + posts Post[] + + @@schema("auth") +} + +model Post { + id Int @id @default(autoincrement()) + title String + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + + @@schema("content") +}`, + { provider: 'postgresql', extra:{ schemas: ["public", "content", "auth"] } }, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + const { model } = await loadSchemaDocument(schemaFile, { returnServices: true }); + const expectedSchema = generator.generate(model); + + fs.writeFileSync(schemaFile, getDefaultPrelude({ provider: 'postgresql', extra:{ schemas: ["public", "content", "auth"]} })); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(expectedSchema); + }); + + it('should preserve native PostgreSQL enums when schema exists', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + const workDir = await createFormattedProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + status UserStatus @default(ACTIVE) + role UserRole @default(USER) +} + +enum UserStatus { + ACTIVE + INACTIVE + SUSPENDED +} + +enum UserRole { + USER + ADMIN + MODERATOR +}`, + { provider: 'postgresql' }, + ); + runCli('db push', workDir); + + const originalSchema = getSchema(workDir); + runCli('db pull --indent 4', workDir); + const pulledSchema = getSchema(workDir); + + expect(pulledSchema).toEqual(originalSchema); + }); + + it('should not modify schema with PostgreSQL-specific features', async ({ skip }) => { + const provider = getTestDbProvider(); + if (provider !== 'postgresql') { + skip(); + return; + } + const workDir = await createFormattedProject( + `model User { + id Int @id @default(autoincrement()) + email String @unique + status UserStatus @default(ACTIVE) + posts Post[] + metadata Json? + + @@schema("auth") + @@index([status]) +} + +model Post { + id Int @id @default(autoincrement()) + title String + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + tags String[] + + @@schema("content") + @@index([authorId]) +} + +enum UserStatus { + ACTIVE + INACTIVE + SUSPENDED +}`, + { provider: 'postgresql', extra:{ schemas: ["public", "content", "auth"] } }, + ); + runCli('db push', workDir); + + const originalSchema = getSchema(workDir); + runCli('db pull --indent 4', workDir); + + expect(getSchema(workDir)).toEqual(originalSchema); + }); +}); diff --git a/packages/cli/test/db/push.test.ts b/packages/cli/test/db/push.test.ts new file mode 100644 index 000000000..9c688df4d --- /dev/null +++ b/packages/cli/test/db/push.test.ts @@ -0,0 +1,18 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import { describe, expect, it } from 'vitest'; +import { createProject, runCli } from '../utils'; + +const model = ` +model User { + id String @id @default(cuid()) +} +`; + +describe('CLI db commands test', () => { + it('should generate a database with db push', () => { + const workDir = createProject(model, { provider: 'sqlite' }); + runCli('db push', workDir); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); + }); +}); diff --git a/packages/cli/test/migrate.test.ts b/packages/cli/test/migrate.test.ts index 56a0fec83..86abc3576 100644 --- a/packages/cli/test/migrate.test.ts +++ b/packages/cli/test/migrate.test.ts @@ -11,36 +11,36 @@ model User { describe('CLI migrate commands test', () => { it('should generate a database with migrate dev', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); runCli('migrate dev --name init', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); expect(fs.existsSync(path.join(workDir, 'zenstack/migrations'))).toBe(true); }); it('should reset the database with migrate reset', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); runCli('db push', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); runCli('migrate reset --force', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); }); it('should reset the database with migrate deploy', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); runCli('migrate dev --name init', workDir); - fs.rmSync(path.join(workDir, 'zenstack/dev.db')); + fs.rmSync(path.join(workDir, 'zenstack/test.db')); runCli('migrate deploy', workDir); - expect(fs.existsSync(path.join(workDir, 'zenstack/dev.db'))).toBe(true); + expect(fs.existsSync(path.join(workDir, 'zenstack/test.db'))).toBe(true); }); it('supports migrate status', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); runCli('migrate dev --name init', workDir); runCli('migrate status', workDir); }); it('supports migrate resolve', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); runCli('migrate dev --name init', workDir); // find the migration record "timestamp_init" @@ -51,7 +51,7 @@ describe('CLI migrate commands test', () => { fs.writeFileSync(path.join(workDir, 'zenstack/migrations', migration!, 'migration.sql'), 'invalid content'); // redeploy the migration, which will fail - fs.rmSync(path.join(workDir, 'zenstack/dev.db'), { force: true }); + fs.rmSync(path.join(workDir, 'zenstack/test.db'), { force: true }); try { runCli('migrate deploy', workDir); } catch { @@ -66,7 +66,7 @@ describe('CLI migrate commands test', () => { }); it('should throw error when neither applied nor rolled-back is provided', () => { - const workDir = createProject(model); + const workDir = createProject(model, { provider: 'sqlite' }); expect(() => runCli('migrate resolve', workDir)).toThrow(); }); }); diff --git a/packages/cli/test/utils.ts b/packages/cli/test/utils.ts index 2fafb2074..4a58598c2 100644 --- a/packages/cli/test/utils.ts +++ b/packages/cli/test/utils.ts @@ -1,22 +1,106 @@ -import { createTestProject } from '@zenstackhq/testtools'; +import { createTestProject, getTestDbProvider } from '@zenstackhq/testtools'; +import { createHash } from 'node:crypto'; import { execSync } from 'node:child_process'; import fs from 'node:fs'; import path from 'node:path'; +import { expect } from 'vitest'; +import { formatDocument } from '@zenstackhq/language'; -const ZMODEL_PRELUDE = `datasource db { - provider = "sqlite" - url = "file:./dev.db" +const TEST_PG_CONFIG = { + host: process.env['TEST_PG_HOST'] ?? 'localhost', + port: process.env['TEST_PG_PORT'] ? parseInt(process.env['TEST_PG_PORT']) : 5432, + user: process.env['TEST_PG_USER'] ?? 'postgres', + password: process.env['TEST_PG_PASSWORD'] ?? 'postgres', +}; + +const TEST_MYSQL_CONFIG = { + host: process.env['TEST_MYSQL_HOST'] ?? 'localhost', + port: process.env['TEST_MYSQL_PORT'] ? parseInt(process.env['TEST_MYSQL_PORT']) : 3306, + user: process.env['TEST_MYSQL_USER'] ?? 'root', + password: process.env['TEST_MYSQL_PASSWORD'] ?? 'mysql', +}; + +function getTestDbName(provider: string) { + if (provider === 'sqlite') { + return './test.db'; + } + const testName = expect.getState().currentTestName ?? 'unnamed'; + const testPath = expect.getState().testPath ?? ''; + // digest test name + const digest = createHash('md5') + .update(testName + testPath) + .digest('hex'); + // compute a database name based on test name + return ( + 'test_' + + testName + .toLowerCase() + .replace(/[^a-z0-9_]/g, '_') + .replace(/_+/g, '_') + .substring(0, 30) + + digest.slice(0, 6) + ); } -`; -export function createProject(zmodel: string, addPrelude = true) { +export function getDefaultPrelude(options?: { provider?: 'sqlite' | 'postgresql' | 'mysql', extra?: Record }) { + const provider = (options?.provider || getTestDbProvider()) ?? 'sqlite'; + const dbName = getTestDbName(provider); + let dbUrl: string; + + switch (provider) { + case 'sqlite': + dbUrl = `file:${dbName}`; + break; + case 'postgresql': + dbUrl = `postgres://${TEST_PG_CONFIG.user}:${TEST_PG_CONFIG.password}@${TEST_PG_CONFIG.host}:${TEST_PG_CONFIG.port}/${dbName}`; + break; + case 'mysql': + dbUrl = `mysql://${TEST_MYSQL_CONFIG.user}:${TEST_MYSQL_CONFIG.password}@${TEST_MYSQL_CONFIG.host}:${TEST_MYSQL_CONFIG.port}/${dbName}`; + break; + default: + throw new Error(`Unsupported provider: ${provider}`); + } + // Build fields array for proper alignment (matching ZModelCodeGenerator) + const fields: [string, string][] = [ + ['provider', `"${provider}"`], + ['url', `"${dbUrl}"`], + ...Object.entries(options?.extra || {}).map(([k, v]) => { + const value = Array.isArray(v) ? `[${v.map(item => `"${item}"`).join(', ')}]` : `"${v}"`; + return [k, value] as [string, string]; + }), + ]; + + // Calculate alignment padding based on longest field name + const longestName = Math.max(...fields.map(([name]) => name.length)); + const formattedFields = fields.map(([name, value]) => { + const padding = ' '.repeat(longestName - name.length + 1); + return ` ${name}${padding}= ${value}`; + }).join('\n'); + + const ZMODEL_PRELUDE = `datasource db {\n${formattedFields}\n}`; + return ZMODEL_PRELUDE; +} + +export function createProject( + zmodel: string, + options?: { customPrelude?: boolean; provider?: 'sqlite' | 'postgresql' | 'mysql' }, +) { const workDir = createTestProject(); fs.mkdirSync(path.join(workDir, 'zenstack'), { recursive: true }); const schemaPath = path.join(workDir, 'zenstack/schema.zmodel'); - fs.writeFileSync(schemaPath, addPrelude ? `${ZMODEL_PRELUDE}\n\n${zmodel}` : zmodel); + fs.writeFileSync(schemaPath, !options?.customPrelude ? `${getDefaultPrelude({ provider: options?.provider })}\n\n${zmodel}` : zmodel); return workDir; } +export async function createFormattedProject( + zmodel: string, + options?: { provider?: 'sqlite' | 'postgresql' | 'mysql', extra?: Record }, +) { + const fullContent = `${getDefaultPrelude({ provider: options?.provider, extra: options?.extra })}\n\n${zmodel}`; + const formatted = await formatDocument(fullContent); + return createProject(formatted, { customPrelude: true, provider: options?.provider }); +} + export function runCli(command: string, cwd: string) { const cli = path.join(__dirname, '../dist/index.js'); execSync(`node ${cli} ${command}`, { cwd }); diff --git a/packages/language/package.json b/packages/language/package.json index c5f60b107..ca3dc9a9e 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -49,6 +49,16 @@ "default": "./dist/utils.cjs" } }, + "./factory": { + "import": { + "types": "./dist/factory.d.ts", + "default": "./dist/factory.js" + }, + "require": { + "types": "./dist/factory.d.cts", + "default": "./dist/factory.cjs" + } + }, "./package.json": { "import": "./package.json", "require": "./package.json" diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index d0c3c0003..82cd78362 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -120,7 +120,7 @@ function dbgenerated(expr: String?): Any { /** * Checks if the field value contains the search string. By default, the search is case-sensitive, and * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if - * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's * behavior is. */ function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean { @@ -135,7 +135,7 @@ function contains(field: String, search: String, caseInSensitive: Boolean?): Boo /** * Checks the field value starts with the search string. By default, the search is case-sensitive, and * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if - * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's * behavior is. */ function startsWith(field: String, search: String, caseInSensitive: Boolean?): Boolean { @@ -144,7 +144,7 @@ function startsWith(field: String, search: String, caseInSensitive: Boolean?): B /** * Checks if the field value ends with the search string. By default, the search is case-sensitive, and * "LIKE" operator is used to match. If `caseInSensitive` is true, "ILIKE" operator is used if - * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's + * supported, otherwise it still falls back to "LIKE" and delivers whatever the database's * behavior is. */ function endsWith(field: String, search: String, caseInSensitive: Boolean?): Boolean { diff --git a/packages/language/src/document.ts b/packages/language/src/document.ts index 9642e61d5..026d3d23e 100644 --- a/packages/language/src/document.ts +++ b/packages/language/src/document.ts @@ -13,7 +13,7 @@ import path from 'node:path'; import { fileURLToPath } from 'node:url'; import { isDataModel, isDataSource, type Model } from './ast'; import { DB_PROVIDERS_SUPPORTING_LIST_TYPE, STD_LIB_MODULE_NAME } from './constants'; -import { createZModelServices } from './module'; +import { createZModelServices, type ZModelServices } from './module'; import { getAllFields, getDataModelAndTypeDefs, @@ -32,8 +32,10 @@ import type { ZModelFormatter } from './zmodel-formatter'; export async function loadDocument( fileName: string, additionalModelFiles: string[] = [], + keepImports: boolean = false, ): Promise< - { success: true; model: Model; warnings: string[] } | { success: false; errors: string[]; warnings: string[] } + | { success: true; model: Model; warnings: string[]; services: ZModelServices } + | { success: false; errors: string[]; warnings: string[] } > { const { ZModelLanguage: services } = createZModelServices(false); const extensions = services.LanguageMetaData.fileExtensions; @@ -121,14 +123,16 @@ export async function loadDocument( const model = document.parseResult.value as Model; - // merge all declarations into the main document - const imported = mergeImportsDeclarations(langiumDocuments, model); + if (keepImports === false) { + // merge all declarations into the main document + const imported = mergeImportsDeclarations(langiumDocuments, model); - // remove imported documents - imported.forEach((model) => { - langiumDocuments.deleteDocument(model.$document!.uri); - services.shared.workspace.IndexManager.remove(model.$document!.uri); - }); + // remove imported documents + imported.forEach((model) => { + langiumDocuments.deleteDocument(model.$document!.uri); + services.shared.workspace.IndexManager.remove(model.$document!.uri); + }); + } // extra validation after merging imported declarations const additionalErrors = validationAfterImportMerge(model); @@ -143,6 +147,7 @@ export async function loadDocument( return { success: true, model: document.parseResult.value as Model, + services, warnings, }; } diff --git a/packages/language/src/factory/ast-factory.ts b/packages/language/src/factory/ast-factory.ts new file mode 100644 index 000000000..e01dd7ced --- /dev/null +++ b/packages/language/src/factory/ast-factory.ts @@ -0,0 +1,56 @@ +import { type AstNode } from '../ast'; + +export type ContainerProps = { + $container: T; + $containerProperty?: string; + $containerIndex?: number; +}; + +type NodeFactoriesFor = { + [K in keyof N as {} extends Pick ? never : K]: N[K] extends (infer U)[] + ? (AstFactory | U)[] + : AstFactory | N[K]; +} & { + [K in keyof N as {} extends Pick ? K : never]?: N[K] extends (infer U)[] + ? (AstFactory | U)[] + : AstFactory | N[K]; +}; + +export abstract class AstFactory { + node = {} as T; + constructor({ type, node }: { type: T['$type']; node?: Partial }) { + (this.node as any).$type = type; + if (node) { + this.update(node); + } + } + setContainer(container: T['$container']) { + (this.node as any).$container = container; + return this; + } + + get(params?: ContainerProps): T { + if (params) this.update(params as any); + return this.node; + } + update(nodeArg: Partial>): T { + const keys = Object.keys(nodeArg as object); + keys.forEach((key) => { + const child = (nodeArg as any)[key]; + if (child instanceof AstFactory) { + (this.node as any)[key] = child.get({ $container: this.node as any }); + } else if (Array.isArray(child)) { + (this.node as any)[key] = child.map((item: any) => + item instanceof AstFactory ? item.get({ $container: this.node as any }) : item, + ); + } else { + (this.node as any)[key] = child; + } + }); + return this.node; + } + + resolveChilds(nodeArg: T | NodeFactoriesFor): T { + return this.update(nodeArg); + } +} diff --git a/packages/language/src/factory/attribute.ts b/packages/language/src/factory/attribute.ts new file mode 100644 index 000000000..138d41c8f --- /dev/null +++ b/packages/language/src/factory/attribute.ts @@ -0,0 +1,281 @@ +import { AstFactory } from '.'; +import { + Attribute, + AttributeArg, + AttributeParam, + AttributeParamType, + DataFieldAttribute, + DataModelAttribute, + Expression, + InternalAttribute, + TypeDeclaration, + type Reference, + type RegularID, +} from '../ast'; +import { ExpressionBuilder } from './expression'; + +export class DataFieldAttributeFactory extends AstFactory { + args: AttributeArgFactory[] = []; + decl?: Reference; + constructor() { + super({ type: DataFieldAttribute, node: { args: [] } }); + } + setDecl(decl: Attribute) { + if (!decl) { + throw new Error('Attribute declaration is required'); + } + this.decl = { + $refText: decl.name, + ref: decl, + }; + this.update({ + decl: this.decl, + }); + return this; + } + addArg(builder: (b: ExpressionBuilder) => AstFactory, name?: string) { + const factory = new AttributeArgFactory().setValue(builder); + if (name) { + factory.setName(name); + } + this.args.push(factory); + this.update({ + args: this.args, + }); + return this; + } +} + +export class DataModelAttributeFactory extends AstFactory { + args: AttributeArgFactory[] = []; + decl?: Reference; + constructor() { + super({ type: DataModelAttribute, node: { args: [] } }); + } + setDecl(decl: Attribute) { + if (!decl) { + throw new Error('Attribute declaration is required'); + } + this.decl = { + $refText: decl.name, + ref: decl, + }; + this.update({ + decl: this.decl, + }); + return this; + } + addArg(builder: (b: ExpressionBuilder) => AstFactory, name?: string) { + const factory = new AttributeArgFactory().setValue(builder); + if (name) { + factory.setName(name); + } + this.args.push(factory); + this.update({ + args: this.args, + }); + return this; + } +} + +export class AttributeArgFactory extends AstFactory { + name?: RegularID = ''; + value?: AstFactory; + + constructor() { + super({ type: AttributeArg }); + } + + setName(name: RegularID) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setValue(builder: (b: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value, + }); + return this; + } +} + +export class InternalAttributeFactory extends AstFactory { + decl?: Reference; + args: AttributeArgFactory[] = []; + + constructor() { + super({ type: InternalAttribute, node: { args: [] } }); + } + + setDecl(decl: Attribute) { + this.decl = { + $refText: decl.name, + ref: decl, + }; + this.update({ + decl: this.decl, + }); + return this; + } + + addArg(builder: (b: ExpressionBuilder) => AstFactory, name?: string) { + const factory = new AttributeArgFactory().setValue(builder); + if (name) { + factory.setName(name); + } + this.args.push(factory); + this.update({ + args: this.args, + }); + return this; + } +} + +export class AttributeParamFactory extends AstFactory { + attributes: InternalAttributeFactory[] = []; + comments: string[] = []; + default?: boolean; + name?: RegularID; + type?: AttributeParamTypeFactory; + + constructor() { + super({ + type: AttributeParam, + node: { + comments: [], + attributes: [], + }, + }); + } + + addAttribute(builder: (b: InternalAttributeFactory) => InternalAttributeFactory) { + this.attributes.push(builder(new InternalAttributeFactory())); + this.update({ + attributes: this.attributes, + }); + return this; + } + + setComments(comments: string[]) { + this.comments = comments; + this.update({ + comments: this.comments, + }); + return this; + } + + setDefault(defaultValue: boolean) { + this.default = defaultValue; + this.update({ + default: this.default, + }); + return this; + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setType(builder: (b: AttributeParamTypeFactory) => AttributeParamTypeFactory) { + this.type = builder(new AttributeParamTypeFactory()); + this.update({ + type: this.type, + }); + return this; + } +} + +export class AttributeParamTypeFactory extends AstFactory { + array?: boolean; + optional?: boolean; + reference?: Reference; + type?: AttributeParamType['type']; + constructor() { + super({ type: AttributeParamType }); + } + setArray(array: boolean) { + this.array = array; + this.update({ + array: this.array, + }); + return this; + } + + setOptional(optional: boolean) { + this.optional = optional; + this.update({ + optional: this.optional, + }); + return this; + } + + setReference(reference: TypeDeclaration) { + this.reference = { + $refText: reference.name, + ref: reference, + }; + this.update({ + reference: this.reference, + }); + return this; + } + + setType(type: AttributeParamType['type']) { + this.type = type; + this.update({ + type: this.type, + }); + return this; + } +} + +export class AttributeFactory extends AstFactory { + name?: string; + comments: string[] = []; + attributes: InternalAttributeFactory[] = []; + params: AttributeParamFactory[] = []; + + constructor() { + super({ type: Attribute, node: { comments: [], attributes: [], params: [] } }); + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setComments(comments: string[]) { + this.comments = comments; + this.update({ + comments: this.comments, + }); + return this; + } + + addAttribute(builder: (b: InternalAttributeFactory) => InternalAttributeFactory) { + this.attributes.push(builder(new InternalAttributeFactory())); + this.update({ + attributes: this.attributes, + }); + return this; + } + + addParam(builder: (b: AttributeParamFactory) => AttributeParamFactory) { + this.params.push(builder(new AttributeParamFactory())); + this.update({ + params: this.params, + }); + return this; + } +} diff --git a/packages/language/src/factory/declaration.ts b/packages/language/src/factory/declaration.ts new file mode 100644 index 000000000..1f514982b --- /dev/null +++ b/packages/language/src/factory/declaration.ts @@ -0,0 +1,363 @@ +import { AstFactory } from '.'; +import { AbstractDeclaration, type Reference } from '../ast'; +import { + type BuiltinType, + DataField, + DataFieldType, + DataModel, + Enum, + EnumField, + LiteralExpr, + Model, + ModelImport, + type RegularID, + type RegularIDWithTypeNames, + TypeDeclaration, + type TypeDef, + UnsupportedFieldType, +} from '../generated/ast'; +import { AttributeFactory, DataFieldAttributeFactory, DataModelAttributeFactory } from './attribute'; +import { ExpressionBuilder } from './expression'; +export const DeclarationBuilder = () => + ({ + get Attribute() { + return new AttributeFactory(); + }, + get DataModel() { + return new DataModelFactory(); + }, + get DataSource(): any { + throw new Error('DataSource is not implemented'); + }, + get Enum() { + return new EnumFactory(); + }, + get FunctionDecl(): any { + throw new Error('FunctionDecl is not implemented'); + }, + get GeneratorDecl(): any { + throw new Error('GeneratorDecl is not implemented'); + }, + get Plugin(): any { + throw new Error('Plugin is not implemented'); + }, + get Procedure(): any { + throw new Error('Procedure is not implemented'); + }, + get TypeDef(): any { + throw new Error('TypeDef is not implemented'); + }, + }) satisfies DeclarationBuilderType; +type DeclarationBuilderType = { + [K in T['$type']]: AstFactory>; +}; +type DeclarationBuilderMap = ReturnType; + +export type DeclarationBuilder = Pick< + DeclarationBuilderMap, + Extract +>; + +export class DataModelFactory extends AstFactory { + attributes: DataModelAttributeFactory[] = []; + baseModel?: Reference; + comments: string[] = []; + fields: DataFieldFactory[] = []; + isView?: boolean; + mixins: Reference[] = []; + name?: RegularID; + + constructor() { + super({ + type: DataModel, + node: { + attributes: [], + comments: [], + fields: [], + mixins: [], + }, + }); + } + + addAttribute(builder: (attr: DataModelAttributeFactory) => DataModelAttributeFactory) { + this.attributes.push(builder(new DataModelAttributeFactory())); + this.update({ + attributes: this.attributes, + }); + return this; + } + + setBaseModel(model: Reference) { + this.baseModel = model; + this.update({ + baseModel: this.baseModel, + }); + return this; + } + + setComments(comments: string[]) { + this.comments = comments; + this.update({ + comments: this.comments, + }); + return this; + } + + addField(builder: (field: DataFieldFactory) => DataFieldFactory) { + this.fields.push(builder(new DataFieldFactory())); + this.update({ + fields: this.fields, + }); + return this; + } + + setIsView(isView: boolean) { + this.isView = isView; + this.update({ + isView: this.isView, + }); + return this; + } + + addMixin(mixin: Reference) { + this.mixins.push(mixin); + this.update({ + mixins: this.mixins, + }); + return this; + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } +} + +export class DataFieldFactory extends AstFactory { + attributes: DataFieldAttributeFactory[] = []; + comments: string[] = []; + name?: string; + type?: DataFieldTypeFactory; + + constructor() { + super({ type: DataField, node: { attributes: [], comments: [] } }); + } + + addAttribute( + builder: ((attr: DataFieldAttributeFactory) => DataFieldAttributeFactory) | DataFieldAttributeFactory, + ) { + if (builder instanceof DataFieldAttributeFactory) { + builder.setContainer(this.node); + this.attributes.push(builder); + } else { + this.attributes.push(builder(new DataFieldAttributeFactory())); + } + this.update({ + attributes: this.attributes, + }); + return this; + } + + setComments(comments: string[]) { + this.comments = comments; + this.update({ + comments: this.comments, + }); + return this; + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setType(builder: (type: DataFieldTypeFactory) => DataFieldTypeFactory) { + this.type = builder(new DataFieldTypeFactory()); + this.update({ + type: this.type, + }); + return this; + } +} + +export class DataFieldTypeFactory extends AstFactory { + array?: boolean; + optional?: boolean; + reference?: Reference; + type?: BuiltinType; + unsupported?: UnsupportedFieldTypeFactory; + + constructor() { + super({ type: DataFieldType }); + } + + setArray(array: boolean) { + this.array = array; + this.update({ + array: this.array, + }); + return this; + } + + setOptional(optional: boolean) { + this.optional = optional; + this.update({ + optional: this.optional, + }); + return this; + } + + setReference(reference: TypeDeclaration) { + this.reference = { + $refText: reference.name, + ref: reference, + }; + this.update({ + reference: this.reference, + }); + return this; + } + + setType(type: BuiltinType) { + this.type = type; + this.update({ + type: this.type, + }); + return this; + } + + setUnsupported(builder: (a: UnsupportedFieldTypeFactory) => UnsupportedFieldTypeFactory) { + this.unsupported = builder(new UnsupportedFieldTypeFactory()); + this.update({ + unsupported: this.unsupported, + }); + return this; + } +} + +export class UnsupportedFieldTypeFactory extends AstFactory { + value?: AstFactory; + constructor() { + super({ type: UnsupportedFieldType }); + } + setValue(builder: (value: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value!, + }); + return this; + } +} + +export class ModelFactory extends AstFactory { + declarations: AstFactory[] = []; + imports: ModelImportFactory[] = []; + constructor() { + super({ type: Model, node: { declarations: [], imports: [] } }); + } + addImport(builder: (b: ModelImportFactory) => ModelImportFactory) { + this.imports.push(builder(new ModelImportFactory())); + this.update({ + imports: this.imports, + }); + return this; + } + addDeclaration(builder: (b: DeclarationBuilder) => AstFactory) { + this.declarations.push(builder(DeclarationBuilder())); + this.update({ + declarations: this.declarations, + }); + return this; + } +} + +export class ModelImportFactory extends AstFactory { + path?: string | undefined; + + constructor() { + super({ type: ModelImport }); + } + + setPath(path: string) { + this.path = path; + this.update({ + path: this.path, + }); + return this; + } +} + +export class EnumFactory extends AstFactory { + name?: string; + comments: string[] = []; + fields: EnumFieldFactory[] = []; + attributes: DataModelAttributeFactory[] = []; + + constructor() { + super({ type: Enum, node: { comments: [], fields: [], attributes: [] } }); + } + + addField(builder: (b: EnumFieldFactory) => EnumFieldFactory) { + this.fields.push(builder(new EnumFieldFactory())); + this.update({ + fields: this.fields, + }); + return this; + } + + addAttribute(builder: (b: DataModelAttributeFactory) => DataModelAttributeFactory) { + this.attributes.push(builder(new DataModelAttributeFactory())); + this.update({ + attributes: this.attributes, + }); + return this; + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } +} + +export class EnumFieldFactory extends AstFactory { + name?: RegularIDWithTypeNames; + comments: string[] = []; + attributes: DataFieldAttributeFactory[] = []; + + constructor() { + super({ type: EnumField, node: { comments: [], attributes: [] } }); + } + + setName(name: RegularIDWithTypeNames) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + addAttribute(builder: (b: DataFieldAttributeFactory) => DataFieldAttributeFactory) { + this.attributes.push(builder(new DataFieldAttributeFactory())); + this.update({ + attributes: this.attributes, + }); + return this; + } + + addComment(comment: string) { + this.comments.push(comment); + this.update({ + comments: this.comments, + }); + return this; + } +} diff --git a/packages/language/src/factory/expression.ts b/packages/language/src/factory/expression.ts new file mode 100644 index 000000000..a0ba84001 --- /dev/null +++ b/packages/language/src/factory/expression.ts @@ -0,0 +1,307 @@ +import type { Reference } from 'langium'; +import { AstFactory } from '.'; +import { + Argument, + ArrayExpr, + BinaryExpr, + FieldInitializer, + FunctionDecl, + InvocationExpr, + MemberAccessExpr, + MemberAccessTarget, + ObjectExpr, + ReferenceArg, + ReferenceExpr, + ReferenceTarget, + UnaryExpr, + type Expression, + type RegularID, +} from '../ast'; +import { + BooleanLiteralFactory, + NullExprFactory, + NumberLiteralFactory, + StringLiteralFactory, + ThisExprFactory, +} from './primitives'; + +export const ExpressionBuilder = () => + ({ + get ArrayExpr() { + return new ArrayExprFactory(); + }, + get BinaryExpr() { + return new BinaryExprFactory(); + }, + get BooleanLiteral() { + return new BooleanLiteralFactory(); + }, + get InvocationExpr() { + return new InvocationExprFactory(); + }, + get MemberAccessExpr() { + return new MemberAccessExprFactory(); + }, + get NullExpr() { + return new NullExprFactory(); + }, + get NumberLiteral() { + return new NumberLiteralFactory(); + }, + get ObjectExpr() { + return new ObjectExprFactory(); + }, + get ReferenceExpr() { + return new ReferenceExprFactory(); + }, + get StringLiteral() { + return new StringLiteralFactory(); + }, + get ThisExpr() { + return new ThisExprFactory(); + }, + get UnaryExpr() { + return new UnaryExprFactory(); + }, + }) satisfies ExpressionBuilderType; +type ExpressionBuilderType = { + [K in T['$type']]: AstFactory>; +}; + +type ExpressionFactoryMap = ReturnType; + +export type ExpressionBuilder = Pick< + ExpressionFactoryMap, + Extract +>; + +export class UnaryExprFactory extends AstFactory { + operand?: AstFactory; + + constructor() { + super({ type: UnaryExpr, node: { operator: '!' } }); + } + + setOperand(builder: (a: ExpressionBuilder) => AstFactory) { + this.operand = builder(ExpressionBuilder()); + this.update({ + operand: this.operand, + }); + return this; + } +} + +export class ReferenceExprFactory extends AstFactory { + target?: Reference; + args: ReferenceArgFactory[] = []; + + constructor() { + super({ type: ReferenceExpr, node: { args: [] } }); + } + + setTarget(target: ReferenceTarget) { + this.target = { + $refText: target.name, + ref: target, + }; + this.update({ + target: this.target, + }); + return this; + } + + addArg(builder: (a: ExpressionBuilder) => AstFactory, name?: string) { + const arg = new ReferenceArgFactory().setValue(builder); + if (name) { + arg.setName(name); + } + this.args.push(arg); + this.update({ + args: this.args, + }); + return this; + } +} + +export class ReferenceArgFactory extends AstFactory { + name?: string; + value?: AstFactory; + + constructor() { + super({ type: ReferenceArg }); + } + + setName(name: string) { + this.name = name; + this.update({ + name: this.name, + }); + return this; + } + + setValue(builder: (a: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value, + }); + return this; + } +} + +export class MemberAccessExprFactory extends AstFactory { + member?: Reference; + operand?: AstFactory; + + constructor() { + super({ type: MemberAccessExpr }); + } + + setMember(target: Reference) { + this.member = target; + this.update({ + member: this.member, + }); + return this; + } + + setOperand(builder: (b: ExpressionBuilder) => AstFactory) { + this.operand = builder(ExpressionBuilder()); + this.update({ + operand: this.operand, + }); + return this; + } +} + +export class ObjectExprFactory extends AstFactory { + fields: FieldInitializerFactory[] = []; + + constructor() { + super({ type: ObjectExpr, node: { fields: [] } }); + } + + addField(builder: (b: FieldInitializerFactory) => FieldInitializerFactory) { + this.fields.push(builder(new FieldInitializerFactory())); + this.update({ + fields: this.fields, + }); + return this; + } +} + +export class FieldInitializerFactory extends AstFactory { + name?: RegularID; + value?: AstFactory; + + constructor() { + super({ type: FieldInitializer }); + } + + setName(name: RegularID) { + this.name = name; + this.update({ + name: this.name!, + }); + return this; + } + + setValue(builder: (a: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value!, + }); + return this; + } +} + +export class InvocationExprFactory extends AstFactory { + args: ArgumentFactory[] = []; + function?: Reference; + + constructor() { + super({ type: InvocationExpr, node: { args: [] } }); + } + + addArg(builder: (arg: ArgumentFactory) => ArgumentFactory) { + this.args.push(builder(new ArgumentFactory())); + this.update({ + args: this.args, + }); + return this; + } + + setFunction(value: FunctionDecl) { + this.function = { + $refText: value.name, + ref: value, + }; + this.update({ + function: this.function!, + }); + return this; + } +} + +export class ArgumentFactory extends AstFactory { + value?: AstFactory; + + constructor() { + super({ type: Argument }); + } + + setValue(builder: (a: ExpressionBuilder) => AstFactory) { + this.value = builder(ExpressionBuilder()); + this.update({ + value: this.value!, + }); + return this; + } +} + +export class ArrayExprFactory extends AstFactory { + items: AstFactory[] = []; + + constructor() { + super({ type: ArrayExpr, node: { items: [] } }); + } + + addItem(builder: (a: ExpressionBuilder) => AstFactory) { + this.items.push(builder(ExpressionBuilder())); + this.update({ + items: this.items, + }); + return this; + } +} + +export class BinaryExprFactory extends AstFactory { + operator?: BinaryExpr['operator']; + right?: AstFactory; + left?: AstFactory; + + constructor() { + super({ type: BinaryExpr }); + } + + setOperator(operator: BinaryExpr['operator']) { + this.operator = operator; + this.update({ + operator: this.operator!, + }); + return this; + } + setRight(builder: (arg: ExpressionBuilder) => AstFactory) { + this.right = builder(ExpressionBuilder()); + this.update({ + right: this.right!, + }); + return this; + } + setLeft(builder: (arg: ExpressionBuilder) => AstFactory) { + this.left = builder(ExpressionBuilder()); + this.update({ + left: this.left!, + }); + return this; + } +} diff --git a/packages/language/src/factory/index.ts b/packages/language/src/factory/index.ts new file mode 100644 index 000000000..1ea2a286b --- /dev/null +++ b/packages/language/src/factory/index.ts @@ -0,0 +1,5 @@ +export * from './ast-factory'; +export * from './primitives'; +export * from './expression'; +export * from './declaration'; +export * from './attribute'; diff --git a/packages/language/src/factory/primitives.ts b/packages/language/src/factory/primitives.ts new file mode 100644 index 000000000..1db7e0515 --- /dev/null +++ b/packages/language/src/factory/primitives.ts @@ -0,0 +1,61 @@ +import { AstFactory } from '.'; +import { BooleanLiteral, NullExpr, NumberLiteral, StringLiteral, ThisExpr } from '../ast'; + +export class ThisExprFactory extends AstFactory { + constructor() { + super({ type: ThisExpr, node: { value: 'this' } }); + } +} + +export class NullExprFactory extends AstFactory { + constructor() { + super({ type: NullExpr, node: { value: 'null' } }); + } +} + +export class NumberLiteralFactory extends AstFactory { + value?: number | string; + + constructor() { + super({ type: NumberLiteral }); + } + + setValue(value: number | string) { + this.value = value; + this.update({ + value: this.value.toString(), + }); + return this; + } +} + +export class StringLiteralFactory extends AstFactory { + value?: string; + + constructor() { + super({ type: StringLiteral }); + } + + setValue(value: string) { + this.value = value; + this.update({ + value: this.value, + }); + return this; + } +} +export class BooleanLiteralFactory extends AstFactory { + value?: boolean; + + constructor() { + super({ type: BooleanLiteral }); + } + + setValue(value: boolean) { + this.value = value; + this.update({ + value: this.value, + }); + return this; + } +} diff --git a/packages/language/src/validators/datamodel-validator.ts b/packages/language/src/validators/datamodel-validator.ts index 6c5d18ffd..d2fcd155d 100644 --- a/packages/language/src/validators/datamodel-validator.ts +++ b/packages/language/src/validators/datamodel-validator.ts @@ -44,13 +44,15 @@ export default class DataModelValidator implements AstValidator { const uniqueFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@unique')); const modelLevelIds = getModelIdFields(dm); const modelUniqueFields = getModelUniqueFields(dm); + const ignore = hasAttribute(dm, '@@ignore'); if ( !dm.isView && idFields.length === 0 && modelLevelIds.length === 0 && uniqueFields.length === 0 && - modelUniqueFields.length === 0 + modelUniqueFields.length === 0 && + !ignore ) { accept( 'error', diff --git a/packages/language/src/zmodel-code-generator.ts b/packages/language/src/zmodel-code-generator.ts index 1e0366ede..c6059ebe6 100644 --- a/packages/language/src/zmodel-code-generator.ts +++ b/packages/language/src/zmodel-code-generator.ts @@ -28,6 +28,7 @@ import { LiteralExpr, MemberAccessExpr, Model, + ModelImport, NullExpr, NumberLiteral, ObjectExpr, @@ -70,7 +71,7 @@ function gen(name: string) { */ export class ZModelCodeGenerator { private readonly options: ZModelCodeOptions; - + private readonly quote: string; constructor(options?: Partial) { this.options = { binaryExprNumberOfSpaces: options?.binaryExprNumberOfSpaces ?? 1, @@ -78,6 +79,7 @@ export class ZModelCodeGenerator { indent: options?.indent ?? 4, quote: options?.quote ?? 'single', }; + this.quote = this.options.quote === 'double' ? '"' : "'"; } /** @@ -91,9 +93,17 @@ export class ZModelCodeGenerator { return handler.value.call(this, ast); } + private quotedStr(val: string): string { + const trimmedVal = val.replace(new RegExp(`${this.quote}`, 'g'), `\\${this.quote}`); + return `${this.quote}${trimmedVal}${this.quote}`; + } + @gen(Model) private _generateModel(ast: Model) { - return ast.declarations.map((d) => this.generate(d)).join('\n\n'); + return `${ast.imports.map((d) => this.generate(d)).join('\n')}${ast.imports.length > 0 ? '\n\n' : ''}${ast.declarations + .sort((d) => (d.$type === 'Enum' ? 1 : 0)) + .map((d) => this.generate(d)) + .join('\n\n')}`; } @gen(DataSource) @@ -103,10 +113,19 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')} }`; } + @gen(ModelImport) + private _generateModelImport(ast: ModelImport) { + return `import ${this.quotedStr(ast.path)}`; + } + @gen(Enum) private _generateEnum(ast: Enum) { return `enum ${ast.name} { -${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')} +${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ + ast.attributes.length > 0 + ? '\n\n' + ast.attributes.map((x) => this.indent + this.generate(x)).join('\n') + : '' + } }`; } @@ -126,7 +145,9 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')} @gen(ConfigField) private _generateConfigField(ast: ConfigField) { - return `${ast.name} = ${this.generate(ast.value)}`; + const longestName = Math.max(...ast.$container.fields.map((x) => x.name.length)); + const padding = ' '.repeat(longestName - ast.name.length + 1); + return `${ast.name}${padding}= ${this.generate(ast.value)}`; } @gen(ConfigArrayExpr) @@ -154,15 +175,24 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')} @gen(PluginField) private _generatePluginField(ast: PluginField) { - return `${ast.name} = ${this.generate(ast.value)}`; + const longestName = Math.max(...ast.$container.fields.map((x) => x.name.length)); + const padding = ' '.repeat(longestName - ast.name.length + 1); + return `${ast.name}${padding}= ${this.generate(ast.value)}`; } @gen(DataModel) private _generateDataModel(ast: DataModel) { - return `${ast.isView ? 'view' : 'model'} ${ast.name}${ + const comments = `${ast.comments.join('\n')}\n`; + + return `${ast.comments.length > 0 ? comments : ''}${ast.isView ? 'view' : 'model'} ${ast.name}${ ast.mixins.length > 0 ? ' mixes ' + ast.mixins.map((x) => x.$refText).join(', ') : '' } { -${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ +${ast.fields + .map((x) => { + const comments = x.comments.map((c) => `${this.indent}${c}`).join('\n'); + return (x.comments.length ? `${comments}\n` : '') + this.indent + this.generate(x); + }) + .join('\n')}${ ast.attributes.length > 0 ? '\n\n' + ast.attributes.map((x) => this.indent + this.generate(x)).join('\n') : '' @@ -172,7 +202,11 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ @gen(DataField) private _generateDataField(ast: DataField) { - return `${ast.name} ${this.fieldType(ast.type)}${ + const longestFieldName = Math.max(...ast.$container.fields.map((f) => f.name.length)); + const longestType = Math.max(...ast.$container.fields.map((f) => this.fieldType(f.type).length)); + const paddingLeft = longestFieldName - ast.name.length; + const paddingRight = ast.attributes.length > 0 ? longestType - this.fieldType(ast.type).length : 0; + return `${ast.name}${' '.repeat(paddingLeft)} ${this.fieldType(ast.type)}${' '.repeat(paddingRight)}${ ast.attributes.length > 0 ? ' ' + ast.attributes.map((x) => this.generate(x)).join(' ') : '' }`; } @@ -226,7 +260,7 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ @gen(StringLiteral) private _generateLiteralExpr(ast: LiteralExpr) { - return this.options.quote === 'single' ? `'${ast.value}'` : `"${ast.value}"`; + return this.quotedStr(ast.value as string); } @gen(NumberLiteral) @@ -271,7 +305,7 @@ ${ast.fields.map((x) => this.indent + this.generate(x)).join('\n')}${ @gen(ReferenceArg) private _generateReferenceArg(ast: ReferenceArg) { - return `${ast.name}:${this.generate(ast.value)}`; + return `${ast.name}: ${this.generate(ast.value)}`; } @gen(MemberAccessExpr) diff --git a/packages/language/tsup.config.ts b/packages/language/tsup.config.ts index 0d5d2b6c4..48282a08c 100644 --- a/packages/language/tsup.config.ts +++ b/packages/language/tsup.config.ts @@ -5,6 +5,7 @@ export default defineConfig({ index: 'src/index.ts', ast: 'src/ast.ts', utils: 'src/utils.ts', + factory: 'src/factory/index.ts', }, outDir: 'dist', splitting: false, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3f519e5c3..3412d6554 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -192,6 +192,9 @@ importers: packages/cli: dependencies: + '@dotenvx/dotenvx': + specifier: ^1.51.0 + version: 1.52.0 '@zenstackhq/common-helpers': specifier: workspace:* version: link:../common-helpers @@ -201,6 +204,9 @@ importers: '@zenstackhq/orm': specifier: workspace:* version: link:../orm + '@zenstackhq/schema': + specifier: workspace:* + version: link:../schema '@zenstackhq/sdk': specifier: workspace:* version: link:../sdk @@ -1558,12 +1564,22 @@ packages: resolution: {integrity: sha512-Vd/9EVDiu6PPJt9yAh6roZP6El1xHrdvIVGjyBsHR0RYwNHgL7FJPyIIW4fANJNG6FtyZfvlRPpFI4ZM/lubvw==} engines: {node: '>=18'} + '@dotenvx/dotenvx@1.52.0': + resolution: {integrity: sha512-CaQcc8JvtzQhUSm9877b6V4Tb7HCotkcyud9X2YwdqtQKwgljkMRwU96fVYKnzN3V0Hj74oP7Es+vZ0mS+Aa1w==} + hasBin: true + '@dxup/nuxt@0.2.2': resolution: {integrity: sha512-RNpJjDZs9+JcT9N87AnOuHsNM75DEd58itADNd/s1LIF6BZbTLZV0xxilJZb55lntn4TYvscTaXLCBX2fq9CXg==} '@dxup/unimport@0.1.2': resolution: {integrity: sha512-/B8YJGPzaYq1NbsQmwgP8EZqg40NpTw4ZB3suuI0TplbxKHeK94jeaawLmVhCv+YwUnOpiWEz9U6SeThku/8JQ==} + '@ecies/ciphers@0.2.5': + resolution: {integrity: sha512-GalEZH4JgOMHYYcYmVqnFirFsjZHeoGMDt9IxEnM9F7GRUUyUksJ7Ou53L83WHJq3RWKD3AcBpo0iQh0oMpf8A==} + engines: {bun: '>=1', deno: '>=2', node: '>=16'} + peerDependencies: + '@noble/ciphers': ^1.0.0 + '@edge-runtime/primitives@6.0.0': resolution: {integrity: sha512-FqoxaBT+prPBHBwE1WXS1ocnu/VLTQyZ6NMUBAdbP7N2hsFTTxMC/jMu2D/8GAlMQfxeuppcPuCUk/HO3fpIvA==} engines: {node: '>=18'} @@ -2383,14 +2399,26 @@ packages: cpu: [x64] os: [win32] + '@noble/ciphers@1.3.0': + resolution: {integrity: sha512-2I0gnIVPtfnMw9ee9h1dJG7tp81+8Ob3OJb3Mv37rx5L40/b0i7djjCVvGOVqc9AEIQyvyu1i6ypKdFw8R8gQw==} + engines: {node: ^14.21.3 || >=16} + '@noble/ciphers@2.0.1': resolution: {integrity: sha512-xHK3XHPUW8DTAobU+G0XT+/w+JLM7/8k1UFdB5xg/zTFPnFCobhftzw8wl4Lw2aq/Rvir5pxfZV5fEazmeCJ2g==} engines: {node: '>= 20.19.0'} + '@noble/curves@1.9.7': + resolution: {integrity: sha512-gbKGcRUYIjA3/zCCNaWDciTMFI0dCkvou3TL8Zmy5Nc7sJ47a0jtOeZoTaMxkuqRo9cRhjOdZJXegxYE5FN/xw==} + engines: {node: ^14.21.3 || >=16} + '@noble/hashes@1.7.1': resolution: {integrity: sha512-B8XBPsn4vT/KJAGqDzbwztd+6Yte3P4V7iafm24bxgDe/mlRuK6xmWPuCNrKt2vDafZ8MfJLlchDG/vYafQEjQ==} engines: {node: ^14.21.3 || >=16} + '@noble/hashes@1.8.0': + resolution: {integrity: sha512-jCs9ldd7NwzpgXDIf6P3+NrHh9/sD6CQdxHyjQI+h/6rDNo88ypBxxz45UDuZHz9r3tNz7N/VInSVoVdtXEI4A==} + engines: {node: ^14.21.3 || >=16} + '@noble/hashes@2.0.1': resolution: {integrity: sha512-XlOlEbQcE9fmuXxrVTXCTlG2nlRXa9Rj3rr5Ue/+tX+nmkgbX720YHh0VR3hBF9xDvwnb8D2shVGOwNx+ulArw==} engines: {node: '>= 20.19.0'} @@ -5087,6 +5115,10 @@ packages: eastasianwidth@0.2.0: resolution: {integrity: sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==} + eciesjs@0.4.17: + resolution: {integrity: sha512-TOOURki4G7sD1wDCjj7NfLaXZZ49dFOeEb5y39IXpb8p0hRzVvfvzZHOi5JcT+PpyAbi/Y+lxPb8eTag2WYH8w==} + engines: {bun: '>=1', deno: '>=2', node: '>=16'} + ee-first@1.1.1: resolution: {integrity: sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==} @@ -5376,6 +5408,10 @@ packages: '@sinclair/typebox': optional: true + execa@5.1.1: + resolution: {integrity: sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==} + engines: {node: '>=10'} + execa@8.0.1: resolution: {integrity: sha512-VyhnebXciFV2DESc+p6B+y0LjSm0krU4OgJN44qFAhBY0TJ+1V61tYD2+wHusZ6F9n5K+vl8k0sTy7PEfV4qpg==} engines: {node: '>=16.17'} @@ -5593,6 +5629,10 @@ packages: resolution: {integrity: sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==} engines: {node: '>= 0.4'} + get-stream@6.0.1: + resolution: {integrity: sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==} + engines: {node: '>=10'} + get-stream@8.0.1: resolution: {integrity: sha512-VaUJspBffn/LMCJVoMvSAdmscJyS1auj5Zulnn5UoYcY531UWmdwhRWkcGKnGU93m5HSXP9LP2usOryrBtQowA==} engines: {node: '>=16'} @@ -5756,6 +5796,10 @@ packages: httpxy@0.1.7: resolution: {integrity: sha512-pXNx8gnANKAndgga5ahefxc++tJvNL87CXoRwxn1cJE2ZkWEojF3tNfQIEhZX/vfpt+wzeAzpUI4qkediX1MLQ==} + human-signals@2.1.0: + resolution: {integrity: sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==} + engines: {node: '>=10.17.0'} + human-signals@5.0.0: resolution: {integrity: sha512-AXcZb6vzzrFAUE61HnN4mpLqd/cSIwNQjtNWR0euPm6y0iqx3G4gOXaIDdtdDwZmhwe82LA6+zinmW4UBWVePQ==} engines: {node: '>=16.17.0'} @@ -6677,6 +6721,10 @@ packages: engines: {node: '>= 4'} hasBin: true + npm-run-path@4.0.1: + resolution: {integrity: sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==} + engines: {node: '>=8'} + npm-run-path@5.3.0: resolution: {integrity: sha512-ppwTtiJZq0O/ai0z7yfudtBpWIoxM8yE6nHi1X47eFR2EWORqfbu6CnPlNsjeN683eT0qG6H/Pyf9fCcvjnnnQ==} engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} @@ -6718,6 +6766,10 @@ packages: resolution: {integrity: sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA==} engines: {node: '>= 0.4'} + object-treeify@1.1.33: + resolution: {integrity: sha512-EFVjAYfzWqWsBMRHPMAXLCDIJnpMhdWAqR7xG6M6a2cs6PMFpl/+Z20w9zDW4vkxOFfddegBKq9Rehd0bxWE7A==} + engines: {node: '>= 10'} + object.assign@4.1.7: resolution: {integrity: sha512-nK28WOo+QIjBkDduTINE4JkF/UJJKyf2EJxvJKfblDpyg0Q+pkOHNTL0Qwy6NP6FhE/EnzV73BxxqcJaXY9anw==} engines: {node: '>= 0.4'} @@ -7874,6 +7926,10 @@ packages: resolution: {integrity: sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==} engines: {node: '>=4'} + strip-final-newline@2.0.0: + resolution: {integrity: sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==} + engines: {node: '>=6'} + strip-final-newline@3.0.0: resolution: {integrity: sha512-dOESqjYr96iWYylGObzd39EuNTa5VJxyvVAEm5Jnh7KGo75V43Hk1odPQkNDyXNmUR6k+gEiDVXnjB8HJ3crXw==} engines: {node: '>=12'} @@ -8713,6 +8769,11 @@ packages: engines: {node: '>= 8'} hasBin: true + which@4.0.0: + resolution: {integrity: sha512-GlaYyEb07DPxYCKhKzplCWBJtvxZcZMrL+4UkrTSJHHPyZU4mYYTv3qaOe77H7EODLSSopAUFAc6W8U4yqvscg==} + engines: {node: ^16.13.0 || >=18.0.0} + hasBin: true + which@5.0.0: resolution: {integrity: sha512-JEdGzHwwkrbWoGOlIHqQ5gtprKGOenpDHpxE9zVR1bWbOtYRyPPHMe9FaP6x61CmNaTThSkb0DAJte5jD+DmzQ==} engines: {node: ^18.17.0 || >=20.5.0} @@ -9340,6 +9401,18 @@ snapshots: '@csstools/css-tokenizer@3.0.4': optional: true + '@dotenvx/dotenvx@1.52.0': + dependencies: + commander: 11.1.0 + dotenv: 17.2.3 + eciesjs: 0.4.17 + execa: 5.1.1 + fdir: 6.5.0(picomatch@4.0.3) + ignore: 5.3.2 + object-treeify: 1.1.33 + picomatch: 4.0.3 + which: 4.0.0 + '@dxup/nuxt@0.2.2(magicast@0.5.1)': dependencies: '@dxup/unimport': 0.1.2 @@ -9352,6 +9425,10 @@ snapshots: '@dxup/unimport@0.1.2': {} + '@ecies/ciphers@0.2.5(@noble/ciphers@1.3.0)': + dependencies: + '@noble/ciphers': 1.3.0 + '@edge-runtime/primitives@6.0.0': {} '@edge-runtime/vm@5.0.0': @@ -9923,10 +10000,18 @@ snapshots: '@next/swc-win32-x64-msvc@16.0.10': optional: true + '@noble/ciphers@1.3.0': {} + '@noble/ciphers@2.0.1': {} + '@noble/curves@1.9.7': + dependencies: + '@noble/hashes': 1.8.0 + '@noble/hashes@1.7.1': {} + '@noble/hashes@1.8.0': {} + '@noble/hashes@2.0.1': {} '@nodelib/fs.scandir@2.1.5': @@ -12720,6 +12805,13 @@ snapshots: eastasianwidth@0.2.0: {} + eciesjs@0.4.17: + dependencies: + '@ecies/ciphers': 0.2.5(@noble/ciphers@1.3.0) + '@noble/ciphers': 1.3.0 + '@noble/curves': 1.9.7 + '@noble/hashes': 1.8.0 + ee-first@1.1.1: {} effect@3.18.4: @@ -13206,6 +13298,18 @@ snapshots: optionalDependencies: '@sinclair/typebox': 0.34.41 + execa@5.1.1: + dependencies: + cross-spawn: 7.0.6 + get-stream: 6.0.1 + human-signals: 2.1.0 + is-stream: 2.0.1 + merge-stream: 2.0.0 + npm-run-path: 4.0.1 + onetime: 5.1.2 + signal-exit: 3.0.7 + strip-final-newline: 2.0.0 + execa@8.0.1: dependencies: cross-spawn: 7.0.6 @@ -13499,6 +13603,8 @@ snapshots: dunder-proto: 1.0.1 es-object-atoms: 1.1.1 + get-stream@6.0.1: {} + get-stream@8.0.1: {} get-stream@9.0.1: @@ -13689,6 +13795,8 @@ snapshots: httpxy@0.1.7: {} + human-signals@2.1.0: {} + human-signals@5.0.0: {} human-signals@8.0.1: {} @@ -14653,6 +14761,10 @@ snapshots: shell-quote: 1.8.3 string.prototype.padend: 3.1.6 + npm-run-path@4.0.1: + dependencies: + path-key: 3.1.1 + npm-run-path@5.3.0: dependencies: path-key: 4.0.0 @@ -14803,6 +14915,8 @@ snapshots: object-keys@1.1.1: {} + object-treeify@1.1.33: {} + object.assign@4.1.7: dependencies: call-bind: 1.0.8 @@ -16063,6 +16177,8 @@ snapshots: strip-bom@3.0.0: {} + strip-final-newline@2.0.0: {} + strip-final-newline@3.0.0: {} strip-final-newline@4.0.0: {} @@ -17046,6 +17162,10 @@ snapshots: dependencies: isexe: 2.0.0 + which@4.0.0: + dependencies: + isexe: 3.1.1 + which@5.0.0: dependencies: isexe: 3.1.1