diff --git a/examples/3-ingest-csv-edge-weights/ingest.py b/examples/3-ingest-csv-edge-weights/ingest.py index 393a2bf..33f28e6 100644 --- a/examples/3-ingest-csv-edge-weights/ingest.py +++ b/examples/3-ingest-csv-edge-weights/ingest.py @@ -17,6 +17,7 @@ conn_conf = Neo4jConfig.from_docker_env() # from graflo.db.connection.onto import TigergraphConfig +# # conn_conf = TigergraphConfig.from_docker_env() # Alternative: Create config directly or use environment variables diff --git a/graflo/db/connection/onto.py b/graflo/db/connection/onto.py index deebd35..4d8053c 100644 --- a/graflo/db/connection/onto.py +++ b/graflo/db/connection/onto.py @@ -1,4 +1,6 @@ import abc +import logging +import warnings from pathlib import Path from strenum import StrEnum from typing import Any, Dict, Type, TypeVar @@ -10,6 +12,8 @@ from graflo.onto import MetaEnum +logger = logging.getLogger(__name__) + # Type variable for DBConfig subclasses T = TypeVar("T", bound="DBConfig") @@ -126,24 +130,174 @@ def effective_schema(self) -> str | None: return self._get_effective_schema() @model_validator(mode="after") - def _add_default_port_to_uri(self): - """Add default port to URI if missing.""" + def _normalize_uri(self): + """Normalize URI: handle URIs without scheme and add default port if missing.""" if self.uri is None: return self + # Valid URL schemes (common database protocols) + valid_schemes = { + "http", + "https", + "bolt", + "bolt+s", + "bolt+ssc", + "neo4j", + "neo4j+s", + "neo4j+ssc", + "mongodb", + "postgresql", + "postgres", + "mysql", + "nebula", + "redis", # FalkorDB uses redis:// protocol + "rediss", # Redis with SSL + } + + # Try to parse as-is first parsed = urlparse(self.uri) - if parsed.port is not None: + + # Check if parsed scheme is actually a valid scheme or if it's a hostname + # urlparse treats "localhost:14240" as scheme="localhost", path="14240" + # We need to detect this case + has_valid_scheme = parsed.scheme.lower() in valid_schemes + has_netloc = bool(parsed.netloc) + + # If scheme doesn't look like a valid scheme and we have a colon, treat as host:port + if not has_valid_scheme and ":" in self.uri and not self.uri.startswith("//"): + # Check if it looks like host:port format + parts = self.uri.split(":", 1) + if len(parts) == 2: + potential_host = parts[0] + port_and_rest = parts[1] + # Extract port (may have path/query after it) + port_part = port_and_rest.split("/")[0].split("?")[0].split("#")[0] + try: + # Validate port is numeric + int(port_part) + # If hostname doesn't look like a scheme (contains dots, is localhost, etc.) + # or if the parsed scheme is not in valid schemes, treat as host:port + if ( + "." in potential_host + or potential_host.lower() in {"localhost", "127.0.0.1"} + or not has_valid_scheme + ): + # Reconstruct as proper URI with default scheme + default_scheme = "http" # Default to http for most DBs + rest = port_and_rest[len(port_part) :] # Everything after port + self.uri = ( + f"{default_scheme}://{potential_host}:{port_part}{rest}" + ) + parsed = urlparse(self.uri) + except ValueError: + # Not a valid port, treat as regular URI - add scheme if needed + if not has_valid_scheme: + default_scheme = "http" + self.uri = f"{default_scheme}://{self.uri}" + parsed = urlparse(self.uri) + elif not has_valid_scheme and not has_netloc: + # No valid scheme and no netloc - add default scheme + default_scheme = "http" + self.uri = f"{default_scheme}://{self.uri}" + parsed = urlparse(self.uri) + + # Add default port if missing + if parsed.port is None: + default_port = self._get_default_port() + if parsed.scheme and parsed.hostname: + # Reconstruct URI with port + port_part = f":{default_port}" if default_port else "" + path_part = parsed.path or "" + query_part = f"?{parsed.query}" if parsed.query else "" + fragment_part = f"#{parsed.fragment}" if parsed.fragment else "" + self.uri = f"{parsed.scheme}://{parsed.hostname}{port_part}{path_part}{query_part}{fragment_part}" + + return self + + @model_validator(mode="after") + def _extract_port_from_uri(self): + """Extract port from URI and set it as gs_port for TigerGraph (if applicable). + + For TigerGraph 4+, gs_port is the primary port. If URI has a port but gs_port + is not set, automatically extract and set gs_port from URI port. + This simplifies configuration - users can just provide URI with port. + """ + # Only apply to configs that have gs_port field (TigerGraph) + if not hasattr(self, "gs_port"): + return self + + if self.uri and self.gs_port is None: + uri_port = self.port # Get port from URI (property from base class) + if uri_port: + try: + self.gs_port = int(uri_port) + logger.debug( + f"Automatically set gs_port={self.gs_port} from URI port" + ) + except (ValueError, TypeError): + # Port couldn't be converted to int, skip auto-setting + pass + + return self + + @model_validator(mode="after") + def _check_port_conflicts(self): + """Check for port conflicts between URI and separate port fields. + + If port is provided both in URI and as a separate field, warn and prefer URI port. + This ensures consistency and avoids confusion. + """ + if self.uri is None: + return self + + uri_port = self.port # Get port from URI + if uri_port is None: return self - # Add default port - default_port = self._get_default_port() - if parsed.scheme and parsed.hostname: - # Reconstruct URI with port - port_part = f":{default_port}" if default_port else "" - path_part = parsed.path or "" - query_part = f"?{parsed.query}" if parsed.query else "" - fragment_part = f"#{parsed.fragment}" if parsed.fragment else "" - self.uri = f"{parsed.scheme}://{parsed.hostname}{port_part}{path_part}{query_part}{fragment_part}" + # Check for port fields in subclasses + # Get model fields to check for port-related fields + port_fields = [] + + # Check for specific port fields that might exist in subclasses + # Use getattr with None default to avoid AttributeError + if hasattr(self, "gs_port"): + gs_port_val = getattr(self, "gs_port", None) + if gs_port_val is not None: + port_fields.append(("gs_port", gs_port_val)) + + if hasattr(self, "bolt_port"): + bolt_port_val = getattr(self, "bolt_port", None) + if bolt_port_val is not None: + port_fields.append(("bolt_port", bolt_port_val)) + + # Check each port field for conflicts + port_conflicts = [] + for field_name, field_port in port_fields: + # Compare as strings to handle int vs str differences + if str(field_port) != str(uri_port): + port_conflicts.append((field_name, field_port, uri_port)) + + # Warn about conflicts and prefer URI port + if port_conflicts: + conflict_msgs = [ + f"{field_name}={field_port} (URI has port={uri_port})" + for field_name, field_port, _ in port_conflicts + ] + warning_msg = ( + f"Port conflict detected: Port specified both in URI ({uri_port}) " + f"and as separate field(s): {', '.join(conflict_msgs)}. " + f"Using port from URI ({uri_port}). Consider removing the separate port field(s)." + ) + warnings.warn(warning_msg, UserWarning, stacklevel=2) + logger.warning(warning_msg) + + # Update port fields to match URI port (prefer URI) + for field_name, _, _ in port_conflicts: + try: + setattr(self, field_name, int(uri_port)) + except (ValueError, AttributeError): + # Field might be read-only or not settable, that's okay + pass return self @@ -546,14 +700,13 @@ class TigergraphConfig(DBConfig): TigerGraph 4.1+ uses port 14240 (GSQL server) as the primary interface. Port 9000 (REST++) is for internal use only in TG 4.1+. - For vanilla TigerGraph 4+ installations, you typically only need port 14240. - Both restppPort and gsPort default to 14240 for TG 4+ compatibility. + Standard ports: + - Port 14240: GSQL server (primary interface for all API requests) + - Port 9000: REST++ (internal-only in TG 4.1+) - For custom Docker deployments with port mapping, override the ports: - >>> config = TigergraphConfig( - ... uri="http://localhost:9001", # Custom mapped REST++ port - ... gs_port=14241, # Custom mapped GSQL port - ... ) + For custom Docker deployments with port mapping, ports are configured via + environment variables (e.g., TG_WEB, TG_REST) and loaded automatically + when using TigergraphConfig.from_docker_env(). """ model_config = SettingsConfigDict( @@ -562,7 +715,7 @@ class TigergraphConfig(DBConfig): ) gs_port: int | None = Field( - default=None, description="TigerGraph GSQL port (default: 14240 for TG 4+)" + default=None, description="TigerGraph GSQL port (standard: 14240 for TG 4+)" ) secret: str | None = Field( default=None, @@ -580,19 +733,27 @@ class TigergraphConfig(DBConfig): "for cases where certificate hostname doesn't match (e.g., internal deployments with self-signed certs). " "WARNING: Disabling SSL verification reduces security and should only be used in trusted environments.", ) + max_job_size: int = Field( + default=1000, + description="Maximum size (in characters) for a single SCHEMA_CHANGE JOB. " + "Large jobs (>30k chars) can cause parser failures. The schema change will be split " + "into multiple batches if the estimated size exceeds this limit. Default: 1000.", + ) def _get_default_port(self) -> int: """Get default TigerGraph REST++ port. Note: TigerGraph 4.1+ uses port 14240 (GSQL server) as the primary interface. Port 9000 (REST++) is for internal use only in TG 4.1+. - However, pyTigerGraph's connection object still needs this port configured - for backward compatibility with older TG versions. - For TigerGraph 4+, it's recommended to explicitly set both port and gs_port - to the publicly accessible GSQL port (typically 14240). + Standard ports: + - Port 14240: GSQL server (primary interface) + - Port 9000: REST++ (internal-only in TG 4.1+) + + This method is kept for backward compatibility but should not be relied upon. + Ports should be explicitly configured in TigergraphConfig. """ - return 14240 # Default to GSQL port for TG 4+ compatibility + return 14240 # Standard GSQL port for TG 4+ def _get_effective_database(self) -> str | None: """TigerGraph doesn't have a database level (connection -> schema -> vertices/edges).""" @@ -607,11 +768,21 @@ def _get_effective_schema(self) -> str | None: return self.schema_name def __init__(self, **data): - """Initialize TigerGraph config.""" + """Initialize TigerGraph config. + + Note: For TigerGraph 4+, gs_port is the primary port (14240). + If URI is provided with a port, it will be automatically set as gs_port + by the _extract_port_from_uri validator. + Standard ports: + - 14240: GSQL server (primary interface) + - 9000: REST++ (internal-only in TG 4.1+) + + If port is provided both in URI and as gs_port, the port from URI will be used + and a warning will be issued. + """ super().__init__(**data) - # Set default gs_port if not provided - if self.gs_port is None: - self.gs_port = 14240 + # Port extraction from URI is handled by _extract_port_from_uri validator + # Port conflicts are handled by _check_port_conflicts validator in base class @classmethod def from_docker_env( @@ -640,22 +811,52 @@ def from_docker_env( # Map environment variables to config config_data: Dict[str, Any] = {} - if "TG_REST" in env_vars or "TIGERGRAPH_PORT" in env_vars: - port = env_vars.get("TG_REST") or env_vars.get("TIGERGRAPH_PORT") - hostname = env_vars.get("TIGERGRAPH_HOSTNAME", "localhost") - protocol = env_vars.get("TIGERGRAPH_PROTOCOL", "http") - config_data["uri"] = f"{protocol}://{hostname}:{port}" - if "TG_WEB" in env_vars or "TIGERGRAPH_GS_PORT" in env_vars: - gs_port = env_vars.get("TG_WEB") or env_vars.get("TIGERGRAPH_GS_PORT") - config_data["gs_port"] = int(gs_port) if gs_port else None + # For TigerGraph 4+, use GSQL port (TG_WEB) for both REST++ and GSQL + # TG_REST (port 9000) is internal-only in TG 4.1+ + gs_port = env_vars.get("TG_WEB") or env_vars.get("TIGERGRAPH_GS_PORT") + rest_port = env_vars.get("TG_REST") or env_vars.get("TIGERGRAPH_PORT") + + # Prefer GSQL port for TigerGraph 4+ compatibility + # Standard ports: 14240 (GSQL), 9000 (REST++) + # Docker may map these to different external ports (e.g., 14241, 9001) + if gs_port: + port = gs_port + config_data["gs_port"] = int(gs_port) + elif rest_port: + port = rest_port + # If only REST port is provided, use it for both (Docker mapping scenario) + config_data["gs_port"] = int(rest_port) + else: + raise ValueError( + "Either TG_WEB or TG_REST must be set in .env file. " + "Standard ports: 14240 (GSQL), 9000 (REST++)." + ) + hostname = env_vars.get("TIGERGRAPH_HOSTNAME", "localhost") + protocol = env_vars.get("TIGERGRAPH_PROTOCOL", "http") + config_data["uri"] = f"{protocol}://{hostname}:{port}" + + # Set default username if not provided if "TIGERGRAPH_USERNAME" in env_vars: config_data["username"] = env_vars["TIGERGRAPH_USERNAME"] + else: + config_data["username"] = "tigergraph" # Default username + + # Set password from env vars or use default if "TIGERGRAPH_PASSWORD" in env_vars or "GSQL_PASSWORD" in env_vars: config_data["password"] = env_vars.get( "TIGERGRAPH_PASSWORD" ) or env_vars.get("GSQL_PASSWORD") + else: + # Check environment variable as fallback, default to "tigergraph" + import os + + config_data["password"] = ( + os.environ.get("GSQL_PASSWORD") + or os.environ.get("TIGERGRAPH_PASSWORD") + or "tigergraph" + ) if "TIGERGRAPH_DATABASE" in env_vars: config_data["database"] = env_vars["TIGERGRAPH_DATABASE"] diff --git a/graflo/db/postgres/schema_inference.py b/graflo/db/postgres/schema_inference.py index 2064823..1c1d5fe 100644 --- a/graflo/db/postgres/schema_inference.py +++ b/graflo/db/postgres/schema_inference.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging +from typing import TYPE_CHECKING from graflo.architecture.edge import Edge, EdgeConfig, WeightConfig from graflo.architecture.onto import Index, IndexType @@ -15,9 +16,13 @@ from graflo.onto import DBFlavor from ...architecture.onto_sql import EdgeTableInfo, SchemaIntrospectionResult +from ..util import load_reserved_words, sanitize_attribute_name from .conn import PostgresConnection from .types import PostgresTypeMapper +if TYPE_CHECKING: + from graflo.architecture.resource import Resource + logger = logging.getLogger(__name__) @@ -42,6 +47,8 @@ def __init__( self.db_flavor = db_flavor self.type_mapper = PostgresTypeMapper() self.conn = conn + # Load reserved words for the target database flavor + self.reserved_words = load_reserved_words(db_flavor) def infer_vertex_config( self, introspection_result: SchemaIntrospectionResult @@ -324,6 +331,263 @@ def infer_edge_config( return EdgeConfig(edges=edges) + def _sanitize_schema_attributes(self, schema: Schema) -> Schema: + """Sanitize attribute names and vertex names in the schema to avoid reserved words. + + This method modifies: + - Field names in vertices and edges + - Vertex names themselves + - Edge source/target/by references to vertices + - Resource apply lists that reference vertices + + The sanitization is deterministic: the same input always produces the same output. + + Args: + schema: The schema to sanitize + + Returns: + Schema with sanitized attribute names and vertex names + """ + if not self.reserved_words: + # No reserved words to check, return schema as-is + return schema + + # Track name mappings for attributes (fields/weights) + attribute_mappings: dict[str, str] = {} + # Track name mappings for vertex names (separate from attributes) + vertex_mappings: dict[str, str] = {} + + # First pass: Sanitize vertex names + for vertex in schema.vertex_config.vertices: + original_vertex_name = vertex.name + if original_vertex_name not in vertex_mappings: + sanitized_vertex_name = sanitize_attribute_name( + original_vertex_name, self.reserved_words, suffix="_vertex" + ) + if sanitized_vertex_name != original_vertex_name: + vertex_mappings[original_vertex_name] = sanitized_vertex_name + logger.debug( + f"Sanitizing vertex name '{original_vertex_name}' -> '{sanitized_vertex_name}'" + ) + else: + vertex_mappings[original_vertex_name] = original_vertex_name + else: + sanitized_vertex_name = vertex_mappings[original_vertex_name] + + # Update vertex name if it changed + if sanitized_vertex_name != original_vertex_name: + vertex.name = sanitized_vertex_name + # Also update dbname if it matches the original name (default behavior) + if vertex.dbname == original_vertex_name or vertex.dbname is None: + vertex.dbname = sanitized_vertex_name + + # Rebuild VertexConfig's internal _vertices_map after renaming vertices + schema.vertex_config._vertices_map = { + vertex.name: vertex for vertex in schema.vertex_config.vertices + } + + # Update blank_vertices references if they were sanitized + schema.vertex_config.blank_vertices = [ + vertex_mappings.get(v, v) for v in schema.vertex_config.blank_vertices + ] + + # Update force_types keys if they were sanitized + schema.vertex_config.force_types = { + vertex_mappings.get(k, k): v + for k, v in schema.vertex_config.force_types.items() + } + + # Second pass: Sanitize vertex field names + for vertex in schema.vertex_config.vertices: + for field in vertex.fields: + original_name = field.name + if original_name not in attribute_mappings: + sanitized_name = sanitize_attribute_name( + original_name, self.reserved_words + ) + if sanitized_name != original_name: + attribute_mappings[original_name] = sanitized_name + logger.debug( + f"Sanitizing field name '{original_name}' -> '{sanitized_name}' " + f"in vertex '{vertex.name}'" + ) + else: + attribute_mappings[original_name] = original_name + else: + sanitized_name = attribute_mappings[original_name] + + # Update field name if it changed + if sanitized_name != original_name: + field.name = sanitized_name + + # Update index field references if they were sanitized + for index in vertex.indexes: + updated_fields = [] + for field_name in index.fields: + sanitized_field_name = attribute_mappings.get( + field_name, field_name + ) + updated_fields.append(sanitized_field_name) + index.fields = updated_fields + + # Third pass: Update edge references to sanitized vertex names + for edge in schema.edge_config.edges: + # Update source vertex reference + if edge.source in vertex_mappings: + edge.source = vertex_mappings[edge.source] + logger.debug( + f"Updated edge source reference '{edge.source}' (sanitized vertex name)" + ) + + # Update target vertex reference + if edge.target in vertex_mappings: + edge.target = vertex_mappings[edge.target] + logger.debug( + f"Updated edge target reference '{edge.target}' (sanitized vertex name)" + ) + + # Update 'by' vertex reference for indirect edges + # Note: edge.by might be a vertex name or a dbname (if finish_init was already called) + # We check both the direct mapping and reverse lookup via dbname + if edge.by is not None: + if edge.by in vertex_mappings: + # edge.by is a vertex name that needs sanitization + edge.by = vertex_mappings[edge.by] + logger.debug( + f"Updated edge 'by' reference to '{edge.by}' (sanitized vertex name)" + ) + else: + # edge.by might be a dbname - try to find the vertex that has this dbname + # and check if its name was sanitized + try: + vertex = schema.vertex_config._get_vertex_by_name_or_dbname( + edge.by + ) + vertex_name = vertex.name + if vertex_name in vertex_mappings: + # This vertex was sanitized, update edge.by to use sanitized name + # (finish_init will convert it back to dbname) + edge.by = vertex_mappings[vertex_name] + logger.debug( + f"Updated edge 'by' reference from dbname '{edge.by}' " + f"to sanitized vertex name '{vertex_mappings[vertex_name]}'" + ) + except (KeyError, AttributeError): + # edge.by is neither a vertex name nor a dbname we recognize + # This shouldn't happen in normal operation, but we'll skip it + pass + + # Sanitize edge weight field names + if edge.weights and edge.weights.direct: + for weight_field in edge.weights.direct: + original_name = weight_field.name + if original_name not in attribute_mappings: + sanitized_name = sanitize_attribute_name( + original_name, self.reserved_words + ) + if sanitized_name != original_name: + attribute_mappings[original_name] = sanitized_name + logger.debug( + f"Sanitizing weight field name '{original_name}' -> " + f"'{sanitized_name}' in edge '{edge.source}' -> '{edge.target}'" + ) + else: + attribute_mappings[original_name] = original_name + else: + sanitized_name = attribute_mappings[original_name] + + # Update weight field name if it changed + if sanitized_name != original_name: + weight_field.name = sanitized_name + + # Fourth pass: Re-initialize edges after vertex name sanitization + # This ensures edge._source, edge._target, and edge.by are correctly set + # with the sanitized vertex names + schema.edge_config.finish_init(schema.vertex_config) + + # Fifth pass: Update resource apply lists that reference vertices + for resource in schema.resources: + self._sanitize_resource_vertex_references(resource, vertex_mappings) + + return schema + + def _sanitize_resource_vertex_references( + self, resource: Resource, vertex_mappings: dict[str, str] + ) -> None: + """Sanitize vertex name references in a resource's apply list. + + Resources can reference vertices in their apply list through: + - {"vertex": vertex_name} for VertexActor + - {"target_vertex": vertex_name, ...} for mapping actors + - {"source": vertex_name, "target": vertex_name} for EdgeActor + - Nested structures in tree_likes resources + + Args: + resource: The resource to sanitize + vertex_mappings: Dictionary mapping original vertex names to sanitized names + """ + if not hasattr(resource, "apply") or not resource.apply: + return + + def sanitize_apply_item(item): + """Recursively sanitize vertex references in apply items.""" + if isinstance(item, dict): + # Handle vertex references in dictionaries + sanitized_item = {} + for key, value in item.items(): + if key == "vertex" and isinstance(value, str): + # {"vertex": vertex_name} + sanitized_item[key] = vertex_mappings.get(value, value) + if value != sanitized_item[key]: + logger.debug( + f"Updated resource '{resource.resource_name}' apply item: " + f"'{key}': '{value}' -> '{sanitized_item[key]}'" + ) + elif key == "target_vertex" and isinstance(value, str): + # {"target_vertex": vertex_name, ...} + sanitized_item[key] = vertex_mappings.get(value, value) + if value != sanitized_item[key]: + logger.debug( + f"Updated resource '{resource.resource_name}' apply item: " + f"'{key}': '{value}' -> '{sanitized_item[key]}'" + ) + elif key in ("source", "target") and isinstance(value, str): + # {"source": vertex_name, "target": vertex_name} for EdgeActor + sanitized_item[key] = vertex_mappings.get(value, value) + if value != sanitized_item[key]: + logger.debug( + f"Updated resource '{resource.resource_name}' apply item: " + f"'{key}': '{value}' -> '{sanitized_item[key]}'" + ) + elif key == "name" and isinstance(value, str): + # Keep transform names as-is + sanitized_item[key] = value + elif key == "children" and isinstance(value, list): + # Recursively sanitize children in tree_likes resources + sanitized_item[key] = [ + sanitize_apply_item(child) for child in value + ] + elif isinstance(value, dict): + # Recursively sanitize nested dictionaries + sanitized_item[key] = sanitize_apply_item(value) + elif isinstance(value, list): + # Recursively sanitize lists + sanitized_item[key] = [ + sanitize_apply_item(subitem) for subitem in value + ] + else: + sanitized_item[key] = value + return sanitized_item + elif isinstance(item, list): + # Recursively sanitize lists + return [sanitize_apply_item(subitem) for subitem in item] + else: + # Return non-dict/list items as-is + return item + + # Sanitize the entire apply list + resource.apply = [sanitize_apply_item(item) for item in resource.apply] + def infer_schema( self, introspection_result: SchemaIntrospectionResult, @@ -355,7 +619,7 @@ def infer_schema( # Create schema metadata metadata = SchemaMetadata(name=schema_name) - # Create schema (resources will be added separately) + # Create schema (resources will be created separately) schema = Schema( general=metadata, vertex_config=vertex_config, @@ -363,6 +627,9 @@ def infer_schema( resources=[], # Resources will be created separately ) + # Sanitize attribute names to avoid reserved words + schema = self._sanitize_schema_attributes(schema) + logger.info( f"Successfully inferred schema '{schema_name}' with " f"{len(vertex_config.vertices)} vertices and " diff --git a/graflo/db/tigergraph/conn.py b/graflo/db/tigergraph/conn.py index 4ac65dd..a37d45e 100644 --- a/graflo/db/tigergraph/conn.py +++ b/graflo/db/tigergraph/conn.py @@ -26,13 +26,14 @@ import contextlib import json import logging -from typing import Any, cast +from pathlib import Path +from typing import Any import requests from requests import exceptions as requests_exceptions -from pyTigerGraph import TigerGraphConnection as PyTigerGraphConnection +# Removed pyTigerGraph dependency - using direct REST API calls instead from graflo.architecture.edge import Edge @@ -90,13 +91,7 @@ def _patch_exception_class(cls: type[Exception]) -> None: except (ImportError, AttributeError): pass -# Patch TigerGraphException -try: - from pyTigerGraph import TigerGraphException - - _patch_exception_class(TigerGraphException) -except (ImportError, AttributeError): - pass +# Removed pyTigerGraph dependency - no longer need TigerGraphException patching def _wrap_tg_exception(func): @@ -116,6 +111,73 @@ def wrapper(*args, **kwargs): return wrapper +def _validate_tigergraph_schema_name(name: str, name_type: str) -> None: + """ + Validate a TigerGraph schema name (graph, vertex, or edge) against reserved words + and invalid characters. + + Args: + name: The schema name to validate + name_type: Type of schema name ("graph", "vertex", or "edge") + + Raises: + ValueError: If the name contains reserved words, forbidden prefixes, or invalid characters + """ + if not name: + raise ValueError(f"{name_type.capitalize()} name cannot be empty") + + # Load reserved words from JSON file + json_path = Path(__file__).parent / "reserved_words.json" + try: + with open(json_path, "r") as f: + reserved_data = json.load(f) + except FileNotFoundError: + logger.warning( + f"Could not find reserved_words.json at {json_path}, skipping validation" + ) + return + except json.JSONDecodeError as e: + logger.warning(f"Could not parse reserved_words.json: {e}, skipping validation") + return + + reserved_words = set() + reserved_words.update( + reserved_data.get("reserved_words", {}).get("gsql_keywords", []) + ) + reserved_words.update( + reserved_data.get("reserved_words", {}).get("cpp_keywords", []) + ) + + # Check for reserved words (case-insensitive) + name_upper = name.upper() + if name_upper in reserved_words: + raise ValueError( + f"{name_type.capitalize()} name '{name}' is a TigerGraph reserved word. " + f"Reserved words cannot be used as identifiers. " + f"Please choose a different name." + ) + + # Check for forbidden prefixes + forbidden_prefixes = reserved_data.get("forbidden_prefixes", []) + for prefix in forbidden_prefixes: + if name.startswith(prefix): + raise ValueError( + f"{name_type.capitalize()} name '{name}' starts with forbidden prefix '{prefix}'. " + f"Please choose a different name." + ) + + # Check for invalid characters + invalid_chars = reserved_data.get("invalid_characters", {}).get("characters", []) + found_chars = [char for char in invalid_chars if char in name] + if found_chars: + raise ValueError( + f"{name_type.capitalize()} name '{name}' contains invalid characters: {found_chars}. " + f"TigerGraph identifiers should use alphanumeric characters and underscores only. " + f"Special characters (especially hyphens and dots) are problematic for REST API endpoints. " + f"Please choose a different name." + ) + + class TigerGraphConnection(Connection): """ TigerGraph database connection implementation. @@ -136,7 +198,7 @@ class TigerGraphConnection(Connection): approach for TigerGraph 4+. The connection will: 1. Use username/password for initial connection 2. Generate a token from the secret - 3. Use the token for both GSQL operations (via pyTigerGraph) and REST API calls + 3. Use the token for both GSQL operations (via REST API) and REST API calls Example: >>> config = TigergraphConfig( @@ -152,17 +214,16 @@ class TigerGraphConnection(Connection): TigerGraph 4.1+ uses port 14240 (GSQL server) as the primary interface. Port 9000 (REST++) is for internal use only in TG 4.1+. - Default behavior: Both restppPort and gsPort default to 14240 for TG 4+ compatibility. + Standard ports: + - Port 14240: GSQL server (primary interface for all API requests) + - Port 9000: REST++ (internal-only in TG 4.1+) - For custom Docker deployments with port mapping, explicitly set both ports: - >>> config = TigergraphConfig( - ... uri="http://localhost:9001", # Custom REST++ port - ... gs_port=14241, # Custom GSQL port - ... ) + For custom Docker deployments with port mapping, ports are configured via + environment variables (e.g., TG_WEB, TG_REST) and loaded automatically + when using TigergraphConfig.from_docker_env(). Version Compatibility: - - TigerGraph 4.2.2+: Direct REST API endpoints (no /restpp prefix) - - TigerGraph 4.2.1 and older: REST API with /restpp prefix + - All TigerGraph versions use /restpp prefix for REST++ endpoints - Version is auto-detected, or can be manually specified in config """ @@ -173,83 +234,36 @@ def __init__(self, config: TigergraphConfig): self.config = config self.ssl_verify = getattr(config, "ssl_verify", True) - # Initialize pyTigerGraph connection for most operations - # Use type narrowing to help type checker understand non-None values - # For TigerGraph 4+, both ports typically route through the GSQL server (14240) + # Store connection configuration (no longer using pyTigerGraph) + # For TigerGraph 4+, both ports typically route through the GSQL server # Port 9000 (REST++) is internal-only in TG 4.1+ - restpp_port: int | str = config.port if config.port is not None else "14240" - gs_port: int | str = config.gs_port if config.gs_port is not None else "14240" - graphname: str = ( + self.graphname: str = ( config.database if config.database is not None else "DefaultGraph" ) - username: str = config.username if config.username is not None else "tigergraph" - password: str = config.password if config.password is not None else "tigergraph" - cert_path: str | None = getattr(config, "certPath", None) - - # Build connection kwargs, only include certPath if it's not None - conn_kwargs: dict[str, Any] = { - "host": config.url_without_port, - "restppPort": restpp_port, - "gsPort": gs_port, - "graphname": graphname, - "username": username, - "password": password, - } - if cert_path is not None: - conn_kwargs["certPath"] = cert_path - self.conn = PyTigerGraphConnection(**conn_kwargs) + # Initialize URLs (ports come from config, no hardcoded defaults) + # Set GSQL URL first as it's needed for token generation + # For TigerGraph 4+, gs_port is the primary port (extracted from URI if not explicitly set) + # Fall back to port from URI if gs_port is not set + gs_port: int | str | None = config.gs_port + if gs_port is None: + # Try to get port from URI + uri_port = config.port + if uri_port: + try: + gs_port = int(uri_port) + logger.debug(f"Using port {gs_port} from URI for GSQL endpoint") + except (ValueError, TypeError): + pass - # Get authentication token if secret is provided - # Token-based auth is the recommended approach for TigerGraph 4+ - # IMPORTANT: You should provide BOTH username/password AND secret: - # - username/password: Used for initial connection and GSQL operations - # - secret: Generates token that works for both GSQL and REST API operations - self.api_token: str | None = None - if config.secret: - try: - # Explicitly set setToken=True for TigerGraph 4.2.1+ compatibility - # This ensures the token is set on the connection object before any operations - token = self.conn.getToken(config.secret, setToken=True) - # getToken returns tuple (token, expiration) or just token - if isinstance(token, tuple): - self.api_token = token[0] - logger.info( - f"Successfully obtained API token (expires: {token[1]})" - ) - else: - self.api_token = token - logger.info("Successfully obtained API token") - # Explicitly set token on connection object for TigerGraph 4.2.1 compatibility - # This ensures pyTigerGraph internal calls (including GSQL) use the same token - if self.api_token: - # Set token via all available methods to ensure compatibility - if hasattr(self.conn, "apiToken"): - self.conn.apiToken = self.api_token # type: ignore[attr-defined] - if hasattr(self.conn, "token"): - self.conn.token = self.api_token # type: ignore[attr-defined] - if hasattr(self.conn, "setToken"): - self.conn.setToken(self.api_token) # type: ignore[attr-defined] - # Verify token is set (for debugging) - if hasattr(self.conn, "apiToken"): - actual_token = getattr(self.conn, "apiToken", None) - if actual_token != self.api_token: - logger.warning( - "Token mismatch detected. Expected token set, " - "but connection has different token." - ) - else: - logger.debug("Token successfully set on connection object") - except Exception as e: - # Log and fall back to username/password authentication - logger.warning(f"Failed to get authentication token: {e}") - logger.warning("Falling back to username/password authentication") - logger.warning( - "Note: For best results, provide both username/password AND secret. " - "Username/password is used for GSQL operations, secret generates token for REST API." - ) + if gs_port is None: + raise ValueError( + "gs_port or URI with port must be set in TigergraphConfig. " + "Standard ports: 14240 (GSQL), 9000 (REST++)." + ) + self.gsql_url = f"{config.url_without_port}:{gs_port}" - # Detect TigerGraph version for compatibility + # Detect TigerGraph version for compatibility (needed before token generation) self.tg_version: str | None = None self._use_restpp_prefix = False # Default for 4.2.2+ @@ -258,35 +272,9 @@ def __init__(self, config: TigergraphConfig): version_str = config.version logger.info(f"Using manually configured TigerGraph version: {version_str}") else: - # Auto-detect version + # Auto-detect version using REST API try: - version_info = self.conn.getVersion() - # getVersion() can return different formats: - # - list: [{"version": "release_4.2.2_..."}, ...] - # - dict: {"version": "release_4.2.2_..."} or {"api": [{"version": "4.2.2"}]} - # - str: "4.2.2" - version_str = None - - if isinstance(version_info, list) and len(version_info) > 0: - first_item = version_info[0] - if isinstance(first_item, dict) and "version" in first_item: - version_str = str(first_item["version"]) - else: - version_str = str(first_item) - elif isinstance(version_info, dict): - # Try different dict structures - if "version" in version_info: - version_str = str(version_info["version"]) - elif "api" in version_info and isinstance( - version_info["api"], list - ): - if ( - len(version_info["api"]) > 0 - and "version" in version_info["api"][0] - ): - version_str = str(version_info["api"][0]["version"]) - elif isinstance(version_info, str): - version_str = version_info + version_str = self._get_version() except Exception as e: logger.warning( f"Failed to detect TigerGraph version: {e}. " @@ -306,58 +294,848 @@ def __init__(self, config: TigergraphConfig): patch = int(version_match.group(3)) self.tg_version = f"{major}.{minor}.{patch}" - # Version 4.2.1 and older need /restpp prefix - if (major, minor, patch) < (4, 2, 2): - self._use_restpp_prefix = True + # All TigerGraph versions use /restpp prefix for REST++ endpoints + # Even 4.2.2+ requires /restpp prefix (despite some documentation suggesting otherwise) + self._use_restpp_prefix = True + logger.info( + f"TigerGraph version {self.tg_version} detected, " + f"using /restpp prefix for REST API" + ) + else: + logger.warning( + f"Could not extract version number from '{version_str}'. " + f"Defaulting to using /restpp prefix for REST API" + ) + self._use_restpp_prefix = True + + # Store base URLs for REST++ and GSQL endpoints + # For TigerGraph 4.1+, REST++ endpoints use the GSQL port with /restpp prefix + # Port 9000 is internal-only in TG 4.1+, so we use the same port as GSQL + # Use the GSQL port we already determined to ensure consistency + base_url = f"{config.url_without_port}:{gs_port}" + # Always use /restpp prefix for REST++ endpoints (required for all TG versions) + self.restpp_url = f"{base_url}/restpp" + + # Get authentication token if secret is provided + # Token-based auth is the recommended approach for TigerGraph 4+ + # IMPORTANT: You should provide BOTH username/password AND secret: + # - username/password: Used for initial connection and GSQL operations + # - secret: Generates token that works for both GSQL and REST API operations + # Use graph-specific token (is_global=False) for better security + self.api_token: str | None = None + if config.secret: + try: + token, expiration = self._get_token_from_secret( + config.secret, + self.graphname, # Pass graph name for graph-specific token + ) + self.api_token = token + if expiration: logger.info( - f"TigerGraph version {self.tg_version} detected, " - f"using /restpp prefix for REST API" + f"Successfully obtained API token for graph '{self.graphname}' " + f"(expires: {expiration})" ) else: logger.info( - f"TigerGraph version {self.tg_version} detected, " - f"using direct REST API endpoints" + f"Successfully obtained API token for graph '{self.graphname}'" ) - else: + except Exception as e: + # Log and fall back to username/password authentication + logger.warning(f"Failed to get authentication token: {e}") + logger.warning("Falling back to username/password authentication") logger.warning( - f"Could not extract version number from '{version_str}'. " - f"Defaulting to 4.2.2+ behavior (no /restpp prefix)" + "Note: For best results, provide both username/password AND secret. " + "Username/password is used for GSQL operations, secret generates token for REST API." ) - # Store base URLs for REST++ and GSQL endpoints - # For version 4.2.1 and older, include /restpp in the path - base_url = f"{config.url_without_port}:{config.port}" - if self._use_restpp_prefix: - self.restpp_url = f"{base_url}/restpp" - else: - self.restpp_url = base_url - self.gsql_url = f"{config.url_without_port}:{config.gs_port}" - - def _get_auth_headers(self) -> dict[str, str]: + def _get_auth_headers(self, use_basic_auth: bool = False) -> dict[str, str]: """Get authentication headers for REST API calls. - Prioritizes token-based authentication over Basic Auth: + Args: + use_basic_auth: If True, always use Basic Auth (required for GSQL endpoints). + If False, prioritize token-based auth for REST++ endpoints. + + Prioritizes token-based authentication over Basic Auth for REST++ endpoints: 1. If API token is available (from secret), use Bearer token (recommended for TG 4+) 2. Otherwise, fall back to HTTP Basic Auth with username/password + For GSQL endpoints, always use Basic Auth as they don't support Bearer tokens. + Returns: Dictionary with Authorization header """ headers = {} - # Prefer token-based authentication (recommended for TigerGraph 4+) - if self.api_token: - headers["Authorization"] = f"Bearer {self.api_token}" - elif self.config.username and self.config.password: - # Fallback to HTTP Basic Auth - import base64 + # GSQL endpoints require Basic Auth, not Bearer tokens + if use_basic_auth or not self.api_token: + # Use default username "tigergraph" if username is None but password is set + username = self.config.username if self.config.username else "tigergraph" + password = self.config.password - credentials = f"{self.config.username}:{self.config.password}" - encoded_credentials = base64.b64encode(credentials.encode()).decode() - headers["Authorization"] = f"Basic {encoded_credentials}" + if password: + import base64 + + credentials = f"{username}:{password}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + headers["Authorization"] = f"Basic {encoded_credentials}" + else: + logger.warning( + f"No password configured for Basic Auth. " + f"Username: {username}, Password: {password}" + ) + else: + # Use Bearer token for REST++ endpoints + headers["Authorization"] = f"Bearer {self.api_token}" return headers + def _get_token_from_secret( + self, secret: str, graph_name: str | None = None, lifetime: int = 3600 * 24 * 30 + ) -> tuple[str, str | None]: + """ + Generate authentication token from secret using TigerGraph REST API. + + Implements robust token generation with fallback logic for different TG 4.x versions: + - TigerGraph 4.2.2+: POST /gsql/v1/tokens (lifetime in milliseconds) + - TigerGraph 4.0-4.2.1: POST /gsql/v1/auth/token (lifetime in seconds) + + Based on pyTigerGraph's token generation mechanism with version-specific endpoint handling. + + Args: + secret: Secret string created via CREATE SECRET in GSQL + graph_name: Name of the graph (None for global token) + lifetime: Token lifetime in seconds (default: 30 days) + + Returns: + Tuple of (token, expiration_timestamp) or (token, None) if expiration not provided + + Raises: + RuntimeError: If token generation fails after all retry attempts + """ + auth_headers = self._get_auth_headers(use_basic_auth=True) + headers = { + "Content-Type": "application/json", + **auth_headers, + } + + # Determine which endpoint to try based on version + # For TG 4.2.2+, use /gsql/v1/tokens (lifetime in milliseconds) + # For TG 4.0-4.2.1, use /gsql/v1/auth/token (lifetime in seconds) + use_new_endpoint = False + if self.tg_version: + import re + + version_match = re.search(r"(\d+)\.(\d+)\.(\d+)", self.tg_version) + if version_match: + major = int(version_match.group(1)) + minor = int(version_match.group(2)) + patch = int(version_match.group(3)) + # Use new endpoint for 4.2.2+ + use_new_endpoint = (major, minor, patch) >= (4, 2, 2) + + # Try endpoints in order: new endpoint first (if version >= 4.2.2), then fallback + endpoints_to_try = [] + if use_new_endpoint: + # Try new endpoint first for 4.2.2+ + endpoints_to_try.append( + ( + f"{self.gsql_url}/gsql/v1/tokens", + { + "secret": secret, + "graph": graph_name, + "lifetime": lifetime * 1000, # Convert to milliseconds + }, + True, # lifetime in milliseconds + ) + ) + # Fallback to old endpoint if new one fails + endpoints_to_try.append( + ( + f"{self.gsql_url}/gsql/v1/auth/token", + { + "secret": secret, + "graph": graph_name, + "lifetime": lifetime, # In seconds + }, + False, # lifetime in seconds + ) + ) + else: + # For older versions or unknown version, try old endpoint first + endpoints_to_try.append( + ( + f"{self.gsql_url}/gsql/v1/auth/token", + { + "secret": secret, + "graph": graph_name, + "lifetime": lifetime, # In seconds + }, + False, # lifetime in seconds + ) + ) + # Fallback to new endpoint (in case version detection was wrong) + endpoints_to_try.append( + ( + f"{self.gsql_url}/gsql/v1/tokens", + { + "secret": secret, + "graph": graph_name, + "lifetime": lifetime * 1000, # Convert to milliseconds + }, + True, # lifetime in milliseconds + ) + ) + + last_error: Exception | None = None + all_404_errors = True # Track if all failures were 404 errors + + for url, payload, _is_milliseconds in endpoints_to_try: + try: + # Remove None values from payload + clean_payload = {k: v for k, v in payload.items() if v is not None} + + response = requests.post( + url, + headers=headers, + json=clean_payload, # Use json parameter instead of data + timeout=30, + verify=self.ssl_verify, + ) + + # Check for 404 - might indicate wrong endpoint or port issue + if response.status_code == 404: + # Try port fallback (similar to pyTigerGraph's _req method) + # If using wrong port, try GSQL port + if ( + "/gsql" in url + and self.config.port is not None + and self.config.gs_port is not None + and self.config.port != self.config.gs_port + ): + logger.debug(f"404 on {url}, trying GSQL port fallback...") + # Replace port in URL with GSQL port + fallback_url = url.replace( + f":{self.config.port}", f":{self.config.gs_port}" + ) + try: + response = requests.post( + fallback_url, + headers=headers, + json=clean_payload, + timeout=30, + verify=self.ssl_verify, + ) + if response.status_code == 200: + url = fallback_url # Update URL for logging + except Exception: + pass # Continue to next endpoint + + response.raise_for_status() + result = response.json() + + # Parse response (both endpoints return similar format) + # Format: {"token": "...", "expiration": "...", "error": false, "message": "..."} + # or {"token": "..."} for older versions + if result.get("error") is True: + error_msg = result.get("message", "Unknown error") + raise RuntimeError(f"Token generation failed: {error_msg}") + + token = result.get("token") + expiration = result.get("expiration") + + if token: + logger.debug( + f"Successfully obtained token from {url} " + f"(expiration: {expiration or 'not provided'})" + ) + return (token, expiration) + else: + raise ValueError(f"No token in response: {result}") + + except requests.exceptions.HTTPError as e: + # Track if this was a 404 error + if e.response.status_code != 404: + all_404_errors = False + + # If 404 and we have more endpoints to try, continue + if e.response.status_code == 404 and len(endpoints_to_try) > 1: + logger.debug( + f"Endpoint {url} returned 404, trying next endpoint..." + ) + last_error = e + continue + # For other HTTP errors, log and try next endpoint if available + logger.debug( + f"HTTP error {e.response.status_code} on {url}: {e.response.text}" + ) + last_error = e + continue + except Exception as e: + all_404_errors = False # Non-HTTP errors are not 404s + logger.debug(f"Error trying {url}: {e}") + last_error = e + continue + + # All graph-specific endpoints failed + # If all failures were 404 errors and we have a graph_name, try generating a global token + # This handles cases where the graph doesn't exist yet (e.g., "DefaultGraph" at init time) + # For TigerGraph 4.2.1, /gsql/v1/tokens requires the graph to exist, but /gsql/v1/auth/token + # can generate a global token without a graph parameter + if all_404_errors and graph_name is not None and last_error: + logger.debug( + f"All graph-specific token attempts failed with 404. " + f"Graph '{graph_name}' may not exist yet. " + f"Trying to generate a global token (without graph parameter)..." + ) + + # Try generating a global token using /gsql/v1/auth/token (works for TG 4.0-4.2.1) + global_token_endpoints = [ + ( + f"{self.gsql_url}/gsql/v1/auth/token", + { + "secret": secret, + "lifetime": lifetime, # In seconds + # No graph parameter = global token + }, + False, # lifetime in seconds + ) + ] + + # Also try /gsql/v1/tokens without graph parameter (for TG 4.2.2+) + global_token_endpoints.append( + ( + f"{self.gsql_url}/gsql/v1/tokens", + { + "secret": secret, + "lifetime": lifetime * 1000, # In milliseconds + # No graph parameter = global token + }, + True, # lifetime in milliseconds + ) + ) + + for url, payload, _is_milliseconds in global_token_endpoints: + try: + clean_payload = {k: v for k, v in payload.items() if v is not None} + + response = requests.post( + url, + headers=headers, + json=clean_payload, + timeout=30, + verify=self.ssl_verify, + ) + + response.raise_for_status() + result = response.json() + + if result.get("error") is True: + error_msg = result.get("message", "Unknown error") + logger.debug(f"Global token generation failed: {error_msg}") + continue + + token = result.get("token") + expiration = result.get("expiration") + + if token: + logger.info( + f"Successfully obtained global token from {url} " + f"(graph '{graph_name}' may not exist yet, using global token). " + f"Expiration: {expiration or 'not provided'}" + ) + return (token, expiration) + + except Exception as e: + logger.debug(f"Error trying global token endpoint {url}: {e}") + continue + + # All endpoints failed (including global token fallback) + error_msg = f"Failed to get token from secret after trying {len(endpoints_to_try)} endpoint(s)" + if all_404_errors and graph_name: + error_msg += f" (all returned 404, graph '{graph_name}' may not exist yet)" + if last_error: + error_msg += f": {last_error}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + def _get_version(self) -> str | None: + """ + Get TigerGraph version using REST API. + + Tries multiple endpoints in order: + 1. GET /gsql/v1/version (GSQL server, port 14240) - primary for TG 4+ + 2. GET /version (REST++ server, port 9000) - fallback for older versions + + Note: The /version endpoint does NOT exist on GSQL port (14240). + It only exists on REST++ port (9000) for older versions. + + Returns: + Version string (e.g., "4.2.1") or None if detection fails + """ + import re + + if self.config.gs_port is None: + raise ValueError("gs_port must be set in config for version detection") + + # Try GSQL endpoint first (primary for TigerGraph 4+) + # Note: /gsql/v1/version exists on GSQL port, but /version does NOT + # Response format: plain text like "GSQL version: 4.2.2\n" + gsql_url = f"{self.gsql_url}/gsql/v1/version" + headers = self._get_auth_headers(use_basic_auth=True) + + try: + response = requests.get( + gsql_url, headers=headers, timeout=10, verify=self.ssl_verify + ) + response.raise_for_status() + + if not response.text.strip(): + # Empty response + logger.debug("GSQL version endpoint returned empty response") + raise ValueError("Empty response from GSQL version endpoint") + + # GSQL /gsql/v1/version returns plain text, not JSON + # Format: "GSQL version: 4.2.2\n" or similar + response_text = response.text.strip() + + # Try to parse version from text response + # Format: "GSQL version: 4.2.2" or "version: 4.2.2" or "4.2.2" + version_match = re.search( + r"version:\s*(\d+)\.(\d+)\.(\d+)", response_text, re.IGNORECASE + ) + if version_match: + version_str = f"{version_match.group(1)}.{version_match.group(2)}.{version_match.group(3)}" + logger.debug( + f"Detected TigerGraph version: {version_str} from GSQL endpoint (text format)" + ) + return version_str + + # Try alternative: just look for version number pattern + version_match = re.search(r"(\d+)\.(\d+)\.(\d+)", response_text) + if version_match: + version_str = f"{version_match.group(1)}.{version_match.group(2)}.{version_match.group(3)}" + logger.debug( + f"Detected TigerGraph version: {version_str} from GSQL endpoint (text format)" + ) + return version_str + + # If text parsing failed, try JSON as fallback (some versions might return JSON) + try: + result = response.json() + message = result.get("message", "") + if message: + version_match = re.search(r"release_(\d+)\.(\d+)\.(\d+)", message) + if version_match: + version_str = f"{version_match.group(1)}.{version_match.group(2)}.{version_match.group(3)}" + logger.debug( + f"Detected TigerGraph version: {version_str} from GSQL endpoint (JSON format)" + ) + return version_str + except ValueError: + # Not JSON, that's fine - we already tried text parsing + pass + + except Exception as e: + logger.debug(f"Failed to get version from GSQL endpoint: {e}") + + # Fallback: Try REST++ /version endpoint (for older versions or if GSQL endpoint fails) + # Note: /version only exists on REST++ port (9000), not GSQL port (14240) + try: + # Use REST++ port if different from GSQL port + restpp_port = self.config.port if self.config.port else self.config.gs_port + if restpp_port is None: + return None + + restpp_url = f"{self.config.url_without_port}:{restpp_port}/version" + headers = self._get_auth_headers(use_basic_auth=True) + + response = requests.get( + restpp_url, headers=headers, timeout=10, verify=self.ssl_verify + ) + response.raise_for_status() + + # Check content type and response + if not response.text.strip(): + logger.debug("REST++ version endpoint returned empty response") + return None + + try: + result = response.json() + except ValueError: + logger.debug( + f"REST++ version endpoint returned non-JSON response: " + f"status={response.status_code}, text={response.text[:200]}" + ) + return None + + # Parse version from REST++ response + message = result.get("message", "") + if message: + version_match = re.search(r"release_(\d+)\.(\d+)\.(\d+)", message) + if version_match: + version_str = f"{version_match.group(1)}.{version_match.group(2)}.{version_match.group(3)}" + logger.debug( + f"Detected TigerGraph version: {version_str} from REST++ endpoint" + ) + return version_str + + except Exception as e: + logger.debug(f"Failed to get version from REST++ endpoint: {e}") + + return None + + def _execute_gsql(self, gsql_command: str) -> str: + """ + Execute GSQL command using REST API. + + For TigerGraph 4.0-4.2.1, uses POST /gsql/v1/statements endpoint. + + Note: GSQL endpoints require Basic Auth (username/password), not Bearer tokens. + + Args: + gsql_command: GSQL command string to execute + + Returns: + Response string from GSQL execution + """ + url = f"{self.gsql_url}/gsql/v1/statements" + auth_headers = self._get_auth_headers(use_basic_auth=True) + headers = { + "Content-Type": "text/plain", + **auth_headers, + } + + # Debug: Log if Authorization header is missing + if "Authorization" not in headers: + logger.error( + f"No Authorization header generated. " + f"Username: {self.config.username}, Password: {'***' if self.config.password else None}" + ) + + try: + response = requests.post( + url, + headers=headers, + data=gsql_command, + timeout=120, + verify=self.ssl_verify, + ) + response.raise_for_status() + + # Try to parse JSON response, fallback to text + try: + result = response.json() + # Extract message or result from JSON response + if isinstance(result, dict): + return result.get("message", str(result)) + return str(result) + except ValueError: + # Not JSON, return text + return response.text + except requests_exceptions.HTTPError as e: + error_msg = str(e) + # Try to extract error message from response + try: + error_details = e.response.json() if e.response else {} + error_msg = error_details.get("message", error_msg) + except Exception: + pass + raise RuntimeError(f"GSQL execution failed: {error_msg}") from e + + def _get_vertex_types(self, graph_name: str | None = None) -> list[str]: + """ + Get list of vertex types using GSQL. + + Args: + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + List of vertex type names + """ + graph_name = graph_name or self.graphname + try: + result = self._execute_gsql(f"USE GRAPH {graph_name}\nSHOW VERTEX *") + # Parse GSQL output using the proper parser + if isinstance(result, str): + return self._parse_show_output(result, "VERTEX") + return [] + except Exception as e: + logger.debug(f"Failed to get vertex types via GSQL: {e}") + return [] + + def _get_edge_types(self, graph_name: str | None = None) -> list[str]: + """ + Get list of edge types using GSQL. + + Args: + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + List of edge type names + """ + graph_name = graph_name or self.graphname + try: + result = self._execute_gsql(f"USE GRAPH {graph_name}\nSHOW EDGE *") + # Parse GSQL output using the proper parser + if isinstance(result, str): + # _parse_show_edge_output returns list of tuples (edge_name, is_directed) + # Extract just the edge names + edge_tuples = self._parse_show_edge_output(result) + return [edge_name for edge_name, _ in edge_tuples] + return [] + except Exception as e: + logger.debug(f"Failed to get edge types via GSQL: {e}") + return [] + + def _get_installed_queries(self, graph_name: str | None = None) -> list[str]: + """ + Get list of installed queries using GSQL. + + Args: + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + List of query names + """ + graph_name = graph_name or self.graphname + try: + result = self._execute_gsql(f"USE GRAPH {graph_name}\nSHOW QUERY *") + # Parse GSQL output to extract query names + queries = [] + if isinstance(result, str): + lines = result.split("\n") + for line in lines: + line = line.strip() + if line and not line.startswith("#") and not line.startswith("USE"): + # Query names are typically on their own lines + if line and not line.startswith("---"): + queries.append(line) + return queries if queries else [] + except Exception as e: + logger.debug(f"Failed to get installed queries via GSQL: {e}") + return [] + + def _run_installed_query( + self, query_name: str, graph_name: str | None = None, **kwargs: Any + ) -> dict[str, Any] | list[dict]: + """ + Run an installed query using REST API. + + Args: + query_name: Name of the installed query + graph_name: Name of the graph (defaults to self.graphname) + **kwargs: Query parameters + + Returns: + Query result (dict or list) + """ + graph_name = graph_name or self.graphname + endpoint = f"/query/{graph_name}/{query_name}" + return self._call_restpp_api(endpoint, method="POST", data=kwargs) + + def _upsert_vertex( + self, + vertex_type: str, + vertex_id: str, + attributes: dict[str, Any], + graph_name: str | None = None, + ) -> dict[str, Any] | list[dict]: + """ + Upsert a single vertex using REST API. + + Args: + vertex_type: Vertex type name + vertex_id: Vertex ID + attributes: Vertex attributes + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + Response from API + """ + graph_name = graph_name or self.graphname + endpoint = f"/graph/{graph_name}/vertices/{vertex_type}/{quote(str(vertex_id))}" + return self._call_restpp_api(endpoint, method="POST", data=attributes) + + def _upsert_edge( + self, + source_type: str, + source_id: str, + edge_type: str, + target_type: str, + target_id: str, + attributes: dict[str, Any] | None = None, + graph_name: str | None = None, + ) -> dict[str, Any] | list[dict]: + """ + Upsert a single edge using REST API. + + Args: + source_type: Source vertex type + source_id: Source vertex ID + edge_type: Edge type name + target_type: Target vertex type + target_id: Target vertex ID + attributes: Edge attributes (optional) + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + Response from API + """ + graph_name = graph_name or self.graphname + endpoint = ( + f"/graph/{graph_name}/edges/{edge_type}/" + f"{source_type}/{quote(str(source_id))}/" + f"{target_type}/{quote(str(target_id))}" + ) + data = attributes if attributes else {} + return self._call_restpp_api(endpoint, method="POST", data=data) + + def _get_edges( + self, + source_type: str, + source_id: str, + edge_type: str | None = None, + graph_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Get edges from a vertex using REST API. + + Based on pyTigerGraph's getEdges() implementation. + Uses GET /graph/{graph}/edges/{source_vertex_type}/{source_vertex_id} endpoint. + + Args: + source_type: Source vertex type + source_id: Source vertex ID + edge_type: Edge type to filter by (optional, filtered client-side) + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + List of edge dictionaries + """ + graph_name = graph_name or self.graphname + + # Use the correct endpoint format matching pyTigerGraph's _prep_get_edges: + # GET /graph/{graph}/edges/{source_type}/{source_id} + # If edge_type is specified, append it: /graph/{graph}/edges/{source_type}/{source_id}/{edge_type} + if edge_type: + endpoint = f"/graph/{graph_name}/edges/{source_type}/{quote(str(source_id))}/{edge_type}" + else: + endpoint = ( + f"/graph/{graph_name}/edges/{source_type}/{quote(str(source_id))}" + ) + + result = self._call_restpp_api(endpoint, method="GET") + + # Parse REST++ API response format + # Response format: {"version": {...}, "error": false, "message": "", "results": [...]} + if isinstance(result, dict): + # Check for error first + if result.get("error") is True: + error_msg = result.get("message", "Unknown error") + logger.error(f"Error fetching edges: {error_msg}") + return [] + + # Extract results array + if "results" in result: + edges = result["results"] + else: + logger.debug( + f"Unexpected response format from edges endpoint: {result.keys()}" + ) + return [] + elif isinstance(result, list): + edges = result + else: + logger.debug( + f"Unexpected response type from edges endpoint: {type(result)}" + ) + return [] + + # Filter by edge_type if specified (client-side filtering) + # REST API endpoint doesn't support edge_type filtering directly + if edge_type and isinstance(edges, list): + edges = [ + e for e in edges if isinstance(e, dict) and e.get("e_type") == edge_type + ] + + return edges + + def _get_vertices_by_id( + self, vertex_type: str, vertex_id: str, graph_name: str | None = None + ) -> dict[str, dict[str, Any]]: + """ + Get vertex by ID using REST API. + + Args: + vertex_type: Vertex type name + vertex_id: Vertex ID + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + Dictionary mapping vertex_id to vertex data + """ + graph_name = graph_name or self.graphname + endpoint = f"/graph/{graph_name}/vertices/{vertex_type}/{quote(str(vertex_id))}" + result = self._call_restpp_api(endpoint, method="GET") + # Parse response format to match expected format + # Returns {vertex_id: {"attributes": {...}}} + if isinstance(result, dict): + if "results" in result: + # REST API format + results = result["results"] + if results and isinstance(results, list) and len(results) > 0: + vertex_data = results[0] + return { + vertex_id: {"attributes": vertex_data.get("attributes", {})} + } + elif vertex_id in result: + return {vertex_id: result[vertex_id]} + else: + # Try to extract vertex data + return {vertex_id: {"attributes": result.get("attributes", {})}} + return {} + + def _get_vertex_count(self, vertex_type: str, graph_name: str | None = None) -> int: + """ + Get vertex count using REST API. + + Args: + vertex_type: Vertex type name + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + Number of vertices + """ + graph_name = graph_name or self.graphname + endpoint = f"/graph/{graph_name}/vertices/{vertex_type}" + params = {"limit": "1", "count": "true"} + result = self._call_restpp_api(endpoint, method="GET", params=params) + # Parse count from response + if isinstance(result, dict): + return result.get("count", 0) + return 0 + + def _delete_vertices( + self, vertex_type: str, where: str | None = None, graph_name: str | None = None + ) -> dict[str, Any] | list[dict]: + """ + Delete vertices using REST API. + + Args: + vertex_type: Vertex type name + where: WHERE clause for filtering (optional) + graph_name: Name of the graph (defaults to self.graphname) + + Returns: + Response from API + """ + graph_name = graph_name or self.graphname + endpoint = f"/graph/{graph_name}/vertices/{vertex_type}" + params = {} + if where: + params["filter"] = where + return self._call_restpp_api(endpoint, method="DELETE", params=params) + def _call_restpp_api( self, endpoint: str, @@ -509,8 +1287,7 @@ def _ensure_graph_context(self, graph_name: str | None = None): """ Context manager that ensures graph context for metadata operations. - Updates conn.graphname for PyTigerGraph metadata operations that rely on it - (e.g., getVertexTypes(), getEdgeTypes()). + Stores graph name for operations that need it. Args: graph_name: Name of the graph to use. If None, uses self.config.database. @@ -524,14 +1301,14 @@ def _ensure_graph_context(self, graph_name: str | None = None): "Graph name must be provided via graph_name parameter or config.database" ) - old_graphname = self.conn.graphname - self.conn.graphname = graph_name + old_graphname = self.graphname + self.graphname = graph_name try: yield graph_name finally: # Restore original graphname - self.conn.graphname = old_graphname + self.graphname = old_graphname def graph_exists(self, name: str) -> bool: """ @@ -548,7 +1325,7 @@ def graph_exists(self, name: str) -> bool: bool: True if the graph exists, False otherwise """ try: - result = self.conn.gsql(f"USE GRAPH {name}") + result = self._execute_gsql(f"USE GRAPH {name}") result_str = str(result).lower() # If the graph doesn't exist, USE GRAPH returns an error message @@ -591,7 +1368,7 @@ def create_database( This method creates a graph with explicitly attached vertices and edges. Example: CREATE GRAPH researchGraph (author, paper, wrote) - This method uses the pyTigerGraph gsql() method to execute GSQL commands + This method uses direct REST API calls to execute GSQL commands that create and use the graph. Supported in TigerGraph version 4.2.2+. Args: @@ -600,8 +1377,12 @@ def create_database( edge_names: Optional list of edge type names to attach to the graph Raises: - Exception: If graph creation fails + RuntimeError: If graph already exists or creation fails """ + # Check if graph already exists first + if self.graph_exists(name): + raise RuntimeError(f"Graph '{name}' already exists") + try: # Build the list of types to include in CREATE GRAPH all_types = [] @@ -618,23 +1399,43 @@ def create_database( # Fallback to empty graph if no types provided gsql_commands = f"CREATE GRAPH {name}()\nUSE GRAPH {name}" - # Execute using pyTigerGraph's gsql method which handles authentication + # Execute using direct GSQL REST API which handles authentication logger.debug(f"Creating graph '{name}' via GSQL: {gsql_commands}") try: - result = self.conn.gsql(gsql_commands) + result = self._execute_gsql(gsql_commands) logger.info( f"Successfully created graph '{name}' with types {all_types}: {result}" ) + # Verify the result doesn't indicate the graph already existed + result_str = str(result).lower() + if ( + "already exists" in result_str + or "duplicate" in result_str + or "graph already exists" in result_str + ): + raise RuntimeError(f"Graph '{name}' already exists") return result + except RuntimeError: + # Re-raise RuntimeError as-is (already handled) + raise except Exception as e: error_msg = str(e).lower() - # Check if graph already exists (might be acceptable) - if "already exists" in error_msg or "duplicate" in error_msg: - logger.info(f"Graph '{name}' may already exist: {e}") - return str(e) + # Check if graph already exists - raise exception in this case + # TigerGraph may return various error messages for existing graphs + if ( + "already exists" in error_msg + or "duplicate" in error_msg + or "graph already exists" in error_msg + or "already exist" in error_msg + ): + logger.warning(f"Graph '{name}' already exists: {e}") + raise RuntimeError(f"Graph '{name}' already exists") from e logger.error(f"Failed to create graph '{name}': {e}") raise + except RuntimeError: + # Re-raise RuntimeError as-is + raise except Exception as e: logger.error(f"Error creating graph '{name}' via GSQL: {e}") raise @@ -667,7 +1468,7 @@ def delete_database(self, name: str): with self._ensure_graph_context(name): # Get all installed queries for this graph try: - queries = self.conn.getInstalledQueries() + queries = self._get_installed_queries() if queries: logger.info( f"Dropping {len(queries)} queries from graph '{name}'" @@ -676,7 +1477,7 @@ def delete_database(self, name: str): try: # Try DROP QUERY with IF EXISTS to avoid errors drop_query_cmd = f"USE GRAPH {name}\nDROP QUERY {query_name} IF EXISTS" - self.conn.gsql(drop_query_cmd) + self._execute_gsql(drop_query_cmd) logger.debug( f"Dropped query '{query_name}' from graph '{name}'" ) @@ -687,7 +1488,7 @@ def delete_database(self, name: str): drop_query_cmd = ( f"USE GRAPH {name}\nDROP QUERY {query_name}" ) - self.conn.gsql(drop_query_cmd) + self._execute_gsql(drop_query_cmd) logger.debug( f"Dropped query '{query_name}' from graph '{name}'" ) @@ -709,7 +1510,7 @@ def delete_database(self, name: str): try: # Try to drop queries using GSQL directly list_queries_cmd = f"USE GRAPH {name}\nSHOW QUERY *" - result = self.conn.gsql(list_queries_cmd) + result = self._execute_gsql(list_queries_cmd) # Parse result to get query names and drop them # This is a fallback if getInstalledQueries() doesn't work except Exception as e: @@ -723,10 +1524,10 @@ def delete_database(self, name: str): with self._ensure_graph_context(name): # Clear all vertices to remove dependencies try: - vertex_types = self.conn.getVertexTypes(force=True) + vertex_types = self._get_vertex_types() for v_type in vertex_types: try: - self.conn.delVertices(v_type) + self._delete_vertices(v_type) logger.debug( f"Cleared vertices of type '{v_type}' from graph '{name}'" ) @@ -742,7 +1543,7 @@ def delete_database(self, name: str): try: # Use the graph first to ensure we're working with the right graph drop_command = f"USE GRAPH {name}\nDROP GRAPH {name}" - result = self.conn.gsql(drop_command) + result = self._execute_gsql(drop_command) logger.info(f"Successfully dropped graph '{name}': {result}") return result except Exception as e: @@ -768,7 +1569,7 @@ def delete_database(self, name: str): with self._ensure_graph_context(name): # Disassociate edge types from graph (but don't drop them globally) try: - edge_types = self.conn.getEdgeTypes(force=True) + edge_types = self._get_edge_types() except Exception: edge_types = [] @@ -777,7 +1578,7 @@ def delete_database(self, name: str): # ALTER GRAPH requires USE GRAPH context try: drop_edge_cmd = f"USE GRAPH {name}\nALTER GRAPH {name} DROP DIRECTED EDGE {e_type}" - self.conn.gsql(drop_edge_cmd) + self._execute_gsql(drop_edge_cmd) logger.debug( f"Disassociated edge type '{e_type}' from graph '{name}'" ) @@ -789,7 +1590,7 @@ def delete_database(self, name: str): # Disassociate vertex types from graph (but don't drop them globally) try: - vertex_types = self.conn.getVertexTypes(force=True) + vertex_types = self._get_vertex_types() except Exception: vertex_types = [] @@ -797,7 +1598,7 @@ def delete_database(self, name: str): # Only clear data from this graph's vertices, don't drop vertex type globally # Clear data first to avoid dependency issues try: - self.conn.delVertices(v_type) + self._delete_vertices(v_type) logger.debug( f"Cleared vertices of type '{v_type}' from graph '{name}'" ) @@ -809,7 +1610,7 @@ def delete_database(self, name: str): # ALTER GRAPH requires USE GRAPH context try: drop_vertex_cmd = f"USE GRAPH {name}\nALTER GRAPH {name} DROP VERTEX {v_type}" - self.conn.gsql(drop_vertex_cmd) + self._execute_gsql(drop_vertex_cmd) logger.debug( f"Disassociated vertex type '{v_type}' from graph '{name}'" ) @@ -826,9 +1627,9 @@ def delete_database(self, name: str): # Fallback 2: Clear all data (if any remain) try: with self._ensure_graph_context(name): - vertex_types = self.conn.getVertexTypes() + vertex_types = self._get_vertex_types() for v_type in vertex_types: - result = self.conn.delVertices(v_type) + result = self._delete_vertices(v_type) logger.debug(f"Cleared vertices of type {v_type}: {result}") logger.info(f"Cleared all data from graph '{name}'") except Exception as e2: @@ -849,17 +1650,17 @@ def execute(self, query, **kwargs): if query.strip().upper().startswith("RUN "): # Extract query name and parameters query_name = query.strip()[4:].split("(")[0].strip() - result = self.conn.runInstalledQuery(query_name, **kwargs) + result = self._run_installed_query(query_name, **kwargs) else: # Execute as raw GSQL - result = self.conn.gsql(query) + result = self._execute_gsql(query) return result except Exception as e: logger.error(f"Error executing query '{query}': {e}") raise def close(self): - """Close connection - pyTigerGraph handles cleanup automatically.""" + """Close connection - no cleanup needed (using direct REST API calls).""" pass def _get_vertex_add_statement( @@ -1090,6 +1891,9 @@ def _define_schema_local(self, schema: Schema) -> None: if not graph_name: raise ValueError("Graph name (database) must be configured") + # Validate graph name + _validate_tigergraph_schema_name(graph_name, "graph") + vertex_config = schema.vertex_config edge_config = schema.edge_config @@ -1097,6 +1901,8 @@ def _define_schema_local(self, schema: Schema) -> None: # Vertices for vertex in vertex_config.vertices: + # Validate vertex name + _validate_tigergraph_schema_name(vertex.name, "vertex") stmt = self._get_vertex_add_statement(vertex, vertex_config) schema_change_stmts.append(stmt) @@ -1104,6 +1910,8 @@ def _define_schema_local(self, schema: Schema) -> None: edges_to_create = list(edge_config.edges_list(include_aux=True)) for edge in edges_to_create: edge.finish_init(vertex_config) + # Validate edge name + _validate_tigergraph_schema_name(edge.relation, "edge") stmt = self._get_edge_add_statement(edge) schema_change_stmts.append(stmt) @@ -1111,125 +1919,235 @@ def _define_schema_local(self, schema: Schema) -> None: logger.debug(f"No schema changes to apply for graph '{graph_name}'") return - job_name = f"schema_change_{graph_name}" + # Estimate the size of the GSQL command to determine if we need to split it + # Large SCHEMA_CHANGE JOBs (>30k chars) can cause parser failures with misleading errors + # like "Missing return statement" (which is actually a parser size limit issue) + # We'll split into batches based on configurable max_job_size (default: 1000) + MAX_JOB_SIZE = self.config.max_job_size + + # Calculate accurate size estimation + # Actual format: + # USE GRAPH {graph_name} + # CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {graph_name} { + # stmt1; + # stmt2; + # ... + # } + # RUN SCHEMA_CHANGE JOB {job_name} + # + # For N statements: + # - Base overhead: USE GRAPH line + CREATE line + closing brace + RUN line + newlines + # - Statement overhead: first gets " " + ";" (5 chars), others get ";\n " (5 chars each) + # - Total: base + sum(len(stmt)) + 5*N + + # Use worst-case job name length (multi-batch format) for conservative estimation + worst_case_job_name = ( + f"schema_change_{graph_name}_batch_999" # Use large number for worst case + ) + base_template = ( + f"USE GRAPH {graph_name}\n" + f"CREATE SCHEMA_CHANGE JOB {worst_case_job_name} FOR GRAPH {graph_name} {{\n" + f"}}\n" + f"RUN SCHEMA_CHANGE JOB {worst_case_job_name}" + ) + base_overhead = len(base_template) + + # Each statement adds 5 characters: first gets " " (4) + ";" (1), + # subsequent get ";\n " (5) between statements, final ";" (1) is included + # For N statements: 4 (first indent) + (N-1)*5 (separators) + 1 (final semicolon) = 5*N + num_statements = len(schema_change_stmts) + total_stmt_size = sum(len(stmt) for stmt in schema_change_stmts) + estimated_size = base_overhead + total_stmt_size + 5 * num_statements + + if estimated_size <= MAX_JOB_SIZE: + # Small enough for a single job + batches = [schema_change_stmts] + logger.info( + f"Applying schema change as single job (estimated size: {estimated_size} chars)" + ) + else: + # Split into multiple batches + # Calculate how many statements per batch + # For a batch of M statements: base_overhead + sum(len(stmt)) + 5*M <= MAX_JOB_SIZE + # So: sum(len(stmt)) + 5*M <= MAX_JOB_SIZE - base_overhead + # If avg_stmt_size = sum(len(stmt)) / M, then: M * (avg_stmt_size + 5) <= MAX_JOB_SIZE - base_overhead + avg_stmt_size = ( + total_stmt_size / num_statements if num_statements > 0 else 0 + ) + available_space = MAX_JOB_SIZE - base_overhead + stmts_per_batch = max(1, int(available_space / (avg_stmt_size + 5))) - # First, try to drop the job if it exists (ignore errors if it doesn't) - try: - drop_job_cmd = f"USE GRAPH {graph_name}\nDROP JOB {job_name}" - self.conn.gsql(drop_job_cmd) - logger.debug(f"Dropped existing schema change job '{job_name}'") - except Exception as e: - err_str = str(e).lower() - # Ignore errors if job doesn't exist - if "not found" in err_str or "could not be found" in err_str: - logger.debug( - f"Schema change job '{job_name}' does not exist, skipping drop" - ) - else: - logger.debug(f"Could not drop schema change job '{job_name}': {e}") + batches = [] + for i in range(0, len(schema_change_stmts), stmts_per_batch): + batches.append(schema_change_stmts[i : i + stmts_per_batch]) - # Combine into a single SCHEMA_CHANGE job - gsql_commands = [ - f"USE GRAPH {graph_name}", - f"CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {graph_name} {{", - " " + ";\n ".join(schema_change_stmts) + ";", - "}", - f"RUN SCHEMA_CHANGE JOB {job_name}", - ] + logger.info( + f"Large schema detected (estimated size: {estimated_size} chars). " + f"Splitting into {len(batches)} batches of ~{stmts_per_batch} statements each." + ) - full_gsql = "\n".join(gsql_commands) - logger.info(f"Applying local schema change for graph '{graph_name}'") - logger.info(f"GSQL command:\n{full_gsql}") - try: - result = self.conn.gsql(full_gsql) - logger.debug(f"Schema change result: {result}") + # Execute batches sequentially + for batch_idx, batch_stmts in enumerate(batches): + job_name = ( + f"schema_change_{graph_name}_batch_{batch_idx}" + if len(batches) > 1 + else f"schema_change_{graph_name}" + ) - # Check if result indicates an error - be more lenient with error detection - result_str = str(result) if result else "" - # Only treat as error if result explicitly contains error indicators - if ( - result - and result_str - and ( - "Encountered" in result_str - or "syntax error" in result_str.lower() - or "parse error" in result_str.lower() - ) - ): - error_msg = f"Schema change job reported a syntax/parse error: {result}" - logger.error(error_msg) - logger.error(f"GSQL command that failed: {full_gsql}") - raise RuntimeError(error_msg) + # First, try to drop the job if it exists (ignore errors if it doesn't) + try: + drop_job_cmd = f"USE GRAPH {graph_name}\nDROP JOB {job_name}" + self._execute_gsql(drop_job_cmd) + logger.debug(f"Dropped existing schema change job '{job_name}'") + except Exception as e: + err_str = str(e).lower() + # Ignore errors if job doesn't exist + if "not found" in err_str or "could not be found" in err_str: + logger.debug( + f"Schema change job '{job_name}' does not exist, skipping drop" + ) + else: + logger.debug(f"Could not drop schema change job '{job_name}': {e}") + + # Create and run SCHEMA_CHANGE job for this batch + gsql_commands = [ + f"USE GRAPH {graph_name}", + f"CREATE SCHEMA_CHANGE JOB {job_name} FOR GRAPH {graph_name} {{", + " " + ";\n ".join(batch_stmts) + ";", + "}", + f"RUN SCHEMA_CHANGE JOB {job_name}", + ] - # Verify that the schema was actually created by checking vertex and edge types - # Wait a moment for schema changes to propagate - import time + full_gsql = "\n".join(gsql_commands) + actual_size = len(full_gsql) - time.sleep(1.0) # Increased wait time + # Safety check: warn if actual size exceeds limit (indicates estimation error) + if actual_size > MAX_JOB_SIZE: + logger.warning( + f"Batch {batch_idx + 1} actual size ({actual_size} chars) exceeds limit ({MAX_JOB_SIZE} chars). " + f"This may cause parser errors. Consider reducing MAX_JOB_SIZE or improving estimation." + ) - with self._ensure_graph_context(graph_name): - vertex_types = self.conn.getVertexTypes(force=True) - edge_types = self.conn.getEdgeTypes(force=True) + logger.info( + f"Applying schema change batch {batch_idx + 1}/{len(batches)} for graph '{graph_name}' " + f"({len(batch_stmts)} statements, {actual_size} chars)" + ) + if actual_size < 5000: # Only log full command if it's reasonably small + logger.debug(f"GSQL command:\n{full_gsql}") + else: + logger.debug(f"GSQL command size: {actual_size} characters") - # Use vertex_dbname instead of v.name to match what TigerGraph actually creates - # vertex_dbname returns dbname if set, otherwise None - fallback to v.name if None - expected_vertex_types = set() - for v in vertex_config.vertices: - try: - dbname = vertex_config.vertex_dbname(v.name) - # If dbname is None, use vertex name - expected_name = dbname if dbname is not None else v.name - except (KeyError, AttributeError): - # Fallback to vertex name if vertex_dbname fails - expected_name = v.name - expected_vertex_types.add(expected_name) - - expected_edge_types = { - e.relation for e in edges_to_create if e.relation - } + try: + result = self._execute_gsql(full_gsql) + logger.debug(f"Schema change batch {batch_idx + 1} result: {result}") + + # Check if result indicates success - should contain "Local schema change succeeded." near the end + result_str = str(result) if result else "" + if result_str: + # Check for success message near the end (last 500 characters to handle long outputs) + result_tail = ( + result_str[-500:] if len(result_str) > 500 else result_str + ) + if "Local schema change succeeded." not in result_tail: + error_msg = ( + f"Schema change job batch {batch_idx + 1} did not report success. " + f"Expected 'Local schema change succeeded.' near the end of the result. " + f"Result (last 500 chars): {result_tail}" + ) + logger.error(error_msg) + logger.error(f"Full result: {result_str}") + raise RuntimeError(error_msg) - # Convert to sets for case-insensitive comparison - # TigerGraph may capitalize vertex names, so compare case-insensitively - vertex_types_lower = {vt.lower() for vt in vertex_types} - expected_vertex_types_lower = { - evt.lower() for evt in expected_vertex_types - } + # Check if result indicates an error - be more lenient with error detection + # Only treat as error if result explicitly contains error indicators + if ( + result + and result_str + and ( + "Encountered" in result_str + or "syntax error" in result_str.lower() + or "parse error" in result_str.lower() + or "missing return statement" in result_str.lower() + ) + ): + # "Missing return statement" is a misleading error - it's actually a parser size limit + # SCHEMA_CHANGE JOB doesn't require RETURN statements, so this indicates parser failure + if "missing return statement" in result_str.lower(): + error_msg = ( + f"Schema change job batch {batch_idx + 1} failed with parser error. " + f"This is likely due to the GSQL command size ({actual_size} chars) exceeding " + f"TigerGraph's parser limit (~30-40K chars). The 'Missing return statement' error " + f"is misleading - SCHEMA_CHANGE JOB doesn't require RETURN statements. " + f"Original error: {result}" + ) + else: + error_msg = f"Schema change job batch {batch_idx + 1} reported an error: {result}" - missing_vertices_lower = ( - expected_vertex_types_lower - vertex_types_lower + logger.error(error_msg) + logger.error( + f"GSQL command that failed (first 1000 chars):\n{full_gsql[:1000]}..." + ) + raise RuntimeError(error_msg) + except Exception as e: + logger.error( + f"Failed to execute schema change batch {batch_idx + 1}: {e}" ) - # Convert back to original case for error message - missing_vertices = { - evt - for evt in expected_vertex_types - if evt.lower() in missing_vertices_lower - } + raise - missing_edges = expected_edge_types - set(edge_types) + # Verify that the schema was actually created by checking vertex and edge types + # Wait a moment for schema changes to propagate (after all batches) + import time - if missing_vertices or missing_edges: - error_msg = ( - f"Schema change job completed but types were not created correctly. " - f"Missing vertex types: {missing_vertices}, " - f"Missing edge types: {missing_edges}. " - f"Created vertex types: {vertex_types}, " - f"Created edge types: {edge_types}. " - f"GSQL result: {result}" - ) - logger.error(error_msg) - logger.error(f"GSQL command that failed: {full_gsql}") - raise RuntimeError(error_msg) + time.sleep(1.0) # Increased wait time - logger.info( - f"Schema verified: {len(vertex_types)} vertex types, {len(edge_types)} edge types created" + with self._ensure_graph_context(graph_name): + vertex_types = self._get_vertex_types() + edge_types = self._get_edge_types() + + # Use vertex_dbname instead of v.name to match what TigerGraph actually creates + # vertex_dbname returns dbname if set, otherwise None - fallback to v.name if None + expected_vertex_types = set() + for v in vertex_config.vertices: + try: + dbname = vertex_config.vertex_dbname(v.name) + # If dbname is None, use vertex name + expected_name = dbname if dbname is not None else v.name + except (KeyError, AttributeError): + # Fallback to vertex name if vertex_dbname fails + expected_name = v.name + expected_vertex_types.add(expected_name) + + expected_edge_types = {e.relation for e in edges_to_create if e.relation} + + # Convert to sets for case-insensitive comparison + # TigerGraph may capitalize vertex names, so compare case-insensitively + vertex_types_lower = {vt.lower() for vt in vertex_types} + expected_vertex_types_lower = {evt.lower() for evt in expected_vertex_types} + + missing_vertices_lower = expected_vertex_types_lower - vertex_types_lower + # Convert back to original case for error message + missing_vertices = { + evt + for evt in expected_vertex_types + if evt.lower() in missing_vertices_lower + } + + missing_edges = expected_edge_types - set(edge_types) + + if missing_vertices or missing_edges: + error_msg = ( + f"Schema change job completed but types were not created correctly. " + f"Missing vertex types: {missing_vertices}, " + f"Missing edge types: {missing_edges}. " + f"Created vertex types: {vertex_types}, " + f"Created edge types: {edge_types}." ) - except RuntimeError: - # Re-raise RuntimeError as-is - raise - except Exception as e: - logger.error(f"Failed to apply local schema change: {e}") - logger.error(f"GSQL command was: {full_gsql}") - raise + logger.error(error_msg) + raise RuntimeError(error_msg) + + logger.info( + f"Schema verified: {len(vertex_types)} vertex types, {len(edge_types)} edge types created" + ) @_wrap_tg_exception def init_db(self, schema: Schema, clean_start: bool = False) -> None: @@ -1255,6 +2173,9 @@ def init_db(self, schema: Schema, clean_start: bool = False) -> None: self.config.database = graph_name logger.info(f"Using schema name '{graph_name}' from schema.general.name") + # Validate graph name + _validate_tigergraph_schema_name(graph_name, "graph") + try: if clean_start: try: @@ -1365,7 +2286,7 @@ def define_vertex_classes( # type: ignore[override] ] logger.info(f"Adding vertices locally to graph '{graph_name}'") - self.conn.gsql("\n".join(gsql_commands)) + self._execute_gsql("\n".join(gsql_commands)) def define_edge_classes(self, edges: list[Edge]): """Define TigerGraph edge types locally for the current graph. @@ -1400,7 +2321,7 @@ def define_edge_classes(self, edges: list[Edge]): ] logger.info(f"Adding edges locally to graph '{graph_name}'") - self.conn.gsql("\n".join(gsql_commands)) + self._execute_gsql("\n".join(gsql_commands)) def _format_vertex_fields(self, vertex: Vertex) -> str: """ @@ -1590,7 +2511,7 @@ def _add_index(self, obj_name, index: Index, is_vertex_index=True): # Step 1: Drop existing job if it exists (ignore errors) try: drop_job_cmd = f"USE GRAPH {graph_name}\nDROP JOB {job_name}" - self.conn.gsql(drop_job_cmd) + self._execute_gsql(drop_job_cmd) logger.debug(f"Dropped existing job '{job_name}'") except Exception as e: err_str = str(e).lower() @@ -1609,7 +2530,7 @@ def _add_index(self, obj_name, index: Index, is_vertex_index=True): logger.debug(f"Executing GSQL (create job): {create_job_cmd}") try: - result = self.conn.gsql(create_job_cmd) + result = self._execute_gsql(create_job_cmd) logger.debug(f"Created schema change job '{job_name}': {result}") except Exception as e: err = str(e).lower() @@ -1631,7 +2552,7 @@ def _add_index(self, obj_name, index: Index, is_vertex_index=True): logger.debug(f"Executing GSQL (run job): {run_job_cmd}") try: - result = self.conn.gsql(run_job_cmd) + result = self._execute_gsql(run_job_cmd) logger.debug( f"Ran schema change job '{job_name}', created index '{index_name}' on {obj_name}: {result}" ) @@ -1655,47 +2576,34 @@ def _add_index(self, obj_name, index: Index, is_vertex_index=True): def _parse_show_output(self, result_str: str, prefix: str) -> list[str]: """ - Generic parser for SHOW * output commands. + Parse SHOW * output to extract type names. - Extracts names from lines matching the pattern: "- PREFIX name(...)" + Looks for lines matching: "- PREFIX name(" or "PREFIX name(" Args: result_str: String output from SHOW * GSQL command - prefix: The prefix to look for (e.g., "VERTEX", "GRAPH", "JOB") + prefix: The prefix to look for (e.g., "VERTEX", "EDGE") Returns: List of extracted names """ + import re + names = [] - lines = result_str.split("\n") + # Pattern: "- VERTEX name(" or "VERTEX name(" + # Match lines that contain the prefix followed by a word (the name) and then "(" + pattern = rf"(?:^|\s)-?\s*{re.escape(prefix)}\s+(\w+)\s*\(" - for line in lines: + for line in result_str.split("\n"): line = line.strip() - # Skip empty lines and headers - if not line or line.startswith("*"): + if not line: continue - # Remove leading "- " if present - if line.startswith("- "): - line = line[2:].strip() - - # Look for prefix pattern - prefix_upper = prefix.upper() - if line.upper().startswith(f"{prefix_upper} "): - # Extract name (after prefix and before opening parenthesis or whitespace) - after_prefix = line[len(prefix_upper) + 1 :].strip() - # Name is the first word (before space or parenthesis) - if "(" in after_prefix: - name = after_prefix.split("(")[0].strip() - else: - # No parenthesis, take the first word - name = ( - after_prefix.split()[0].strip() - if after_prefix.split() - else None - ) - - if name: + # Use regex to find matches + match = re.search(pattern, line, re.IGNORECASE) + if match: + name = match.group(1) + if name and name not in names: names.append(name) return names @@ -1713,38 +2621,33 @@ def _parse_show_edge_output(self, result_str: str) -> list[tuple[str, bool]]: Returns: List of tuples (edge_name, is_directed) """ + import re + edge_types = [] - lines = result_str.split("\n") + # Pattern for DIRECTED EDGE: "- DIRECTED EDGE name(" + directed_pattern = r"(?:^|\s)-?\s*DIRECTED\s+EDGE\s+(\w+)\s*\(" + # Pattern for UNDIRECTED EDGE: "- UNDIRECTED EDGE name(" + undirected_pattern = r"(?:^|\s)-?\s*UNDIRECTED\s+EDGE\s+(\w+)\s*\(" - for line in lines: + for line in result_str.split("\n"): line = line.strip() - # Skip empty lines and headers - if not line or line.startswith("*"): + if not line: continue - # Remove leading "- " if present - if line.startswith("- "): - line = line[2:].strip() - - # Look for "DIRECTED EDGE" or "UNDIRECTED EDGE" pattern - is_directed = None - prefix = None - if "DIRECTED EDGE" in line.upper(): - prefix = "DIRECTED EDGE " - is_directed = True - elif "UNDIRECTED EDGE" in line.upper(): - prefix = "UNDIRECTED EDGE " - is_directed = False - - if prefix: - idx = line.upper().find(prefix) - if idx >= 0: - after_prefix = line[idx + len(prefix) :].strip() - # Extract name before opening parenthesis - if "(" in after_prefix: - edge_name = after_prefix.split("(")[0].strip() - if edge_name: - edge_types.append((edge_name, is_directed)) + # Check for DIRECTED EDGE + match = re.search(directed_pattern, line, re.IGNORECASE) + if match: + edge_name = match.group(1) + if edge_name: + edge_types.append((edge_name, True)) + continue + + # Check for UNDIRECTED EDGE + match = re.search(undirected_pattern, line, re.IGNORECASE) + if match: + edge_name = match.group(1) + if edge_name: + edge_types.append((edge_name, False)) return edge_types @@ -1820,7 +2723,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal try: # Use GSQL to list all graphs show_graphs_cmd = "SHOW GRAPH *" - result = self.conn.gsql(show_graphs_cmd) + result = self._execute_gsql(show_graphs_cmd) result_str = str(result) # Parse graph names using helper method @@ -1854,7 +2757,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal try: # Use GSQL to list all global edge types (not graph-scoped) show_edges_cmd = "SHOW EDGE *" - result = self.conn.gsql(show_edges_cmd) + result = self._execute_gsql(show_edges_cmd) result_str = str(result) # Parse edge types using helper method @@ -1868,7 +2771,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal # DROP EDGE works for both directed and undirected edges drop_edge_cmd = f"DROP EDGE {e_type}" logger.debug(f"Executing: {drop_edge_cmd}") - result = self.conn.gsql(drop_edge_cmd) + result = self._execute_gsql(drop_edge_cmd) logger.info( f"Successfully dropped edge type '{e_type}': {result}" ) @@ -1894,7 +2797,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal try: # Use GSQL to list all global vertex types (not graph-scoped) show_vertices_cmd = "SHOW VERTEX *" - result = self.conn.gsql(show_vertices_cmd) + result = self._execute_gsql(show_vertices_cmd) result_str = str(result) # Parse vertex types using helper method @@ -1907,7 +2810,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal try: # Clear data first to avoid dependency issues try: - result = self.conn.delVertices(v_type) + result = self._delete_vertices(v_type) logger.debug( f"Cleared data from vertex type '{v_type}': {result}" ) @@ -1919,7 +2822,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal # Drop vertex type drop_vertex_cmd = f"DROP VERTEX {v_type}" logger.debug(f"Executing: {drop_vertex_cmd}") - result = self.conn.gsql(drop_vertex_cmd) + result = self._execute_gsql(drop_vertex_cmd) logger.info( f"Successfully dropped vertex type '{v_type}': {result}" ) @@ -1944,7 +2847,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal try: # Use GSQL to list all global jobs show_jobs_cmd = "SHOW JOB *" - result = self.conn.gsql(show_jobs_cmd) + result = self._execute_gsql(show_jobs_cmd) result_str = str(result) # Parse job names using helper method @@ -1958,7 +2861,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal # DROP JOB works for all job types drop_job_cmd = f"DROP JOB {job_name}" logger.debug(f"Executing: {drop_job_cmd}") - result = self.conn.gsql(drop_job_cmd) + result = self._execute_gsql(drop_job_cmd) logger.info( f"Successfully dropped job '{job_name}': {result}" ) @@ -1988,7 +2891,7 @@ def delete_graph_structure(self, vertex_types=(), graph_names=(), delete_all=Fal with self._ensure_graph_context(): for class_name in cnames: try: - result = self.conn.delVertices(class_name) + result = self._delete_vertices(class_name) logger.debug( f"Deleted vertices from {class_name}: {result}" ) @@ -2210,12 +3113,12 @@ def _fallback_individual_upsert(self, docs, class_name, match_keys): vertex_id = self._extract_id(doc, match_keys) if vertex_id: clean_doc = self._clean_document(doc) - # Serialize datetime objects before passing to pyTigerGraph - # pyTigerGraph's upsertVertex expects JSON-serializable data + # Serialize datetime objects before passing to REST API + # REST API expects JSON-serializable data serialized_doc = json.loads( json.dumps(clean_doc, default=_json_serializer) ) - self.conn.upsertVertex(class_name, vertex_id, serialized_doc) + self._upsert_vertex(class_name, vertex_id, serialized_doc) except Exception as e: logger.error(f"Error upserting individual vertex {vertex_id}: {e}") @@ -2407,17 +3310,17 @@ def _fallback_individual_edge_upsert( if source_id and target_id: clean_edge_props = self._clean_document(edge_props) - # Serialize data for pyTigerGraph + # Serialize data for REST API serialized_props = json.loads( json.dumps(clean_edge_props, default=_json_serializer) ) - self.conn.upsertEdge( + self._upsert_edge( source_class, source_id, edge_type, target_class, target_id, - attributes=serialized_props, + serialized_props, ) except Exception as e: logger.error(f"Error upserting individual edge: {e}") @@ -2795,19 +3698,19 @@ def fetch_edges( **kwargs: Any, ) -> list[dict[str, Any]]: """ - Fetch edges from TigerGraph using pyTigerGraph's getEdges method. + Fetch edges from TigerGraph using REST API. In TigerGraph, you must know at least one vertex ID before you can fetch edges. - Uses pyTigerGraph's getEdges method which handles special characters in vertex IDs. + Uses REST API which handles special characters in vertex IDs. Args: from_type: Source vertex type (required) from_id: Source vertex ID (required) edge_type: Optional edge type to filter by - to_type: Optional target vertex type to filter by (not used in pyTigerGraph) - to_id: Optional target vertex ID to filter by (not used in pyTigerGraph) - filters: Additional query filters (not supported by pyTigerGraph getEdges) - limit: Maximum number of edges to return (not supported by pyTigerGraph getEdges) + to_type: Optional target vertex type to filter by (not used in REST API) + to_id: Optional target vertex ID to filter by (not used in REST API) + filters: Additional query filters (not supported by REST API) + limit: Maximum number of edges to return (not supported by REST API) return_keys: Keys to return (projection) unset_keys: Keys to exclude (projection) **kwargs: Additional parameters @@ -2821,34 +3724,122 @@ def fetch_edges( "from_type and from_id are required for fetching edges in TigerGraph" ) - # Use pyTigerGraph's getEdges method - # Signature: getEdges(sourceVertexType, sourceVertexId, edgeType=None) + # Use REST API to get edges # Returns: list of edge dictionaries logger.debug( - f"Fetching edges using pyTigerGraph: from_type={from_type}, from_id={from_id}, edge_type={edge_type}" + f"Fetching edges using REST API: from_type={from_type}, from_id={from_id}, edge_type={edge_type}" ) - # Handle None edge_type by passing empty string (default behavior) - edge_type_str = edge_type if edge_type is not None else "" - edges = self.conn.getEdges(from_type, from_id, edge_type_str, fmt="py") + # Handle None edge_type + edge_type_str = edge_type if edge_type is not None else None + edges = self._get_edges(from_type, from_id, edge_type_str) - # Parse pyTigerGraph response format - # getEdges returns list of dicts with format like: - # [{"e_type": "...", "from": {...}, "to": {...}, "attributes": {...}}, ...] - # Type annotation: result is list[dict[str, Any]] - # getEdges can return dict, str, or DataFrame, but with fmt="py" it returns dict + # Parse REST API response format + # _get_edges() returns list of edge dicts from REST++ API + # Format: [{"e_type": "...", "from_id": "...", "to_id": "...", "attributes": {...}}, ...] + # The REST API returns edges in a flat format with e_type, from_id, to_id, attributes if isinstance(edges, list): - # Type narrowing: after isinstance check, we know it's a list - # Use cast to help type checker understand the elements are dicts - result = cast(list[dict[str, Any]], edges) + # Process each edge to normalize format + result = [] + for edge in edges: + if isinstance(edge, dict): + # Normalize edge format - REST API returns flat structure + normalized_edge = {} + + # Extract edge type (rename e_type to edge_type for consistency) + normalized_edge["edge_type"] = edge.get( + "e_type", edge.get("edge_type", "") + ) + + # Extract from/to IDs and types + normalized_edge["from_id"] = edge.get("from_id", "") + normalized_edge["from_type"] = edge.get("from_type", "") + normalized_edge["to_id"] = edge.get("to_id", "") + normalized_edge["to_type"] = edge.get("to_type", "") + + # Handle nested "from"/"to" objects if present (some API versions) + if "from" in edge and isinstance(edge["from"], dict): + normalized_edge["from_id"] = edge["from"].get( + "id", + edge["from"].get("v_id", normalized_edge["from_id"]), + ) + normalized_edge["from_type"] = edge["from"].get( + "type", + edge["from"].get( + "v_type", normalized_edge["from_type"] + ), + ) + + if "to" in edge and isinstance(edge["to"], dict): + normalized_edge["to_id"] = edge["to"].get( + "id", edge["to"].get("v_id", normalized_edge["to_id"]) + ) + normalized_edge["to_type"] = edge["to"].get( + "type", + edge["to"].get("v_type", normalized_edge["to_type"]), + ) + + # Extract attributes and merge into normalized edge + attributes = edge.get("attributes", {}) + if attributes: + normalized_edge.update(attributes) + else: + # If no attributes key, include all other fields as attributes + for k, v in edge.items(): + if k not in ( + "e_type", + "edge_type", + "from", + "to", + "from_id", + "to_id", + "from_type", + "to_type", + "directed", + ): + normalized_edge[k] = v + + result.append(normalized_edge) elif isinstance(edges, dict): - # If it's a single dict, wrap it in a list - result = [cast(dict[str, Any], edges)] + # Single edge dict - normalize and wrap in list + normalized_edge = {} + normalized_edge["edge_type"] = edges.get( + "e_type", edges.get("edge_type", "") + ) + normalized_edge["from_id"] = edges.get("from_id", "") + normalized_edge["to_id"] = edges.get("to_id", "") + + if "from" in edges and isinstance(edges["from"], dict): + normalized_edge["from_id"] = edges["from"].get( + "id", edges["from"].get("v_id", normalized_edge["from_id"]) + ) + if "to" in edges and isinstance(edges["to"], dict): + normalized_edge["to_id"] = edges["to"].get( + "id", edges["to"].get("v_id", normalized_edge["to_id"]) + ) + + attributes = edges.get("attributes", {}) + if attributes: + normalized_edge.update(attributes) + else: + for k, v in edges.items(): + if k not in ( + "e_type", + "edge_type", + "from", + "to", + "from_id", + "to_id", + ): + normalized_edge[k] = v + + result = [normalized_edge] else: # Fallback for unexpected types result: list[dict[str, Any]] = [] + logger.debug(f"Unexpected edges type: {type(edges)}") - # Apply limit if specified (client-side since pyTigerGraph doesn't support it) + # Apply limit if specified (client-side since REST API doesn't support it) if limit is not None and limit > 0: result = result[:limit] @@ -2869,7 +3860,7 @@ def fetch_edges( return result except Exception as e: - logger.error(f"Error fetching edges via pyTigerGraph: {e}") + logger.error(f"Error fetching edges via REST API: {e}") raise def _parse_restpp_response( @@ -2954,7 +3945,7 @@ def fetch_present_documents( continue try: - vertex_data = self.conn.getVerticesById(class_name, vertex_id) + vertex_data = self._get_vertices_by_id(class_name, vertex_id) if vertex_data and vertex_id in vertex_data: # Extract requested keys vertex_attrs = vertex_data[vertex_id].get("attributes", {}) @@ -2997,7 +3988,7 @@ def aggregate( try: if aggregation_function == AggregationType.COUNT and discriminant is None: # Simple vertex count - count = self.conn.getVertexCount(class_name) + count = self._get_vertex_count(class_name) return [{"_value": count}] else: # Complex aggregations require custom GSQL queries @@ -3092,7 +4083,7 @@ def fetch_indexes(self, vertex_type: str | None = None): if vertex_type: vertex_types = [vertex_type] else: - vertex_types = self.conn.getVertexTypes(force=True) + vertex_types = self._get_vertex_types() for v_type in vertex_types: try: diff --git a/graflo/db/tigergraph/reserved_words.json b/graflo/db/tigergraph/reserved_words.json new file mode 100644 index 0000000..efdb56f --- /dev/null +++ b/graflo/db/tigergraph/reserved_words.json @@ -0,0 +1,291 @@ +{ + "version": "4.2", + "description": "TigerGraph GSQL reserved words and keywords that cannot be used as identifiers (vertex types, edge types, graph names, etc.)", + "reserved_words": { + "gsql_keywords": [ + "ACCUM", + "ADD", + "ALL", + "ALLOCATE", + "ALTER", + "AND", + "ANY", + "AS", + "ASC", + "AVG", + "BAG", + "BATCH", + "BETWEEN", + "BIGINT", + "BLOB", + "BOOL", + "BOOLEAN", + "BOTH", + "BREAK", + "BY", + "CALL", + "CASCADE", + "CASE", + "CATCH", + "CHAR", + "CHARACTER", + "CHECK", + "CLOB", + "COALESCE", + "COMPRESS", + "CONST", + "CONSTRAINT", + "CONTINUE", + "COST", + "COUNT", + "CREATE", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATETIME", + "DECIMAL", + "DELETE", + "DESC", + "DISTINCT", + "DO", + "DOUBLE", + "DROP", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXECUTE", + "EXISTS", + "FALSE", + "FILTER", + "FLOAT", + "FOR", + "FOREACH", + "FROM", + "FULL", + "FUNCTION", + "GRAPH", + "GROUP", + "GSQL_SYS_TAG", + "HAVING", + "IF", + "IN", + "INNER", + "INSERT", + "INT", + "INTEGER", + "INTERSECT", + "INTO", + "IS", + "JOB", + "JOIN", + "KEY", + "LEFT", + "LIKE", + "LIMIT", + "LIST", + "LOAD", + "LOG", + "MAP", + "MATCH", + "MAX", + "MIN", + "NOT", + "NULL", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OUTER", + "PINNED", + "POST_ACCUM", + "POST-ACCUM", + "PRIMARY", + "PRIMARY_ID", + "PRINT", + "PROXY", + "QUERY", + "QUIT", + "RAISE", + "RANGE", + "REDUCE", + "REPLACE", + "RESET_COLLECTION_ACCUM", + "RETURN", + "RETURNS", + "RIGHT", + "SAMPLE", + "SECOND", + "SELECT", + "SELECTVERTEX", + "SET", + "STATIC", + "STRING", + "SUM", + "TARGET", + "TEMP_TABLE", + "THEN", + "TO", + "TO_CSV", + "TO_DATETIME", + "TRAILING", + "TRANSLATESQL", + "TRIM", + "TRUE", + "TRY", + "TUPLE", + "TYPEDEF", + "UINT", + "UINT8", + "UINT16", + "UINT32", + "UINT8_T", + "UINT32_T", + "UINT64_T", + "UNION", + "UPDATE", + "UPSERT", + "USE", + "USING", + "VALUES", + "VERTEX", + "VERSION", + "WHEN", + "WHERE", + "WHILE", + "WITH", + "_INTERNAL_ATTR_TAG" + ], + "cpp_keywords": [ + "ALIGNAS", + "ALIGNOF", + "AND", + "AND_EQ", + "ASM", + "AUTO", + "BITAND", + "BITOR", + "BOOL", + "BREAK", + "CASE", + "CATCH", + "CHAR", + "CHAR16_T", + "CHAR32_T", + "CLASS", + "COMPL", + "CONCEPT", + "CONST", + "CONSTEXPR", + "CONST_CAST", + "CONTINUE", + "DECLTYPE", + "DEFAULT", + "DELETE", + "DO", + "DOUBLE", + "DYNAMIC_CAST", + "ELSE", + "ENUM", + "EXPLICIT", + "EXPORT", + "EXTERN", + "FALSE", + "FLOAT", + "FOR", + "FRIEND", + "GOTO", + "IF", + "INDEX", + "INLINE", + "INT", + "LONG", + "MUTABLE", + "NAMESPACE", + "NEW", + "NOEXCEPT", + "NOT", + "NOT_EQ", + "NULLPTR", + "OPERATOR", + "OR", + "OR_EQ", + "PRIVATE", + "PROTECTED", + "PUBLIC", + "REGISTER", + "REINTERPRET_CAST", + "RETURN", + "SHORT", + "SIGNED", + "SIZEOF", + "STATIC", + "STATIC_ASSERT", + "STATIC_CAST", + "STRUCT", + "SWITCH", + "TEMPLATE", + "THIS", + "THREAD_LOCAL", + "THROW", + "TRUE", + "TRY", + "TYPE", + "TYPEDEF", + "TYPEID", + "TYPENAME", + "UNION", + "UNSIGNED", + "USING", + "VIRTUAL", + "VOID", + "VOLATILE", + "WCHAR_T", + "WHILE", + "XOR", + "XOR_EQ" + ] + }, + "forbidden_prefixes": [ + "gsql_sys_" + ], + "invalid_characters": { + "description": "Characters that are problematic for TigerGraph API endpoints and GSQL identifiers", + "characters": [ + " ", + ".", + "-", + "@", + "#", + "$", + "%", + "^", + "&", + "*", + "(", + ")", + "[", + "]", + "{", + "}", + "|", + "\\", + "/", + "?", + "<", + ">", + ",", + ";", + ":", + "'", + "\"", + "`", + "~", + "!", + "=", + "+" + ], + "note": "TigerGraph identifiers should use alphanumeric characters and underscores only. Hyphens and dots are problematic for REST API endpoints." + } +} diff --git a/graflo/db/util.py b/graflo/db/util.py index 9bd91da..88108fc 100644 --- a/graflo/db/util.py +++ b/graflo/db/util.py @@ -1,12 +1,14 @@ """Database utilities for graph operations. This module provides utility functions for working with database operations, -including cursor handling and data serialization. +including cursor handling, data serialization, and schema management. Key Functions: - get_data_from_cursor: Retrieve data from a cursor with optional limit - serialize_value: Serialize non-serializable values (datetime, Decimal, etc.) - serialize_document: Serialize all values in a document dictionary + - load_reserved_words: Load reserved words for a database flavor + - sanitize_attribute_name: Sanitize attribute names to avoid reserved words Example: >>> # ArangoDB-specific AQL query (collection is ArangoDB terminology) @@ -15,10 +17,24 @@ >>> # Serialize datetime objects in a document >>> doc = {"id": 1, "created_at": datetime.now()} >>> serialized = serialize_document(doc) + >>> # Sanitize reserved words + >>> from graflo.onto import DBFlavor + >>> reserved = load_reserved_words(DBFlavor.TIGERGRAPH) + >>> sanitized = sanitize_attribute_name("SELECT", reserved) """ +from __future__ import annotations + +import json +import logging +from pathlib import Path + from arango.exceptions import CursorNextError +from graflo.onto import DBFlavor + +logger = logging.getLogger(__name__) + def get_data_from_cursor(cursor, limit=None): """Retrieve data from a cursor with optional limit. @@ -162,3 +178,115 @@ def json_serializer(obj): if not isinstance(obj, (list, dict)): raise TypeError(f"Type {type(obj)} not serializable") return serialized + + +def load_reserved_words(db_flavor: DBFlavor) -> set[str]: + """Load reserved words for a given database flavor. + + Args: + db_flavor: The database flavor to load reserved words for + + Returns: + Set of reserved words (uppercase) for the database flavor. + Returns empty set if no reserved words file exists or for unsupported flavors. + """ + if db_flavor != DBFlavor.TIGERGRAPH: + # Currently only TigerGraph has reserved words defined + return set() + + # Load TigerGraph reserved words + json_path = Path(__file__).parent / "tigergraph" / "reserved_words.json" + try: + with open(json_path, "r") as f: + reserved_data = json.load(f) + except FileNotFoundError: + logger.warning( + f"Could not find reserved_words.json at {json_path}, " + f"no reserved word sanitization will be performed" + ) + return set() + except json.JSONDecodeError as e: + logger.warning( + f"Could not parse reserved_words.json: {e}, " + f"no reserved word sanitization will be performed" + ) + return set() + + reserved_words = set() + reserved_words.update( + reserved_data.get("reserved_words", {}).get("gsql_keywords", []) + ) + reserved_words.update( + reserved_data.get("reserved_words", {}).get("cpp_keywords", []) + ) + + # Return uppercase set for case-insensitive comparison + return {word.upper() for word in reserved_words} + + +def sanitize_attribute_name( + name: str, reserved_words: set[str], suffix: str = "_attr" +) -> str: + """Sanitize an attribute name to avoid reserved words. + + This function deterministically replaces reserved attribute names with + modified versions. The algorithm: + 1. Checks if the name (case-insensitive) is in the reserved words set + 2. If reserved, appends a suffix (default: "_attr") + 3. If the modified name is still reserved, appends a numeric suffix + incrementally until a non-reserved name is found + + The algorithm is deterministic: the same input always produces the same output. + + Args: + name: The attribute name to sanitize + reserved_words: Set of reserved words (uppercase) to avoid + suffix: Suffix to append if name is reserved (default: "_attr") + + Returns: + Sanitized attribute name that is not in the reserved words set + + Examples: + >>> reserved = {"SELECT", "FROM", "WHERE"} + >>> sanitize_attribute_name("name", reserved) + 'name' + >>> sanitize_attribute_name("SELECT", reserved) + 'SELECT_attr' + >>> sanitize_attribute_name("SELECT_attr", reserved) + 'SELECT_attr_1' + """ + if not name: + return name + + if not reserved_words: + return name + + name_upper = name.upper() + + # If name is not reserved, return as-is + if name_upper not in reserved_words: + return name + + # Name is reserved, try appending suffix + candidate = f"{name}{suffix}" + candidate_upper = candidate.upper() + + # If candidate is not reserved, use it + if candidate_upper not in reserved_words: + return candidate + + # Candidate is also reserved, append numeric suffix + counter = 1 + while True: + candidate = f"{name}{suffix}_{counter}" + candidate_upper = candidate.upper() + if candidate_upper not in reserved_words: + return candidate + counter += 1 + # Safety check to avoid infinite loop (should never happen in practice) + if counter > 1000: + logger.warning( + f"Could not find non-reserved name for '{name}' after 1000 attempts, " + f"returning '{candidate}'" + ) + return candidate diff --git a/pyproject.toml b/pyproject.toml index e362a47..402e5dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ dependencies = [ "pydantic>=2.12.5", "pymgclient>=1.3.1", "python-arango>=8.1.2,<9", - "pytigergraph>=1.9.0", "redis>=5.0.0", "requests>=2.31.0", "sqlalchemy>=2.0.0", @@ -50,7 +49,7 @@ description = "A framework for transforming tabular (CSV, SQL) and hierarchical name = "graflo" readme = "README.md" requires-python = "~=3.10.0" -version = "1.3.12" +version = "1.3.13" [project.optional-dependencies] plot = [ diff --git a/test/db/connection/test_onto.py b/test/db/connection/test_onto.py index 2edac0d..9a36fad 100644 --- a/test/db/connection/test_onto.py +++ b/test/db/connection/test_onto.py @@ -152,7 +152,7 @@ class TestTigergraphConfigFromEnv: def test_from_env_default_prefix(self, monkeypatch): """Test default behavior without prefix - reads TIGERGRAPH_* variables.""" # Set environment variables with default prefix - monkeypatch.setenv("TIGERGRAPH_URI", "http://localhost:9000") + monkeypatch.setenv("TIGERGRAPH_URI", "http://localhost:14240") monkeypatch.setenv("TIGERGRAPH_USERNAME", "tigergraph_user") monkeypatch.setenv("TIGERGRAPH_PASSWORD", "tigergraph_pass") monkeypatch.setenv("TIGERGRAPH_DATABASE", "tigergraph_db") @@ -161,10 +161,11 @@ def test_from_env_default_prefix(self, monkeypatch): config = TigergraphConfig.from_env() # Verify values are read correctly - assert config.uri == "http://localhost:9000" + assert config.uri == "http://localhost:14240" assert config.username == "tigergraph_user" assert config.password == "tigergraph_pass" assert config.database == "tigergraph_db" + assert config.gs_port == 14240 def test_from_env_with_prefixes(self, monkeypatch): """Test behavior with two different prefixes - USER_ and LAKE_.""" @@ -195,3 +196,68 @@ def test_from_env_with_prefixes(self, monkeypatch): assert lake_config.username == "lake_tg" assert lake_config.password == "lake_tg_pass" assert lake_config.database == "lake_tg_db" + + def test_uri_without_scheme(self, monkeypatch): + """Test that URIs without scheme (host:port format) are handled correctly.""" + # Set URI without scheme + monkeypatch.setenv("TIGERGRAPH_URI", "localhost:14240") + monkeypatch.setenv("TIGERGRAPH_USERNAME", "testuser") + + config = TigergraphConfig.from_env() + + # Should normalize to include scheme + assert config.uri == "http://localhost:14240" + assert config.port == "14240" + assert config.hostname == "localhost" + assert config.protocol == "http" + + def test_port_conflict_warning(self, monkeypatch): + """Test that port conflicts between URI and gs_port generate a warning.""" + import warnings + + # Set both URI with port and gs_port with different value + monkeypatch.setenv("TIGERGRAPH_URI", "http://localhost:14240") + monkeypatch.setenv("TIGERGRAPH_GS_PORT", "9000") + monkeypatch.setenv("TIGERGRAPH_USERNAME", "testuser") + + # Capture warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + config = TigergraphConfig.from_env() + + # Should have a warning about port conflict + assert len(w) > 0 + assert any("Port conflict" in str(warning.message) for warning in w) + assert any("14240" in str(warning.message) for warning in w) + assert any("9000" in str(warning.message) for warning in w) + + # Port from URI should be preferred + assert config.uri == "http://localhost:14240" + assert config.port == "14240" + # gs_port should be updated to match URI port + assert config.gs_port == 14240 + + def test_port_no_conflict_when_matching(self, monkeypatch): + """Test that no warning is generated when URI port matches gs_port.""" + import warnings + + # Set both URI with port and gs_port with same value + monkeypatch.setenv("TIGERGRAPH_URI", "http://localhost:14240") + monkeypatch.setenv("TIGERGRAPH_GS_PORT", "14240") + monkeypatch.setenv("TIGERGRAPH_USERNAME", "testuser") + + # Capture warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + config = TigergraphConfig.from_env() + + # Should have no port conflict warnings + port_conflict_warnings = [ + warning for warning in w if "Port conflict" in str(warning.message) + ] + assert len(port_conflict_warnings) == 0 + + # Both should match + assert config.uri == "http://localhost:14240" + assert config.port == "14240" + assert config.gs_port == 14240 diff --git a/test/db/falkordbs/test_performance.py b/test/db/falkordbs/test_performance.py index c40bbc4..1a41feb 100644 --- a/test/db/falkordbs/test_performance.py +++ b/test/db/falkordbs/test_performance.py @@ -465,6 +465,7 @@ def reader(thread_id): assert len(errors) == 0 + @pytest.mark.slow def test_mixed_read_write_load(self, conn_conf, test_graph_name, clean_db): """Mixed concurrent read/write workload.""" _ = clean_db diff --git a/test/db/memgraphs/test_performance.py b/test/db/memgraphs/test_performance.py index 05b6d32..6fe09e0 100644 --- a/test/db/memgraphs/test_performance.py +++ b/test/db/memgraphs/test_performance.py @@ -468,6 +468,7 @@ def reader(thread_id): assert len(errors) == 0 + @pytest.mark.slow def test_mixed_read_write_load(self, conn_conf, test_graph_name, clean_db): """Mixed concurrent read/write workload.""" _ = clean_db diff --git a/test/db/tigergraphs/test_db_creation.py b/test/db/tigergraphs/test_db_creation.py index daf7e83..79b65f0 100644 --- a/test/db/tigergraphs/test_db_creation.py +++ b/test/db/tigergraphs/test_db_creation.py @@ -92,8 +92,8 @@ def test_schema_creation(conn_conf, test_graph_name, schema_obj): # getVertexTypes() and getEdgeTypes() require graph context via _ensure_graph_context with db_client._ensure_graph_context(test_graph_name): # Verify schema was created - vertex_types = db_client.conn.getVertexTypes(force=True) - edge_types = db_client.conn.getEdgeTypes(force=True) + vertex_types = db_client._get_vertex_types() + edge_types = db_client._get_edge_types() # Check expected types exist assert len(vertex_types) > 0, "No vertex types created" diff --git a/test/db/tigergraphs/test_db_index.py b/test/db/tigergraphs/test_db_index.py index c18552a..727c0c3 100644 --- a/test/db/tigergraphs/test_db_index.py +++ b/test/db/tigergraphs/test_db_index.py @@ -23,7 +23,7 @@ def test_create_vertex_index(conn_conf, schema_obj, test_graph_name): # Note: We use dbnames (Author, ResearchField) not vertex names (author, researchField) with ConnectionManager(connection_config=conn_conf) as db_client: # Verify vertex types exist (using dbnames) - vertex_types = db_client.conn.getVertexTypes(force=True) + vertex_types = db_client._get_vertex_types() assert "Author" in vertex_types, "Vertex type 'Author' not found" assert "ResearchField" in vertex_types, "Vertex type 'ResearchField' not found" @@ -38,21 +38,21 @@ def test_create_vertex_index(conn_conf, schema_obj, test_graph_name): # Clean up: drop job if it exists (ignore errors) try: - db_client.conn.gsql(drop_job_cmd) + db_client._execute_gsql(drop_job_cmd) except Exception: pass try: # Create the job - db_client.conn.gsql(create_job_cmd) + db_client._execute_gsql(create_job_cmd) # Run the job - db_client.conn.gsql(run_job_cmd) + db_client._execute_gsql(run_job_cmd) except Exception as e: # Clean up on failure try: - db_client.conn.gsql(drop_job_cmd) + db_client._execute_gsql(drop_job_cmd) except Exception: pass pytest.fail(f"Failed to create or run schema change job: {e}") diff --git a/test/db/tigergraphs/test_reserved_words.py b/test/db/tigergraphs/test_reserved_words.py new file mode 100644 index 0000000..12844c5 --- /dev/null +++ b/test/db/tigergraphs/test_reserved_words.py @@ -0,0 +1,446 @@ +"""Tests for reserved word sanitization in PostgreSQL schema inference. + +This module tests that reserved words are properly sanitized when inferring schemas +for TigerGraph, including: +- Vertex name sanitization +- Attribute name sanitization +- Edge reference updates +- Resource apply list updates + +Note: TigerGraph does NOT support quoted identifiers for reserved words (unlike PostgreSQL). +Therefore, we must sanitize reserved words by appending suffixes like "_vertex" for vertex +names and "_attr" for attribute names. This is different from PostgreSQL which allows +quoted identifiers like "SELECT" to use reserved words as column names. + +These tests create Schema objects directly without requiring PostgreSQL connections, +since the sanitization logic is independent of the database source. +""" + +import logging + +import pytest + +from graflo.architecture.edge import Edge, EdgeConfig, WeightConfig +from graflo.architecture.resource import Resource +from graflo.architecture.schema import Schema, SchemaMetadata +from graflo.architecture.vertex import Field, FieldType, Vertex, VertexConfig +from graflo.db.postgres.schema_inference import PostgresSchemaInferencer +from graflo.onto import DBFlavor + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def schema_with_reserved_words(): + """Create a Schema object with reserved words for testing.""" + # Create vertices with reserved words + version_vertex = Vertex( + name="version", # "version" is a TigerGraph reserved word + fields=[ + Field(name="id", type=FieldType.INT), + Field(name="SELECT", type=FieldType.STRING), # Reserved word + Field(name="FROM", type=FieldType.STRING), # Reserved word + Field(name="WHERE", type=FieldType.STRING), # Reserved word + Field(name="name", type=FieldType.STRING), + ], + ) + + users_vertex = Vertex( + name="users", + fields=[ + Field(name="id", type=FieldType.INT), + Field(name="name", type=FieldType.STRING), + ], + ) + + vertex_config = VertexConfig( + vertices=[version_vertex, users_vertex], db_flavor=DBFlavor.TIGERGRAPH + ) + + # Create edge with reserved word attributes + version_users_edge = Edge( + source="version", # Will be sanitized + target="users", + weights=WeightConfig( + direct=[ + Field(name="SELECT", type=FieldType.STRING), # Reserved word + Field(name="FROM", type=FieldType.STRING), # Reserved word + ] + ), + ) + + edge_config = EdgeConfig(edges=[version_users_edge]) + + # Create resource with vertex reference + version_resource = Resource( + resource_name="version", + apply=[{"vertex": "version"}], # Will be sanitized + ) + + schema = Schema( + general=SchemaMetadata(name="test_schema"), + vertex_config=vertex_config, + edge_config=edge_config, + resources=[version_resource], + ) + + return schema + + +def test_vertex_name_sanitization_for_tigergraph(schema_with_reserved_words): + """Test that vertex names with reserved words are sanitized for TigerGraph.""" + schema = schema_with_reserved_words + + # Create inferencer with TigerGraph flavor + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + + # Sanitize the schema + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Check that "version" vertex was sanitized to "version_vertex" + vertex_names = [v.name for v in sanitized_schema.vertex_config.vertices] + assert "version_vertex" in vertex_names, ( + f"Expected 'version_vertex' in vertices after sanitization, got {vertex_names}" + ) + assert "version" not in vertex_names, ( + f"Original reserved word 'version' should not be in vertices, got {vertex_names}" + ) + + # Verify the sanitized vertex exists and has correct fields + version_vertex = next( + v for v in sanitized_schema.vertex_config.vertices if v.name == "version_vertex" + ) + assert version_vertex is not None, "version_vertex should exist" + + # Check that attribute names were also sanitized + field_names = [f.name for f in version_vertex.fields] + assert "SELECT_attr" in field_names, ( + f"Expected 'SELECT_attr' in fields after sanitization, got {field_names}" + ) + assert "FROM_attr" in field_names, ( + f"Expected 'FROM_attr' in fields after sanitization, got {field_names}" + ) + assert "WHERE_attr" in field_names, ( + f"Expected 'WHERE_attr' in fields after sanitization, got {field_names}" + ) + assert "SELECT" not in field_names, ( + f"Original reserved word 'SELECT' should not be in fields, got {field_names}" + ) + + +def test_edge_references_updated_after_vertex_sanitization(schema_with_reserved_words): + """Test that edge source/target references are updated when vertex names are sanitized.""" + schema = schema_with_reserved_words + + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Find edge connecting version_vertex to users + edges = list(sanitized_schema.edge_config.edges_list()) + version_users_edge = next( + ( + e + for e in edges + if "version_vertex" in (e.source, e.target) + and "users" in (e.source, e.target) + ), + None, + ) + + assert version_users_edge is not None, ( + "Edge between version_vertex and users should exist" + ) + assert ( + version_users_edge.source == "version_vertex" + or version_users_edge.target == "version_vertex" + ), ( + f"Edge should reference 'version_vertex', got source={version_users_edge.source}, " + f"target={version_users_edge.target}" + ) + assert "version" not in ( + version_users_edge.source, + version_users_edge.target, + ), ( + f"Edge should not reference original 'version' name, " + f"got source={version_users_edge.source}, target={version_users_edge.target}" + ) + + # Check that edge weight attributes were sanitized + assert version_users_edge.weights is not None, "Edge should have weights" + assert version_users_edge.weights.direct is not None, ( + "Edge should have direct weights" + ) + weight_field_names = [f.name for f in version_users_edge.weights.direct] + assert "SELECT_attr" in weight_field_names, ( + f"Expected 'SELECT_attr' in weight fields, got {weight_field_names}" + ) + assert "FROM_attr" in weight_field_names, ( + f"Expected 'FROM_attr' in weight fields, got {weight_field_names}" + ) + + +def test_resource_apply_lists_updated_after_vertex_sanitization( + schema_with_reserved_words, +): + """Test that resource apply lists reference sanitized vertex names.""" + schema = schema_with_reserved_words + + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Find resource for version table + version_resource = next( + (r for r in sanitized_schema.resources if r.resource_name == "version"), None + ) + assert version_resource is not None, "version resource should exist" + + # Check that apply list references sanitized vertex name + apply_str = str(version_resource.apply) + assert "version_vertex" in apply_str, ( + f"Resource apply list should reference 'version_vertex', got {apply_str}" + ) + # Check the actual apply item + assert len(version_resource.apply) > 0, "Resource should have apply items" + apply_item = version_resource.apply[0] + assert isinstance(apply_item, dict), "Apply item should be a dict" + assert apply_item.get("vertex") == "version_vertex", ( + f"Apply item should reference 'version_vertex', got {apply_item}" + ) + + +def test_arango_no_sanitization(schema_with_reserved_words): + """Test that ArangoDB flavor does not sanitize names (no reserved words).""" + schema = schema_with_reserved_words + + # Change schema to ArangoDB flavor + schema.vertex_config.db_flavor = DBFlavor.ARANGO + schema.edge_config.db_flavor = DBFlavor.ARANGO + + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.ARANGO) + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Check that "version" vertex name is NOT sanitized for ArangoDB + vertex_names = [v.name for v in sanitized_schema.vertex_config.vertices] + assert "version" in vertex_names, ( + f"Expected 'version' in vertices for ArangoDB (no sanitization), got {vertex_names}" + ) + assert "version_vertex" not in vertex_names, ( + f"Should not have 'version_vertex' for ArangoDB, got {vertex_names}" + ) + + # Check that attribute names are NOT sanitized for ArangoDB + version_vertex = next( + v for v in sanitized_schema.vertex_config.vertices if v.name == "version" + ) + field_names = [f.name for f in version_vertex.fields] + assert "SELECT" in field_names, ( + f"Expected 'SELECT' in fields for ArangoDB (no sanitization), got {field_names}" + ) + assert "SELECT_attr" not in field_names, ( + f"Should not have 'SELECT_attr' for ArangoDB, got {field_names}" + ) + + +def test_multiple_reserved_words_sanitization(schema_with_reserved_words): + """Test that multiple reserved words are all sanitized correctly.""" + schema = schema_with_reserved_words + + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Check vertex name sanitization + vertex_names = [v.name for v in sanitized_schema.vertex_config.vertices] + assert "version_vertex" in vertex_names, ( + f"Expected 'version_vertex' after sanitization, got {vertex_names}" + ) + + # Check attribute name sanitization + version_vertex = next( + v for v in sanitized_schema.vertex_config.vertices if v.name == "version_vertex" + ) + field_names = [f.name for f in version_vertex.fields] + + # All reserved word attributes should be sanitized + reserved_attrs = ["SELECT", "FROM", "WHERE"] + sanitized_attrs = ["SELECT_attr", "FROM_attr", "WHERE_attr"] + + for reserved, sanitized in zip(reserved_attrs, sanitized_attrs): + assert sanitized in field_names, ( + f"Expected '{sanitized}' in fields after sanitization, got {field_names}" + ) + assert reserved not in field_names, ( + f"Original reserved word '{reserved}' should not be in fields, got {field_names}" + ) + + +def test_vertex_config_internal_mappings_updated(schema_with_reserved_words): + """Test that VertexConfig internal mappings are updated after sanitization.""" + schema = schema_with_reserved_words + + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Check that _vertices_map uses sanitized names + assert "version_vertex" in sanitized_schema.vertex_config._vertices_map, ( + "VertexConfig._vertices_map should contain sanitized vertex name" + ) + assert "version" not in sanitized_schema.vertex_config._vertices_map, ( + "VertexConfig._vertices_map should not contain original reserved word" + ) + + # Check that vertex_set uses sanitized names + assert "version_vertex" in sanitized_schema.vertex_config.vertex_set, ( + "VertexConfig.vertex_set should contain sanitized vertex name" + ) + assert "version" not in sanitized_schema.vertex_config.vertex_set, ( + "VertexConfig.vertex_set should not contain original reserved word" + ) + + # Verify we can look up the vertex by sanitized name + version_vertex = sanitized_schema.vertex_config._vertices_map["version_vertex"] + assert version_vertex is not None, ( + "Should be able to look up vertex by sanitized name" + ) + assert version_vertex.name == "version_vertex", ( + f"Vertex name should be 'version_vertex', got {version_vertex.name}" + ) + + +def test_edge_finish_init_after_sanitization(schema_with_reserved_words): + """Test that edges are properly re-initialized after vertex name sanitization.""" + schema = schema_with_reserved_words + + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Find edge connecting version_vertex to users + edges = list(sanitized_schema.edge_config.edges_list()) + version_users_edge = next( + ( + e + for e in edges + if "version_vertex" in (e.source, e.target) + and "users" in (e.source, e.target) + ), + None, + ) + + assert version_users_edge is not None, "Edge should exist" + + # After finish_init (called in _sanitize_schema_attributes), _source and _target should be set correctly + assert version_users_edge._source is not None, ( + "Edge._source should be set after finish_init" + ) + assert version_users_edge._target is not None, ( + "Edge._target should be set after finish_init" + ) + + # The _source and _target should reference the sanitized vertex name (via dbname lookup) + # Since dbname defaults to name, they should be the sanitized names + assert "version_vertex" in ( + version_users_edge._source, + version_users_edge._target, + ), ( + f"Edge internal references should use sanitized vertex name, " + f"got _source={version_users_edge._source}, _target={version_users_edge._target}" + ) + + +def test_tigergraph_schema_validation_with_reserved_words(schema_with_reserved_words): + """Test that sanitized schema has no reserved words.""" + schema = schema_with_reserved_words + + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Verify schema structure is valid + assert sanitized_schema is not None, "Schema should be sanitized" + assert sanitized_schema.vertex_config is not None, ( + "Schema should have vertex_config" + ) + assert sanitized_schema.edge_config is not None, "Schema should have edge_config" + + # Verify all vertex names are sanitized (no reserved words) + vertex_names = [v.name for v in sanitized_schema.vertex_config.vertices] + reserved_words = inferencer.reserved_words + for vertex_name in vertex_names: + assert vertex_name.upper() not in reserved_words, ( + f"Vertex name '{vertex_name}' should not be a reserved word" + ) + + # Verify all edge source/target names are sanitized + edges = list(sanitized_schema.edge_config.edges_list()) + for edge in edges: + assert edge.source.upper() not in reserved_words, ( + f"Edge source '{edge.source}' should not be a reserved word" + ) + assert edge.target.upper() not in reserved_words, ( + f"Edge target '{edge.target}' should not be a reserved word" + ) + + # Verify all attribute names are sanitized + for vertex in sanitized_schema.vertex_config.vertices: + for field in vertex.fields: + assert field.name.upper() not in reserved_words, ( + f"Field name '{field.name}' in vertex '{vertex.name}' should not be a reserved word" + ) + + # Verify edge weight names are sanitized + for edge in edges: + if edge.weights and edge.weights.direct: + for weight_field in edge.weights.direct: + assert weight_field.name.upper() not in reserved_words, ( + f"Weight field name '{weight_field.name}' in edge " + f"'{edge.source}' -> '{edge.target}' should not be a reserved word" + ) + + logger.info("Schema validation passed - all names are sanitized for TigerGraph") + + +def test_indirect_edge_by_reference_sanitization(): + """Test that indirect edge 'by' references are sanitized.""" + # Create schema with indirect edge + version_vertex = Vertex( + name="version", + fields=[Field(name="id", type=FieldType.INT)], + ) + + users_vertex = Vertex( + name="users", + fields=[Field(name="id", type=FieldType.INT)], + ) + + vertex_config = VertexConfig( + vertices=[version_vertex, users_vertex], db_flavor=DBFlavor.TIGERGRAPH + ) + + # Create indirect edge with 'by' referencing reserved word vertex + indirect_edge = Edge( + source="users", + target="users", + by="version", # Will be sanitized + ) + + edge_config = EdgeConfig(edges=[indirect_edge]) + + schema = Schema( + general=SchemaMetadata(name="test_schema"), + vertex_config=vertex_config, + edge_config=edge_config, + resources=[], + ) + + inferencer = PostgresSchemaInferencer(db_flavor=DBFlavor.TIGERGRAPH) + sanitized_schema = inferencer._sanitize_schema_attributes(schema) + + # Find the indirect edge + edges = list(sanitized_schema.edge_config.edges_list()) + indirect_edge_sanitized = edges[0] + + # Check that 'by' reference was sanitized + assert indirect_edge_sanitized.by == "version_vertex", ( + f"Indirect edge 'by' should reference 'version_vertex', got {indirect_edge_sanitized.by}" + ) + assert indirect_edge_sanitized.by != "version", ( + f"Indirect edge 'by' should not reference original 'version', got {indirect_edge_sanitized.by}" + ) diff --git a/uv.lock b/uv.lock index dad1a2e..e2641ea 100644 --- a/uv.lock +++ b/uv.lock @@ -11,20 +11,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] -[[package]] -name = "anyio" -version = "4.12.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "exceptiongroup" }, - { name = "idna" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/96/f0/5eb65b2bb0d09ac6776f2eb54adee6abe8228ea05b20a5ad0e4945de8aac/anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703", size = 228685, upload-time = "2026-01-06T11:45:21.246Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, -] - [[package]] name = "async-timeout" version = "5.0.1" @@ -253,7 +239,7 @@ wheels = [ [[package]] name = "graflo" -version = "1.3.12" +version = "1.3.13" source = { editable = "." } dependencies = [ { name = "click" }, @@ -269,7 +255,6 @@ dependencies = [ { name = "pydantic-settings" }, { name = "pymgclient" }, { name = "python-arango" }, - { name = "pytigergraph" }, { name = "redis" }, { name = "requests" }, { name = "sqlalchemy" }, @@ -319,7 +304,6 @@ requires-dist = [ { name = "pygraphviz", marker = "extra == 'plot'", specifier = ">=1.14" }, { name = "pymgclient", specifier = ">=1.3.1" }, { name = "python-arango", specifier = ">=8.1.2,<9" }, - { name = "pytigergraph", specifier = ">=1.9.0" }, { name = "redis", specifier = ">=5.0.0" }, { name = "requests", specifier = ">=2.31.0" }, { name = "sqlalchemy", specifier = ">=2.0.0" }, @@ -375,43 +359,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/83/3b1d03d36f224edded98e9affd0467630fc09d766c0e56fb1498cbb04a9b/griffe-1.15.0-py3-none-any.whl", hash = "sha256:6f6762661949411031f5fcda9593f586e6ce8340f0ba88921a0f2ef7a81eb9a3", size = 150705, upload-time = "2025-11-10T15:03:13.549Z" }, ] -[[package]] -name = "h11" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, -] - -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "h11" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "certifi" }, - { name = "httpcore" }, - { name = "idna" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, -] - [[package]] name = "identify" version = "2.6.15" @@ -1050,20 +997,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, ] -[[package]] -name = "pytigergraph" -version = "1.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "httpx" }, - { name = "requests" }, - { name = "validators" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/52/23/a037ea3e1cf44bfd85a2f8811a0011e84ca6641cf724527e47f6976d8017/pytigergraph-1.9.1.tar.gz", hash = "sha256:a3a2ce43999193bb2fec5eb904ddc5221972b342160d2cdc68cb7796685a6a79", size = 221961, upload-time = "2025-11-04T17:21:07.179Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/94/154efe5df5e297b044385d17d0d41b05c1d3acb1801e20c4c7581843e223/pytigergraph-1.9.1-py3-none-any.whl", hash = "sha256:484dcc821a347b89a5104a1d7d39ea655d0e4d798c0d18a21cf476cf99715e9e", size = 295991, upload-time = "2025-11-04T17:21:05.628Z" }, -] - [[package]] name = "pytz" version = "2025.2" @@ -1313,15 +1246,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl", hash = "sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd", size = 131182, upload-time = "2025-12-11T15:56:38.584Z" }, ] -[[package]] -name = "validators" -version = "0.35.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/53/66/a435d9ae49850b2f071f7ebd8119dd4e84872b01630d6736761e6e7fd847/validators-0.35.0.tar.gz", hash = "sha256:992d6c48a4e77c81f1b4daba10d16c3a9bb0dbb79b3a19ea847ff0928e70497a", size = 73399, upload-time = "2025-05-01T05:42:06.7Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/6e/3e955517e22cbdd565f2f8b2e73d52528b14b8bcfdb04f62466b071de847/validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd", size = 44712, upload-time = "2025-05-01T05:42:04.203Z" }, -] - [[package]] name = "virtualenv" version = "20.35.4"