diff --git a/graflo/architecture/transform.py b/graflo/architecture/transform.py index a10c3658..015f21f8 100644 --- a/graflo/architecture/transform.py +++ b/graflo/architecture/transform.py @@ -161,32 +161,68 @@ def __post_init__(self): ValueError: If transform configuration is invalid """ super().__post_init__() - self.functional_transform = False - if self._foo is not None: - self.functional_transform = True - - self.input = self._tuple_it(self.input) + self.functional_transform = self._foo is not None + # Normalize containers self.fields = self._tuple_it(self.fields) + self.input = self._tuple_it(self.input) + self.output = self._tuple_it(self.output) - self.input = self.fields if self.fields and not self.input else self.input + # Derive relationships between map, input, output, and fields. + self._init_input_from_fields() + self._init_io_from_map() + self._init_from_switch() + self._default_output_from_input() + self._init_map_from_io() + + self._validate_configuration() + + def _init_input_from_fields(self) -> None: + """Populate input from fields when provided.""" + if self.fields and not self.input: + self.input = self.fields + + def _init_io_from_map(self, force_init=False) -> None: + """Populate input/output tuples from an explicit map.""" + if not self.map: + return + if force_init or (not self.input and not self.output): + input_fields, output_fields = zip(*self.map.items()) + self.input = tuple(input_fields) + self.output = tuple(output_fields) + elif not self.input: + self.input = tuple(self.map.keys()) + elif not self.output: + self.output = tuple(self.map.values()) + + def _init_from_switch(self) -> None: + """Fallback initialization using switch definitions.""" + if self.switch and not self.input and not self.output: + self.input = tuple(self.switch) + # We rely on the first switch entry to infer the output shape. + first_key = self.input[0] + self.output = self._tuple_it(self.switch[first_key]) + + def _default_output_from_input(self) -> None: + """Ensure output mirrors input when not explicitly provided.""" if not self.output: self.output = self.input - self.output = self._tuple_it(self.output) - if not self.input and not self.output: - if self.map: - items = list(self.map.items()) - self.input = tuple(x for x, _ in items) - self.output = tuple(x for _, x in items) - elif self.switch: - self.input = tuple([k for k in self.switch]) - self.output = tuple(self.switch[self.input[0]]) - elif not self.name: - raise ValueError( - "Either input and output, fields, map or name should be" - " provided in Transform constructor." - ) + def _init_map_from_io(self) -> None: + """Derive map from input/output when possible.""" + if self.map or not self.input or not self.output: + return + if len(self.input) != len(self.output): + return + self.map = {src: dst for src, dst in zip(self.input, self.output)} + + def _validate_configuration(self) -> None: + """Validate that the transform has enough information to operate.""" + if not self.input and not self.output and not self.name: + raise ValueError( + "Either input/output, fields, map or name must be provided in Transform " + "constructor." + ) def __call__(self, *nargs, **kwargs): """Execute the transform. @@ -198,9 +234,7 @@ def __call__(self, *nargs, **kwargs): Returns: dict: Transformed data """ - is_mapping = self._foo is None - - if is_mapping: + if self.is_mapping: input_doc = nargs[0] if isinstance(input_doc, dict): output_values = [input_doc[k] for k in self.input] @@ -219,7 +253,12 @@ def __call__(self, *nargs, **kwargs): r = output_values return r - def _dress_as_dict(self, transform_result): + @property + def is_mapping(self) -> bool: + """True when the transform is pure mapping (no function).""" + return self._foo is None + + def _dress_as_dict(self, transform_result) -> dict[str, Any]: """Convert transform result to dictionary format. Args: @@ -238,7 +277,7 @@ def _dress_as_dict(self, transform_result): return upd @property - def is_dummy(self): + def is_dummy(self) -> bool: """Check if this is a dummy transform. Returns: @@ -246,7 +285,7 @@ def is_dummy(self): """ return (self.name is not None) and (not self.map and self._foo is None) - def update(self, t: Transform): + def update(self, t: Transform) -> Transform: """Update this transform with another transform's configuration. Args: diff --git a/graflo/db/postgres/schema_inference.py b/graflo/db/postgres/schema_inference.py index 1c1d5fe9..73b3ee42 100644 --- a/graflo/db/postgres/schema_inference.py +++ b/graflo/db/postgres/schema_inference.py @@ -7,8 +7,10 @@ from __future__ import annotations import logging +from collections import Counter from typing import TYPE_CHECKING + from graflo.architecture.edge import Edge, EdgeConfig, WeightConfig from graflo.architecture.onto import Index, IndexType from graflo.architecture.schema import Schema, SchemaMetadata @@ -355,47 +357,17 @@ def _sanitize_schema_attributes(self, schema: Schema) -> Schema: # Track name mappings for attributes (fields/weights) attribute_mappings: dict[str, str] = {} # Track name mappings for vertex names (separate from attributes) - vertex_mappings: dict[str, str] = {} - # First pass: Sanitize vertex names + # First pass: Sanitize vertex dbnames for vertex in schema.vertex_config.vertices: - original_vertex_name = vertex.name - if original_vertex_name not in vertex_mappings: - sanitized_vertex_name = sanitize_attribute_name( - original_vertex_name, self.reserved_words, suffix="_vertex" + sanitized_vertex_name = sanitize_attribute_name( + vertex.dbname, self.reserved_words, suffix="_vertex" + ) + if sanitized_vertex_name != vertex.dbname: + logger.debug( + f"Sanitizing vertex name '{vertex.dbname}' -> '{sanitized_vertex_name}'" ) - if sanitized_vertex_name != original_vertex_name: - vertex_mappings[original_vertex_name] = sanitized_vertex_name - logger.debug( - f"Sanitizing vertex name '{original_vertex_name}' -> '{sanitized_vertex_name}'" - ) - else: - vertex_mappings[original_vertex_name] = original_vertex_name - else: - sanitized_vertex_name = vertex_mappings[original_vertex_name] - - # Update vertex name if it changed - if sanitized_vertex_name != original_vertex_name: - vertex.name = sanitized_vertex_name - # Also update dbname if it matches the original name (default behavior) - if vertex.dbname == original_vertex_name or vertex.dbname is None: - vertex.dbname = sanitized_vertex_name - - # Rebuild VertexConfig's internal _vertices_map after renaming vertices - schema.vertex_config._vertices_map = { - vertex.name: vertex for vertex in schema.vertex_config.vertices - } - - # Update blank_vertices references if they were sanitized - schema.vertex_config.blank_vertices = [ - vertex_mappings.get(v, v) for v in schema.vertex_config.blank_vertices - ] - - # Update force_types keys if they were sanitized - schema.vertex_config.force_types = { - vertex_mappings.get(k, k): v - for k, v in schema.vertex_config.force_types.items() - } + vertex.dbname = sanitized_vertex_name # Second pass: Sanitize vertex field names for vertex in schema.vertex_config.vertices: @@ -430,163 +402,354 @@ def _sanitize_schema_attributes(self, schema: Schema) -> Schema: updated_fields.append(sanitized_field_name) index.fields = updated_fields - # Third pass: Update edge references to sanitized vertex names - for edge in schema.edge_config.edges: - # Update source vertex reference - if edge.source in vertex_mappings: - edge.source = vertex_mappings[edge.source] - logger.debug( - f"Updated edge source reference '{edge.source}' (sanitized vertex name)" - ) - - # Update target vertex reference - if edge.target in vertex_mappings: - edge.target = vertex_mappings[edge.target] - logger.debug( - f"Updated edge target reference '{edge.target}' (sanitized vertex name)" - ) + # Third pass: Normalize edge indexes for TigerGraph + # TigerGraph requires that edges with the same relation have consistent source and target indexes + # 1) group edges by relation + # 2) check that for each group specified by relation the sources have the same index + # and separately the targets have the same index + # 3) if this is not the case, identify the most popular index + # 4) for vertices that don't comply with the chose source/target index, we want to prepare a mapping + # and rename relevant fields indexes + field_index_mappings: dict[ + str, dict[str, str] + ] = {} # vertex_name -> {old_field: new_field} + + if schema.vertex_config.db_flavor == DBFlavor.TIGERGRAPH: + # Group edges by relation + edges_by_relation: dict[str | None, list[Edge]] = {} + for edge in schema.edge_config.edges: + relation = edge.relation + if relation not in edges_by_relation: + edges_by_relation[relation] = [] + edges_by_relation[relation].append(edge) + + # Process each relation group + for relation, relation_edges in edges_by_relation.items(): + if len(relation_edges) <= 1: + # Only one edge with this relation, no normalization needed + continue + + # Collect all vertex/index pairs using a list to capture all occurrences + # This handles cases where a vertex appears multiple times in edges for the same relation + source_vertex_indexes: list[tuple[str, tuple[str, ...]]] = [] + target_vertex_indexes: list[tuple[str, tuple[str, ...]]] = [] + + for edge in relation_edges: + source_vertex = edge.source + target_vertex = edge.target + + # Get primary index for source vertex + source_index = schema.vertex_config.index(source_vertex) + source_vertex_indexes.append( + (source_vertex, tuple(source_index.fields)) + ) - # Update 'by' vertex reference for indirect edges - # Note: edge.by might be a vertex name or a dbname (if finish_init was already called) - # We check both the direct mapping and reverse lookup via dbname - if edge.by is not None: - if edge.by in vertex_mappings: - # edge.by is a vertex name that needs sanitization - edge.by = vertex_mappings[edge.by] - logger.debug( - f"Updated edge 'by' reference to '{edge.by}' (sanitized vertex name)" + # Get primary index for target vertex + target_index = schema.vertex_config.index(target_vertex) + target_vertex_indexes.append( + (target_vertex, tuple(target_index.fields)) ) - else: - # edge.by might be a dbname - try to find the vertex that has this dbname - # and check if its name was sanitized - try: - vertex = schema.vertex_config._get_vertex_by_name_or_dbname( - edge.by - ) - vertex_name = vertex.name - if vertex_name in vertex_mappings: - # This vertex was sanitized, update edge.by to use sanitized name - # (finish_init will convert it back to dbname) - edge.by = vertex_mappings[vertex_name] - logger.debug( - f"Updated edge 'by' reference from dbname '{edge.by}' " - f"to sanitized vertex name '{vertex_mappings[vertex_name]}'" - ) - except (KeyError, AttributeError): - # edge.by is neither a vertex name nor a dbname we recognize - # This shouldn't happen in normal operation, but we'll skip it - pass - - # Sanitize edge weight field names - if edge.weights and edge.weights.direct: - for weight_field in edge.weights.direct: - original_name = weight_field.name - if original_name not in attribute_mappings: - sanitized_name = sanitize_attribute_name( - original_name, self.reserved_words - ) - if sanitized_name != original_name: - attribute_mappings[original_name] = sanitized_name - logger.debug( - f"Sanitizing weight field name '{original_name}' -> " - f"'{sanitized_name}' in edge '{edge.source}' -> '{edge.target}'" - ) - else: - attribute_mappings[original_name] = original_name - else: - sanitized_name = attribute_mappings[original_name] - # Update weight field name if it changed - if sanitized_name != original_name: - weight_field.name = sanitized_name + # Normalize source indexes + self._normalize_vertex_indexes( + source_vertex_indexes, + relation, + schema, + field_index_mappings, + "source", + ) - # Fourth pass: Re-initialize edges after vertex name sanitization - # This ensures edge._source, edge._target, and edge.by are correctly set - # with the sanitized vertex names - schema.edge_config.finish_init(schema.vertex_config) + # Normalize target indexes + self._normalize_vertex_indexes( + target_vertex_indexes, + relation, + schema, + field_index_mappings, + "target", + ) - # Fifth pass: Update resource apply lists that reference vertices - for resource in schema.resources: - self._sanitize_resource_vertex_references(resource, vertex_mappings) + # Fourth pass: the field maps from edge/relation normalization should be applied to resources: + # new transforms should be added mapping old index names to those identified in the previous step + if field_index_mappings: + for resource in schema.resources: + self._apply_field_index_mappings_to_resource( + resource, field_index_mappings + ) return schema - def _sanitize_resource_vertex_references( - self, resource: Resource, vertex_mappings: dict[str, str] + def _normalize_vertex_indexes( + self, + vertex_indexes: list[tuple[str, tuple[str, ...]]], + relation: str | None, + schema: Schema, + field_index_mappings: dict[str, dict[str, str]], + role: str, # "source" or "target" for logging ) -> None: - """Sanitize vertex name references in a resource's apply list. + """Normalize vertex indexes to use the most popular index pattern. - Resources can reference vertices in their apply list through: - - {"vertex": vertex_name} for VertexActor - - {"target_vertex": vertex_name, ...} for mapping actors - - {"source": vertex_name, "target": vertex_name} for EdgeActor - - Nested structures in tree_likes resources + For vertices that don't match the most popular index, this method: + 1. Creates field mappings (old_field -> new_field) + 2. Updates vertex indexes to match the most popular pattern + 3. Adds new fields to vertices if needed + 4. Removes old fields that are being replaced Args: - resource: The resource to sanitize - vertex_mappings: Dictionary mapping original vertex names to sanitized names + vertex_indexes: List of (vertex_name, index_fields_tuple) pairs + relation: Relation name for logging + schema: Schema to update + field_index_mappings: Dictionary to update with field mappings + role: "source" or "target" for logging purposes """ - if not hasattr(resource, "apply") or not resource.apply: + if not vertex_indexes: return - def sanitize_apply_item(item): - """Recursively sanitize vertex references in apply items.""" - if isinstance(item, dict): - # Handle vertex references in dictionaries - sanitized_item = {} - for key, value in item.items(): - if key == "vertex" and isinstance(value, str): - # {"vertex": vertex_name} - sanitized_item[key] = vertex_mappings.get(value, value) - if value != sanitized_item[key]: - logger.debug( - f"Updated resource '{resource.resource_name}' apply item: " - f"'{key}': '{value}' -> '{sanitized_item[key]}'" - ) - elif key == "target_vertex" and isinstance(value, str): - # {"target_vertex": vertex_name, ...} - sanitized_item[key] = vertex_mappings.get(value, value) - if value != sanitized_item[key]: - logger.debug( - f"Updated resource '{resource.resource_name}' apply item: " - f"'{key}': '{value}' -> '{sanitized_item[key]}'" + # Extract unique vertex/index pairs (a vertex might appear multiple times) + vertex_index_dict: dict[str, tuple[str, ...]] = {} + for vertex_name, index_fields in vertex_indexes: + # Only store first occurrence - we'll normalize all vertices together + if vertex_name not in vertex_index_dict: + vertex_index_dict[vertex_name] = index_fields + + # Check if all indexes are consistent + indexes_list = list(vertex_index_dict.values()) + indexes_set = set(indexes_list) + indexes_consistent = len(indexes_set) == 1 + + if indexes_consistent: + # All indexes are the same, no normalization needed + return + + # Find most popular index + index_counter = Counter(indexes_list) + most_popular_index = index_counter.most_common(1)[0][0] + + # Normalize vertices that don't match + for vertex_name, index_fields in vertex_index_dict.items(): + if index_fields == most_popular_index: + continue + + # Initialize mappings for this vertex if needed + if vertex_name not in field_index_mappings: + field_index_mappings[vertex_name] = {} + + # Map old fields to new fields + old_fields = list(index_fields) + new_fields = list(most_popular_index) + + # Create field-to-field mapping + # If lengths match, map positionally; otherwise map first field to first field + if len(old_fields) == len(new_fields): + for old_field, new_field in zip(old_fields, new_fields): + if old_field != new_field: + # Update existing mapping if it exists, otherwise create new one + field_index_mappings[vertex_name][old_field] = new_field + else: + # If lengths don't match, map the first field + if old_fields and new_fields: + if old_fields[0] != new_fields[0]: + field_index_mappings[vertex_name][old_fields[0]] = new_fields[0] + + # Update vertex index and fields + vertex = schema.vertex_config[vertex_name] + existing_field_names = {f.name for f in vertex.fields} + + # Add new fields that don't exist + for new_field in most_popular_index: + if new_field not in existing_field_names: + vertex.fields.append(Field(name=new_field, type=None)) + existing_field_names.add(new_field) + + # Remove old fields that are being replaced (not in new index) + fields_to_remove = [ + f + for f in vertex.fields + if f.name in old_fields and f.name not in new_fields + ] + for field_to_remove in fields_to_remove: + vertex.fields.remove(field_to_remove) + + # Update vertex index to match the most popular one + vertex.indexes[0].fields = list(most_popular_index) + + logger.debug( + f"Normalizing {role} index for vertex '{vertex_name}' in relation '{relation}': " + f"{old_fields} -> {new_fields}" + ) + + def _apply_field_index_mappings_to_resource( + self, resource: Resource, field_index_mappings: dict[str, dict[str, str]] + ) -> None: + """Apply field index mappings to TransformActor instances in a resource. + + For vertices that had their indexes normalized, this method updates TransformActor + instances to map old field names to new field names in their Transform.map attribute. + Only updates TransformActors where the vertex is confirmed to be created at that level + (via VertexActor). + + Args: + resource: The resource to update + field_index_mappings: Dictionary mapping vertex names to field mappings + (old_field -> new_field) + """ + from graflo.architecture.actor import ( + ActorWrapper, + DescendActor, + TransformActor, + VertexActor, + ) + + def collect_vertices_at_level(wrappers: list[ActorWrapper]) -> set[str]: + """Collect vertices created by VertexActor instances at the current level only. + + Does not recurse into nested structures - only collects vertices from + the immediate level. + + Args: + wrappers: List of ActorWrapper instances + + Returns: + set[str]: Set of vertex names created at this level + """ + vertices = set() + for wrapper in wrappers: + if isinstance(wrapper.actor, VertexActor): + vertices.add(wrapper.actor.name) + return vertices + + def update_transform_actor_maps( + wrapper: ActorWrapper, parent_vertices: set[str] | None = None + ) -> set[str]: + """Recursively update TransformActor instances with field index mappings. + + Args: + wrapper: ActorWrapper instance to process + parent_vertices: Set of vertices available from parent levels (for nested structures) + + Returns: + set[str]: Set of all vertices available at this level (including parent) + """ + if parent_vertices is None: + parent_vertices = set() + + # Collect vertices created at this level + current_level_vertices = set() + if isinstance(wrapper.actor, VertexActor): + current_level_vertices.add(wrapper.actor.name) + + # All available vertices = current level + parent levels + all_available_vertices = current_level_vertices | parent_vertices + + # Process TransformActor if present + if isinstance(wrapper.actor, TransformActor): + transform_actor: TransformActor = wrapper.actor + + def apply_mappings_to_transform( + mappings: dict[str, str], + vertex_name: str, + actor: TransformActor, + ) -> None: + """Apply field mappings to TransformActor's transform.map attribute. + + Args: + mappings: Dictionary mapping old field names to new field names + vertex_name: Name of the vertex these mappings belong to (for logging) + actor: The TransformActor instance to update + """ + transform = actor.t + if transform.map: + # Update existing map: replace values and keys that match old field names + # First, update values + for map_key, map_value in transform.map.items(): + if isinstance(map_value, str) and map_value in mappings: + transform.map[map_key] = mappings[map_value] + + # if the terminal attr not in the map - add it + for k, v in mappings.items(): + if v not in transform.map.values(): + transform.map[k] = v + else: + # Create new map with all mappings + transform.map = mappings.copy() + + # Update Transform object IO to reflect map edits + actor.t._init_io_from_map(force_init=True) + + logger.debug( + f"Updated TransformActor map in resource '{resource.resource_name}' " + f"for vertex '{vertex_name}': {mappings}" + ) + + target_vertex = transform_actor.vertex + + if isinstance(target_vertex, str): + # TransformActor has explicit target_vertex + if ( + target_vertex in field_index_mappings + and target_vertex in all_available_vertices + ): + mappings = field_index_mappings[target_vertex] + if mappings: + apply_mappings_to_transform( + mappings, target_vertex, transform_actor ) - elif key in ("source", "target") and isinstance(value, str): - # {"source": vertex_name, "target": vertex_name} for EdgeActor - sanitized_item[key] = vertex_mappings.get(value, value) - if value != sanitized_item[key]: + else: logger.debug( - f"Updated resource '{resource.resource_name}' apply item: " - f"'{key}': '{value}' -> '{sanitized_item[key]}'" + f"Skipping TransformActor for vertex '{target_vertex}' " + f"in resource '{resource.resource_name}': no mappings needed" ) - elif key == "name" and isinstance(value, str): - # Keep transform names as-is - sanitized_item[key] = value - elif key == "children" and isinstance(value, list): - # Recursively sanitize children in tree_likes resources - sanitized_item[key] = [ - sanitize_apply_item(child) for child in value - ] - elif isinstance(value, dict): - # Recursively sanitize nested dictionaries - sanitized_item[key] = sanitize_apply_item(value) - elif isinstance(value, list): - # Recursively sanitize lists - sanitized_item[key] = [ - sanitize_apply_item(subitem) for subitem in value - ] else: - sanitized_item[key] = value - return sanitized_item - elif isinstance(item, list): - # Recursively sanitize lists - return [sanitize_apply_item(subitem) for subitem in item] - else: - # Return non-dict/list items as-is - return item + logger.debug( + f"Skipping TransformActor for vertex '{target_vertex}' " + f"in resource '{resource.resource_name}': vertex not created at this level" + ) + else: + # TransformActor has no target_vertex + # Apply mappings from all available vertices (parent and current level) + # since transformed fields will be attributed to those vertices + applied_any = False + for vertex in all_available_vertices: + if vertex in field_index_mappings: + mappings = field_index_mappings[vertex] + if mappings: + apply_mappings_to_transform( + mappings, vertex, transform_actor + ) + applied_any = True + + if not applied_any: + logger.debug( + f"Skipping TransformActor without target_vertex " + f"in resource '{resource.resource_name}': " + f"no mappings found for available vertices {all_available_vertices}" + ) - # Sanitize the entire apply list - resource.apply = [sanitize_apply_item(item) for item in resource.apply] + # Recursively process nested structures (DescendActor) + if isinstance(wrapper.actor, DescendActor): + # Collect vertices from all descendants at this level + descendant_vertices = collect_vertices_at_level( + wrapper.actor.descendants + ) + all_available_vertices |= descendant_vertices + + # Recursively process each descendant + for descendant_wrapper in wrapper.actor.descendants: + nested_vertices = update_transform_actor_maps( + descendant_wrapper, parent_vertices=all_available_vertices + ) + # Merge nested vertices into available vertices + all_available_vertices |= nested_vertices + + return all_available_vertices + + # Process the root ActorWrapper if it exists + if hasattr(resource, "root") and resource.root is not None: + update_transform_actor_maps(resource.root) + else: + logger.warning( + f"Resource '{resource.resource_name}' does not have a root ActorWrapper. " + f"Skipping field index mapping updates." + ) def infer_schema( self, diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index 15d06f7a..47e4a7d9 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -26,6 +26,8 @@ import contextlib import json import logging +import re +from collections import defaultdict from pathlib import Path from typing import Any @@ -863,29 +865,71 @@ def _get_vertex_types(self, graph_name: str | None = None) -> list[str]: logger.debug(f"Failed to get vertex types via GSQL: {e}") return [] - def _get_edge_types(self, graph_name: str | None = None) -> list[str]: + def _parse_show_edge_output_with_vertices( + self, output: str + ) -> dict[str, list[tuple[str, str]]]: """ - Get list of edge types using GSQL. + Parse SHOW EDGE * output (compact TigerGraph format). + + Returns: + dict mapping edge_name -> list of (source_vertex, target_vertex) + """ + edge_map: dict[str, list[tuple[str, str]]] = defaultdict(list) + + # Match lines like: + # - DIRECTED EDGE contains(FROM Author, TO ResearchField|FROM ResearchField, TO ResearchField) + edge_line_pattern = re.compile( + r"-\s+(?:DIRECTED|UNDIRECTED)\s+EDGE\s+(\w+)\(([^)]+)\)" + ) + + # Match FROM X, TO Y + from_to_pattern = re.compile(r"FROM\s+(\w+)\s*,\s*TO\s+(\w+)") + + for line in output.splitlines(): + line = line.strip() + if not line.startswith("-"): + continue + + edge_match = edge_line_pattern.search(line) + if not edge_match: + continue + + edge_name = edge_match.group(1) + endpoints_blob = edge_match.group(2) + + # Split multiple vertex pairs + for endpoint in endpoints_blob.split("|"): + ft_match = from_to_pattern.search(endpoint) + if ft_match: + source, target = ft_match.groups() + edge_map[edge_name].append((source, target)) + + return dict(edge_map) + + def _get_edge_types( + self, graph_name: str | None = None + ) -> dict[str, list[tuple[str, str]]]: + """ + Get edge types and their (source, target) vertex pairs using GSQL. Args: graph_name: Name of the graph (defaults to self.graphname) Returns: - List of edge type names + Dict mapping edge_type -> list of (source_vertex, target_vertex) """ graph_name = graph_name or self.graphname try: result = self._execute_gsql(f"USE GRAPH {graph_name}\nSHOW EDGE *") - # Parse GSQL output using the proper parser + if isinstance(result, str): - # _parse_show_edge_output returns list of tuples (edge_name, is_directed) - # Extract just the edge names - edge_tuples = self._parse_show_edge_output(result) - return [edge_name for edge_name, _ in edge_tuples] - return [] + return self._parse_show_edge_output_with_vertices(result) + + return {} + except Exception as e: - logger.debug(f"Failed to get edge types via GSQL: {e}") - return [] + logger.error(f"Failed to get edge types via GSQL: {e}") + return {} def _get_installed_queries(self, graph_name: str | None = None) -> list[str]: """ @@ -1880,6 +1924,102 @@ def _get_edge_add_statement(self, edge: Edge) -> str: # No trailing comma needed return f"ADD DIRECTED EDGE {edge.relation} (\n{from_to_line}\n )" + def _get_edge_group_create_statement(self, edges: list[Edge]) -> str: + """Generate ADD DIRECTED EDGE statement for a group of edges with the same relation. + + TigerGraph requires edges of the same type to be created in a single statement + with multiple FROM/TO pairs separated by |. + + Args: + edges: List of Edge objects with the same relation (edge type) + + Returns: + str: GSQL ADD DIRECTED EDGE statement with multiple FROM/TO pairs + """ + if not edges: + raise ValueError("Cannot create edge statement from empty edge list") + + # Use the first edge to determine attributes and discriminator + # (all edges of the same relation should have the same schema) + first_edge = edges[0] + relation = first_edge.relation + + # Collect indexed fields for discriminator (same logic as _get_edge_add_statement) + indexed_field_names = set() + for index in first_edge.indexes: + for field_name in index.fields: + if field_name not in ["_from", "_to"]: + indexed_field_names.add(field_name) + + if ( + first_edge.relation_field + and first_edge.relation_field not in indexed_field_names + ): + indexed_field_names.add(first_edge.relation_field) + + # Ensure indexed fields are in weights (same logic as _get_edge_add_statement) + if first_edge.weights is None: + from graflo.architecture.edge import WeightConfig + + first_edge.weights = WeightConfig() + + assert first_edge.weights is not None, "weights should be initialized" + existing_weight_names = set() + if first_edge.weights.direct: + existing_weight_names = {field.name for field in first_edge.weights.direct} + + for field_name in indexed_field_names: + if field_name not in existing_weight_names: + from graflo.architecture.edge import Field + + first_edge.weights.direct.append( + Field(name=field_name, type=FieldType.STRING) + ) + + # Format edge attributes, excluding discriminator fields + edge_attrs = self._format_edge_attributes( + first_edge, exclude_fields=indexed_field_names + ) + + # Get field types for discriminator fields + field_types = {} + if first_edge.weights and first_edge.weights.direct: + for field in first_edge.weights.direct: + field_types[field.name] = self._get_tigergraph_type(field.type) + + # Build FROM/TO pairs for all edges, separated by | + from_to_lines = [] + for edge in edges: + # Build FROM/TO line: "FROM A, TO B" or "FROM A, TO B, DISCRIMINATOR(...)" + from_to_parts = [f"FROM {edge._source}", f"TO {edge._target}"] + + # Add discriminator if needed (same for all edges of the same relation) + if indexed_field_names: + discriminator_parts = [] + for field_name in sorted(indexed_field_names): + field_type = field_types.get(field_name, "STRING") + discriminator_parts.append(f"{field_name} {field_type}") + + discriminator_str = f"DISCRIMINATOR({', '.join(discriminator_parts)})" + from_to_parts.append(discriminator_str) + + # Combine FROM/TO and discriminator with commas on one line + from_to_line = ", ".join(from_to_parts) + from_to_lines.append(f" {from_to_line}") + + # Join all FROM/TO pairs with | + all_from_to = " |\n".join(from_to_lines) + + # Build the complete statement + if edge_attrs: + # Has attributes - add comma after FROM/TO section + return ( + f"ADD DIRECTED EDGE {relation} (\n{all_from_to},\n{edge_attrs}\n )" + ) + else: + # No attributes - FROM/TO section is the last thing + return f"ADD DIRECTED EDGE {relation} (\n{all_from_to}\n )" + @_wrap_tg_exception def _define_schema_local(self, schema: Schema) -> None: """Define TigerGraph schema locally for the current graph using a SCHEMA_CHANGE job. @@ -1906,13 +2046,22 @@ def _define_schema_local(self, schema: Schema) -> None: stmt = self._get_vertex_add_statement(vertex, vertex_config) schema_change_stmts.append(stmt) - # Edges + # Edges - group by relation since TigerGraph requires edges of the same type + # to be created in a single statement with multiple FROM/TO pairs edges_to_create = list(edge_config.edges_list(include_aux=True)) for edge in edges_to_create: edge.finish_init(vertex_config) # Validate edge name _validate_tigergraph_schema_name(edge.relation, "edge") - stmt = self._get_edge_add_statement(edge) + + # Group edges by relation + edges_by_relation: dict[str, list[Edge]] = defaultdict(list) + for edge in edges_to_create: + edges_by_relation[edge.relation].append(edge) + + # Create one statement per relation with all FROM/TO pairs + for relation, edge_group in edges_by_relation.items(): + stmt = self._get_edge_group_create_statement(edge_group) schema_change_stmts.append(stmt) if not schema_change_stmts: diff --git a/graflo/db/tigergraph/reserved_words.json b/graflo/db/tigergraph/reserved_words.json index efdb56ff..e4f27041 100644 --- a/graflo/db/tigergraph/reserved_words.json +++ b/graflo/db/tigergraph/reserved_words.json @@ -211,6 +211,7 @@ "OPERATOR", "OR", "OR_EQ", + "PACKAGE", "PRIVATE", "PROTECTED", "PUBLIC", diff --git a/test/config/schema/lake_odds.yaml b/test/config/schema/lake_odds.yaml deleted file mode 100644 index f24646ca..00000000 --- a/test/config/schema/lake_odds.yaml +++ /dev/null @@ -1,18 +0,0 @@ -general: - name: lake_odds -json: - ~maps: - - ~type: dict - ~name: chunk -vertex_collections: - collections: - chunk: - dbname: chunks - fields: - - kind - - data - - fetched_time - indexes: - - fields: - - kind - - fetched_time diff --git a/test/config/schema/review-tigergraph-edges.yaml b/test/config/schema/review-tigergraph-edges.yaml new file mode 100644 index 00000000..e9d1dd9e --- /dev/null +++ b/test/config/schema/review-tigergraph-edges.yaml @@ -0,0 +1,43 @@ +general: + name: review +resources: +- resource_name: authors + apply: + - vertex: author + - vertex: researchField + - map: + author_id: id + FullName: full_name + HIndex: hindex + - target_vertex: researchField + map: + research_sector: id +vertex_config: + db_flavor: tigergraph + vertices: + - name: author + dbname: Author + fields: + - id + - full_name + - hindex + indexes: + - fields: + - id + - name: researchField + dbname: ResearchField + fields: + - id + - name + - level + indexes: + - fields: + - id +edge_config: + edges: + - source: author + target: researchField + relation: contains + - source: researchField + target: researchField + relation: contains diff --git a/test/config/schema/review-tigergraph.yaml b/test/config/schema/review-tigergraph.yaml index ce5398e3..ff1471d6 100644 --- a/test/config/schema/review-tigergraph.yaml +++ b/test/config/schema/review-tigergraph.yaml @@ -13,7 +13,7 @@ resources: map: research_sector: id vertex_config: - db_flavor: neo4j + db_flavor: tigergraph vertices: - name: author dbname: Author diff --git a/test/config/schema/tigergraph-sanitize-edges.yaml b/test/config/schema/tigergraph-sanitize-edges.yaml new file mode 100644 index 00000000..b82ac7ce --- /dev/null +++ b/test/config/schema/tigergraph-sanitize-edges.yaml @@ -0,0 +1,53 @@ +general: + name: tigergraph-sanitize-edges +resources: +- resource_name: parcels + apply: + - vertex: parcel + - map: + parcel_id: id +- resource_name: boxes + apply: + - vertex: box + - map: + box_id: id +- resource_name: containers + apply: + - vertex: container + - map: + container_name: name +vertex_config: + db_flavor: tigergraph + vertices: + - name: parcel + fields: + - name: id + type: INT + indexes: + - fields: + - id + - name: box + fields: + - name: id + type: INT + indexes: + - fields: + - id + - name: container + fields: + - name + indexes: + - fields: + - name +edge_config: + edges: + - source: box + target: parcel + relation: contains + - source: box + target: box + relation: contains + - source: container + target: box + relation: contains + diff --git a/test/config/schema/tigergraph-sanitize.yaml b/test/config/schema/tigergraph-sanitize.yaml new file mode 100644 index 00000000..78c5dec7 --- /dev/null +++ b/test/config/schema/tigergraph-sanitize.yaml @@ -0,0 +1,40 @@ +general: + name: tigergraph-sanitize +resources: +- resource_name: package + apply: + - vertex: package +vertex_config: + db_flavor: tigergraph + vertices: + - name: package + dbname: Package + fields: + - name: id + type: INT + - name: SELECT + type: STRING + - name: FROM + type: STRING + - name: WHERE + type: STRING + - name: name + type: STRING + indexes: + - fields: + - id + - name: users + dbname: Users + fields: + - name: id + type: INT + - name: name + type: STRING + indexes: + - fields: + - id +edge_config: + edges: + - source: package + target: users + diff --git a/test/db/arangos/conftest.py b/test/db/arangos/conftest.py index 1c7daa6e..205e192d 100644 --- a/test/db/arangos/conftest.py +++ b/test/db/arangos/conftest.py @@ -5,8 +5,6 @@ from graflo.db import ConnectionManager from graflo.db.connection.onto import ArangoConfig -from graflo.filter.onto import ComparisonOperator -from graflo.onto import AggregationType from test.conftest import fetch_schema_obj @@ -98,72 +96,3 @@ def ingest(create_db, modes, conn_conf, current_path, test_db_name, reset, n_cor mode=m, reset=reset, ) - if m == "lake_odds": - conn_conf.database = test_db_name - with ConnectionManager(connection_config=conn_conf) as db_client: - r = db_client.fetch_docs("chunks") - assert len(r) == 2 - assert r[0]["data"] - r = db_client.fetch_docs("chunks", filters=["==", "odds", "kind"]) - assert len(r) == 1 - r = db_client.fetch_docs("chunks", limit=1) - assert len(r) == 1 - r = db_client.fetch_docs( - "chunks", - filters=["==", "odds", "kind"], - return_keys=["kind"], - ) - assert len(r[0]) == 1 - batch = [{"kind": "odds"}, {"kind": "strange"}] - with ConnectionManager(connection_config=conn_conf) as db_client: - r = db_client.fetch_present_documents( - batch, - "chunks", - match_keys=("kind",), - keep_keys=("_key",), - flatten=False, - ) - assert len(r) == 1 - - batch = [{"kind": "odds"}, {"kind": "scores"}, {"kind": "strange"}] - with ConnectionManager(connection_config=conn_conf) as db_client: - r = db_client.fetch_present_documents( - batch, - "chunks", - match_keys=("kind",), - keep_keys=("_key",), - flatten=False, - filters=[ComparisonOperator.NEQ, "odds", "kind"], - ) - assert len(r) == 1 - - with ConnectionManager(connection_config=conn_conf) as db_client: - r = db_client.keep_absent_documents( - batch, - "chunks", - match_keys=("kind",), - keep_keys=("_key",), - filters=[ComparisonOperator.EQ, None, "data"], - ) - assert len(r) == 3 - - with ConnectionManager(connection_config=conn_conf) as db_client: - r = db_client.aggregate( - "chunks", - aggregation_function=AggregationType.COUNT, - discriminant="kind", - ) - assert len(r) == 2 - assert r == [ - {"kind": "odds", "_value": 1}, - {"kind": "scores", "_value": 1}, - ] - - with ConnectionManager(connection_config=conn_conf) as db_client: - r = db_client.aggregate( - "chunks", - aggregation_function=AggregationType.COUNT, - discriminant="kind", - filters=[ComparisonOperator.NEQ, "odds", "kind"], - ) - assert len(r) == 1 diff --git a/test/db/arangos/test_db_index.py b/test/db/arangos/test_db_index.py index da66d671..0041fd1a 100644 --- a/test/db/arangos/test_db_index.py +++ b/test/db/arangos/test_db_index.py @@ -10,10 +10,6 @@ def modes(): return [ "kg", "ibes", - # "wos_json", - # "lake_odds", - # "wos_csv", - # "ticker", ] diff --git a/test/db/arangos/test_ingest_parallel.py b/test/db/arangos/test_ingest_parallel.py index bd2dc66a..5d25bbde 100644 --- a/test/db/arangos/test_ingest_parallel.py +++ b/test/db/arangos/test_ingest_parallel.py @@ -8,10 +8,6 @@ def modes(): return [ "kg", "ibes", - # "wos_json", - # "lake_odds", - # "wos_csv", - # "ticker", ] diff --git a/test/db/tigergraphs/test_db_creation.py b/test/db/tigergraphs/test_db_creation.py index 79b65f05..a415d75c 100644 --- a/test/db/tigergraphs/test_db_creation.py +++ b/test/db/tigergraphs/test_db_creation.py @@ -101,3 +101,44 @@ def test_schema_creation(conn_conf, test_graph_name, schema_obj): print(f"Created vertex types: {vertex_types}") print(f"Created edge types: {edge_types}") + + +def test_schema_creation_edges(conn_conf, test_graph_name, schema_obj): + """Test creating schema using init_db (follows ArangoDB pattern). + + Pattern: init_db creates graph, defines schema, then defines indexes. + Uses schema.general.name as the graph name (from test_graph fixture). + + Note: In TigerGraph, vertex and edge types are global and shared between graphs. + The test verifies that types are created and associated with the test graph. + The test verifies that types are created and associated with the test graph. + """ + schema_obj = schema_obj("review-tigergraph-edges") + # Set graph name in schema.general.name; conn_conf.database is set by fixture + schema_obj.general.name = test_graph_name + + with ConnectionManager(connection_config=conn_conf) as db_client: + # init_db will: create graph, define schema, define indexes + # Graph name comes from schema.general.name + db_client.init_db(schema_obj, clean_start=True) + + with ConnectionManager(connection_config=conn_conf) as db_client: + # Verify graph exists (using name from schema.general.name) + assert db_client.graph_exists(test_graph_name) + + # Use the graph context to verify schema + # getVertexTypes() and getEdgeTypes() require graph context via _ensure_graph_context + with db_client._ensure_graph_context(test_graph_name): + # Verify schema was created + vertex_types = db_client._get_vertex_types() + edge_types = db_client._get_edge_types() + + # Check expected types exist + assert len(vertex_types) > 0, "No vertex types created" + assert len(edge_types) == 1, "No edge types created" + assert len(edge_types["contains"]) == 2, ( + "Should have to edges for relation `contains`" + ) + + print(f"Created vertex types: {vertex_types}") + print(f"Created edge types: {edge_types}") diff --git a/test/db/tigergraphs/test_reserved_words.py b/test/db/tigergraphs/test_reserved_words.py index 12844c59..535d6dc7 100644 --- a/test/db/tigergraphs/test_reserved_words.py +++ b/test/db/tigergraphs/test_reserved_words.py @@ -20,71 +20,23 @@ import pytest -from graflo.architecture.edge import Edge, EdgeConfig, WeightConfig -from graflo.architecture.resource import Resource -from graflo.architecture.schema import Schema, SchemaMetadata -from graflo.architecture.vertex import Field, FieldType, Vertex, VertexConfig from graflo.db.postgres.schema_inference import PostgresSchemaInferencer from graflo.onto import DBFlavor +from test.conftest import fetch_schema_obj logger = logging.getLogger(__name__) @pytest.fixture def schema_with_reserved_words(): - """Create a Schema object with reserved words for testing.""" - # Create vertices with reserved words - version_vertex = Vertex( - name="version", # "version" is a TigerGraph reserved word - fields=[ - Field(name="id", type=FieldType.INT), - Field(name="SELECT", type=FieldType.STRING), # Reserved word - Field(name="FROM", type=FieldType.STRING), # Reserved word - Field(name="WHERE", type=FieldType.STRING), # Reserved word - Field(name="name", type=FieldType.STRING), - ], - ) - - users_vertex = Vertex( - name="users", - fields=[ - Field(name="id", type=FieldType.INT), - Field(name="name", type=FieldType.STRING), - ], - ) - - vertex_config = VertexConfig( - vertices=[version_vertex, users_vertex], db_flavor=DBFlavor.TIGERGRAPH - ) - - # Create edge with reserved word attributes - version_users_edge = Edge( - source="version", # Will be sanitized - target="users", - weights=WeightConfig( - direct=[ - Field(name="SELECT", type=FieldType.STRING), # Reserved word - Field(name="FROM", type=FieldType.STRING), # Reserved word - ] - ), - ) - - edge_config = EdgeConfig(edges=[version_users_edge]) + schema_o = fetch_schema_obj("tigergraph-sanitize") + return schema_o - # Create resource with vertex reference - version_resource = Resource( - resource_name="version", - apply=[{"vertex": "version"}], # Will be sanitized - ) - - schema = Schema( - general=SchemaMetadata(name="test_schema"), - vertex_config=vertex_config, - edge_config=edge_config, - resources=[version_resource], - ) - return schema +@pytest.fixture +def schema_with_incompatible_edges(): + schema_o = fetch_schema_obj("tigergraph-sanitize-edges") + return schema_o def test_vertex_name_sanitization_for_tigergraph(schema_with_reserved_words): @@ -97,350 +49,35 @@ def test_vertex_name_sanitization_for_tigergraph(schema_with_reserved_words): # Sanitize the schema sanitized_schema = inferencer._sanitize_schema_attributes(schema) - # Check that "version" vertex was sanitized to "version_vertex" - vertex_names = [v.name for v in sanitized_schema.vertex_config.vertices] - assert "version_vertex" in vertex_names, ( - f"Expected 'version_vertex' in vertices after sanitization, got {vertex_names}" - ) - assert "version" not in vertex_names, ( - f"Original reserved word 'version' should not be in vertices, got {vertex_names}" - ) - - # Verify the sanitized vertex exists and has correct fields - version_vertex = next( - v for v in sanitized_schema.vertex_config.vertices if v.name == "version_vertex" - ) - assert version_vertex is not None, "version_vertex should exist" - - # Check that attribute names were also sanitized - field_names = [f.name for f in version_vertex.fields] - assert "SELECT_attr" in field_names, ( - f"Expected 'SELECT_attr' in fields after sanitization, got {field_names}" - ) - assert "FROM_attr" in field_names, ( - f"Expected 'FROM_attr' in fields after sanitization, got {field_names}" - ) - assert "WHERE_attr" in field_names, ( - f"Expected 'WHERE_attr' in fields after sanitization, got {field_names}" - ) - assert "SELECT" not in field_names, ( - f"Original reserved word 'SELECT' should not be in fields, got {field_names}" - ) - - -def test_edge_references_updated_after_vertex_sanitization(schema_with_reserved_words): - """Test that edge source/target references are updated when vertex names are sanitized.""" - schema = schema_with_reserved_words - - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) - sanitized_schema = inferencer._sanitize_schema_attributes(schema) - - # Find edge connecting version_vertex to users - edges = list(sanitized_schema.edge_config.edges_list()) - version_users_edge = next( - ( - e - for e in edges - if "version_vertex" in (e.source, e.target) - and "users" in (e.source, e.target) - ), - None, - ) - - assert version_users_edge is not None, ( - "Edge between version_vertex and users should exist" - ) - assert ( - version_users_edge.source == "version_vertex" - or version_users_edge.target == "version_vertex" - ), ( - f"Edge should reference 'version_vertex', got source={version_users_edge.source}, " - f"target={version_users_edge.target}" - ) - assert "version" not in ( - version_users_edge.source, - version_users_edge.target, - ), ( - f"Edge should not reference original 'version' name, " - f"got source={version_users_edge.source}, target={version_users_edge.target}" - ) - - # Check that edge weight attributes were sanitized - assert version_users_edge.weights is not None, "Edge should have weights" - assert version_users_edge.weights.direct is not None, ( - "Edge should have direct weights" - ) - weight_field_names = [f.name for f in version_users_edge.weights.direct] - assert "SELECT_attr" in weight_field_names, ( - f"Expected 'SELECT_attr' in weight fields, got {weight_field_names}" - ) - assert "FROM_attr" in weight_field_names, ( - f"Expected 'FROM_attr' in weight fields, got {weight_field_names}" - ) - - -def test_resource_apply_lists_updated_after_vertex_sanitization( - schema_with_reserved_words, -): - """Test that resource apply lists reference sanitized vertex names.""" - schema = schema_with_reserved_words - - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) - sanitized_schema = inferencer._sanitize_schema_attributes(schema) - - # Find resource for version table - version_resource = next( - (r for r in sanitized_schema.resources if r.resource_name == "version"), None - ) - assert version_resource is not None, "version resource should exist" - - # Check that apply list references sanitized vertex name - apply_str = str(version_resource.apply) - assert "version_vertex" in apply_str, ( - f"Resource apply list should reference 'version_vertex', got {apply_str}" - ) - # Check the actual apply item - assert len(version_resource.apply) > 0, "Resource should have apply items" - apply_item = version_resource.apply[0] - assert isinstance(apply_item, dict), "Apply item should be a dict" - assert apply_item.get("vertex") == "version_vertex", ( - f"Apply item should reference 'version_vertex', got {apply_item}" - ) - - -def test_arango_no_sanitization(schema_with_reserved_words): - """Test that ArangoDB flavor does not sanitize names (no reserved words).""" - schema = schema_with_reserved_words - - # Change schema to ArangoDB flavor - schema.vertex_config.db_flavor = DBFlavor.ARANGO - schema.edge_config.db_flavor = DBFlavor.ARANGO - - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.ARANGO) - sanitized_schema = inferencer._sanitize_schema_attributes(schema) - - # Check that "version" vertex name is NOT sanitized for ArangoDB - vertex_names = [v.name for v in sanitized_schema.vertex_config.vertices] - assert "version" in vertex_names, ( - f"Expected 'version' in vertices for ArangoDB (no sanitization), got {vertex_names}" - ) - assert "version_vertex" not in vertex_names, ( - f"Should not have 'version_vertex' for ArangoDB, got {vertex_names}" - ) - - # Check that attribute names are NOT sanitized for ArangoDB - version_vertex = next( - v for v in sanitized_schema.vertex_config.vertices if v.name == "version" - ) - field_names = [f.name for f in version_vertex.fields] - assert "SELECT" in field_names, ( - f"Expected 'SELECT' in fields for ArangoDB (no sanitization), got {field_names}" - ) - assert "SELECT_attr" not in field_names, ( - f"Should not have 'SELECT_attr' for ArangoDB, got {field_names}" - ) - - -def test_multiple_reserved_words_sanitization(schema_with_reserved_words): - """Test that multiple reserved words are all sanitized correctly.""" - schema = schema_with_reserved_words - - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) - sanitized_schema = inferencer._sanitize_schema_attributes(schema) - - # Check vertex name sanitization - vertex_names = [v.name for v in sanitized_schema.vertex_config.vertices] - assert "version_vertex" in vertex_names, ( - f"Expected 'version_vertex' after sanitization, got {vertex_names}" - ) - - # Check attribute name sanitization - version_vertex = next( - v for v in sanitized_schema.vertex_config.vertices if v.name == "version_vertex" - ) - field_names = [f.name for f in version_vertex.fields] - - # All reserved word attributes should be sanitized - reserved_attrs = ["SELECT", "FROM", "WHERE"] - sanitized_attrs = ["SELECT_attr", "FROM_attr", "WHERE_attr"] - - for reserved, sanitized in zip(reserved_attrs, sanitized_attrs): - assert sanitized in field_names, ( - f"Expected '{sanitized}' in fields after sanitization, got {field_names}" - ) - assert reserved not in field_names, ( - f"Original reserved word '{reserved}' should not be in fields, got {field_names}" - ) - - -def test_vertex_config_internal_mappings_updated(schema_with_reserved_words): - """Test that VertexConfig internal mappings are updated after sanitization.""" - schema = schema_with_reserved_words - - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) - sanitized_schema = inferencer._sanitize_schema_attributes(schema) - - # Check that _vertices_map uses sanitized names - assert "version_vertex" in sanitized_schema.vertex_config._vertices_map, ( - "VertexConfig._vertices_map should contain sanitized vertex name" - ) - assert "version" not in sanitized_schema.vertex_config._vertices_map, ( - "VertexConfig._vertices_map should not contain original reserved word" - ) - - # Check that vertex_set uses sanitized names - assert "version_vertex" in sanitized_schema.vertex_config.vertex_set, ( - "VertexConfig.vertex_set should contain sanitized vertex name" + vertex_dbnames = [v.dbname for v in sanitized_schema.vertex_config.vertices] + assert "Package_vertex" in vertex_dbnames, ( + f"Expected 'package_vertex' in vertices after sanitization, got {vertex_dbnames}" ) - assert "version" not in sanitized_schema.vertex_config.vertex_set, ( - "VertexConfig.vertex_set should not contain original reserved word" + assert "package" not in vertex_dbnames, ( + f"Original reserved word 'package' should not be in vertices, got {vertex_dbnames}" ) - # Verify we can look up the vertex by sanitized name - version_vertex = sanitized_schema.vertex_config._vertices_map["version_vertex"] - assert version_vertex is not None, ( - "Should be able to look up vertex by sanitized name" - ) - assert version_vertex.name == "version_vertex", ( - f"Vertex name should be 'version_vertex', got {version_vertex.name}" - ) +def test_edges_sanitization_for_tigergraph(schema_with_incompatible_edges): + """Test that vertex names with reserved words are sanitized for TigerGraph.""" + schema = schema_with_incompatible_edges -def test_edge_finish_init_after_sanitization(schema_with_reserved_words): - """Test that edges are properly re-initialized after vertex name sanitization.""" - schema = schema_with_reserved_words - + # Create inferencer with TigerGraph flavor inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) - sanitized_schema = inferencer._sanitize_schema_attributes(schema) - - # Find edge connecting version_vertex to users - edges = list(sanitized_schema.edge_config.edges_list()) - version_users_edge = next( - ( - e - for e in edges - if "version_vertex" in (e.source, e.target) - and "users" in (e.source, e.target) - ), - None, - ) - - assert version_users_edge is not None, "Edge should exist" - # After finish_init (called in _sanitize_schema_attributes), _source and _target should be set correctly - assert version_users_edge._source is not None, ( - "Edge._source should be set after finish_init" - ) - assert version_users_edge._target is not None, ( - "Edge._target should be set after finish_init" - ) - - # The _source and _target should reference the sanitized vertex name (via dbname lookup) - # Since dbname defaults to name, they should be the sanitized names - assert "version_vertex" in ( - version_users_edge._source, - version_users_edge._target, - ), ( - f"Edge internal references should use sanitized vertex name, " - f"got _source={version_users_edge._source}, _target={version_users_edge._target}" - ) - - -def test_tigergraph_schema_validation_with_reserved_words(schema_with_reserved_words): - """Test that sanitized schema has no reserved words.""" - schema = schema_with_reserved_words - - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + # Sanitize the schema sanitized_schema = inferencer._sanitize_schema_attributes(schema) - # Verify schema structure is valid - assert sanitized_schema is not None, "Schema should be sanitized" - assert sanitized_schema.vertex_config is not None, ( - "Schema should have vertex_config" - ) - assert sanitized_schema.edge_config is not None, "Schema should have edge_config" - - # Verify all vertex names are sanitized (no reserved words) - vertex_names = [v.name for v in sanitized_schema.vertex_config.vertices] - reserved_words = inferencer.reserved_words - for vertex_name in vertex_names: - assert vertex_name.upper() not in reserved_words, ( - f"Vertex name '{vertex_name}' should not be a reserved word" - ) - - # Verify all edge source/target names are sanitized - edges = list(sanitized_schema.edge_config.edges_list()) - for edge in edges: - assert edge.source.upper() not in reserved_words, ( - f"Edge source '{edge.source}' should not be a reserved word" - ) - assert edge.target.upper() not in reserved_words, ( - f"Edge target '{edge.target}' should not be a reserved word" - ) - - # Verify all attribute names are sanitized - for vertex in sanitized_schema.vertex_config.vertices: - for field in vertex.fields: - assert field.name.upper() not in reserved_words, ( - f"Field name '{field.name}' in vertex '{vertex.name}' should not be a reserved word" - ) + # sanitized_schema.to_yaml_file( + # os.path.join( + # os.path.dirname(__file__), + # "../../config/schema/tigergraph-sanitize-edges.corrected.yaml", + # ) + # ) - # Verify edge weight names are sanitized - for edge in edges: - if edge.weights and edge.weights.direct: - for weight_field in edge.weights.direct: - assert weight_field.name.upper() not in reserved_words, ( - f"Weight field name '{weight_field.name}' in edge " - f"'{edge.source}' -> '{edge.target}' should not be a reserved word" - ) - - logger.info("Schema validation passed - all names are sanitized for TigerGraph") - - -def test_indirect_edge_by_reference_sanitization(): - """Test that indirect edge 'by' references are sanitized.""" - # Create schema with indirect edge - version_vertex = Vertex( - name="version", - fields=[Field(name="id", type=FieldType.INT)], - ) + assert sanitized_schema.resources[-1].root.actor.descendants[0].actor.t.map == { + "container_name": "id" + } - users_vertex = Vertex( - name="users", - fields=[Field(name="id", type=FieldType.INT)], - ) - - vertex_config = VertexConfig( - vertices=[version_vertex, users_vertex], db_flavor=DBFlavor.TIGERGRAPH - ) - - # Create indirect edge with 'by' referencing reserved word vertex - indirect_edge = Edge( - source="users", - target="users", - by="version", # Will be sanitized - ) - - edge_config = EdgeConfig(edges=[indirect_edge]) - - schema = Schema( - general=SchemaMetadata(name="test_schema"), - vertex_config=vertex_config, - edge_config=edge_config, - resources=[], - ) - - inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) - sanitized_schema = inferencer._sanitize_schema_attributes(schema) - - # Find the indirect edge - edges = list(sanitized_schema.edge_config.edges_list()) - indirect_edge_sanitized = edges[0] - - # Check that 'by' reference was sanitized - assert indirect_edge_sanitized.by == "version_vertex", ( - f"Indirect edge 'by' should reference 'version_vertex', got {indirect_edge_sanitized.by}" - ) - assert indirect_edge_sanitized.by != "version", ( - f"Indirect edge 'by' should not reference original 'version', got {indirect_edge_sanitized.by}" - ) + assert sanitized_schema.vertex_config.vertices[-1].fields[0].name == "id" + assert sanitized_schema.vertex_config.vertices[-1].indexes[0].fields[0] == "id"