diff --git a/graflo/architecture/edge.py b/graflo/architecture/edge.py index c9161a2..c62046c 100644 --- a/graflo/architecture/edge.py +++ b/graflo/architecture/edge.py @@ -190,7 +190,12 @@ class Edge(BaseDataclass): # relation represents Class in neo4j, for arango it becomes a weight relation: str | None = None - # field that contains Class or relation + _relation_dbname: str | None = dataclasses.field( + default=None, + repr=False, + metadata={"dump": False}, + ) + relation_field: str | None = None relation_from_key: bool = False @@ -219,6 +224,14 @@ def __post_init__(self): self._source: str | None = None self._target: str | None = None + @property + def relation_dbname(self) -> str | None: + return self._relation_dbname or self.relation + + @relation_dbname.setter + def relation_dbname(self, value: str | None): + self._relation_dbname = value + def finish_init(self, vertex_config: VertexConfig): """Complete edge initialization with vertex configuration. @@ -251,6 +264,9 @@ def finish_init(self, vertex_config: VertexConfig): # Use default relation name for TigerGraph # TigerGraph requires all edges to have a named type (relation) self.relation = DEFAULT_TIGERGRAPH_RELATION + # Ensure dbname follows logical relation by default + if self.relation_dbname is None: + self.relation_dbname = self.relation # TigerGraph: add relation field to weights if relation_field or relation_from_key is set # This ensures the relation value is included as a typed property in the edge schema diff --git a/graflo/db/postgres/resource_mapping.py b/graflo/db/postgres/resource_mapping.py index 28ef9b7..0363de2 100644 --- a/graflo/db/postgres/resource_mapping.py +++ b/graflo/db/postgres/resource_mapping.py @@ -256,7 +256,6 @@ def map_tables_to_resources( edge_tables = introspection_result.edge_tables for edge_table_info in edge_tables: try: - # NB: use sanitizer sanitizer.relation_mappings resource = self.create_edge_resource( edge_table_info, vertex_config, match_cache ) diff --git a/graflo/db/sanitizer.py b/graflo/db/sanitizer.py index 18a31d6..8f93aa5 100644 --- a/graflo/db/sanitizer.py +++ b/graflo/db/sanitizer.py @@ -22,6 +22,9 @@ logger = logging.getLogger(__name__) +VERTEX_SUFFIX = "vertex" +RELATION_SUFFIX = "relation" + class SchemaSanitizer: """Sanitizes schema attributes to avoid reserved words and normalize indexes. @@ -43,8 +46,6 @@ def __init__(self, db_flavor: DBFlavor): self.reserved_words = load_reserved_words(db_flavor) self.attribute_mappings: dict[str, str] = {} self.vertex_mappings: dict[str, str] = {} - # Track relation name mappings - self.relation_mappings: dict[str, str] = {} def sanitize(self, schema: Schema) -> Schema: """Sanitize attribute names and vertex names in the schema to avoid reserved words. @@ -70,7 +71,7 @@ def sanitize(self, schema: Schema) -> Schema: # First pass: Sanitize vertex dbnames for vertex in schema.vertex_config.vertices: sanitized_vertex_name = sanitize_attribute_name( - vertex.dbname, self.reserved_words, suffix="_vertex" + vertex.dbname, self.reserved_words, suffix=f"_{VERTEX_SUFFIX}" ) if sanitized_vertex_name != vertex.dbname: logger.debug( @@ -115,49 +116,33 @@ def sanitize(self, schema: Schema) -> Schema: vertex_names = {vertex.dbname for vertex in schema.vertex_config.vertices} for edge in schema.edge_config.edges: - if edge.relation is None: + if not edge.relation: continue - original_relation = edge.relation - new_relation_name = original_relation + original = edge.relation_dbname - # First, sanitize for reserved words - sanitized_relation = sanitize_attribute_name( - original_relation, self.reserved_words, suffix="_relation" + # First pass: sanitize against reserved words + sanitized = sanitize_attribute_name( + original, + self.reserved_words, + suffix=f"_{RELATION_SUFFIX}", ) - if sanitized_relation != original_relation: - new_relation_name = sanitized_relation - logger.debug( - f"Sanitizing relation name '{original_relation}' -> '{sanitized_relation}' " - f"to avoid reserved word" - ) - # Then, check for collision with vertex names - if new_relation_name in vertex_names: - # Collision detected - rename relation - new_relation_name = f"{new_relation_name}_relation" - - # Ensure the new name doesn't collide either + # Second pass: avoid collision with vertex names + if sanitized in vertex_names: + base = f"{sanitized}_{RELATION_SUFFIX}" + candidate = base counter = 1 - while new_relation_name in vertex_names: - new_relation_name = f"{edge.relation}_relation{counter}" - counter += 1 - logger.debug( - f"Renaming relation '{sanitized_relation if sanitized_relation != original_relation else original_relation}' " - f"to '{new_relation_name}' to avoid collision with vertex name" - ) + while candidate in vertex_names: + candidate = f"{base}_{counter}" + counter += 1 - # Update the edge relation if it changed - if new_relation_name != original_relation: - self.relation_mappings[original_relation] = new_relation_name - edge.relation = new_relation_name + sanitized = candidate - if self.relation_mappings: - logger.info( - f"Renamed {len(self.relation_mappings)} relation(s) due to reserved words or vertex collisions: " - f"{self.relation_mappings}" - ) + # Update only if needed + if sanitized != original: + edge.relation_dbname = sanitized # Third pass: Normalize edge indexes for TigerGraph # TigerGraph requires that edges with the same relation have consistent source and target indexes @@ -175,7 +160,12 @@ def sanitize(self, schema: Schema) -> Schema: # Group edges by relation edges_by_relation: dict[str | None, list[Edge]] = {} for edge in schema.edge_config.edges: - relation = edge.relation + # Use sanitized dbname when grouping by relation for TigerGraph + relation = ( + edge.relation_dbname + if edge.relation_dbname is not None + else edge.relation + ) if relation not in edges_by_relation: edges_by_relation[relation] = [] edges_by_relation[relation].append(edge) diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index c8c1a59..6101d5b 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -1881,6 +1881,9 @@ def _get_edge_add_statement(self, edge: Edge) -> str: for field in edge.weights.direct: field_types[field.name] = self._get_tigergraph_type(field.type) + # Use sanitized dbname for schema names when available + relation_db = edge.relation_dbname + # Build FROM/TO line with discriminator from_to_parts = [ f" FROM {edge._source}", @@ -1897,11 +1900,11 @@ def _get_edge_add_statement(self, edge: Edge) -> str: discriminator_str = f"DISCRIMINATOR({', '.join(discriminator_parts)})" from_to_parts.append(f" {discriminator_str}") logger.info( - f"Added discriminator for edge {edge.relation}: {', '.join(discriminator_parts)}" + f"Added discriminator for edge {relation_db}: {', '.join(discriminator_parts)}" ) else: logger.debug( - f"No indexed fields found for edge {edge.relation}. " + f"No indexed fields found for edge {relation_db}. " f"Indexes: {[idx.fields for idx in edge.indexes]}, " f"relation_field: {edge.relation_field}" ) @@ -1914,7 +1917,7 @@ def _get_edge_add_statement(self, edge: Edge) -> str: # Has attributes - add comma after FROM/TO line (which may include discriminator) # edge_attrs already has proper indentation, so we just need to add it after a comma return ( - f"ADD DIRECTED EDGE {edge.relation} (\n" + f"ADD DIRECTED EDGE {relation_db} (\n" f"{from_to_line},\n" f"{edge_attrs}\n" f" )" @@ -1922,7 +1925,7 @@ def _get_edge_add_statement(self, edge: Edge) -> str: else: # No attributes - FROM/TO line (which may include discriminator) is the last thing # No trailing comma needed - return f"ADD DIRECTED EDGE {edge.relation} (\n{from_to_line}\n )" + return f"ADD DIRECTED EDGE {relation_db} (\n{from_to_line}\n )" def _get_edge_group_create_statement(self, edges: list[Edge]) -> str: """Generate ADD DIRECTED EDGE statement for a group of edges with the same relation. @@ -1942,7 +1945,7 @@ def _get_edge_group_create_statement(self, edges: list[Edge]) -> str: # Use the first edge to determine attributes and discriminator # (all edges of the same relation should have the same schema) first_edge = edges[0] - relation = first_edge.relation + relation = first_edge.relation_dbname # Collect indexed fields for discriminator (same logic as _get_edge_add_statement) indexed_field_names = set() @@ -2051,13 +2054,15 @@ def _define_schema_local(self, schema: Schema) -> None: edges_to_create = list(edge_config.edges_list(include_aux=True)) for edge in edges_to_create: edge.finish_init(vertex_config) - # Validate edge name - _validate_tigergraph_schema_name(edge.relation, "edge") + # Validate edge name using sanitized dbname when available + edge_dbname = edge.relation_dbname + _validate_tigergraph_schema_name(edge_dbname, "edge") # Group edges by relation edges_by_relation: dict[str, list[Edge]] = defaultdict(list) for edge in edges_to_create: - edges_by_relation[edge.relation].append(edge) + key = edge.relation_dbname + edges_by_relation[key].append(edge) # Create one statement per relation with all FROM/TO pairs for relation, edge_group in edges_by_relation.items(): @@ -2580,8 +2585,9 @@ def define_edge_indices(self, edges: list[Edge]): """ for edge in edges: if edge.indexes: + edge_db = edge.relation_dbname logger.info( - f"Skipping {len(edge.indexes)} index(es) on edge '{edge.relation}': " + f"Skipping {len(edge.indexes)} index(es) on edge '{edge_db}': " f"TigerGraph does not support indexes on edge attributes. " f"Only vertex indexes are supported." ) diff --git a/test/config/schema/tigergraph-sanitize-edges.yaml b/test/config/schema/tigergraph-sanitize-edges.yaml index b82ac7c..e01407c 100644 --- a/test/config/schema/tigergraph-sanitize-edges.yaml +++ b/test/config/schema/tigergraph-sanitize-edges.yaml @@ -50,4 +50,9 @@ edge_config: - source: container target: box relation: contains - + - source: container + target: container + relation: package + - source: container + target: container + relation: box diff --git a/test/db/tigergraphs/test_reserved_words.py b/test/db/tigergraphs/test_reserved_words.py index d8a58fc..da3b3f6 100644 --- a/test/db/tigergraphs/test_reserved_words.py +++ b/test/db/tigergraphs/test_reserved_words.py @@ -78,3 +78,5 @@ def test_edges_sanitization_for_tigergraph(schema_with_incompatible_edges): assert sanitized_schema.vertex_config.vertices[-1].fields[0].name == "id" assert sanitized_schema.vertex_config.vertices[-1].indexes[0].fields[0] == "id" + assert sanitized_schema.edge_config.edges[-2].relation_dbname == "package_relation" + assert sanitized_schema.edge_config.edges[-1].relation_dbname == "box_relation"