Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion graflo/architecture/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,12 @@ class Edge(BaseDataclass):

# relation represents Class in neo4j, for arango it becomes a weight
relation: str | None = None
# field that contains Class or relation
_relation_dbname: str | None = dataclasses.field(
default=None,
repr=False,
metadata={"dump": False},
)

relation_field: str | None = None
relation_from_key: bool = False

Expand Down Expand Up @@ -219,6 +224,14 @@ def __post_init__(self):
self._source: str | None = None
self._target: str | None = None

@property
def relation_dbname(self) -> str | None:
return self._relation_dbname or self.relation

@relation_dbname.setter
def relation_dbname(self, value: str | None):
self._relation_dbname = value

def finish_init(self, vertex_config: VertexConfig):
"""Complete edge initialization with vertex configuration.

Expand Down Expand Up @@ -251,6 +264,9 @@ def finish_init(self, vertex_config: VertexConfig):
# Use default relation name for TigerGraph
# TigerGraph requires all edges to have a named type (relation)
self.relation = DEFAULT_TIGERGRAPH_RELATION
# Ensure dbname follows logical relation by default
if self.relation_dbname is None:
self.relation_dbname = self.relation

# TigerGraph: add relation field to weights if relation_field or relation_from_key is set
# This ensures the relation value is included as a typed property in the edge schema
Expand Down
1 change: 0 additions & 1 deletion graflo/db/postgres/resource_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def map_tables_to_resources(
edge_tables = introspection_result.edge_tables
for edge_table_info in edge_tables:
try:
# NB: use sanitizer sanitizer.relation_mappings
resource = self.create_edge_resource(
edge_table_info, vertex_config, match_cache
)
Expand Down
66 changes: 28 additions & 38 deletions graflo/db/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

logger = logging.getLogger(__name__)

VERTEX_SUFFIX = "vertex"
RELATION_SUFFIX = "relation"


class SchemaSanitizer:
"""Sanitizes schema attributes to avoid reserved words and normalize indexes.
Expand All @@ -43,8 +46,6 @@ def __init__(self, db_flavor: DBFlavor):
self.reserved_words = load_reserved_words(db_flavor)
self.attribute_mappings: dict[str, str] = {}
self.vertex_mappings: dict[str, str] = {}
# Track relation name mappings
self.relation_mappings: dict[str, str] = {}

