diff --git a/README.md b/README.md index ec0828b..eb9830a 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,7 @@ caster.ingest( ```python from graflo.db.postgres import PostgresConnection -from graflo.db.postgres.heuristics import infer_schema_from_postgres +from graflo.db.inferencer import infer_schema_from_postgres from graflo.db.connection.onto import PostgresConfig from graflo import Caster from graflo.onto import DBFlavor diff --git a/docs/examples/example-5.md b/docs/examples/example-5.md index 95acca2..81b6ab4 100644 --- a/docs/examples/example-5.md +++ b/docs/examples/example-5.md @@ -262,7 +262,8 @@ Automatically generate a graflo Schema from your PostgreSQL database. This is th 5. **Creates Resources**: Resource definitions are generated for each table with appropriate actors (VertexActor for vertex tables, EdgeActor for edge tables). Foreign keys are mapped to vertex matching keys. ```python -from graflo.db.postgres import infer_schema_from_postgres + +from graflo.db.inferencer import infer_schema_from_postgres from graflo.onto import DBFlavor from graflo.db.connection.onto import ArangoConfig, Neo4jConfig, TigergraphConfig, FalkordbConfig from graflo.db import DBType @@ -283,7 +284,7 @@ db_flavor = ( schema = infer_schema_from_postgres( postgres_conn, schema_name="public", # PostgreSQL schema name - db_flavor=db_flavor # Target graph database flavor + db_flavor=db_flavor # Target graph database flavor ) ``` @@ -327,7 +328,8 @@ logger.info(f"Inferred schema saved to {schema_output_file}") Create `Patterns` that map PostgreSQL tables to resources: ```python -from graflo.db.postgres import create_patterns_from_postgres + +from graflo.db.inferencer import create_patterns_from_postgres # Create patterns from PostgreSQL tables patterns = create_patterns_from_postgres( @@ -402,9 +404,8 @@ from graflo.onto import DBFlavor from graflo.db import DBType from graflo.db.postgres import ( PostgresConnection, - create_patterns_from_postgres, - infer_schema_from_postgres, ) +from graflo.db.inferencer import infer_schema_from_postgres, create_patterns_from_postgres from graflo.db.connection.onto import ArangoConfig, PostgresConfig logger = logging.getLogger(__name__) @@ -419,6 +420,7 @@ postgres_conn = PostgresConnection(postgres_conf) # Step 3: Connect to target graph database # You can try different databases by uncommenting the desired config: from graflo.db.connection.onto import ArangoConfig, Neo4jConfig, TigergraphConfig, FalkordbConfig + target_config = ArangoConfig.from_docker_env() # or Neo4jConfig, TigergraphConfig, FalkordbConfig # Step 4: Infer Schema from PostgreSQL database structure diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 5244952..36f3d81 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -108,7 +108,8 @@ The `ingest()` method takes: You can ingest data directly from PostgreSQL tables. First, infer the schema from your PostgreSQL database: ```python -from graflo.db.postgres import PostgresConnection, infer_schema_from_postgres, create_patterns_from_postgres +from graflo.db.postgres import PostgresConnection +from graflo.db.inferencer import infer_schema_from_postgres, create_patterns_from_postgres from graflo.db.connection.onto import PostgresConfig # Connect to PostgreSQL diff --git a/examples/5-ingest-postgres/generated-schema.yaml b/examples/5-ingest-postgres/generated-schema.yaml index b896ded..cf96ab0 100644 --- a/examples/5-ingest-postgres/generated-schema.yaml +++ b/examples/5-ingest-postgres/generated-schema.yaml @@ -1,8 +1,53 @@ +edge_config: + edges: + - relation: follows + source: users + target: users + weights: + direct: + - name: created_at + type: DATETIME + - relation: purchases + source: users + target: products + weights: + direct: + - name: purchase_date + type: DATETIME + - name: quantity + type: INT + - name: total_amount + type: FLOAT general: - name: public + name: accounting +resources: +- apply: + - vertex: products + resource_name: products +- apply: + - vertex: users + resource_name: users +- apply: + - map: + follower_id: id + target_vertex: users + - map: + followed_id: id + target_vertex: users + resource_name: follows +- apply: + - map: + user_id: id + target_vertex: users + - map: + product_id: id + target_vertex: products + resource_name: purchases +transforms: {} vertex_config: + db_flavor: tigergraph vertices: - - name: products + - dbname: products fields: - name: id type: INT @@ -17,8 +62,8 @@ vertex_config: indexes: - fields: - id - dbname: products - - name: users + name: products + - dbname: users fields: - name: id type: INT @@ -31,48 +76,4 @@ vertex_config: indexes: - fields: - id - dbname: users -edge_config: - edges: - - source: users - target: users - weights: - direct: - - name: created_at - type: DATETIME - relation: follows - - source: users - target: products - weights: - direct: - - name: purchase_date - type: DATETIME - - name: quantity - type: INT - - name: total_amount - type: FLOAT - relation: purchases -resources: -- resource_name: products - apply: - - vertex: products -- resource_name: users - apply: - - vertex: users -- resource_name: follows - apply: - - target_vertex: users - map: - follower_id: id - - target_vertex: users - map: - followed_id: id -- resource_name: purchases - apply: - - target_vertex: users - map: - user_id: id - - target_vertex: products - map: - product_id: id -transforms: {} + name: users diff --git a/examples/5-ingest-postgres/ingest.py b/examples/5-ingest-postgres/ingest.py index 27a0f26..28bebc1 100644 --- a/examples/5-ingest-postgres/ingest.py +++ b/examples/5-ingest-postgres/ingest.py @@ -13,15 +13,17 @@ import logging from pathlib import Path -import yaml +from suthing import FileHandle from graflo.onto import DBFlavor from graflo.db import DBType from graflo import Caster from graflo.db.postgres import ( PostgresConnection, - create_patterns_from_postgres, +) +from graflo.db.inferencer import ( infer_schema_from_postgres, + create_patterns_from_postgres, ) from graflo.db.postgres.util import load_schema_from_sql_file from graflo.db.connection.onto import PostgresConfig, TigergraphConfig @@ -103,15 +105,13 @@ postgres_conn, schema_name="public", db_flavor=db_flavor ) + +schema.general.name = "accounting" # Step 3.5: Dump inferred schema to YAML file schema_output_file = Path(__file__).parent / "generated-schema.yaml" # Convert schema to dict (enums are automatically converted to strings by BaseDataclass.to_dict()) -schema_dict = schema.to_dict() - -# Write to YAML file -with open(schema_output_file, "w") as f: - yaml.safe_dump(schema_dict, f, default_flow_style=False, sort_keys=False) +FileHandle.dump(schema.to_dict(), schema_output_file) logger.info(f"Inferred schema saved to {schema_output_file}") diff --git a/graflo/db/__init__.py b/graflo/db/__init__.py index c1418be..9940b33 100644 --- a/graflo/db/__init__.py +++ b/graflo/db/__init__.py @@ -32,12 +32,16 @@ from .neo4j.conn import Neo4jConnection from .postgres.conn import PostgresConnection from .tigergraph.conn import TigerGraphConnection +from .inferencer import infer_schema_from_postgres, create_patterns_from_postgres + __all__ = [ "Connection", "ConnectionType", + "create_patterns_from_postgres", "DBType", "DBConfig", + "infer_schema_from_postgres", "ConnectionManager", "ArangoConnection", "FalkordbConnection", diff --git a/graflo/db/inferencer.py b/graflo/db/inferencer.py new file mode 100644 index 0000000..38c4df1 --- /dev/null +++ b/graflo/db/inferencer.py @@ -0,0 +1,206 @@ +from graflo import Schema +from graflo.util.onto import Patterns, TablePattern +from graflo.architecture import Resource +from graflo.db import PostgresConnection +from graflo.db.postgres import PostgresSchemaInferencer, PostgresResourceMapper +from graflo.db.sanitizer import SchemaSanitizer +from graflo.onto import DBFlavor +import logging + +logger = logging.getLogger(__name__) + + +class InferenceManager: + """Inference manager for PostgreSQL sources.""" + + def __init__( + self, + conn: PostgresConnection, + target_db_flavor: DBFlavor = DBFlavor.ARANGO, + ): + """Initialize the PostgreSQL inference manager. + + Args: + conn: PostgresConnection instance + target_db_flavor: Target database flavor for schema sanitization + """ + self.target_db_flavor = target_db_flavor + self.sanitizer = SchemaSanitizer(target_db_flavor) + self.conn = conn + self.inferencer = PostgresSchemaInferencer( + db_flavor=target_db_flavor, conn=conn + ) + self.mapper = PostgresResourceMapper() + + def introspect(self, schema_name: str | None = None): + """Introspect PostgreSQL schema. + + Args: + schema_name: Schema name to introspect + + Returns: + SchemaIntrospectionResult: PostgreSQL schema introspection result + """ + return self.conn.introspect_schema(schema_name=schema_name) + + def infer_schema( + self, introspection_result, schema_name: str | None = None + ) -> Schema: + """Infer graflo Schema from PostgreSQL introspection result. + + Args: + introspection_result: SchemaIntrospectionResult from PostgreSQL + schema_name: Schema name (optional, may be inferred from result) + + Returns: + Schema: Inferred schema with vertices and edges + """ + return self.inferencer.infer_schema( + introspection_result, schema_name=schema_name + ) + + def create_resources( + self, introspection_result, schema: Schema + ) -> list["Resource"]: + """Create Resources from PostgreSQL introspection result. + + Args: + introspection_result: SchemaIntrospectionResult from PostgreSQL + schema: Existing Schema object + + Returns: + list[Resource]: List of Resources for PostgreSQL tables + """ + return self.mapper.map_tables_to_resources( + introspection_result, schema.vertex_config, self.sanitizer + ) + + def infer_complete_schema(self, schema_name: str | None = None) -> Schema: + """Infer a complete Schema from source and sanitize for target. + + This is a convenience method that: + 1. Introspects the source schema + 2. Infers the graflo Schema + 3. Sanitizes for the target database flavor + 4. Creates and adds resources + 5. Re-initializes the schema + + Args: + schema_name: Schema name to introspect (source-specific) + + Returns: + Schema: Complete inferred schema with vertices, edges, and resources + """ + # Introspect the schema + introspection_result = self.introspect(schema_name=schema_name) + + # Infer schema + schema = self.infer_schema(introspection_result, schema_name=schema_name) + + # Sanitize for target database flavor + schema = self.sanitizer.sanitize(schema) + + # Create and add resources + resources = self.create_resources(introspection_result, schema) + schema.resources = resources + + # Re-initialize to set up resource mappings + schema.__post_init__() + + return schema + + def create_resources_for_schema( + self, schema: Schema, schema_name: str | None = None + ) -> list["Resource"]: + """Create Resources from source for an existing schema. + + Args: + schema: Existing Schema object + schema_name: Schema name to introspect (source-specific) + + Returns: + list[Resource]: List of Resources for the source + """ + # Introspect the schema + introspection_result = self.introspect(schema_name=schema_name) + + # Create resources + return self.create_resources(introspection_result, schema) + + +def infer_schema_from_postgres( + conn: PostgresConnection, + schema_name: str | None = None, + db_flavor: DBFlavor = DBFlavor.ARANGO, +) -> Schema: + """Convenience function to infer a graflo Schema from PostgreSQL database. + + Args: + conn: PostgresConnection instance + schema_name: Schema name to introspect (defaults to config schema_name or 'public') + db_flavor: Target database flavor (defaults to ARANGO) + + Returns: + Schema: Inferred schema with vertices, edges, and resources + """ + manager = InferenceManager(conn, target_db_flavor=db_flavor) + return manager.infer_complete_schema(schema_name=schema_name) + + +def create_patterns_from_postgres( + conn: PostgresConnection, schema_name: str | None = None +) -> Patterns: + """Create Patterns from PostgreSQL tables. + + Args: + conn: PostgresConnection instance + schema_name: Schema name to introspect + + Returns: + Patterns: Patterns object with TablePattern instances for all tables + """ + + # Introspect the schema + introspection_result = conn.introspect_schema(schema_name=schema_name) + + # Create patterns + patterns = Patterns() + + # Get schema name + effective_schema = schema_name or introspection_result.schema_name + + # Store the connection config + config_key = "default" + patterns.postgres_configs[(config_key, effective_schema)] = conn.config + + # Add patterns for vertex tables + for table_info in introspection_result.vertex_tables: + table_name = table_info.name + table_pattern = TablePattern( + table_name=table_name, + schema_name=effective_schema, + resource_name=table_name, + ) + patterns.table_patterns[table_name] = table_pattern + patterns.postgres_table_configs[table_name] = ( + config_key, + effective_schema, + table_name, + ) + + # Add patterns for edge tables + for table_info in introspection_result.edge_tables: + table_name = table_info.name + table_pattern = TablePattern( + table_name=table_name, + schema_name=effective_schema, + resource_name=table_name, + ) + patterns.table_patterns[table_name] = table_pattern + patterns.postgres_table_configs[table_name] = ( + config_key, + effective_schema, + table_name, + ) + + return patterns diff --git a/graflo/db/postgres/__init__.py b/graflo/db/postgres/__init__.py index 8698b4f..d5f657d 100644 --- a/graflo/db/postgres/__init__.py +++ b/graflo/db/postgres/__init__.py @@ -10,7 +10,7 @@ - PostgresResourceMapper: Maps PostgreSQL tables to graflo Resources Example: - >>> from graflo.db.postgres.heuristics import infer_schema_from_postgres >>> from graflo.db.postgres import PostgresConnection + >>> from graflo.db.inferencer import infer_schema_from_postgres >>> >>> from graflo.db.postgres import PostgresConnection >>> from graflo.db.connection.onto import PostgresConfig >>> config = PostgresConfig.from_docker_env() >>> conn = PostgresConnection(config) @@ -19,11 +19,6 @@ """ from .conn import PostgresConnection -from .heuristics import ( - create_patterns_from_postgres, - create_resources_from_postgres, - infer_schema_from_postgres, -) from .resource_mapping import PostgresResourceMapper from .schema_inference import PostgresSchemaInferencer @@ -31,7 +26,4 @@ "PostgresConnection", "PostgresSchemaInferencer", "PostgresResourceMapper", - "infer_schema_from_postgres", - "create_resources_from_postgres", - "create_patterns_from_postgres", ] diff --git a/graflo/db/postgres/heuristics.py b/graflo/db/postgres/heuristics.py index e735c49..e69de29 100644 --- a/graflo/db/postgres/heuristics.py +++ b/graflo/db/postgres/heuristics.py @@ -1,133 +0,0 @@ -import logging - -from graflo.util.onto import Patterns, TablePattern -from graflo.db.postgres.conn import ( - PostgresConnection, -) -from graflo.db.postgres.resource_mapping import PostgresResourceMapper -from graflo.db.postgres.schema_inference import PostgresSchemaInferencer - -logger = logging.getLogger(__name__) - - -def create_patterns_from_postgres( - conn: PostgresConnection, schema_name: str | None = None -) -> Patterns: - """Create Patterns from PostgreSQL tables. - - Args: - conn: PostgresConnection instance - schema_name: Schema name to introspect - - Returns: - Patterns: Patterns object with TablePattern instances for all tables - """ - - # Introspect the schema - introspection_result = conn.introspect_schema(schema_name=schema_name) - - # Create patterns - patterns = Patterns() - - # Get schema name - effective_schema = schema_name or introspection_result.schema_name - - # Store the connection config - config_key = "default" - patterns.postgres_configs[(config_key, effective_schema)] = conn.config - - # Add patterns for vertex tables - for table_info in introspection_result.vertex_tables: - table_name = table_info.name - table_pattern = TablePattern( - table_name=table_name, - schema_name=effective_schema, - resource_name=table_name, - ) - patterns.table_patterns[table_name] = table_pattern - patterns.postgres_table_configs[table_name] = ( - config_key, - effective_schema, - table_name, - ) - - # Add patterns for edge tables - for table_info in introspection_result.edge_tables: - table_name = table_info.name - table_pattern = TablePattern( - table_name=table_name, - schema_name=effective_schema, - resource_name=table_name, - ) - patterns.table_patterns[table_name] = table_pattern - patterns.postgres_table_configs[table_name] = ( - config_key, - effective_schema, - table_name, - ) - - return patterns - - -def create_resources_from_postgres( - conn: PostgresConnection, schema, schema_name: str | None = None -): - """Create Resources from PostgreSQL tables for an existing schema. - - Args: - conn: PostgresConnection instance - schema: Existing Schema object - schema_name: Schema name to introspect - - Returns: - list[Resource]: List of Resources for PostgreSQL tables - """ - # Introspect the schema - introspection_result = conn.introspect_schema(schema_name=schema_name) - - # Map tables to resources - mapper = PostgresResourceMapper() - resources = mapper.map_tables_to_resources( - introspection_result, schema.vertex_config, schema.edge_config - ) - - return resources - - -def infer_schema_from_postgres( - conn: PostgresConnection, schema_name: str | None = None, db_flavor=None -): - """Convenience function to infer a graflo Schema from PostgreSQL database. - - Args: - conn: PostgresConnection instance - schema_name: Schema name to introspect (defaults to config schema_name or 'public') - db_flavor: Target database flavor (defaults to ARANGO) - - Returns: - Schema: Inferred schema with vertices, edges, and resources - """ - from graflo.onto import DBFlavor - - if db_flavor is None: - db_flavor = DBFlavor.ARANGO - - # Introspect the schema - introspection_result = conn.introspect_schema(schema_name=schema_name) - - # Infer schema (pass connection for type sampling) - inferencer = PostgresSchemaInferencer(db_flavor=db_flavor, conn=conn) - schema = inferencer.infer_schema(introspection_result, schema_name=schema_name) - - # Create and add resources - mapper = PostgresResourceMapper() - resources = mapper.map_tables_to_resources( - introspection_result, schema.vertex_config, schema.edge_config - ) - - # Update schema with resources - schema.resources = resources - # Re-initialize to set up resource mappings - schema.__post_init__() - - return schema diff --git a/graflo/db/postgres/resource_mapping.py b/graflo/db/postgres/resource_mapping.py index 6c4861d..28ef9b7 100644 --- a/graflo/db/postgres/resource_mapping.py +++ b/graflo/db/postgres/resource_mapping.py @@ -6,10 +6,9 @@ import logging -from graflo.architecture.edge import EdgeConfig from graflo.architecture.resource import Resource from graflo.architecture.vertex import VertexConfig - +from ..sanitizer import SchemaSanitizer from .conn import EdgeTableInfo, SchemaIntrospectionResult from .fuzzy_matcher import FuzzyMatchCache from .inference_utils import ( @@ -224,7 +223,7 @@ def map_tables_to_resources( self, introspection_result: SchemaIntrospectionResult, vertex_config: VertexConfig, - edge_config: EdgeConfig, + sanitizer: SchemaSanitizer, ) -> list[Resource]: """Map all PostgreSQL tables to Resources. @@ -234,7 +233,7 @@ def map_tables_to_resources( Args: introspection_result: Result from PostgresConnection.introspect_schema() vertex_config: Inferred vertex configuration - edge_config: Inferred edge configuration + sanitizer: carries mappiings Returns: list[Resource]: List of Resources for all tables @@ -257,6 +256,7 @@ def map_tables_to_resources( edge_tables = introspection_result.edge_tables for edge_table_info in edge_tables: try: + # NB: use sanitizer sanitizer.relation_mappings resource = self.create_edge_resource( edge_table_info, vertex_config, match_cache ) diff --git a/graflo/db/postgres/schema_inference.py b/graflo/db/postgres/schema_inference.py index 73b3ee4..67354cb 100644 --- a/graflo/db/postgres/schema_inference.py +++ b/graflo/db/postgres/schema_inference.py @@ -7,7 +7,6 @@ from __future__ import annotations import logging -from collections import Counter from typing import TYPE_CHECKING @@ -18,12 +17,11 @@ from graflo.onto import DBFlavor from ...architecture.onto_sql import EdgeTableInfo, SchemaIntrospectionResult -from ..util import load_reserved_words, sanitize_attribute_name from .conn import PostgresConnection from .types import PostgresTypeMapper if TYPE_CHECKING: - from graflo.architecture.resource import Resource + pass logger = logging.getLogger(__name__) @@ -49,8 +47,6 @@ def __init__( self.db_flavor = db_flavor self.type_mapper = PostgresTypeMapper() self.conn = conn - # Load reserved words for the target database flavor - self.reserved_words = load_reserved_words(db_flavor) def infer_vertex_config( self, introspection_result: SchemaIntrospectionResult @@ -333,424 +329,6 @@ def infer_edge_config( return EdgeConfig(edges=edges) - def _sanitize_schema_attributes(self, schema: Schema) -> Schema: - """Sanitize attribute names and vertex names in the schema to avoid reserved words. - - This method modifies: - - Field names in vertices and edges - - Vertex names themselves - - Edge source/target/by references to vertices - - Resource apply lists that reference vertices - - The sanitization is deterministic: the same input always produces the same output. - - Args: - schema: The schema to sanitize - - Returns: - Schema with sanitized attribute names and vertex names - """ - if not self.reserved_words: - # No reserved words to check, return schema as-is - return schema - - # Track name mappings for attributes (fields/weights) - attribute_mappings: dict[str, str] = {} - # Track name mappings for vertex names (separate from attributes) - - # First pass: Sanitize vertex dbnames - for vertex in schema.vertex_config.vertices: - sanitized_vertex_name = sanitize_attribute_name( - vertex.dbname, self.reserved_words, suffix="_vertex" - ) - if sanitized_vertex_name != vertex.dbname: - logger.debug( - f"Sanitizing vertex name '{vertex.dbname}' -> '{sanitized_vertex_name}'" - ) - vertex.dbname = sanitized_vertex_name - - # Second pass: Sanitize vertex field names - for vertex in schema.vertex_config.vertices: - for field in vertex.fields: - original_name = field.name - if original_name not in attribute_mappings: - sanitized_name = sanitize_attribute_name( - original_name, self.reserved_words - ) - if sanitized_name != original_name: - attribute_mappings[original_name] = sanitized_name - logger.debug( - f"Sanitizing field name '{original_name}' -> '{sanitized_name}' " - f"in vertex '{vertex.name}'" - ) - else: - attribute_mappings[original_name] = original_name - else: - sanitized_name = attribute_mappings[original_name] - - # Update field name if it changed - if sanitized_name != original_name: - field.name = sanitized_name - - # Update index field references if they were sanitized - for index in vertex.indexes: - updated_fields = [] - for field_name in index.fields: - sanitized_field_name = attribute_mappings.get( - field_name, field_name - ) - updated_fields.append(sanitized_field_name) - index.fields = updated_fields - - # Third pass: Normalize edge indexes for TigerGraph - # TigerGraph requires that edges with the same relation have consistent source and target indexes - # 1) group edges by relation - # 2) check that for each group specified by relation the sources have the same index - # and separately the targets have the same index - # 3) if this is not the case, identify the most popular index - # 4) for vertices that don't comply with the chose source/target index, we want to prepare a mapping - # and rename relevant fields indexes - field_index_mappings: dict[ - str, dict[str, str] - ] = {} # vertex_name -> {old_field: new_field} - - if schema.vertex_config.db_flavor == DBFlavor.TIGERGRAPH: - # Group edges by relation - edges_by_relation: dict[str | None, list[Edge]] = {} - for edge in schema.edge_config.edges: - relation = edge.relation - if relation not in edges_by_relation: - edges_by_relation[relation] = [] - edges_by_relation[relation].append(edge) - - # Process each relation group - for relation, relation_edges in edges_by_relation.items(): - if len(relation_edges) <= 1: - # Only one edge with this relation, no normalization needed - continue - - # Collect all vertex/index pairs using a list to capture all occurrences - # This handles cases where a vertex appears multiple times in edges for the same relation - source_vertex_indexes: list[tuple[str, tuple[str, ...]]] = [] - target_vertex_indexes: list[tuple[str, tuple[str, ...]]] = [] - - for edge in relation_edges: - source_vertex = edge.source - target_vertex = edge.target - - # Get primary index for source vertex - source_index = schema.vertex_config.index(source_vertex) - source_vertex_indexes.append( - (source_vertex, tuple(source_index.fields)) - ) - - # Get primary index for target vertex - target_index = schema.vertex_config.index(target_vertex) - target_vertex_indexes.append( - (target_vertex, tuple(target_index.fields)) - ) - - # Normalize source indexes - self._normalize_vertex_indexes( - source_vertex_indexes, - relation, - schema, - field_index_mappings, - "source", - ) - - # Normalize target indexes - self._normalize_vertex_indexes( - target_vertex_indexes, - relation, - schema, - field_index_mappings, - "target", - ) - - # Fourth pass: the field maps from edge/relation normalization should be applied to resources: - # new transforms should be added mapping old index names to those identified in the previous step - if field_index_mappings: - for resource in schema.resources: - self._apply_field_index_mappings_to_resource( - resource, field_index_mappings - ) - - return schema - - def _normalize_vertex_indexes( - self, - vertex_indexes: list[tuple[str, tuple[str, ...]]], - relation: str | None, - schema: Schema, - field_index_mappings: dict[str, dict[str, str]], - role: str, # "source" or "target" for logging - ) -> None: - """Normalize vertex indexes to use the most popular index pattern. - - For vertices that don't match the most popular index, this method: - 1. Creates field mappings (old_field -> new_field) - 2. Updates vertex indexes to match the most popular pattern - 3. Adds new fields to vertices if needed - 4. Removes old fields that are being replaced - - Args: - vertex_indexes: List of (vertex_name, index_fields_tuple) pairs - relation: Relation name for logging - schema: Schema to update - field_index_mappings: Dictionary to update with field mappings - role: "source" or "target" for logging purposes - """ - if not vertex_indexes: - return - - # Extract unique vertex/index pairs (a vertex might appear multiple times) - vertex_index_dict: dict[str, tuple[str, ...]] = {} - for vertex_name, index_fields in vertex_indexes: - # Only store first occurrence - we'll normalize all vertices together - if vertex_name not in vertex_index_dict: - vertex_index_dict[vertex_name] = index_fields - - # Check if all indexes are consistent - indexes_list = list(vertex_index_dict.values()) - indexes_set = set(indexes_list) - indexes_consistent = len(indexes_set) == 1 - - if indexes_consistent: - # All indexes are the same, no normalization needed - return - - # Find most popular index - index_counter = Counter(indexes_list) - most_popular_index = index_counter.most_common(1)[0][0] - - # Normalize vertices that don't match - for vertex_name, index_fields in vertex_index_dict.items(): - if index_fields == most_popular_index: - continue - - # Initialize mappings for this vertex if needed - if vertex_name not in field_index_mappings: - field_index_mappings[vertex_name] = {} - - # Map old fields to new fields - old_fields = list(index_fields) - new_fields = list(most_popular_index) - - # Create field-to-field mapping - # If lengths match, map positionally; otherwise map first field to first field - if len(old_fields) == len(new_fields): - for old_field, new_field in zip(old_fields, new_fields): - if old_field != new_field: - # Update existing mapping if it exists, otherwise create new one - field_index_mappings[vertex_name][old_field] = new_field - else: - # If lengths don't match, map the first field - if old_fields and new_fields: - if old_fields[0] != new_fields[0]: - field_index_mappings[vertex_name][old_fields[0]] = new_fields[0] - - # Update vertex index and fields - vertex = schema.vertex_config[vertex_name] - existing_field_names = {f.name for f in vertex.fields} - - # Add new fields that don't exist - for new_field in most_popular_index: - if new_field not in existing_field_names: - vertex.fields.append(Field(name=new_field, type=None)) - existing_field_names.add(new_field) - - # Remove old fields that are being replaced (not in new index) - fields_to_remove = [ - f - for f in vertex.fields - if f.name in old_fields and f.name not in new_fields - ] - for field_to_remove in fields_to_remove: - vertex.fields.remove(field_to_remove) - - # Update vertex index to match the most popular one - vertex.indexes[0].fields = list(most_popular_index) - - logger.debug( - f"Normalizing {role} index for vertex '{vertex_name}' in relation '{relation}': " - f"{old_fields} -> {new_fields}" - ) - - def _apply_field_index_mappings_to_resource( - self, resource: Resource, field_index_mappings: dict[str, dict[str, str]] - ) -> None: - """Apply field index mappings to TransformActor instances in a resource. - - For vertices that had their indexes normalized, this method updates TransformActor - instances to map old field names to new field names in their Transform.map attribute. - Only updates TransformActors where the vertex is confirmed to be created at that level - (via VertexActor). - - Args: - resource: The resource to update - field_index_mappings: Dictionary mapping vertex names to field mappings - (old_field -> new_field) - """ - from graflo.architecture.actor import ( - ActorWrapper, - DescendActor, - TransformActor, - VertexActor, - ) - - def collect_vertices_at_level(wrappers: list[ActorWrapper]) -> set[str]: - """Collect vertices created by VertexActor instances at the current level only. - - Does not recurse into nested structures - only collects vertices from - the immediate level. - - Args: - wrappers: List of ActorWrapper instances - - Returns: - set[str]: Set of vertex names created at this level - """ - vertices = set() - for wrapper in wrappers: - if isinstance(wrapper.actor, VertexActor): - vertices.add(wrapper.actor.name) - return vertices - - def update_transform_actor_maps( - wrapper: ActorWrapper, parent_vertices: set[str] | None = None - ) -> set[str]: - """Recursively update TransformActor instances with field index mappings. - - Args: - wrapper: ActorWrapper instance to process - parent_vertices: Set of vertices available from parent levels (for nested structures) - - Returns: - set[str]: Set of all vertices available at this level (including parent) - """ - if parent_vertices is None: - parent_vertices = set() - - # Collect vertices created at this level - current_level_vertices = set() - if isinstance(wrapper.actor, VertexActor): - current_level_vertices.add(wrapper.actor.name) - - # All available vertices = current level + parent levels - all_available_vertices = current_level_vertices | parent_vertices - - # Process TransformActor if present - if isinstance(wrapper.actor, TransformActor): - transform_actor: TransformActor = wrapper.actor - - def apply_mappings_to_transform( - mappings: dict[str, str], - vertex_name: str, - actor: TransformActor, - ) -> None: - """Apply field mappings to TransformActor's transform.map attribute. - - Args: - mappings: Dictionary mapping old field names to new field names - vertex_name: Name of the vertex these mappings belong to (for logging) - actor: The TransformActor instance to update - """ - transform = actor.t - if transform.map: - # Update existing map: replace values and keys that match old field names - # First, update values - for map_key, map_value in transform.map.items(): - if isinstance(map_value, str) and map_value in mappings: - transform.map[map_key] = mappings[map_value] - - # if the terminal attr not in the map - add it - for k, v in mappings.items(): - if v not in transform.map.values(): - transform.map[k] = v - else: - # Create new map with all mappings - transform.map = mappings.copy() - - # Update Transform object IO to reflect map edits - actor.t._init_io_from_map(force_init=True) - - logger.debug( - f"Updated TransformActor map in resource '{resource.resource_name}' " - f"for vertex '{vertex_name}': {mappings}" - ) - - target_vertex = transform_actor.vertex - - if isinstance(target_vertex, str): - # TransformActor has explicit target_vertex - if ( - target_vertex in field_index_mappings - and target_vertex in all_available_vertices - ): - mappings = field_index_mappings[target_vertex] - if mappings: - apply_mappings_to_transform( - mappings, target_vertex, transform_actor - ) - else: - logger.debug( - f"Skipping TransformActor for vertex '{target_vertex}' " - f"in resource '{resource.resource_name}': no mappings needed" - ) - else: - logger.debug( - f"Skipping TransformActor for vertex '{target_vertex}' " - f"in resource '{resource.resource_name}': vertex not created at this level" - ) - else: - # TransformActor has no target_vertex - # Apply mappings from all available vertices (parent and current level) - # since transformed fields will be attributed to those vertices - applied_any = False - for vertex in all_available_vertices: - if vertex in field_index_mappings: - mappings = field_index_mappings[vertex] - if mappings: - apply_mappings_to_transform( - mappings, vertex, transform_actor - ) - applied_any = True - - if not applied_any: - logger.debug( - f"Skipping TransformActor without target_vertex " - f"in resource '{resource.resource_name}': " - f"no mappings found for available vertices {all_available_vertices}" - ) - - # Recursively process nested structures (DescendActor) - if isinstance(wrapper.actor, DescendActor): - # Collect vertices from all descendants at this level - descendant_vertices = collect_vertices_at_level( - wrapper.actor.descendants - ) - all_available_vertices |= descendant_vertices - - # Recursively process each descendant - for descendant_wrapper in wrapper.actor.descendants: - nested_vertices = update_transform_actor_maps( - descendant_wrapper, parent_vertices=all_available_vertices - ) - # Merge nested vertices into available vertices - all_available_vertices |= nested_vertices - - return all_available_vertices - - # Process the root ActorWrapper if it exists - if hasattr(resource, "root") and resource.root is not None: - update_transform_actor_maps(resource.root) - else: - logger.warning( - f"Resource '{resource.resource_name}' does not have a root ActorWrapper. " - f"Skipping field index mapping updates." - ) - def infer_schema( self, introspection_result: SchemaIntrospectionResult, @@ -790,9 +368,6 @@ def infer_schema( resources=[], # Resources will be created separately ) - # Sanitize attribute names to avoid reserved words - schema = self._sanitize_schema_attributes(schema) - logger.info( f"Successfully inferred schema '{schema_name}' with " f"{len(vertex_config.vertices)} vertices and " diff --git a/graflo/db/postgres/util.py b/graflo/db/postgres/util.py index 5c99e46..9e66f11 100644 --- a/graflo/db/postgres/util.py +++ b/graflo/db/postgres/util.py @@ -2,7 +2,9 @@ from graflo.db import PostgresConnection from graflo.db.connection.onto import PostgresConfig -from graflo.db.postgres.heuristics import logger +import logging + +logger = logging.getLogger(__name__) def load_schema_from_sql_file( diff --git a/graflo/db/sanitizer.py b/graflo/db/sanitizer.py new file mode 100644 index 0000000..18a31d6 --- /dev/null +++ b/graflo/db/sanitizer.py @@ -0,0 +1,509 @@ +"""Schema sanitization for PostgreSQL schema inference. + +This module provides functionality to sanitize schema attributes to avoid +reserved words and normalize vertex indexes for specific database flavors. +""" + +from __future__ import annotations + +import logging +from collections import Counter +from typing import TYPE_CHECKING + +from graflo.architecture.edge import Edge +from graflo.architecture.schema import Schema +from graflo.architecture.vertex import Field +from graflo.onto import DBFlavor + +from graflo.db.util import load_reserved_words, sanitize_attribute_name + +if TYPE_CHECKING: + from graflo.architecture.resource import Resource + +logger = logging.getLogger(__name__) + + +class SchemaSanitizer: + """Sanitizes schema attributes to avoid reserved words and normalize indexes. + + This class handles: + - Sanitizing vertex names and field names to avoid reserved words + - Normalizing vertex indexes for TigerGraph (ensuring consistent indexes + for edges with the same relation) + - Applying field index mappings to resources + """ + + def __init__(self, db_flavor: DBFlavor): + """Initialize the schema sanitizer. + + Args: + db_flavor: Target database flavor to load reserved words for + """ + self.db_flavor = db_flavor + self.reserved_words = load_reserved_words(db_flavor) + self.attribute_mappings: dict[str, str] = {} + self.vertex_mappings: dict[str, str] = {} + # Track relation name mappings + self.relation_mappings: dict[str, str] = {} + + def sanitize(self, schema: Schema) -> Schema: + """Sanitize attribute names and vertex names in the schema to avoid reserved words. + + This method modifies: + - Field names in vertices and edges + - Vertex names themselves + - Edge source/target/by references to vertices + - Resource apply lists that reference vertices + + The sanitization is deterministic: the same input always produces the same output. + + Args: + schema: The schema to sanitize + + Returns: + Schema with sanitized attribute names and vertex names + """ + if not self.reserved_words: + # No reserved words to check, return schema as-is + return schema + + # First pass: Sanitize vertex dbnames + for vertex in schema.vertex_config.vertices: + sanitized_vertex_name = sanitize_attribute_name( + vertex.dbname, self.reserved_words, suffix="_vertex" + ) + if sanitized_vertex_name != vertex.dbname: + logger.debug( + f"Sanitizing vertex name '{vertex.dbname}' -> '{sanitized_vertex_name}'" + ) + self.vertex_mappings[vertex.dbname] = sanitized_vertex_name + vertex.dbname = sanitized_vertex_name + + # Second pass: Sanitize vertex field names + for vertex in schema.vertex_config.vertices: + for field in vertex.fields: + original_name = field.name + if original_name not in self.attribute_mappings: + sanitized_name = sanitize_attribute_name( + original_name, self.reserved_words + ) + if sanitized_name != original_name: + self.attribute_mappings[original_name] = sanitized_name + logger.debug( + f"Sanitizing field name '{original_name}' -> '{sanitized_name}' " + f"in vertex '{vertex.name}'" + ) + else: + self.attribute_mappings[original_name] = original_name + else: + sanitized_name = self.attribute_mappings[original_name] + + # Update field name if it changed + if sanitized_name != original_name: + field.name = sanitized_name + + # Update index field references if they were sanitized + for index in vertex.indexes: + updated_fields = [] + for field_name in index.fields: + sanitized_field_name = self.attribute_mappings.get( + field_name, field_name + ) + updated_fields.append(sanitized_field_name) + index.fields = updated_fields + + vertex_names = {vertex.dbname for vertex in schema.vertex_config.vertices} + + for edge in schema.edge_config.edges: + if edge.relation is None: + continue + + original_relation = edge.relation + new_relation_name = original_relation + + # First, sanitize for reserved words + sanitized_relation = sanitize_attribute_name( + original_relation, self.reserved_words, suffix="_relation" + ) + if sanitized_relation != original_relation: + new_relation_name = sanitized_relation + logger.debug( + f"Sanitizing relation name '{original_relation}' -> '{sanitized_relation}' " + f"to avoid reserved word" + ) + + # Then, check for collision with vertex names + if new_relation_name in vertex_names: + # Collision detected - rename relation + new_relation_name = f"{new_relation_name}_relation" + + # Ensure the new name doesn't collide either + counter = 1 + while new_relation_name in vertex_names: + new_relation_name = f"{edge.relation}_relation{counter}" + counter += 1 + + logger.debug( + f"Renaming relation '{sanitized_relation if sanitized_relation != original_relation else original_relation}' " + f"to '{new_relation_name}' to avoid collision with vertex name" + ) + + # Update the edge relation if it changed + if new_relation_name != original_relation: + self.relation_mappings[original_relation] = new_relation_name + edge.relation = new_relation_name + + if self.relation_mappings: + logger.info( + f"Renamed {len(self.relation_mappings)} relation(s) due to reserved words or vertex collisions: " + f"{self.relation_mappings}" + ) + + # Third pass: Normalize edge indexes for TigerGraph + # TigerGraph requires that edges with the same relation have consistent source and target indexes + # 1) group edges by relation + # 2) check that for each group specified by relation the sources have the same index + # and separately the targets have the same index + # 3) if this is not the case, identify the most popular index + # 4) for vertices that don't comply with the chose source/target index, we want to prepare a mapping + # and rename relevant fields indexes + field_index_mappings: dict[ + str, dict[str, str] + ] = {} # vertex_name -> {old_field: new_field} + + if schema.vertex_config.db_flavor == DBFlavor.TIGERGRAPH: + # Group edges by relation + edges_by_relation: dict[str | None, list[Edge]] = {} + for edge in schema.edge_config.edges: + relation = edge.relation + if relation not in edges_by_relation: + edges_by_relation[relation] = [] + edges_by_relation[relation].append(edge) + + # Process each relation group + for relation, relation_edges in edges_by_relation.items(): + if len(relation_edges) <= 1: + # Only one edge with this relation, no normalization needed + continue + + # Collect all vertex/index pairs using a list to capture all occurrences + # This handles cases where a vertex appears multiple times in edges for the same relation + source_vertex_indexes: list[tuple[str, tuple[str, ...]]] = [] + target_vertex_indexes: list[tuple[str, tuple[str, ...]]] = [] + + for edge in relation_edges: + source_vertex = edge.source + target_vertex = edge.target + + # Get primary index for source vertex + source_index = schema.vertex_config.index(source_vertex) + source_vertex_indexes.append( + (source_vertex, tuple(source_index.fields)) + ) + + # Get primary index for target vertex + target_index = schema.vertex_config.index(target_vertex) + target_vertex_indexes.append( + (target_vertex, tuple(target_index.fields)) + ) + + # Normalize source indexes + self._normalize_vertex_indexes( + source_vertex_indexes, + relation, + schema, + field_index_mappings, + "source", + ) + + # Normalize target indexes + self._normalize_vertex_indexes( + target_vertex_indexes, + relation, + schema, + field_index_mappings, + "target", + ) + + # Fourth pass: the field maps from edge/relation normalization should be applied to resources: + # new transforms should be added mapping old index names to those identified in the previous step + if field_index_mappings: + for resource in schema.resources: + self._apply_field_index_mappings_to_resource( + resource, field_index_mappings + ) + + return schema + + def _normalize_vertex_indexes( + self, + vertex_indexes: list[tuple[str, tuple[str, ...]]], + relation: str | None, + schema: Schema, + field_index_mappings: dict[str, dict[str, str]], + role: str, # "source" or "target" for logging + ) -> None: + """Normalize vertex indexes to use the most popular index pattern. + + For vertices that don't match the most popular index, this method: + 1. Creates field mappings (old_field -> new_field) + 2. Updates vertex indexes to match the most popular pattern + 3. Adds new fields to vertices if needed + 4. Removes old fields that are being replaced + + Args: + vertex_indexes: List of (vertex_name, index_fields_tuple) pairs + relation: Relation name for logging + schema: Schema to update + field_index_mappings: Dictionary to update with field mappings + role: "source" or "target" for logging purposes + """ + if not vertex_indexes: + return + + # Extract unique vertex/index pairs (a vertex might appear multiple times) + vertex_index_dict: dict[str, tuple[str, ...]] = {} + for vertex_name, index_fields in vertex_indexes: + # Only store first occurrence - we'll normalize all vertices together + if vertex_name not in vertex_index_dict: + vertex_index_dict[vertex_name] = index_fields + + # Check if all indexes are consistent + indexes_list = list(vertex_index_dict.values()) + indexes_set = set(indexes_list) + indexes_consistent = len(indexes_set) == 1 + + if indexes_consistent: + # All indexes are the same, no normalization needed + return + + # Find most popular index + index_counter = Counter(indexes_list) + most_popular_index = index_counter.most_common(1)[0][0] + + # Normalize vertices that don't match + for vertex_name, index_fields in vertex_index_dict.items(): + if index_fields == most_popular_index: + continue + + # Initialize mappings for this vertex if needed + if vertex_name not in field_index_mappings: + field_index_mappings[vertex_name] = {} + + # Map old fields to new fields + old_fields = list(index_fields) + new_fields = list(most_popular_index) + + # Create field-to-field mapping + # If lengths match, map positionally; otherwise map first field to first field + if len(old_fields) == len(new_fields): + for old_field, new_field in zip(old_fields, new_fields): + if old_field != new_field: + # Update existing mapping if it exists, otherwise create new one + field_index_mappings[vertex_name][old_field] = new_field + else: + # If lengths don't match, map the first field + if old_fields and new_fields: + if old_fields[0] != new_fields[0]: + field_index_mappings[vertex_name][old_fields[0]] = new_fields[0] + + # Update vertex index and fields + vertex = schema.vertex_config[vertex_name] + existing_field_names = {f.name for f in vertex.fields} + + # Add new fields that don't exist + for new_field in most_popular_index: + if new_field not in existing_field_names: + vertex.fields.append(Field(name=new_field, type=None)) + existing_field_names.add(new_field) + + # Remove old fields that are being replaced (not in new index) + fields_to_remove = [ + f + for f in vertex.fields + if f.name in old_fields and f.name not in new_fields + ] + for field_to_remove in fields_to_remove: + vertex.fields.remove(field_to_remove) + + # Update vertex index to match the most popular one + vertex.indexes[0].fields = list(most_popular_index) + + logger.debug( + f"Normalizing {role} index for vertex '{vertex_name}' in relation '{relation}': " + f"{old_fields} -> {new_fields}" + ) + + def _apply_field_index_mappings_to_resource( + self, resource: Resource, field_index_mappings: dict[str, dict[str, str]] + ) -> None: + """Apply field index mappings to TransformActor instances in a resource. + + For vertices that had their indexes normalized, this method updates TransformActor + instances to map old field names to new field names in their Transform.map attribute. + Only updates TransformActors where the vertex is confirmed to be created at that level + (via VertexActor). + + Args: + resource: The resource to update + field_index_mappings: Dictionary mapping vertex names to field mappings + (old_field -> new_field) + """ + from graflo.architecture.actor import ( + ActorWrapper, + DescendActor, + TransformActor, + VertexActor, + ) + + def collect_vertices_at_level(wrappers: list[ActorWrapper]) -> set[str]: + """Collect vertices created by VertexActor instances at the current level only. + + Does not recurse into nested structures - only collects vertices from + the immediate level. + + Args: + wrappers: List of ActorWrapper instances + + Returns: + set[str]: Set of vertex names created at this level + """ + vertices = set() + for wrapper in wrappers: + if isinstance(wrapper.actor, VertexActor): + vertices.add(wrapper.actor.name) + return vertices + + def update_transform_actor_maps( + wrapper: ActorWrapper, parent_vertices: set[str] | None = None + ) -> set[str]: + """Recursively update TransformActor instances with field index mappings. + + Args: + wrapper: ActorWrapper instance to process + parent_vertices: Set of vertices available from parent levels (for nested structures) + + Returns: + set[str]: Set of all vertices available at this level (including parent) + """ + if parent_vertices is None: + parent_vertices = set() + + # Collect vertices created at this level + current_level_vertices = set() + if isinstance(wrapper.actor, VertexActor): + current_level_vertices.add(wrapper.actor.name) + + # All available vertices = current level + parent levels + all_available_vertices = current_level_vertices | parent_vertices + + # Process TransformActor if present + if isinstance(wrapper.actor, TransformActor): + transform_actor: TransformActor = wrapper.actor + + def apply_mappings_to_transform( + mappings: dict[str, str], + vertex_name: str, + actor: TransformActor, + ) -> None: + """Apply field mappings to TransformActor's transform.map attribute. + + Args: + mappings: Dictionary mapping old field names to new field names + vertex_name: Name of the vertex these mappings belong to (for logging) + actor: The TransformActor instance to update + """ + transform = actor.t + if transform.map: + # Update existing map: replace values and keys that match old field names + # First, update values + for map_key, map_value in transform.map.items(): + if isinstance(map_value, str) and map_value in mappings: + transform.map[map_key] = mappings[map_value] + + # if the terminal attr not in the map - add it + for k, v in mappings.items(): + if v not in transform.map.values(): + transform.map[k] = v + else: + # Create new map with all mappings + transform.map = mappings.copy() + + # Update Transform object IO to reflect map edits + actor.t._init_io_from_map(force_init=True) + + logger.debug( + f"Updated TransformActor map in resource '{resource.resource_name}' " + f"for vertex '{vertex_name}': {mappings}" + ) + + target_vertex = transform_actor.vertex + + if isinstance(target_vertex, str): + # TransformActor has explicit target_vertex + if ( + target_vertex in field_index_mappings + and target_vertex in all_available_vertices + ): + mappings = field_index_mappings[target_vertex] + if mappings: + apply_mappings_to_transform( + mappings, target_vertex, transform_actor + ) + else: + logger.debug( + f"Skipping TransformActor for vertex '{target_vertex}' " + f"in resource '{resource.resource_name}': no mappings needed" + ) + else: + logger.debug( + f"Skipping TransformActor for vertex '{target_vertex}' " + f"in resource '{resource.resource_name}': vertex not created at this level" + ) + else: + # TransformActor has no target_vertex + # Apply mappings from all available vertices (parent and current level) + # since transformed fields will be attributed to those vertices + applied_any = False + for vertex in all_available_vertices: + if vertex in field_index_mappings: + mappings = field_index_mappings[vertex] + if mappings: + apply_mappings_to_transform( + mappings, vertex, transform_actor + ) + applied_any = True + + if not applied_any: + logger.debug( + f"Skipping TransformActor without target_vertex " + f"in resource '{resource.resource_name}': " + f"no mappings found for available vertices {all_available_vertices}" + ) + + # Recursively process nested structures (DescendActor) + if isinstance(wrapper.actor, DescendActor): + # Collect vertices from all descendants at this level + descendant_vertices = collect_vertices_at_level( + wrapper.actor.descendants + ) + all_available_vertices |= descendant_vertices + + # Recursively process each descendant + for descendant_wrapper in wrapper.actor.descendants: + nested_vertices = update_transform_actor_maps( + descendant_wrapper, parent_vertices=all_available_vertices + ) + # Merge nested vertices into available vertices + all_available_vertices |= nested_vertices + + return all_available_vertices + + # Process the root ActorWrapper if it exists + if hasattr(resource, "root") and resource.root is not None: + update_transform_actor_maps(resource.root) + else: + logger.warning( + f"Resource '{resource.resource_name}' does not have a root ActorWrapper. " + f"Skipping field index mapping updates." + ) diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index 47e4a7d..c8c1a59 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -2042,7 +2042,7 @@ def _define_schema_local(self, schema: Schema) -> None: # Vertices for vertex in vertex_config.vertices: # Validate vertex name - _validate_tigergraph_schema_name(vertex.name, "vertex") + _validate_tigergraph_schema_name(vertex.dbname, "vertex") stmt = self._get_vertex_add_statement(vertex, vertex_config) schema_change_stmts.append(stmt) diff --git a/test/db/postgres/test_schema_inference.py b/test/db/postgres/test_schema_inference.py index dae14d9..4e05b78 100644 --- a/test/db/postgres/test_schema_inference.py +++ b/test/db/postgres/test_schema_inference.py @@ -7,7 +7,7 @@ - Resource creation """ -from graflo.db.postgres import infer_schema_from_postgres +from graflo.db import infer_schema_from_postgres def test_infer_schema_from_postgres(postgres_conn, load_mock_schema): diff --git a/test/db/tigergraphs/test_reserved_words.py b/test/db/tigergraphs/test_reserved_words.py index 535d6dc..d8a58fc 100644 --- a/test/db/tigergraphs/test_reserved_words.py +++ b/test/db/tigergraphs/test_reserved_words.py @@ -20,9 +20,10 @@ import pytest -from graflo.db.postgres.schema_inference import PostgresSchemaInferencer from graflo.onto import DBFlavor from test.conftest import fetch_schema_obj +from graflo.db.sanitizer import SchemaSanitizer + logger = logging.getLogger(__name__) @@ -43,11 +44,9 @@ def test_vertex_name_sanitization_for_tigergraph(schema_with_reserved_words): """Test that vertex names with reserved words are sanitized for TigerGraph.""" schema = schema_with_reserved_words - # Create inferencer with TigerGraph flavor - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitizer = SchemaSanitizer(DBFlavor.TIGERGRAPH) - # Sanitize the schema - sanitized_schema = inferencer._sanitize_schema_attributes(schema) + sanitized_schema = sanitizer.sanitize(schema) vertex_dbnames = [v.dbname for v in sanitized_schema.vertex_config.vertices] assert "Package_vertex" in vertex_dbnames, ( @@ -62,11 +61,9 @@ def test_edges_sanitization_for_tigergraph(schema_with_incompatible_edges): """Test that vertex names with reserved words are sanitized for TigerGraph.""" schema = schema_with_incompatible_edges - # Create inferencer with TigerGraph flavor - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitizer = SchemaSanitizer(DBFlavor.TIGERGRAPH) - # Sanitize the schema - sanitized_schema = inferencer._sanitize_schema_attributes(schema) + sanitized_schema = sanitizer.sanitize(schema) # sanitized_schema.to_yaml_file( # os.path.join(