def sanitize(self, schema: Schema) -> Schema:
"""Sanitize attribute names and vertex names in the schema to avoid reserved words.
Expand All @@ -70,7 +71,7 @@ def sanitize(self, schema: Schema) -> Schema:
# First pass: Sanitize vertex dbnames
for vertex in schema.vertex_config.vertices:
sanitized_vertex_name = sanitize_attribute_name(
vertex.dbname, self.reserved_words, suffix="_vertex"
vertex.dbname, self.reserved_words, suffix=f"_{VERTEX_SUFFIX}"
)
if sanitized_vertex_name != vertex.dbname:
logger.debug(
Expand Down Expand Up @@ -115,49 +116,33 @@ def sanitize(self, schema: Schema) -> Schema:
vertex_names = {vertex.dbname for vertex in schema.vertex_config.vertices}

for edge in schema.edge_config.edges:
if edge.relation is None:
if not edge.relation:
continue

original_relation = edge.relation
new_relation_name = original_relation
original = edge.relation_dbname

# First, sanitize for reserved words
sanitized_relation = sanitize_attribute_name(
original_relation, self.reserved_words, suffix="_relation"
# First pass: sanitize against reserved words
sanitized = sanitize_attribute_name(
original,
self.reserved_words,
suffix=f"_{RELATION_SUFFIX}",
)
if sanitized_relation != original_relation:
new_relation_name = sanitized_relation
logger.debug(
f"Sanitizing relation name '{original_relation}' -> '{sanitized_relation}' "
f"to avoid reserved word"
)

# Then, check for collision with vertex names
if new_relation_name in vertex_names:
# Collision detected - rename relation
new_relation_name = f"{new_relation_name}_relation"

# Ensure the new name doesn't collide either
# Second pass: avoid collision with vertex names
if sanitized in vertex_names:
base = f"{sanitized}_{RELATION_SUFFIX}"
candidate = base
counter = 1
while new_relation_name in vertex_names:
new_relation_name = f"{edge.relation}_relation{counter}"
counter += 1

logger.debug(
f"Renaming relation '{sanitized_relation if sanitized_relation != original_relation else original_relation}' "
f"to '{new_relation_name}' to avoid collision with vertex name"
)
while candidate in vertex_names:
candidate = f"{base}_{counter}"
counter += 1

# Update the edge relation if it changed
if new_relation_name != original_relation:
self.relation_mappings[original_relation] = new_relation_name
edge.relation = new_relation_name
sanitized = candidate

if self.relation_mappings:
logger.info(
f"Renamed {len(self.relation_mappings)} relation(s) due to reserved words or vertex collisions: "
f"{self.relation_mappings}"
)
# Update only if needed
if sanitized != original:
edge.relation_dbname = sanitized

# Third pass: Normalize edge indexes for TigerGraph
# TigerGraph requires that edges with the same relation have consistent source and target indexes
Expand All @@ -175,7 +160,12 @@ def sanitize(self, schema: Schema) -> Schema:
# Group edges by relation
edges_by_relation: dict[str | None, list[Edge]] = {}
for edge in schema.edge_config.edges:
relation = edge.relation
# Use sanitized dbname when grouping by relation for TigerGraph
relation = (
edge.relation_dbname
if edge.relation_dbname is not None
else edge.relation
)
if relation not in edges_by_relation:
edges_by_relation[relation] = []
edges_by_relation[relation].append(edge)
Expand Down
24 changes: 15 additions & 9 deletions graflo/db/tigergraph/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,9 @@ def _get_edge_add_statement(self, edge: Edge) -> str:
for field in edge.weights.direct:
field_types[field.name] = self._get_tigergraph_type(field.type)

# Use sanitized dbname for schema names when available
relation_db = edge.relation_dbname

# Build FROM/TO line with discriminator
from_to_parts = [
f" FROM {edge._source}",
Expand All @@ -1897,11 +1900,11 @@ def _get_edge_add_statement(self, edge: Edge) -> str:
discriminator_str = f"DISCRIMINATOR({', '.join(discriminator_parts)})"
from_to_parts.append(f" {discriminator_str}")
logger.info(
f"Added discriminator for edge {edge.relation}: {', '.join(discriminator_parts)}"
f"Added discriminator for edge {relation_db}: {', '.join(discriminator_parts)}"
)
else:
logger.debug(
f"No indexed fields found for edge {edge.relation}. "
f"No indexed fields found for edge {relation_db}. "
f"Indexes: {[idx.fields for idx in edge.indexes]}, "
f"relation_field: {edge.relation_field}"
)
Expand All @@ -1914,15 +1917,15 @@ def _get_edge_add_statement(self, edge: Edge) -> str:
# Has attributes - add comma after FROM/TO line (which may include discriminator)
# edge_attrs already has proper indentation, so we just need to add it after a comma
return (
f"ADD DIRECTED EDGE {edge.relation} (\n"
f"ADD DIRECTED EDGE {relation_db} (\n"
f"{from_to_line},\n"
f"{edge_attrs}\n"
f" )"
)
else:
# No attributes - FROM/TO line (which may include discriminator) is the last thing
# No trailing comma needed
return f"ADD DIRECTED EDGE {edge.relation} (\n{from_to_line}\n )"
return f"ADD DIRECTED EDGE {relation_db} (\n{from_to_line}\n )"

def _get_edge_group_create_statement(self, edges: list[Edge]) -> str:
"""Generate ADD DIRECTED EDGE statement for a group of edges with the same relation.
Expand All @@ -1942,7 +1945,7 @@ def _get_edge_group_create_statement(self, edges: list[Edge]) -> str:
# Use the first edge to determine attributes and discriminator
# (all edges of the same relation should have the same schema)
first_edge = edges[0]
relation = first_edge.relation
relation = first_edge.relation_dbname

# Collect indexed fields for discriminator (same logic as _get_edge_add_statement)
indexed_field_names = set()
Expand Down Expand Up @@ -2051,13 +2054,15 @@ def _define_schema_local(self, schema: Schema) -> None:
edges_to_create = list(edge_config.edges_list(include_aux=True))
for edge in edges_to_create:
edge.finish_init(vertex_config)
# Validate edge name
_validate_tigergraph_schema_name(edge.relation, "edge")
# Validate edge name using sanitized dbname when available
edge_dbname = edge.relation_dbname
_validate_tigergraph_schema_name(edge_dbname, "edge")

# Group edges by relation
edges_by_relation: dict[str, list[Edge]] = defaultdict(list)
for edge in edges_to_create:
edges_by_relation[edge.relation].append(edge)
key = edge.relation_dbname
edges_by_relation[key].append(edge)

# Create one statement per relation with all FROM/TO pairs
for relation, edge_group in edges_by_relation.items():
Expand Down Expand Up @@ -2580,8 +2585,9 @@ def define_edge_indices(self, edges: list[Edge]):
"""
for edge in edges:
if edge.indexes:
edge_db = edge.relation_dbname
logger.info(
f"Skipping {len(edge.indexes)} index(es) on edge '{edge.relation}': "
f"Skipping {len(edge.indexes)} index(es) on edge '{edge_db}': "
f"TigerGraph does not support indexes on edge attributes. "
f"Only vertex indexes are supported."
)
Expand Down
7 changes: 6 additions & 1 deletion test/config/schema/tigergraph-sanitize-edges.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ edge_config:
- source: container
target: box
relation: contains

- source: container
target: container
relation: package
- source: container
target: container
relation: box
2 changes: 2 additions & 0 deletions test/db/tigergraphs/test_reserved_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,5 @@ def test_edges_sanitization_for_tigergraph(schema_with_incompatible_edges):

assert sanitized_schema.vertex_config.vertices[-1].fields[0].name == "id"
assert sanitized_schema.vertex_config.vertices[-1].indexes[0].fields[0] == "id"
assert sanitized_schema.edge_config.edges[-2].relation_dbname == "package_relation"
assert sanitized_schema.edge_config.edges[-1].relation_dbname == "box_relation"