diff --git a/.env.example b/.env.example index c90c0f0..e8f228e 100644 --- a/.env.example +++ b/.env.example @@ -1,48 +1,40 @@ -# TableScanner Environment Variables -# Copy this file to .env and fill in your actual values - -# ============================================================================= -# AUTHENTICATION -# ============================================================================= -# KBase Service Authentication Token -# For development testing, use your personal token from KBase -KB_SERVICE_AUTH_TOKEN=YOUR_KBASE_TOKEN_HERE - -# ============================================================================= -# CACHE SETTINGS -# ============================================================================= -# Cache directory for storing downloaded SQLite databases +# TableScanner Environment Configuration +# Copy this file to .env and fill in your values + +# REQUIRED: KBase authentication token for API access +# Get your token from: https://narrative.kbase.us/#auth/account +KB_SERVICE_AUTH_TOKEN=your_token_here + +# Cache directory for downloaded SQLite databases +# Default: /tmp/tablescanner_cache CACHE_DIR=/tmp/tablescanner_cache -# Maximum age of cached files in hours (default: 24) +# Maximum age of cached files in hours before re-download +# Default: 24 CACHE_MAX_AGE_HOURS=24 -# ============================================================================= -# KBASE SERVICE URLS -# ============================================================================= -# KBase Workspace Service URL -WORKSPACE_URL=https://appdev.kbase.us/services/ws - -# Base URL for KBase services -KBASE_ENDPOINT=https://appdev.kbase.us/services +# Enable debug mode with verbose logging +# Default: false +DEBUG=false -# KBase Blobstore/Shock service URL -BLOBSTORE_URL=https://appdev.kbase.us/services/shock-api +# KBase environment (appdev, ci, prod) +# Default: appdev +KB_ENV=appdev -# ============================================================================= -# APPLICATION SETTINGS -# ============================================================================= -# Enable debug mode (true/false) -DEBUG=false +# CORS allowed origins (JSON array format) +# Use ["*"] for all origins (development only) +# For production, specify exact origins: ["https://kbase.us", "https://narrative.kbase.us"] +CORS_ORIGINS=["*"] -# ============================================================================= -# TEST DATA (AppDev) -# ============================================================================= -# Test BERDLTable object: 76990/ADP1Test -# Test pangenome: GCF_000368685.1 -# Narrative: https://appdev.kbase.us/narrative/76990 +# KBase service URLs (usually don't need to change) WORKSPACE_URL=https://kbase.us/services/ws +KBASE_ENDPOINT=https://kbase.us/services +BLOBSTORE_URL=https://kbase.us/services/shock-api + +# Timeout settings (seconds) +DOWNLOAD_TIMEOUT_SECONDS=30.0 +KBASE_API_TIMEOUT_SECONDS=10.0 # Root path for proxy deployment (e.g., "/services/berdl_table_scanner") -# Leave empty if running at root path (i.e., "/") for local dev -ROOT_PATH=/services/berdl_table_scanner +# Leave empty for standalone deployment +KB_SERVICE_ROOT_PATH= diff --git a/.gitignore b/.gitignore index 5ec0315..e86f279 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,11 @@ trash/ docs/DEMO_SCRIPT.md docs/QUICKSTART.md docs/internal/ +DATABASE_SCHEMA.md +docs/personal/ +archive/ +docs/archive +dummy.db .DS_Store .idea @@ -31,3 +36,8 @@ lib/ # Cache directory cache/ + +# Project-specific artifacts +DATABASE_SCHEMA.md +*.webp +*.png diff --git a/Dockerfile b/Dockerfile index 67152c8..87ac643 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,21 @@ FROM ghcr.io/astral-sh/uv:python3.13-alpine -RUN apk --no-cache add curl +RUN apk --no-cache add curl git WORKDIR /app + +# Clone KBUtilLib (required external dependency) +# This creates /app/lib/KBUtilLib/ which is referenced by app/utils/workspace.py +RUN mkdir -p lib && \ + cd lib && \ + git clone https://github.com/cshenry/KBUtilLib.git && \ + cd .. + +# Add KBUtilLib to PYTHONPATH so it can be imported +ENV PYTHONPATH=/app/lib/KBUtilLib/src:${PYTHONPATH} + +# Copy application code and dependencies COPY app ./app COPY pyproject.toml /app/pyproject.toml RUN uv sync + EXPOSE 8000 -CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/README.md b/README.md index 4fc1c96..bd4697f 100644 --- a/README.md +++ b/README.md @@ -1,64 +1,135 @@ # TableScanner -TableScanner is a microservice for providing filtered and paginated access to tabular data stored in KBase. It uses local SQLite caching and indexing to provide fast access to large datasets without loading them entirely into memory. +TableScanner is a production-grade microservice for querying tabular data from KBase SQLite databases. It provides a comprehensive DataTables Viewer-compatible API with advanced query capabilities, type-aware filtering, and performance optimizations. -## Functionality +## Features -The service provides two methods for data access: -1. **Hierarchical REST**: Path-based endpoints for navigating objects and tables using GET requests. -2. **Flat POST**: A single endpoint (`/table-data`) that accepts a JSON payload for all query parameters. +- **Data Access**: Query SQLite databases from KBase objects and handles. +- **Local Uploads**: Upload local SQLite files (`.db`, `.sqlite`) for temporary access and testing. +- **User-Driven Auth**: Secure access where each user provides their own KBase token. +- **Type-Aware Filtering**: Automatic numeric conversion for proper filtering results. +- **Advanced Operators**: Support for `eq`, `ne`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, `not_in`, `between`, `is_null`, `is_not_null`. +- **Aggregations**: `GROUP BY` support with `count`, `sum`, `avg`, `min`, `max`, `stddev`, `variance`, `distinct_count`. +- **Table Statistics**: Rich column statistics including null counts, distinct counts, min/max/mean, and sample values. +- **Full-Text Search**: FTS5 support with automatic virtual table creation. +- **Automatic Operations**: Lifecycle management for connection pooling, query caching, and automatic disk cleanup. -## Architecture +## Quick Start -TableScanner operates as a bridge between KBase storage and client applications: -1. **Data Fetching**: Retrieves SQLite databases from the KBase Blobstore. -2. **Local Caching**: Stores databases locally to avoid repeated downloads. -3. **Indexing**: Creates indices on-the-fly for all table columns to optimize query performance. -4. **API Layer**: A FastAPI application that handles requests and executes SQL queries against the local cache. +### Production (Docker) -Technical details on race conditions and concurrency handling are available in [docs/ARCHITECTURE.md](docs/ARCHITECTURE.md). - -## Setup - -### Production ```bash docker compose up --build -d ``` -The service will be available at `http://localhost:8000`. API documentation is at `/docs`. +The service will be available at `http://localhost:8000`. API documentation is available at `/docs`. ### Development + ```bash cp .env.example .env -bash scripts/dev.sh +# Edit .env and set local development parameters +./scripts/dev.sh +``` + +## Authentication + +**Each user must provide their own KBase authentication token.** The service prioritizes user-provided tokens over shared service tokens. + +- **Header (Recommended)**: `Authorization: Bearer ` +- **Cookie**: `kbase_session=` (Used by DataTables Viewer) +- **Legacy Fallback**: `KB_SERVICE_AUTH_TOKEN` in `.env` is for **local testing only**. + +## API Usage Examples + +### 1. Upload a Local Database +Upload a SQLite file to receive a temporary handle. + +```bash +curl -X POST "http://localhost:8000/upload" \ + -F "file=@/path/to/my_data.db" +# Returns: {"handle": "local:a1b2-c3d4", ...} +``` + +### 2. List Tables +Works with KBase UPA or the local handle returned above. + +```bash +curl -H "Authorization: Bearer $KB_TOKEN" \ + "http://localhost:8000/object/76990/7/2/tables" ``` -## API Usage +### 3. Get Table Statistics +Retrieve detailed column metrics and sample values. -### Path-based REST -List tables: -`GET /object/{upa}/tables` +```bash +curl -H "Authorization: Bearer $KB_TOKEN" \ + "http://localhost:8000/object/76990/7/2/tables/Genes/stats" +``` + +### 4. Advanced Query (POST) +Comprehensive filtering and pagination. + +```bash +curl -X POST -H "Authorization: Bearer $KB_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "berdl_table_id": "76990/7/2", + "table_name": "Genes", + "limit": 100, + "filters": [ + {"column": "gene_length", "operator": "gt", "value": 1000} + ] + }' \ + "http://localhost:8000/table-data" +``` + +## Performance & Optimization + +- **Gzip Compression**: Compresses large responses (>1KB) to reduce bandwidth usage. +- **High-Performance JSON**: Uses `orjson` for fast JSON serialization. +- **Parallel Metadata Fetching**: Retrieves table metadata concurrently for fast listing. +- **Metadata Caching**: Caches object types locally to minimize KBase API calls. +- **Connection Pooling**: Reuses database connections for up to 10 minutes of inactivity. +- **Automatic Cleanup**: Expired caches are purged on startup. Uploaded databases automatically expire after **1 hour**. +- **Query Caching**: 5-minute TTL, max 1000 entries per instance. +- **Atomic Renaming**: Ensures file integrity during downloads and uploads. -Query table data: -`GET /object/{upa}/tables/{table_name}/data?limit=100` +## Documentation -### Flat POST -Query table data: -`POST /table-data` +- **[API Reference](docs/API.md)** - Complete API documentation with examples +- **[Architecture Dictionary](docs/ARCHITECTURE.md)** - System design and technical overview +- **[Deployment Readiness](docs/internal/DEPLOYMENT_READINESS.md)** - Checklist for production deployment +- **[Contributing Guide](docs/CONTRIBUTING.md)** - Setup, testing, and contribution guidelines -Payload example: -```json -{ - "berdl_table_id": "76990/7/2", - "table_name": "Genes", - "limit": 100 -} +## Testing + +```bash +# Set PYTHONPATH and run all tests +PYTHONPATH=. pytest + +# Run integration tests for local upload +PYTHONPATH=. pytest tests/integration/test_local_upload.py ``` ## Project Structure -- `app/`: Application logic and routes. -- `app/utils/`: Utilities for caching, SQLite operations, and Workspace integration. -- `docs/`: Technical documentation. -- `scripts/`: Client examples and utility scripts. + +``` +TableScanner/ +├── app/ +│ ├── main.py # FastAPI application & Lifecycle handlers +│ ├── routes.py # API endpoints & Auth logic +│ ├── models.py # Pydantic (V2) models +│ ├── config.py # Configuration (BaseSettings) +│ ├── services/ +│ │ ├── data/ # Query & Connection pooling logic +│ │ └── db_helper.py # Secure handle resolution +│ └── utils/ # SQLite, KBase Client, and Cache utilities +├── docs/ # API and Architectural documentation +├── tests/ # Unit & Integration tests +├── scripts/ # Development helper scripts +└── static/ # Static assets for the viewer +``` ## License -MIT License. + +MIT License diff --git a/app/config.py b/app/config.py index 37fb984..85dc665 100644 --- a/app/config.py +++ b/app/config.py @@ -5,7 +5,7 @@ All KBase service URLs and authentication settings are managed here. """ -from pydantic_settings import BaseSettings +from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic import Field @@ -15,13 +15,19 @@ class Settings(BaseSettings): Create a .env file based on .env.example to configure locally. """ + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=True + ) # ========================================================================== # AUTHENTICATION # ========================================================================== - KB_SERVICE_AUTH_TOKEN: str = Field( - ..., - description="KBase authentication token for API access" + KB_SERVICE_AUTH_TOKEN: str | None = Field( + default=None, + description="KBase authentication token for service-to-service API access (optional if using header/cookie auth)" ) # ========================================================================== @@ -59,14 +65,27 @@ class Settings(BaseSettings): default=False, description="Enable debug mode with verbose logging" ) + KB_ENV: str = Field( + default="appdev", + description="KBase environment (appdev, ci, prod)" + ) + CORS_ORIGINS: list[str] = Field( + default=["*"], + description="List of allowed origins for CORS. Use ['*'] for all." + ) # Root path for proxy deployment (e.g., "/services/berdl_table_scanner") ROOT_PATH: str = "" - - class Config: - env_file = ".env" - env_file_encoding = "utf-8" - case_sensitive = True + + # Timeout settings + DOWNLOAD_TIMEOUT_SECONDS: float = Field( + default=30.0, + description="Timeout in seconds for downloading databases" + ) + KBASE_API_TIMEOUT_SECONDS: float = Field( + default=10.0, + description="Timeout in seconds for KBase API calls" + ) # Global settings instance - loaded at module import diff --git a/app/config_constants.py b/app/config_constants.py new file mode 100644 index 0000000..2f5124e --- /dev/null +++ b/app/config_constants.py @@ -0,0 +1,20 @@ +""" +Configuration constants for TableScanner. +""" + +# Default values +DEFAULT_LIMIT = 100 +MAX_LIMIT = 500000 +DEFAULT_OFFSET = 0 +DEFAULT_SORT_ORDER = "ASC" + +# Cache settings +CACHE_TTL_SECONDS = 300 # 5 minutes +CACHE_MAX_ENTRIES = 1000 +INDEX_CACHE_TTL = 3600 # 1 hour + +# Timeout settings +KBASE_API_TIMEOUT_SECONDS = 30 + +# API Version +API_VERSION = "2.0" diff --git a/app/db/__init__.py b/app/db/__init__.py new file mode 100644 index 0000000..3038256 --- /dev/null +++ b/app/db/__init__.py @@ -0,0 +1,5 @@ +""" +Database module for Config Control Plane. + +Provides SQLite-based persistent storage for configuration records. +""" diff --git a/app/db/schema.sql b/app/db/schema.sql new file mode 100644 index 0000000..db58d0a --- /dev/null +++ b/app/db/schema.sql @@ -0,0 +1,107 @@ +-- ============================================================================= +-- Config Control Plane Database Schema +-- ============================================================================= +-- +-- Stores configuration records with full lifecycle support: +-- - draft: Work in progress, modifiable +-- - proposed: Ready for review, read-only +-- - published: Production-ready, locked +-- - deprecated: Marked for removal +-- - archived: Historical reference +-- +-- ============================================================================= + +-- Config records with full lifecycle support +CREATE TABLE IF NOT EXISTS config_records ( + id TEXT PRIMARY KEY, + source_type TEXT NOT NULL CHECK(source_type IN ('object', 'handle', 'builtin', 'custom')), + source_ref TEXT NOT NULL, + fingerprint TEXT, + version INTEGER NOT NULL DEFAULT 1, + + -- Lifecycle + state TEXT NOT NULL DEFAULT 'draft' CHECK(state IN ('draft', 'proposed', 'published', 'deprecated', 'archived')), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + created_by TEXT NOT NULL, + published_at TIMESTAMP, + published_by TEXT, + + -- Content + config_json TEXT NOT NULL, -- Full DataTypeConfig JSON + extends_id TEXT REFERENCES config_records(id), + overlays_json TEXT, + + -- Metadata + object_type TEXT, + ai_provider TEXT, + confidence REAL DEFAULT 1.0, + generation_time_ms REAL, + + -- Audit + change_summary TEXT, + change_author TEXT, + + -- Unique constraint on source_ref + fingerprint + version + UNIQUE(source_ref, fingerprint, version) +); + +-- Audit log for all changes +CREATE TABLE IF NOT EXISTS config_audit_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + config_id TEXT NOT NULL REFERENCES config_records(id) ON DELETE CASCADE, + action TEXT NOT NULL, + old_state TEXT, + new_state TEXT, + changed_by TEXT NOT NULL, + changed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + diff_json TEXT, + reason TEXT +); + +-- User overrides for personalized config preferences +CREATE TABLE IF NOT EXISTS user_config_overrides ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + source_ref TEXT NOT NULL, + override_config_json TEXT NOT NULL, -- Partial or full config override + priority INTEGER DEFAULT 100, -- Lower = higher priority + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + is_active BOOLEAN DEFAULT 1, + UNIQUE(user_id, source_ref) +); + +-- Config version history for diff visualization +CREATE TABLE IF NOT EXISTS config_version_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + config_id TEXT NOT NULL REFERENCES config_records(id) ON DELETE CASCADE, + version INTEGER NOT NULL, + config_json TEXT NOT NULL, + snapshot_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(config_id, version) +); + +-- Config test results for validation against real data +CREATE TABLE IF NOT EXISTS config_test_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + config_id TEXT NOT NULL REFERENCES config_records(id) ON DELETE CASCADE, + test_type TEXT NOT NULL, -- 'schema', 'data', 'performance', 'integration' + test_status TEXT NOT NULL, -- 'passed', 'failed', 'warning' + test_details_json TEXT, -- Detailed test results + tested_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + tested_by TEXT, + execution_time_ms REAL +); + +-- Indexes for fast lookups +CREATE INDEX IF NOT EXISTS idx_config_source ON config_records(source_type, source_ref); +CREATE INDEX IF NOT EXISTS idx_config_state ON config_records(state); +CREATE INDEX IF NOT EXISTS idx_config_fingerprint ON config_records(fingerprint); +CREATE INDEX IF NOT EXISTS idx_config_object_type ON config_records(object_type); +CREATE INDEX IF NOT EXISTS idx_config_extends ON config_records(extends_id); +CREATE INDEX IF NOT EXISTS idx_audit_config_id ON config_audit_log(config_id); +CREATE INDEX IF NOT EXISTS idx_audit_changed_at ON config_audit_log(changed_at); +CREATE INDEX IF NOT EXISTS idx_user_override_user ON user_config_overrides(user_id, source_ref); +CREATE INDEX IF NOT EXISTS idx_version_history_config ON config_version_history(config_id, version); +CREATE INDEX IF NOT EXISTS idx_test_results_config ON config_test_results(config_id, test_type); \ No newline at end of file diff --git a/app/exceptions.py b/app/exceptions.py new file mode 100644 index 0000000..1b707e9 --- /dev/null +++ b/app/exceptions.py @@ -0,0 +1,29 @@ +""" +Custom exceptions for TableScanner. +""" + +class TableScannerError(Exception): + """Base exception for TableScanner.""" + pass + +class TableNotFoundError(TableScannerError): + """Raised when a requested table does not exist.""" + def __init__(self, table_name: str, available_tables: list[str] | None = None): + msg = f"Table '{table_name}' not found" + if available_tables: + msg += f". Available: {available_tables}" + super().__init__(msg) + self.table_name = table_name + +class ColumnNotFoundError(TableScannerError): + """Raised when a requested column does not exist.""" + def __init__(self, column_name: str, table_name: str): + super().__init__(f"Column '{column_name}' not found in table '{table_name}'") + +class InvalidFilterError(TableScannerError): + """Raised when a filter configuration is invalid.""" + pass + +class DatabaseAccessError(TableScannerError): + """Raised when database file cannot be accessed or opened.""" + pass diff --git a/app/main.py b/app/main.py index 8ed4284..d7a84cd 100644 --- a/app/main.py +++ b/app/main.py @@ -7,15 +7,37 @@ Run with: uv run fastapi dev app/main.py """ +import os from pathlib import Path -from fastapi import FastAPI +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +from fastapi.security import HTTPBearer, APIKeyCookie +from fastapi.responses import JSONResponse, ORJSONResponse from app.routes import router from app.config import settings +from app.exceptions import TableNotFoundError, InvalidFilterError +from contextlib import asynccontextmanager +from app.utils.cache import cleanup_old_caches + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup: Clean up old caches and uploads + try: + cleanup_result = cleanup_old_caches(Path(settings.CACHE_DIR)) + # Use print or logger (logger is better if configured, but print works for startup) + import logging + logging.getLogger("uvicorn").info(f"Startup cleanup: {cleanup_result}") + except Exception as e: + import logging + logging.getLogger("uvicorn").warning(f"Startup cleanup failed: {e}") + yield + def create_app() -> FastAPI: """ Application factory function. @@ -30,23 +52,41 @@ def create_app() -> FastAPI: # Configure root_path for KBase dynamic services # KBase services are often deployed at /services/service_name # Pydantic Settings management or manual environ check can handle this. - import os + # Pydantic Settings management or manual environ check can handle this. root_path = os.environ.get("KB_SERVICE_ROOT_PATH", "") description = """ ## TableScanner API - A FastAPI service for querying BERDL table data from KBase. + A FastAPI service for querying tabular data from KBase SQLite databases. + Provides a comprehensive DataTables Viewer-compatible API with advanced + query capabilities, type-aware filtering, and performance optimizations. ### Features - - List pangenomes from BERDLTables objects - - List tables within a pangenome + - List tables in KBase objects - Query table data with filtering, sorting, and pagination - - Local caching for performance + - Type-aware filtering with automatic numeric conversion + - Advanced filter operators (eq, ne, gt, gte, lt, lte, like, ilike, in, not_in, between, is_null, is_not_null) + - Aggregations with GROUP BY support + - Full-text search (FTS5) + - Column statistics and schema information + - Query result caching for performance + - Local database caching + - Connection pooling with automatic lifecycle management ### Authentication - Pass your KBase auth token in the `Authorization` header. + Authentication can be provided in three ways (in order of priority): + 1. **Authorization header**: `Authorization: Bearer ` or `Authorization: ` + 2. **kbase_session cookie**: Set the `kbase_session` cookie with your KBase session token + 3. **Service token**: Configure `KB_SERVICE_AUTH_TOKEN` environment variable (for service-to-service calls) + + **Using Swagger UI**: Click the "Authorize" button (🔒) at the top of this page to enter your authentication token. + - For **BearerAuth**: Enter your KBase token (Bearer prefix is optional) + - For **CookieAuth**: Set the `kbase_session` cookie in your browser's developer tools + + Note: Cookie authentication may have limitations in Swagger UI due to browser security restrictions. + For best results, use the Authorization header method. """ tags_metadata = [ @@ -72,6 +112,23 @@ def create_app() -> FastAPI: }, ] + # Define security schemes for Swagger UI + # These will show up in the "Authorize" button + security_schemes = { + "BearerAuth": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "Token", + "description": "KBase authentication token. Enter your token (Bearer prefix optional)." + }, + "CookieAuth": { + "type": "apiKey", + "in": "cookie", + "name": "kbase_session", + "description": "KBase session cookie. Set this in your browser's developer tools." + } + } + app = FastAPI( title="TableScanner", root_path=root_path, @@ -80,14 +137,19 @@ def create_app() -> FastAPI: openapi_tags=tags_metadata, docs_url="/docs", redoc_url="/redoc", + lifespan=lifespan, + default_response_class=ORJSONResponse ) + # Enable Gzip compression for responses > 1KB + app.add_middleware(GZipMiddleware, minimum_size=1000) + # Add CORS middleware to allow cross-origin requests - # This is necessary when viewer.html is opened from file:// or different origin + # Update CORS middleware to allow requests from the frontend app.add_middleware( CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, + allow_origins=settings.CORS_ORIGINS, + allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) @@ -95,8 +157,89 @@ def create_app() -> FastAPI: # Store settings in app state for access throughout the application app.state.settings = settings + # Exception Handlers + @app.exception_handler(TableNotFoundError) + async def table_not_found_handler(request: Request, exc: TableNotFoundError): + return JSONResponse( + status_code=404, + content={"detail": str(exc)}, + ) + + @app.exception_handler(InvalidFilterError) + async def invalid_filter_handler(request: Request, exc: InvalidFilterError): + return JSONResponse( + status_code=422, + content={"detail": str(exc)}, + ) + + @app.exception_handler(Exception) + async def global_exception_handler(request: Request, exc: Exception): + """ + Global exception handler to catch any unhandled exceptions. + Provides detailed error messages in debug mode. + """ + import logging + import traceback + logger = logging.getLogger(__name__) + + # Log the full exception with traceback + logger.error(f"Unhandled exception: {exc}", exc_info=True) + + # Return detailed error in debug mode, generic message otherwise + if settings.DEBUG: + detail = f"{str(exc)}\n\nTraceback:\n{traceback.format_exc()}" + else: + detail = str(exc) if str(exc) else "An internal server error occurred" + + return JSONResponse( + status_code=500, + content={"detail": detail}, + ) + # Include API routes app.include_router(router) + + # Add security schemes to OpenAPI schema after routes are included + def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + from fastapi.openapi.utils import get_openapi + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + tags=tags_metadata, + ) + # Add security schemes to enable "Authorize" button in Swagger UI + openapi_schema.setdefault("components", {}) + openapi_schema["components"]["securitySchemes"] = security_schemes + + # Mark secured endpoints so Swagger UI "Try it out" + generated curl include auth headers. + # We only apply this to endpoints that actually use KBase auth. + secured_paths_prefixes = ( + "/object/", + ) + secured_exact_paths = { + "/table-data", + } + security_requirement = [{"BearerAuth": []}, {"CookieAuth": []}] + + for path, methods in (openapi_schema.get("paths") or {}).items(): + needs_security = path in secured_exact_paths or any( + path.startswith(prefix) for prefix in secured_paths_prefixes + ) + if not needs_security: + continue + for method, operation in (methods or {}).items(): + if method.lower() not in {"get", "post", "put", "patch", "delete", "options", "head"}: + continue + if isinstance(operation, dict): + operation.setdefault("security", security_requirement) + app.openapi_schema = openapi_schema + return app.openapi_schema + + app.openapi = custom_openapi # Mount static files directory for viewer.html static_dir = Path(__file__).parent.parent / "static" diff --git a/app/models.py b/app/models.py index f24fbfd..f9e68f4 100644 --- a/app/models.py +++ b/app/models.py @@ -2,6 +2,8 @@ from typing import Any, Literal from pydantic import BaseModel, Field +from app.config_constants import MAX_LIMIT + # ============================================================================= # REQUEST MODELS @@ -20,10 +22,10 @@ class TableDataRequest(BaseModel): description="BERDLTables object reference", examples=["76990/ADPITest"] ) - columns: str | None = Field( + columns: str | list[str] | None = Field( "all", - description="Comma-separated list of columns to select or 'all'", - examples=["gene_id, gene_name"] + description="Comma-separated list of columns to select or 'all', or list of strings", + examples=["gene_id, gene_name", ["gene_id", "gene_name"]] ) col_filter: dict[str, str] | None = Field( None, @@ -38,7 +40,7 @@ class TableDataRequest(BaseModel): limit: int = Field( 100, ge=1, - le=500000, + le=MAX_LIMIT, description="Maximum rows to return" ) offset: int = Field( @@ -73,6 +75,20 @@ class TableDataRequest(BaseModel): description="KBase environment" ) + # Advanced Features (System Overhaul) + filters: list[FilterRequest] | None = Field( + None, + description="Advanced filter specifications" + ) + aggregations: list[AggregationRequest] | None = Field( + None, + description="Aggregation specifications" + ) + group_by: list[str] | None = Field( + None, + description="Columns for GROUP BY clause" + ) + model_config = { "json_schema_extra": { "example": { @@ -109,6 +125,7 @@ class TableListResponse(BaseModel): """Response for listing tables in a database.""" berdl_table_id: str | None = Field(None, description="BERDLTable object reference", examples=["76990/7/2"]) handle_ref: str | None = Field(None, description="Blobstore handle reference", examples=["KBH_248028"]) + object_type: str | None = Field(None, description="KBase object type", examples=["KBaseGeneDataLakes.BERDLTables-1.0"]) tables: list[TableInfo] = Field( default_factory=list, description="List of available tables", @@ -118,50 +135,57 @@ class TableListResponse(BaseModel): ]] ) source: str | None = Field(None, description="Data source", examples=["Cache"]) - - -class PangenomeInfo(BaseModel): - """Information about a pangenome found in the SQLite file.""" - pangenome_taxonomy: str | None = Field(None, description="Taxonomy of the pangenome", examples=["Escherichia coli"]) - genome_count: int = Field(..., description="Number of genomes in the pangenome", examples=[42]) - source_berdl_id: str = Field(..., description="Source BERDL Table ID", examples=["76990/7/2"]) - user_genomes: list[str] = Field( - default_factory=list, - description="List of user-provided genome references", - examples=[["76990/1/1", "76990/2/1"]] + + # Viewer integration fields + config_fingerprint: str | None = Field( + None, + description="Fingerprint of cached viewer config (if exists)", + examples=["v1_auto_abc123def456"] ) - berdl_genomes: list[str] = Field( - default_factory=list, - description="List of BERDL/Datalake genome identifiers", - examples=[["GLM4:EC_G1", "GLM4:EC_G2"]] + config_url: str | None = Field( + None, + description="URL to retrieve generated viewer config", + examples=["/config/generated/v1_auto_abc123def456"] + ) + has_cached_config: bool = Field( + False, + description="Whether a viewer config is cached for this database" ) - handle_ref: str | None = Field( + + # Schema information for immediate viewer use + schemas: dict | None = Field( None, - description="Blobstore handle reference for SQLite database", - examples=["KBH_248028"] + description="Column types per table: {table_name: {column: sql_type}}" ) - - -class PangenomesResponse(BaseModel): - """Response for listing pangenomes from a BERDLTables object.""" - berdl_table_id: str | None = Field(None, description="BERDLTable object reference", examples=["76990/7/2"]) - pangenomes: list[PangenomeInfo] = Field( - default_factory=list, - description="List of available pangenomes", - examples=[[ - { - "pangenome_taxonomy": "Escherichia coli", - "genome_count": 42, - "source_berdl_id": "76990/7/2", - "handle_ref": "KBH_248028" - } - ]] + + # Fallback config availability + has_builtin_config: bool = Field( + False, + description="Whether a built-in fallback config exists for this object type" + ) + builtin_config_id: str | None = Field( + None, + description="ID of the matching built-in config" + ) + + # Database metadata + database_size_bytes: int | None = Field( + None, + description="Size of the SQLite database file in bytes" ) - pangenome_count: int = Field( - 1, - description="Total number of pangenomes", - examples=[1] + total_rows: int = Field( + 0, + description="Total rows across all tables" ) + + # Versioning for backward compatibility + api_version: str = Field( + "2.0", + description="API version for response format compatibility" + ) + + + class TableDataResponse(BaseModel): @@ -218,6 +242,45 @@ class TableDataResponse(BaseModel): None, description="Path to SQLite database" ) + object_type: str | None = Field( + None, + description="KBase object type", + examples=["KBaseGeneDataLakes.BERDLTables-1.0"] + ) + + # Enhanced Metadata (System Overhaul) + column_types: list[ColumnTypeInfo] | None = Field( + None, + description="Column type information" + ) + column_schema: list[ColumnTypeInfo] | None = Field( + None, + description="Alias for column_types (for compatibility)" + ) + query_metadata: QueryMetadata | None = Field( + None, + description="Query execution metadata" + ) + cached: bool = Field( + False, + description="Whether result was from cache" + ) + execution_time_ms: float | None = Field( + None, + description="Query execution time in milliseconds (alias for response_time_ms)" + ) + limit: int | None = Field( + None, + description="Limit applied" + ) + offset: int | None = Field( + None, + description="Offset applied" + ) + database_path: str | None = Field( + None, + description="Path to database file" + ) model_config = { "json_schema_extra": { @@ -296,4 +359,139 @@ class ServiceStatus(BaseModel): ..., description="Service status" ) - cache_dir: str = Field(..., description="Cache directory path") \ No newline at end of file + cache_dir: str = Field(..., description="Cache directory path") + + + + +# ============================================================================= +# DATATABLES VIEWER API MODELS +# ============================================================================= + + +class FilterRequest(BaseModel): + """Filter specification for DataTables Viewer API.""" + column: str = Field(..., description="Column name to filter") + operator: str = Field( + ..., + description="Filter operator: eq, ne, gt, gte, lt, lte, like, ilike, in, not_in, between, is_null, is_not_null" + ) + value: Any = Field(None, description="Filter value (or first value for 'between')") + value2: Any = Field(None, description="Second value for 'between' operator") + + +class AggregationRequest(BaseModel): + """Aggregation specification for DataTables Viewer API.""" + column: str = Field(..., description="Column name to aggregate") + function: str = Field( + ..., + description="Aggregation function: count, sum, avg, min, max, stddev, variance, distinct_count" + ) + alias: str | None = Field(None, description="Alias for aggregated column") + + +class TableDataQueryRequest(BaseModel): + """Enhanced table data query request for DataTables Viewer API.""" + berdl_table_id: str = Field(..., description="Database identifier (local/db_name format)") + table_name: str = Field(..., description="Table name") + limit: int = Field(100, ge=1, le=MAX_LIMIT, description="Maximum rows to return") + offset: int = Field(0, ge=0, description="Number of rows to skip") + columns: list[str] | None = Field(None, description="List of columns to select (None = all)") + sort_column: str | None = Field(None, description="Column to sort by") + sort_order: Literal["ASC", "DESC"] = Field("ASC", description="Sort direction") + search_value: str | None = Field(None, description="Global search term") + col_filter: dict[str, str] | None = Field(None, description="Simple column filters (legacy)") + filters: list[FilterRequest] | None = Field(None, description="Advanced filter specifications") + aggregations: list[AggregationRequest] | None = Field(None, description="Aggregation specifications") + group_by: list[str] | None = Field(None, description="Columns for GROUP BY clause") + + +class AggregationQueryRequest(BaseModel): + """Aggregation query request.""" + group_by: list[str] = Field(..., description="Columns for GROUP BY") + aggregations: list[AggregationRequest] = Field(..., description="Aggregation specifications") + filters: list[FilterRequest] | None = Field(None, description="Filter specifications") + limit: int = Field(100, ge=1, le=MAX_LIMIT, description="Maximum rows to return") + offset: int = Field(0, ge=0, description="Number of rows to skip") + + +class ColumnTypeInfo(BaseModel): + """Column type information.""" + name: str = Field(..., description="Column name") + type: str = Field(..., description="SQLite type (INTEGER, REAL, TEXT, etc.)") + notnull: bool = Field(False, description="Whether column is NOT NULL") + pk: bool = Field(False, description="Whether column is PRIMARY KEY") + dflt_value: Any = Field(None, description="Default value") + + +class QueryMetadata(BaseModel): + """Query execution metadata.""" + query_type: str = Field(..., description="Type of query: select, aggregate") + sql: str = Field(..., description="Executed SQL query") + filters_applied: int = Field(0, description="Number of filters applied") + has_search: bool = Field(False, description="Whether search was applied") + has_sort: bool = Field(False, description="Whether sorting was applied") + has_group_by: bool = Field(False, description="Whether GROUP BY was applied") + has_aggregations: bool = Field(False, description="Whether aggregations were applied") + + +class TableDataQueryResponse(BaseModel): + """Enhanced table data query response for DataTables Viewer API.""" + headers: list[str] = Field(..., description="Column names") + data: list[list[str]] = Field(..., description="Row data as list of lists") + total_count: int = Field(..., description="Total rows in table (before filtering)") + column_types: list[ColumnTypeInfo] = Field(..., description="Column type information") + query_metadata: QueryMetadata = Field(..., description="Query execution metadata") + cached: bool = Field(False, description="Whether result was from cache") + execution_time_ms: float = Field(..., description="Query execution time in milliseconds") + limit: int = Field(..., description="Limit applied") + offset: int = Field(..., description="Offset applied") + table_name: str = Field(..., description="Table name") + database_path: str = Field(..., description="Path to database file") + + +class TableSchemaInfo(BaseModel): + """Table schema information.""" + table: str = Field(..., description="Table name") + columns: list[ColumnTypeInfo] = Field(..., description="Column information") + indexes: list[dict[str, str]] = Field(default_factory=list, description="Index information") + + +class ColumnStatistic(BaseModel): + """Column statistics.""" + column: str = Field(..., description="Column name") + type: str = Field(..., description="Column type") + null_count: int = Field(0, description="Number of NULL values") + distinct_count: int = Field(0, description="Number of distinct values") + min: Any = Field(None, description="Minimum value") + max: Any = Field(None, description="Maximum value") + mean: float | None = Field(None, description="Mean value") + median: float | None = Field(None, description="Median value") + stddev: float | None = Field(None, description="Standard deviation") + sample_values: list[Any] = Field(default_factory=list, description="Sample values") + + +class TableStatisticsResponse(BaseModel): + """Table statistics response.""" + table: str = Field(..., description="Table name") + row_count: int = Field(..., description="Total row count") + columns: list[ColumnStatistic] = Field(..., description="Column statistics") + last_updated: int = Field(..., description="Last update timestamp (milliseconds since epoch)") + + +class HealthResponse(BaseModel): + """Health check response.""" + status: str = Field("ok", description="Service status") + timestamp: str = Field(..., description="ISO8601 timestamp") + mode: str = Field("cached_sqlite", description="Service mode") + data_dir: str = Field(..., description="Data directory path") + config_dir: str = Field(..., description="Config directory path") + cache: dict[str, Any] = Field(..., description="Cache information") + + +class UploadDBResponse(BaseModel): + """Response for database upload.""" + handle: str = Field(..., description="Handle for the uploaded database (e.g., local:uuid)") + filename: str = Field(..., description="Original filename") + size_bytes: int = Field(..., description="Size of the uploaded file in bytes") + message: str = Field(..., description="Status message") \ No newline at end of file diff --git a/app/routes.py b/app/routes.py index 12abb08..ad9440e 100644 --- a/app/routes.py +++ b/app/routes.py @@ -12,43 +12,58 @@ """ -import time +import asyncio import logging -from pathlib import Path -from uuid import uuid4 +import traceback +from datetime import datetime +from pathlib import Path as FilePath from app.utils.workspace import KBaseClient - -from fastapi import APIRouter, HTTPException, Header, Query +import shutil +from uuid import uuid4 +from fastapi import APIRouter, HTTPException, Header, Query, Cookie, Path, UploadFile, File +from app.exceptions import InvalidFilterError from app.models import ( TableDataRequest, TableDataResponse, - PangenomesResponse, - PangenomeInfo, TableListResponse, TableInfo, CacheResponse, ServiceStatus, TableSchemaResponse, + TableDataQueryRequest, + TableDataQueryResponse, + TableSchemaInfo, + TableStatisticsResponse, + AggregationQueryRequest, + HealthResponse, + FilterRequest, + AggregationRequest, + UploadDBResponse, ) from app.utils.workspace import ( - list_pangenomes_from_object, download_pangenome_db, + get_object_type, ) from app.utils.sqlite import ( list_tables, - get_table_data, get_table_columns, get_table_row_count, validate_table_exists, - ensure_indices, + get_table_statistics, ) -from app.utils.cache import ( - is_cached, - clear_cache, - list_cached_items, +from app.services.data.schema_service import get_schema_service +from app.services.data.connection_pool import get_connection_pool +from app.services.db_helper import ( + get_object_db_path, + ensure_table_accessible, ) +from app.utils.async_utils import run_sync_in_thread +from app.utils.request_utils import TableRequestProcessor from app.config import settings +from app.config_constants import MAX_LIMIT, DEFAULT_LIMIT +from app.utils.cache import load_cache_metadata, save_cache_metadata + # Configure module logger logger = logging.getLogger(__name__) @@ -61,25 +76,95 @@ # UTILITY FUNCTIONS # ============================================================================= -def get_auth_token(authorization: str | None = None) -> str: - """Extract auth token from header or settings.""" +def get_auth_token( + authorization: str | None = None, + kbase_session: str | None = None, + allow_anonymous: bool = False +) -> str: + """ + Extract auth token from header or cookie. + + **User Authentication Required**: Each user must provide their own KBase token. + The service does NOT use a shared token for production access. + + Priority: + 1. Authorization header (Bearer token or plain token) + 2. kbase_session cookie + 3. KB_SERVICE_AUTH_TOKEN from settings (LEGACY: for local testing only) + + Args: + authorization: Authorization header value + kbase_session: kbase_session cookie value + allow_anonymous: If True, returns empty string instead of raising 401 + + Returns: + Authentication token string + + Raises: + HTTPException: If no token is found and allow_anonymous is False + """ + # Priority 1: User-provided Authorization header if authorization: if authorization.startswith("Bearer "): return authorization[7:] return authorization + # Priority 2: User-provided kbase_session cookie + if kbase_session: + return kbase_session + + # Priority 3 (LEGACY/TESTING ONLY): Fall back to service token from settings + # This is kept for local development and testing purposes. + # In production deployments, users MUST provide their own token. if settings.KB_SERVICE_AUTH_TOKEN: + logger.debug("Using KB_SERVICE_AUTH_TOKEN fallback (legacy/testing mode)") return settings.KB_SERVICE_AUTH_TOKEN + # No token found + if allow_anonymous: + return "" + raise HTTPException( status_code=401, - detail="Authorization token required" + detail="Authorization required. Provide your KBase token via the Authorization header or kbase_session cookie." ) -def get_cache_dir() -> Path: + +async def _get_table_metadata(db_path, name, schema_service): + """ + Helper to fetch metadata for a single table. + """ + try: + # Run lightweight checks in thread + columns = await run_sync_in_thread(get_table_columns, db_path, name) + row_count = await run_sync_in_thread(get_table_row_count, db_path, name) + + display_name = name.replace("_", " ").title() + + # Build schema map + try: + table_schema = await run_sync_in_thread( + schema_service.get_table_schema, db_path, name + ) + schema_map = {col["name"]: col["type"] for col in table_schema["columns"]} + except Exception: + schema_map = {col: "TEXT" for col in columns} + + return { + "name": name, + "displayName": display_name, + "row_count": row_count, + "column_count": len(columns), + "schema": schema_map + } + except Exception: + logger.warning("Error getting table info for %s", name, exc_info=True) + return {"name": name, "displayName": name, "error_fallback": True} + +def get_cache_dir() -> FilePath: """Get configured cache directory.""" - return Path(settings.CACHE_DIR) + return FilePath(settings.CACHE_DIR) # ============================================================================= @@ -97,560 +182,536 @@ async def root(): ) -# ============================================================================= -# HANDLE-BASED ENDPOINTS (Primary REST API per diagram) -# /{handle_ref}/tables - List tables -# /{handle_ref}/tables/{table}/schema - Table schema -# /{handle_ref}/tables/{table}/data - Table data with pagination -# ============================================================================= - -@router.get("/handle/{handle_ref}/tables", tags=["Handle Access"], response_model=TableListResponse) -async def list_tables_by_handle( - handle_ref: str, - kb_env: str = Query("appdev", description="KBase environment"), - authorization: str | None = Header(None) -): +@router.get("/health", response_model=HealthResponse, tags=["General"]) +async def health_check(): """ - List all tables in a SQLite database accessed via handle reference. + Health check endpoint for DataTables Viewer API. - **Example:** - ```bash - curl -H "Authorization: $KB_TOKEN" \ - "https://appdev.kbase.us/services/berdl_table_scanner/handle/KBH_248028/tables" - ``` + Returns service status, cache information, and connection pool stats. """ + + try: - token = get_auth_token(authorization) - cache_dir = get_cache_dir() - - # Download SQLite from handle - client = KBaseClient(token, kb_env, cache_dir) - - # Cache path based on handle - safe_handle = handle_ref.replace(":", "_").replace("/", "_") - db_dir = cache_dir / "handles" - db_dir.mkdir(parents=True, exist_ok=True) - db_path = db_dir / f"{safe_handle}.db" - - # Atomic download to prevent race conditions - if not db_path.exists(): - temp_path = db_path.with_suffix(f".{uuid4().hex}.tmp") - try: - client.download_blob_file(handle_ref, temp_path) - temp_path.rename(db_path) - except Exception: - temp_path.unlink(missing_ok=True) - raise - - # List tables - table_names = list_tables(db_path) - tables = [] - for name in table_names: - try: - columns = get_table_columns(db_path, name) - row_count = get_table_row_count(db_path, name) - tables.append({ - "name": name, - "row_count": row_count, - "column_count": len(columns) - }) - except Exception as e: - logger.warning("Error getting table info for %s", name, exc_info=True) - tables.append({"name": name}) - - return { - "handle_ref": handle_ref, - "tables": tables, - "db_path": str(db_path) - } - + # Get connection pool stats (non-blocking) + try: + pool = get_connection_pool() + cache_stats = pool.get_stats() + except Exception as pool_error: + logger.warning(f"Error getting pool stats: {pool_error}") + cache_stats = {"total_connections": 0, "connections": []} + + return HealthResponse( + status="ok", + timestamp=datetime.utcnow().isoformat() + "Z", + mode="cached_sqlite", + data_dir=str(settings.CACHE_DIR), + config_dir=str(FilePath(settings.CACHE_DIR) / "configs"), + cache={ + "databases_cached": cache_stats.get("total_connections", 0), + "databases": cache_stats.get("connections", []) + } + ) except Exception as e: - logger.error(f"Error listing tables from handle: {e}") + logger.error(f"Error in health check: {e}") raise HTTPException(status_code=500, detail=str(e)) -@router.get("/handle/{handle_ref}/tables/{table_name}/schema", tags=["Handle Access"], response_model=TableSchemaResponse) -async def get_table_schema_by_handle( - handle_ref: str, - table_name: str, - kb_env: str = Query("appdev"), - authorization: str | None = Header(None) -): - """ - Get schema (columns) for a table accessed via handle reference. - - **Example:** - ```bash - curl -H "Authorization: $KB_TOKEN" \ - "https://appdev.kbase.us/services/berdl_table_scanner/handle/KBH_248028/tables/Genes/schema" - ``` - """ - try: - token = get_auth_token(authorization) - cache_dir = get_cache_dir() - - client = KBaseClient(token, kb_env, cache_dir) - - safe_handle = handle_ref.replace(":", "_").replace("/", "_") - db_dir = cache_dir / "handles" - db_dir.mkdir(parents=True, exist_ok=True) - db_path = db_dir / f"{safe_handle}.db" - - if not db_path.exists(): - temp_path = db_path.with_suffix(f".{uuid4().hex}.tmp") - try: - client.download_blob_file(handle_ref, temp_path) - temp_path.rename(db_path) - except Exception: - temp_path.unlink(missing_ok=True) - raise - - if not validate_table_exists(db_path, table_name): - available = list_tables(db_path) - raise HTTPException(404, f"Table '{table_name}' not found. Available: {available}") - - columns = get_table_columns(db_path, table_name) - row_count = get_table_row_count(db_path, table_name) - - return { - "handle_ref": handle_ref, - "table_name": table_name, - "columns": columns, - "row_count": row_count - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error getting schema: {e}") - raise HTTPException(status_code=500, detail=str(e)) +# ============================================================================= +# FILE UPLOAD ENDPOINTS +# ============================================================================= -@router.get("/handle/{handle_ref}/tables/{table_name}/data", tags=["Handle Access"], response_model=TableDataResponse) -async def get_table_data_by_handle( - handle_ref: str, - table_name: str, - limit: int = Query(100, ge=1, le=500000), - offset: int = Query(0, ge=0), - sort_column: str | None = Query(None), - sort_order: str | None = Query("ASC"), - search: str | None = Query(None, description="Global search term"), - kb_env: str = Query("appdev"), - authorization: str | None = Header(None) -): - """ - Query table data from SQLite via handle reference. +@router.post( + "/upload", + tags=["File Upload"], + response_model=UploadDBResponse, + summary="Upload a local SQLite database", + description=""" + Upload a local SQLite database file (.db or .sqlite) for temporary use. + Returns a handle that can be used inplace of a KBase workspace reference. - **Example:** - ```bash - curl -H "Authorization: $KB_TOKEN" \ - "https://appdev.kbase.us/services/berdl_table_scanner/handle/KBH_248028/tables/Genes/data?limit=5" - ``` + The handle format is `local:{uuid}`. """ - start_time = time.time() - +) +async def upload_database( + file: UploadFile = File(..., description="SQLite database file") +): try: - token = get_auth_token(authorization) - cache_dir = get_cache_dir() + if not file.filename.endswith(('.db', '.sqlite', '.sqlite3')): + raise HTTPException(status_code=400, detail="File must be a SQLite database (.db, .sqlite, .sqlite3)") - client = KBaseClient(token, kb_env, cache_dir) + # Validate SQLite header + # SQLite files start with "SQLite format 3\0" + header = await file.read(16) + await file.seek(0) - safe_handle = handle_ref.replace(":", "_").replace("/", "_") - db_dir = cache_dir / "handles" - db_dir.mkdir(parents=True, exist_ok=True) - db_path = db_dir / f"{safe_handle}.db" + if header != b"SQLite format 3\0": + logger.warning(f"Invalid SQLite header for upload {file.filename}: {header}") + raise HTTPException(status_code=400, detail="Invalid SQLite file format (header mismatch)") + + # Generate handle + file_uuid = str(uuid4()) + handle = f"local:{file_uuid}" - if not db_path.exists(): - temp_path = db_path.with_suffix(f".{uuid4().hex}.tmp") - try: - client.download_blob_file(handle_ref, temp_path) - temp_path.rename(db_path) - except Exception: - temp_path.unlink(missing_ok=True) - raise + # Save to uploads directory + cache_dir = get_cache_dir() + upload_dir = cache_dir / "uploads" + upload_dir.mkdir(parents=True, exist_ok=True) - if not validate_table_exists(db_path, table_name): - available = list_tables(db_path) - raise HTTPException(404, f"Table '{table_name}' not found. Available: {available}") + destination = upload_dir / f"{file_uuid}.db" - # Query data - headers, data, total_count, filtered_count, db_query_ms, conversion_ms = get_table_data( - sqlite_file=db_path, - table_name=table_name, - limit=limit, - offset=offset, - sort_column=sort_column, - sort_order=sort_order, - search_value=search, + try: + with destination.open("wb") as buffer: + shutil.copyfileobj(file.file, buffer) + finally: + file.file.close() + + return UploadDBResponse( + handle=handle, + filename=file.filename, + size_bytes=destination.stat().st_size, + message="Database uploaded successfully" ) - - response_time_ms = (time.time() - start_time) * 1000 - - return { - "handle_ref": handle_ref, - "table_name": table_name, - "headers": headers, - "data": data, - "row_count": len(data), - "total_count": total_count, - "filtered_count": filtered_count, - "response_time_ms": response_time_ms, - "db_query_ms": db_query_ms - } - + except HTTPException: raise except Exception as e: - logger.error(f"Error querying data: {e}") - raise HTTPException(status_code=500, detail=str(e)) + logger.error(f"Error uploading file: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") # ============================================================================= # OBJECT-BASED ENDPOINTS (via KBase workspace object reference) -# /object/{ws_ref}/pangenomes - List pangenomes from BERDLTables object -# /object/{ws_ref}/pangenomes/{pg_id}/tables - List tables for a pangenome -# /object/{ws_ref}/pangenomes/{pg_id}/tables/{table}/data - Query data +# /object/{ws_ref}/tables - List tables from KBase object +# /object/{ws_ref}/tables/{table}/data - Query data # ============================================================================= -@router.get("/object/{ws_ref:path}/pangenomes", tags=["Object Access"], response_model=PangenomesResponse) -async def list_pangenomes_by_object( - ws_ref: str, - kb_env: str = Query("appdev"), - authorization: str | None = Header(None) -): - """ - List pangenomes from a BERDLTables/GenomeDataLakeTables object. - - **Example:** +@router.get( + "/object/{ws_ref:path}/tables", + tags=["Object Access"], + response_model=TableListResponse, + summary="List tables in a BERDLTables object", + description=""" + List all tables available in a BERDLTables object from KBase workspace. + + **Example Usage:** ```bash - curl -H "Authorization: $KB_TOKEN" \ - "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/pangenomes" + # Using curl with Authorization header + curl -X GET \\ + "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/tables?kb_env=appdev" \\ + -H "Authorization: Bearer YOUR_KBASE_TOKEN" \\ + -H "accept: application/json" + + # Using curl with cookie + curl -X GET \\ + "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/tables?kb_env=appdev" \\ + -H "Cookie: kbase_session=YOUR_KBASE_TOKEN" \\ + -H "accept: application/json" ``` - """ - try: - token = get_auth_token(authorization) - berdl_table_id = ws_ref - - pangenomes = list_pangenomes_from_object( - berdl_table_id=berdl_table_id, - auth_token=token, - kb_env=kb_env - ) - - return { - "berdl_table_id": berdl_table_id, - "pangenomes": pangenomes - } - - except Exception as e: - logger.error(f"Error listing pangenomes: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/object/{ws_ref:path}/tables", tags=["Object Access"], response_model=TableListResponse) + + **Authentication:** + - Authorization header: `Authorization: Bearer YOUR_TOKEN` or `Authorization: YOUR_TOKEN` + - Cookie: `kbase_session=YOUR_TOKEN` + - Environment variable: `KB_SERVICE_AUTH_TOKEN` (for service-to-service) + """, + responses={ + 200: { + "description": "Successfully retrieved table list", + "content": { + "application/json": { + "example": { + "berdl_table_id": "76990/7/2", + "object_type": "KBaseGeneDataLakes.BERDLTables-1.0", + "tables": [ + { + "name": "Genes", + "displayName": "Genes", + "row_count": 3356, + "column_count": 18 + }, + { + "name": "Contigs", + "displayName": "Contigs", + "row_count": 42, + "column_count": 12 + } + ] + } + } + } + }, + 401: {"description": "Authentication required"}, + 404: {"description": "Object not found"}, + 500: {"description": "Internal server error"} + } +) async def list_tables_by_object( - ws_ref: str, - kb_env: str = Query("appdev"), - authorization: str | None = Header(None) + ws_ref: str = Path(..., description="KBase workspace object reference (UPA format: workspace_id/object_id/version)", examples=["76990/7/2"]), + kb_env: str = Query("appdev", description="KBase environment", examples=["appdev"]), + authorization: str | None = Header(None, description="KBase authentication token (Bearer token or plain token)", examples=["Bearer YOUR_KBASE_TOKEN"]), + kbase_session: str | None = Cookie(None, description="KBase session cookie", examples=["YOUR_KBASE_TOKEN"]) ): - """ - List tables for a BERDLTables object. - **Example:** - ```bash - curl -H "Authorization: $KB_TOKEN" \ - "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/tables" - ``` - """ + try: - token = get_auth_token(authorization) + token = get_auth_token(authorization, kbase_session) cache_dir = get_cache_dir() berdl_table_id = ws_ref - db_path = download_pangenome_db( - berdl_table_id=berdl_table_id, - auth_token=token, - cache_dir=cache_dir, - kb_env=kb_env - ) + # Get database path (handles caching, download timeouts via helper) + db_path = await get_object_db_path(berdl_table_id, token, kb_env, cache_dir) + + # List tables (run in thread) + table_names = await run_sync_in_thread(list_tables, db_path) - table_names = list_tables(db_path) tables = [] - for name in table_names: + schemas = {} + total_rows = 0 + + # Use schema service for better column type information + schema_service = get_schema_service() + + # Process tables + # Parallelize metadata fetching + tasks = [ + _get_table_metadata(db_path, name, schema_service) + for name in table_names + ] + + results = await asyncio.gather(*tasks) + + for res in results: + if "error_fallback" in res: + tables.append({"name": res["name"], "displayName": res["displayName"]}) + continue + + tables.append({ + "name": res["name"], + "displayName": res["displayName"], + "row_count": res["row_count"], + "column_count": res["column_count"] + }) + total_rows += res["row_count"] or 0 + schemas[res["name"]] = res["schema"] + + # Get object type (with caching) + object_type = None + + # 1. Try to load from cache + try: + # db_path is typically .../cache/sanitized_upa/tables.db + # So cache_subdir is the parent directory + cache_subdir = db_path.parent + metadata = load_cache_metadata(cache_subdir) + + if metadata and "object_type" in metadata: + object_type = metadata["object_type"] + logger.debug(f"Using cached object type for {berdl_table_id}: {object_type}") + except Exception as e: + logger.warning(f"Error reading cache metadata: {e}") + + # 2. If not cached, fetch from API + if not object_type: try: - columns = get_table_columns(db_path, name) - row_count = get_table_row_count(db_path, name) - tables.append({ - "name": name, - "row_count": row_count, - "column_count": len(columns) - }) - except Exception as e: - logger.warning("Error getting table info for %s", name, exc_info=True) - tables.append({"name": name}) + # Use specific timeout for API call + object_type = await asyncio.wait_for( + run_sync_in_thread(get_object_type, berdl_table_id, token, kb_env), + timeout=settings.KBASE_API_TIMEOUT_SECONDS + ) + + # 3. Save to cache + if object_type: + try: + save_cache_metadata( + db_path.parent, + { + "berdl_table_id": berdl_table_id, + "object_type": object_type, + "last_checked": datetime.utcnow().isoformat() + } + ) + logger.info(f"Cached object type for {berdl_table_id}") + except Exception as e: + logger.warning(f"Failed to cache metadata: {e}") + + except (asyncio.TimeoutError, Exception) as e: + logger.warning(f"Could not get object type (non-critical): {e}") + object_type = None + + # Config-related fields (deprecated, kept for backward compatibility) + config_fingerprint = None + config_url = None + has_cached_config = False + has_builtin_config = False + builtin_config_id = None + + # Get database size + database_size = None + try: + database_size = db_path.stat().st_size if db_path.exists() else None + except Exception as e: + # Database size is informational; log and continue if it cannot be determined. + logger.debug("Failed to get database size for %s: %s", db_path, e) + + # Format berdl_table_id for DataTables Viewer API (local/db_name format) + berdl_table_id_formatted = f"local/{berdl_table_id.replace('/', '_')}" return { - "berdl_table_id": berdl_table_id, + "berdl_table_id": berdl_table_id_formatted, + "object_type": object_type or "LocalDatabase", "tables": tables, - "source": "Cache" if (db_path.exists() and db_path.stat().st_size > 0) else "Downloaded" + "source": "Local", + "has_config": has_cached_config, + "config_source": "static" if has_cached_config else None, + "config_fingerprint": config_fingerprint, + "config_url": config_url, + "has_cached_config": has_cached_config, + "schemas": schemas, + "has_builtin_config": has_builtin_config, + "builtin_config_id": builtin_config_id, + "database_size_bytes": database_size, + "total_rows": total_rows, + "api_version": "2.0", } + except HTTPException: + # Re-raise HTTP exceptions as-is (don't convert to 500) + raise except Exception as e: - logger.error(f"Error listing tables: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/object/{ws_ref:path}/tables/{table_name}/data", tags=["Object Access"], response_model=TableDataResponse) -async def get_table_data_by_object( - ws_ref: str, - table_name: str, - limit: int = Query(100, ge=1, le=500000), - offset: int = Query(0, ge=0), - sort_column: str | None = Query(None), - sort_order: str | None = Query("ASC"), - search: str | None = Query(None), - kb_env: str = Query("appdev"), - authorization: str | None = Header(None) -): - """ - Query table data from a BERDLTables object. - - **Example:** + # Log full traceback for debugging + logger.error(f"Error listing tables: {e}", exc_info=True) + # Provide detailed error message + # Always include the error message, add traceback in debug mode + error_detail = str(e) if str(e) else f"Error: {type(e).__name__}" + if settings.DEBUG: + error_detail += f"\n\nTraceback:\n{traceback.format_exc()}" + raise HTTPException(status_code=500, detail=error_detail) + + +@router.get( + "/object/{ws_ref:path}/tables/{table_name}/data", + tags=["Object Access"], + response_model=TableDataResponse, + summary="Query table data from a BERDLTables object", + description=""" + Query data from a specific table in a BERDLTables object with filtering, sorting, and pagination. + + **Example Usage:** ```bash - curl -H "Authorization: $KB_TOKEN" \ - "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/tables/Genes/data?limit=5" - ``` - """ - start_time = time.time() + # Get first 10 rows from Genes table + curl -X GET \\ + "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/tables/Genes/data?limit=10&kb_env=appdev" \\ + -H "Authorization: Bearer YOUR_KBASE_TOKEN" \\ + -H "accept: application/json" + # Search and sort + curl -X GET \\ + "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/tables/Genes/data?limit=20&offset=0&search=kinase&sort_column=gene_name&sort_order=ASC&kb_env=appdev" \\ + -H "Authorization: Bearer YOUR_KBASE_TOKEN" \\ + -H "accept: application/json" + ``` + """, + responses={ + 200: {"description": "Successfully retrieved table data"}, + 401: {"description": "Authentication required"}, + 404: {"description": "Table not found"}, + 500: {"description": "Internal server error"} + } +) +async def get_table_data_by_object( + ws_ref: str = Path(..., description="KBase workspace object reference (UPA format)", examples=["76990/7/2"]), + table_name: str = Path(..., description="Name of the table to query", examples=["Genes"]), + limit: int = Query(DEFAULT_LIMIT, ge=1, le=MAX_LIMIT, description="Maximum number of rows to return", examples=[10]), + offset: int = Query(0, ge=0, description="Number of rows to skip (for pagination)", examples=[0]), + sort_column: str | None = Query(None, description="Column name to sort by", examples=["gene_name"]), + sort_order: str | None = Query("ASC", description="Sort order: ASC or DESC", examples=["ASC"]), + search: str | None = Query(None, description="Global text search across all columns", examples=["kinase"]), + kb_env: str = Query("appdev", description="KBase environment", examples=["appdev"]), + authorization: str | None = Header(None, description="KBase authentication token", examples=["Bearer YOUR_KBASE_TOKEN"]), + kbase_session: str | None = Cookie(None, description="KBase session cookie", examples=["YOUR_KBASE_TOKEN"]) +): try: - token = get_auth_token(authorization) + token = get_auth_token(authorization, kbase_session) cache_dir = get_cache_dir() berdl_table_id = ws_ref - db_path = download_pangenome_db( - berdl_table_id=berdl_table_id, - auth_token=token, - cache_dir=cache_dir, - kb_env=kb_env - ) - - if not validate_table_exists(db_path, table_name): - available = list_tables(db_path) - raise HTTPException(404, f"Table '{table_name}' not found. Available: {available}") + + # Get and validate DB access + db_path = await get_object_db_path(berdl_table_id, token, kb_env, cache_dir) + await ensure_table_accessible(db_path, table_name) - headers, data, total_count, filtered_count, db_query_ms, conversion_ms = get_table_data( - sqlite_file=db_path, + result = await TableRequestProcessor.process_data_request( + db_path=db_path, table_name=table_name, limit=limit, offset=offset, sort_column=sort_column, - sort_order=sort_order, + sort_order=sort_order or "ASC", search_value=search, + handle_ref_or_id=berdl_table_id ) - - response_time_ms = (time.time() - start_time) * 1000 - - return { - "berdl_table_id": berdl_table_id, - "table_name": table_name, - "headers": headers, - "data": data, - "row_count": len(data), - "total_count": total_count, - "filtered_count": filtered_count, - "response_time_ms": response_time_ms, - "db_query_ms": db_query_ms, - "sqlite_file": str(db_path) - } + + return result except HTTPException: raise except Exception as e: - logger.error(f"Error querying data: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -# ============================================================================= -# LEGACY ENDPOINTS (for backwards compatibility) -# ============================================================================= - -@router.get("/pangenomes", response_model=PangenomesResponse, tags=["Legacy"]) -async def get_pangenomes( - berdl_table_id: str = Query(..., description="BERDLTables object reference"), - kb_env: str = Query("appdev"), - authorization: str | None = Header(None) -): + logger.error(f"Error querying data: {e}", exc_info=True) + error_detail = str(e) if str(e) else f"Error: {type(e).__name__}" + if settings.DEBUG: + error_detail += f"\n\nTraceback:\n{traceback.format_exc()}" + raise HTTPException(status_code=500, detail=error_detail) + + +@router.get( + "/object/{ws_ref:path}/tables/{table_name}/stats", + tags=["Object Access"], + response_model=TableStatisticsResponse, + summary="Get column statistics for a table", + description=""" + Calculate statistics for all columns in a table (null counts, distinct counts, min/max, samples). + This operation may be slow for large tables. """ - List pangenomes from BERDLTables object. - - Returns: - - pangenomes: List of pangenome info - - pangenome_count: Total number of pangenomes - """ - try: - token = get_auth_token(authorization) - - # Support comma-separated list of IDs - berdl_ids = [bid.strip() for bid in berdl_table_id.split(",") if bid.strip()] - - all_pangenomes: list[dict] = [] - - for bid in berdl_ids: - try: - pangenomes = list_pangenomes_from_object(bid, token, kb_env) - # Tag each pangenome with its source ID - for pg in pangenomes: - pg["source_berdl_id"] = bid - all_pangenomes.extend(pangenomes) - except Exception as e: - logger.error(f"Error fetching pangenomes for {bid}: {e}") - # Continue fetching others even if one fails - continue - - pangenome_list = [PangenomeInfo(**pg) for pg in all_pangenomes] - - return PangenomesResponse( - pangenomes=pangenome_list, - pangenome_count=len(pangenome_list) - ) - except Exception as e: - logger.error(f"Error in get_pangenomes: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/tables", response_model=TableListResponse, tags=["Legacy"]) -async def get_tables( - berdl_table_id: str = Query(..., description="BERDLTables object reference"), - kb_env: str = Query("appdev"), - authorization: str | None = Header(None) +) +async def get_table_stats( + ws_ref: str = Path(..., description="KBase workspace object reference (UPA format)", examples=["76990/7/2"]), + table_name: str = Path(..., description="Name of the table to analyze", examples=["Genes"]), + kb_env: str = Query("appdev", description="KBase environment", examples=["appdev"]), + authorization: str | None = Header(None, description="KBase authentication token", examples=["Bearer YOUR_KBASE_TOKEN"]), + kbase_session: str | None = Cookie(None, description="KBase session cookie", examples=["YOUR_KBASE_TOKEN"]) ): - """List tables for a BERDLTable object (auto-resolves pangenome).""" try: - token = get_auth_token(authorization) + token = get_auth_token(authorization, kbase_session) cache_dir = get_cache_dir() + berdl_table_id = ws_ref - db_path = download_pangenome_db(berdl_table_id, token, cache_dir, kb_env) - table_names = list_tables(db_path) + # Get and validate DB access + db_path = await get_object_db_path(berdl_table_id, token, kb_env, cache_dir) + await ensure_table_accessible(db_path, table_name) - tables = [] - for name in table_names: - try: - columns = get_table_columns(db_path, name) - row_count = get_table_row_count(db_path, name) - tables.append(TableInfo(name=name, row_count=row_count, column_count=len(columns))) - except Exception: - tables.append(TableInfo(name=name)) + # Helper to run stats calculation in thread (CPU bound) + # from app.utils.sqlite import get_table_statistics + + stats = await run_sync_in_thread(get_table_statistics, db_path, table_name) + return stats - return TableListResponse(tables=tables) + except HTTPException: + raise except Exception as e: - logger.error(f"Error listing tables: {e}") - raise HTTPException(status_code=500, detail=str(e)) + logger.error(f"Error calculating stats: {e}", exc_info=True) + error_detail = str(e) if str(e) else f"Error: {type(e).__name__}" + if settings.DEBUG: + error_detail += f"\n\nTraceback:\n{traceback.format_exc()}" + raise HTTPException(status_code=500, detail=error_detail) -@router.post("/table-data", response_model=TableDataResponse, tags=["Legacy"]) -async def query_table_data( - request: TableDataRequest, - authorization: str | None = Header(None) -): - """ - Query table data using a JSON body. Recommended for programmatic access. +# ============================================================================= +# DATA ACCESS ENDPOINTS +# ============================================================================= - **Example:** +@router.post( + "/table-data", + response_model=TableDataResponse, + tags=["Data Access"], + summary="Query table data with advanced filtering (POST)", + description=""" + Query table data using a JSON request body. Recommended for complex queries with multiple filters. + + **Example Usage:** ```bash - curl -X POST -H "Authorization: $KB_TOKEN" -H "Content-Type: application/json" \ - -d '{ - "berdl_table_id": "76990/7/2", - "table_name": "Metadata_Conditions", - "limit": 5" - }' \ - "https://appdev.kbase.us/services/berdl_table_scanner/table-data" - ``` - """ - start_time = time.time() + # Simple query + curl -X POST \\ + "https://appdev.kbase.us/services/berdl_table_scanner/table-data" \\ + -H "Authorization: Bearer YOUR_KBASE_TOKEN" \\ + -H "Content-Type: application/json" \\ + -d '{ + "berdl_table_id": "76990/7/2", + "table_name": "Genes", + "limit": 10, + "offset": 0 + }' + # Query with filters + curl -X POST \\ + "https://appdev.kbase.us/services/berdl_table_scanner/table-data" \\ + -H "Authorization: Bearer YOUR_KBASE_TOKEN" \\ + -H "Content-Type: application/json" \\ + -d '{ + "berdl_table_id": "76990/7/2", + "table_name": "Genes", + "limit": 20, + "query_filters": [ + {"column": "gene_name", "operator": "like", "value": "kinase"}, + {"column": "contigs", "operator": "gt", "value": 5} + ], + "sort": [{"column": "gene_name", "direction": "asc"}] + }' + ``` + """, + responses={ + 200: {"description": "Successfully retrieved table data"}, + 401: {"description": "Authentication required"}, + 404: {"description": "Table not found"}, + 500: {"description": "Internal server error"} + } +) +async def query_table_data( + request: TableDataRequest, + authorization: str | None = Header(None, description="KBase authentication token", examples=["Bearer YOUR_KBASE_TOKEN"]), + kbase_session: str | None = Cookie(None, description="KBase session cookie", examples=["YOUR_KBASE_TOKEN"]) +): try: - token = get_auth_token(authorization) + token = get_auth_token(authorization, kbase_session) cache_dir = get_cache_dir() kb_env = getattr(request, 'kb_env', 'appdev') or 'appdev' - # Determine filters (support both query_filters and col_filter) filters = request.col_filter if request.col_filter else request.query_filters - # Download (or get cached) DB - auto-resolves ID if None - try: - db_path = download_pangenome_db( - request.berdl_table_id, token, cache_dir, kb_env - ) - except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) + # Get and validate DB access (uses generic helper that supports local:) + db_path = await get_object_db_path(request.berdl_table_id, token, kb_env, cache_dir) if not validate_table_exists(db_path, request.table_name): available = list_tables(db_path) raise ValueError(f"Table '{request.table_name}' not found. Available: {available}") - - try: - ensure_indices(db_path, request.table_name) - except: - pass - - headers, data, total_count, filtered_count, db_query_ms, conversion_ms = get_table_data( - sqlite_file=db_path, + + # Column parsing is now handled in process_data_request for both string and list formats + + effective_sort_col = request.sort_column + effective_sort_dir = request.sort_order + + if not effective_sort_col and request.order_by: + first_sort = request.order_by[0] + effective_sort_col = first_sort.get("column") + effective_sort_dir = first_sort.get("direction", "ASC").upper() + + return await TableRequestProcessor.process_data_request( + db_path=db_path, table_name=request.table_name, limit=request.limit, offset=request.offset, - sort_column=request.sort_column, - sort_order=request.sort_order, + sort_column=effective_sort_col, + sort_order=effective_sort_dir or "ASC", search_value=request.search_value, - query_filters=filters, - columns=request.columns, - order_by=request.order_by - ) - - response_time_ms = (time.time() - start_time) * 1000 - - return TableDataResponse( - headers=headers, - data=data, - row_count=len(data), - total_count=total_count, - filtered_count=filtered_count, - table_name=request.table_name, - response_time_ms=response_time_ms, - db_query_ms=db_query_ms, - conversion_ms=conversion_ms, - source="Cache" if is_cached(db_path) else "Downloaded", - cache_file=str(db_path), - sqlite_file=str(db_path) + columns=request.columns, # Now handles list or string + filters=request.filters if request.filters else filters, # Prefer advanced filters, fall back to legacy dict + aggregations=request.aggregations, + group_by=request.group_by, + handle_ref_or_id=request.berdl_table_id ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - logger.error(f"Error querying table data: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -# ============================================================================= -# CACHE MANAGEMENT -# ============================================================================= - -@router.post("/clear-cache", response_model=CacheResponse, tags=["Cache Management"]) -async def clear_pangenome_cache( - berdl_table_id: str | None = Query(None) -): - """Clear cached databases.""" - try: - cache_dir = get_cache_dir() - result = clear_cache(cache_dir, berdl_table_id) - return CacheResponse(status="success", message=result.get("message", "Cache cleared")) + except HTTPException: + # Re-raise HTTP exceptions as-is (don't convert to 500) + raise + except InvalidFilterError: + # Allow invalid filter errors to be handled by main app exception handler (422) + raise except Exception as e: - return CacheResponse(status="error", message=str(e)) - + # Log full traceback for debugging + logger.error(f"Error querying data: {e}", exc_info=True) + # Provide detailed error message + # Always include the error message, add traceback in debug mode + error_detail = str(e) if str(e) else f"Error: {type(e).__name__}" + if settings.DEBUG: + error_detail += f"\n\nTraceback:\n{traceback.format_exc()}" + raise HTTPException(status_code=500, detail=error_detail) -@router.get("/cache", tags=["Cache Management"]) -async def list_cache(): - """List cached items.""" - cache_dir = get_cache_dir() - items = list_cached_items(cache_dir) - return {"cache_dir": str(cache_dir), "items": items, "total": len(items)} diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..c05a668 --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,25 @@ +""" +TableScanner Services Package. + +This package contains data query and schema analysis services. + +Modules: + - connection_pool: Database connection pooling and management + - query_service: Enhanced query execution with type-aware filtering + - schema_service: Schema information retrieval + - statistics_service: Column statistics computation + - schema_analyzer: Database schema introspection and profiling + - fingerprint: Database fingerprinting for caching +""" + +from .data.schema_analyzer import SchemaAnalyzer, ColumnProfile, TableProfile +from .data.fingerprint import DatabaseFingerprint + +__all__ = [ + # Schema analysis + "SchemaAnalyzer", + "ColumnProfile", + "TableProfile", + # Fingerprinting + "DatabaseFingerprint", +] diff --git a/app/services/data/__init__.py b/app/services/data/__init__.py new file mode 100644 index 0000000..dd828f1 --- /dev/null +++ b/app/services/data/__init__.py @@ -0,0 +1,19 @@ +""" +Data Analysis Services. + +Schema analysis, fingerprinting, and validation. +""" + +from .schema_analyzer import SchemaAnalyzer +from .fingerprint import DatabaseFingerprint +from .type_inference import TypeInferenceEngine, InferredType, DataType +from .validation import validate_config + +__all__ = [ + "SchemaAnalyzer", + "DatabaseFingerprint", + "TypeInferenceEngine", + "InferredType", + "DataType", + "validate_config", +] diff --git a/app/services/data/connection_pool.py b/app/services/data/connection_pool.py new file mode 100644 index 0000000..42c6c3c --- /dev/null +++ b/app/services/data/connection_pool.py @@ -0,0 +1,290 @@ +""" +Database Connection Pool Manager. + +Manages a pool of SQLite database connections with: +- Thread-safe Queue-based pooling (one queue per database file) +- Automatic lifecycle management (30-minute inactivity timeout) +- Connection reuse for performance +- SQLite performance optimizations (WAL mode, cache size, etc.) +- Context manager interface for safe connection handling +""" + +from __future__ import annotations + +import sqlite3 +import logging +import threading +import time +import queue +from pathlib import Path +from typing import Any, Generator +from contextlib import contextmanager + +logger = logging.getLogger(__name__) + + +class ConnectionPool: + """ + Manages a pool of SQLite database connections using thread-safe Queues. + + Features: + - Dedicated Queue for each database file to enforce thread safety. + - Context manager `connection()` ensures connections are always returned. + - Automatic cleanup of idle pools. + """ + + # Connection timeout: 10 minutes of inactivity (reduced for local DBs) + POOL_TIMEOUT_SECONDS = 10 * 60 + + # Clean up interval + CLEANUP_INTERVAL_SECONDS = 2 * 60 + + # Maximum connections per database file + MAX_CONNECTIONS = 8 + + def __init__(self) -> None: + """Initialize the connection pool.""" + # Key: str(db_path), Value: (queue.Queue, last_access_time) + self._pools: dict[str, tuple[queue.Queue, float]] = {} + self._lock = threading.RLock() + self._last_cleanup = time.time() + + logger.info("Initialized SQLite connection pool (Queue-based)") + + @contextmanager + def connection(self, db_path: Path, timeout: float = 10.0) -> Generator[sqlite3.Connection, None, None]: + """ + Context manager to aquire a database connection. + + Blocks until a connection is available or timeout occurs. + Automatically returns the connection to the pool when done. + + Args: + db_path: Path to the SQLite database + timeout: Max time to wait for a connection in seconds + + Yields: + sqlite3.Connection: Active database connection + + Raises: + queue.Empty: If no connection available within timeout + sqlite3.Error: If connection cannot be created + """ + db_key = str(db_path.absolute()) + + # 1. Get or create the pool queue for this DB + pool_queue = self._get_or_create_pool(db_key) + + conn = None + try: + # 2. Try to get a connection from the queue + try: + conn = pool_queue.get(block=True, timeout=timeout) + + # Check if file changed since this connection was created + # (Simple check: if we wanted to be robust against file replacements, + # we'd check stats, but for now we assume connections in queue are valid + # or will fail fast) + try: + # Lightweight liveliness check + conn.execute("SELECT 1") + except sqlite3.Error: + # Connection bad, close and make new one + try: + conn.close() + except Exception: + # Best-effort close; log at debug and continue with a fresh connection. + logger.debug("Failed to close bad SQLite connection for %s", db_key, exc_info=True) + conn = self._create_new_connection(db_key) + + except queue.Empty: + # Pool is empty, if we haven't reached max capacity (logic hard to track with Queue size only), + # ideally we pre-fill or dynamic fill. + # With standard Queue, we put connections IN. + # Strategy: Initialize Queue with N "tokens" or create on demand? + # Alternative: On Queue.get, if empty, we wait. + # BUT, initially queue is empty. + # So we need a mechanism to create new connections if < MAX and queue empty. + # Let's simplify: + # The queue holds *idle* connections. + # We need a semaphore for *total* connections? + # + # Let's use a standard sizing approach: + # When getting, if queue empty and we can create more, create one. + # This requires tracking count. Sizing is tricky with just a Queue. + # + # SIMPLIFIED APPROACH for SQLite: + # Just use the Queue as a resource pool. Populate it on demand? + # No, standard pattern: + # Queue initialized empty. + # If queue.empty(): + # if current connections < max: create new + # else: wait on queue + # + # This requires tracking active count. + # Given strict timeline, let's just FILL the queue on first access up to MAX? + # Or lazily create. + + # Let's do lazy creation with a separate semaphore-like logic if needed, + # Or just rely on Python's robust GC and just use a pool of created connections. + + # Refined Strategy: + # Queue contains available connections. + # If we get Empty, we check if we can create better? + # Actually, simpler: Pre-populate or lazily populate? + # Lazy: If invalid/closed, we discard. + # + # For this fix, let's use a "LifoQueue" or standard Queue. + # But to manage the *limit*, we need to know how many are out there. + raise TimeoutError(f"Timeout waiting for database connection: {db_path}") + + yield conn + + finally: + # 3. Return connection to pool + if conn: + # Rollback uncommitted transaction to reset state + try: + conn.rollback() + except Exception: + # If rollback fails, the connection may be in a bad state; it will + # still be returned to the pool but future health checks will replace it. + logger.debug("Failed to rollback SQLite connection for %s", db_key, exc_info=True) + + # Put back in queue + # Note: We must update the last access time for the POOL, not the connection + self._update_pool_access(db_key) + pool_queue.put(conn) + + # 4. Trigger cleanup periodically + self._maybe_cleanup() + + def _get_or_create_pool(self, db_key: str) -> queue.Queue: + """Get existing pool or create a new one with connections.""" + with self._lock: + if db_key in self._pools: + q, _ = self._pools[db_key] + self._pools[db_key] = (q, time.time()) # Update access + return q + + # Create new pool + q = queue.Queue(maxsize=self.MAX_CONNECTIONS) + + # Pre-fill connections (Block-safe inside lock? Creation is IO) + # Better to create them. + # Note: opening 5 sqlite connections is fast. + try: + for _ in range(self.MAX_CONNECTIONS): + conn = self._create_new_connection(db_key) + q.put(conn) + except Exception as e: + logger.error(f"Error filling connection pool for {db_key}: {e}") + # Close any created ones? + while not q.empty(): + try: + q.get_nowait().close() + except Exception: + logger.debug("Failed to close SQLite connection during pool recovery.", exc_info=True) + raise + + self._pools[db_key] = (q, time.time()) + return q + + def _create_new_connection(self, db_path_str: str) -> sqlite3.Connection: + """Create and configure a single SQLite connection.""" + conn = sqlite3.connect(db_path_str, check_same_thread=False) + conn.row_factory = sqlite3.Row + + # Performance optimizations + try: + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute("PRAGMA cache_size=-64000") # 64MB + conn.execute("PRAGMA temp_store=MEMORY") + conn.execute("PRAGMA mmap_size=268435456") # 256MB + except sqlite3.Error as e: + logger.warning(f"Failed to apply optimizations: {e}") + + return conn + + def _update_pool_access(self, db_key: str): + """Update last access timestamp for a pool.""" + with self._lock: + if db_key in self._pools: + q, _ = self._pools[db_key] + self._pools[db_key] = (q, time.time()) + + def _maybe_cleanup(self) -> None: + """Run cleanup if enough time has passed.""" + now = time.time() + # Non-blocking check + if now - self._last_cleanup < self.CLEANUP_INTERVAL_SECONDS: + return + + with self._lock: + # Double check inside lock + if now - self._last_cleanup < self.CLEANUP_INTERVAL_SECONDS: + return + self._last_cleanup = now + self.cleanup_expired() + + def cleanup_expired(self) -> None: + """Close pools that haven't been accessed recently.""" + now = time.time() + expired_keys = [] + + with self._lock: + for db_key, (q, last_access) in self._pools.items(): + if now - last_access > self.POOL_TIMEOUT_SECONDS: + expired_keys.append(db_key) + + for key in expired_keys: + q, _ = self._pools.pop(key) + self._close_pool_queue(q) + logger.info(f"Cleaned up expired pool for: {key}") + + def _close_pool_queue(self, q: queue.Queue): + """Close all connections in a queue.""" + while not q.empty(): + try: + conn = q.get_nowait() + conn.close() + except Exception: + # Best-effort close; swallow errors but record at debug. + logger.debug("Failed to close SQLite connection during pool cleanup.", exc_info=True) + + def get_stats(self) -> dict[str, Any]: + """Get pool statistics.""" + with self._lock: + stats = [] + for db_key, (q, last_access) in self._pools.items(): + stats.append({ + "db_path": db_key, + "available_connections": q.qsize(), + "last_access_ago": time.time() - last_access + }) + return { + "total_pools": len(self._pools), + "pools": stats + } + + # Helper for legacy or non-context usage (Deprecated) + def get_connection(self, db_path: Path) -> sqlite3.Connection: + """ + DEPRECATED: Use `with pool.connection(path) as conn:` instead. + This method will raise an error to enforce refactoring. + """ + raise NotImplementedError("get_connection() is deprecated. Use 'with pool.connection(db_path) as conn:'") + +# Global instances +_global_pool: ConnectionPool | None = None +_pool_lock = threading.Lock() + +def get_connection_pool() -> ConnectionPool: + global _global_pool + if _global_pool is None: + with _pool_lock: + if _global_pool is None: + _global_pool = ConnectionPool() + return _global_pool + diff --git a/app/services/data/fingerprint.py b/app/services/data/fingerprint.py new file mode 100644 index 0000000..04051ed --- /dev/null +++ b/app/services/data/fingerprint.py @@ -0,0 +1,231 @@ +""" +Database Fingerprinting. + +Creates unique fingerprints from database schema structure for cache +invalidation. Fingerprints are based on schema characteristics, not data, +to enable efficient caching of generated configs. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from .schema_analyzer import SchemaAnalyzer, TableProfile + +logger = logging.getLogger(__name__) + + +class DatabaseFingerprint: + """ + Creates unique fingerprints from database schema structure. + + The fingerprint is based on: + - Table names (sorted) + - Column names and types for each table + - Row counts (optional, for change detection) + + This allows caching generated configs and detecting when + a database schema has changed. + """ + + def __init__(self, config_dir: str | Path | None = None) -> None: + """ + Initialize fingerprinting service. + + Args: + config_dir: Directory for storing cached configs + """ + default_dir = os.getenv("GENERATED_CONFIG_DIR", "/tmp/tablescanner_configs") + self.config_dir = Path(config_dir or default_dir) + self.config_dir.mkdir(parents=True, exist_ok=True) + + def compute(self, db_path: Path, include_row_counts: bool = False) -> str: + """ + Compute fingerprint for a database. + + Args: + db_path: Path to the SQLite database + include_row_counts: Whether to include row counts in fingerprint + (makes fingerprint change when data changes) + + Returns: + SHA256 hex string (first 16 characters) + """ + analyzer = SchemaAnalyzer(sample_size=0) # No samples needed + profiles = analyzer.analyze_database(db_path) + + return self.compute_from_profiles(profiles, include_row_counts) + + def compute_from_profiles( + self, + profiles: list[TableProfile], + include_row_counts: bool = False + ) -> str: + """ + Compute fingerprint from table profiles. + + Args: + profiles: List of TableProfile objects + include_row_counts: Whether to include row counts + + Returns: + SHA256 hex string (first 16 characters) + """ + # Build deterministic schema representation + schema_data: list[dict[str, Any]] = [] + + for table in sorted(profiles, key=lambda t: t.name): + table_data: dict[str, Any] = { + "name": table.name, + "columns": [ + {"name": col.name, "type": col.sqlite_type} + for col in sorted(table.columns, key=lambda c: c.name) + ], + } + if include_row_counts: + table_data["row_count"] = table.row_count + + schema_data.append(table_data) + + # Create deterministic JSON string + schema_json = json.dumps(schema_data, sort_keys=True, separators=(",", ":")) + + # Compute SHA256 hash + hash_bytes = hashlib.sha256(schema_json.encode()).hexdigest() + + # Return first 16 characters for reasonable uniqueness + readability + return hash_bytes[:16] + + def compute_for_handle(self, handle_ref: str, db_path: Path) -> str: + """ + Compute fingerprint incorporating handle reference. + + This creates a unique ID that includes both the source + handle and the schema structure. + + Args: + handle_ref: The KBase handle reference + db_path: Path to the SQLite database + + Returns: + Combined fingerprint string + """ + schema_fp = self.compute(db_path) + # Sanitize handle ref for use in filenames + safe_handle = handle_ref.replace("/", "_").replace(":", "_") + return f"{safe_handle}_{schema_fp}" + + # ─── Cache Management ─────────────────────────────────────────────────── + + def is_cached(self, fingerprint: str) -> bool: + """Check if a config is cached for this fingerprint.""" + config_path = self._get_cache_path(fingerprint) + return config_path.exists() + + def get_cached_config(self, fingerprint: str) -> dict | None: + """ + Retrieve cached config for a fingerprint. + + Args: + fingerprint: Database fingerprint + + Returns: + Cached config dict or None if not found + """ + config_path = self._get_cache_path(fingerprint) + + if not config_path.exists(): + return None + + try: + with open(config_path, "r") as f: + return json.load(f) + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"Failed to load cached config {fingerprint}: {e}") + return None + + def cache_config(self, fingerprint: str, config: dict) -> Path: + """ + Cache a generated config. + + Args: + fingerprint: Database fingerprint + config: Generated config to cache + + Returns: + Path to the cached config file + """ + config_path = self._get_cache_path(fingerprint) + + # Add metadata + config_with_meta = { + "_fingerprint": fingerprint, + "_cached_at": self._get_timestamp(), + **config, + } + + with open(config_path, "w") as f: + json.dump(config_with_meta, f, indent=2) + + logger.info(f"Cached config to {config_path}") + return config_path + + def clear_cache(self, fingerprint: str | None = None) -> int: + """ + Clear cached configs. + + Args: + fingerprint: Specific fingerprint to clear, or None for all + + Returns: + Number of configs cleared + """ + if fingerprint: + config_path = self._get_cache_path(fingerprint) + if config_path.exists(): + config_path.unlink() + return 1 + return 0 + + # Clear all + count = 0 + for config_file in self.config_dir.glob("*.json"): + config_file.unlink() + count += 1 + return count + + def list_cached(self) -> list[dict[str, Any]]: + """List all cached configs with metadata.""" + cached: list[dict[str, Any]] = [] + + for config_file in self.config_dir.glob("*.json"): + try: + with open(config_file, "r") as f: + config = json.load(f) + cached.append({ + "fingerprint": config.get("_fingerprint", config_file.stem), + "cached_at": config.get("_cached_at"), + "id": config.get("id"), + "name": config.get("name"), + "path": str(config_file), + }) + except (json.JSONDecodeError, OSError): + continue + + return cached + + # ─── Private Methods ──────────────────────────────────────────────────── + + def _get_cache_path(self, fingerprint: str) -> Path: + """Get cache file path for a fingerprint.""" + return self.config_dir / f"{fingerprint}.json" + + def _get_timestamp(self) -> str: + """Get current ISO timestamp.""" + return datetime.now(timezone.utc).isoformat() diff --git a/app/services/data/query_service.py b/app/services/data/query_service.py new file mode 100644 index 0000000..ec2c01b --- /dev/null +++ b/app/services/data/query_service.py @@ -0,0 +1,647 @@ +""" +Enhanced Query Service for DataTables Viewer API. + +Provides comprehensive query execution with: +- Type-aware filtering with proper numeric conversion +- Advanced filter operators (eq, ne, gt, gte, lt, lte, like, ilike, in, not_in, between, is_null, is_not_null) +- Aggregations with GROUP BY +- Full-text search (FTS5) +- Automatic indexing +- Query result caching +- Comprehensive metadata in responses +""" + +from __future__ import annotations + +import sqlite3 +import logging +import time +import hashlib +import json +import threading +from pathlib import Path +from typing import Any +from collections import OrderedDict +from dataclasses import dataclass + +from app.services.data.connection_pool import get_connection_pool +from app.config_constants import ( + CACHE_MAX_ENTRIES, + INDEX_CACHE_TTL, +) +from app.exceptions import ( + TableNotFoundError, + InvalidFilterError, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class FilterSpec: + """Filter specification for query building.""" + + column: str + operator: str + value: Any = None + value2: Any = None # For 'between' operator + + +@dataclass +class AggregationSpec: + """Aggregation specification for query building.""" + + column: str + function: str # count, sum, avg, min, max, stddev, variance, distinct_count + alias: str | None = None + + +@dataclass +class ColumnType: + """Column type information from schema.""" + + name: str + type: str # INTEGER, REAL, TEXT, etc. + notnull: bool = False + pk: bool = False + dflt_value: Any = None + + +class QueryCache: + """ + Query result cache with 5-minute TTL and LRU eviction. + + Cache key format: {dbPath}:{tableName}:{JSON.stringify(queryParams)} + Invalidates when table modification time changes. + """ + + def __init__(self) -> None: + """Initialize the query cache.""" + self._cache: OrderedDict[str, tuple[Any, float]] = OrderedDict() + self._lock = threading.Lock() + + def get(self, cache_key: str, table_mtime: float) -> Any | None: + """ + Get cached query result. + + Args: + cache_key: Cache key for the query + table_mtime: Table file modification time + + Returns: + Cached result if valid, None otherwise + """ + with self._lock: + if cache_key not in self._cache: + return None + + result, cached_mtime = self._cache[cache_key] + + # Check if table has been modified + if cached_mtime != table_mtime: + del self._cache[cache_key] + return None + + # Move to end (LRU) + self._cache.move_to_end(cache_key) + return result + + def set(self, cache_key: str, result: Any, table_mtime: float) -> None: + """ + Store query result in cache. + + Args: + cache_key: Cache key for the query + result: Query result to cache + table_mtime: Table file modification time + """ + with self._lock: + # Evict oldest if at capacity + if len(self._cache) >= CACHE_MAX_ENTRIES: + self._cache.popitem(last=False) + + self._cache[cache_key] = (result, table_mtime) + # Move to end (LRU) + self._cache.move_to_end(cache_key) + + def clear(self) -> None: + """Clear all cached results.""" + with self._lock: + self._cache.clear() + + +# Global query cache instance +_query_cache: QueryCache | None = None +_cache_lock = threading.Lock() + + +def get_query_cache() -> QueryCache: + """Get the global query cache instance.""" + global _query_cache + + if _query_cache is None: + with _cache_lock: + if _query_cache is None: + _query_cache = QueryCache() + + return _query_cache + + +class QueryService: + """ + Enhanced query service for DataTables Viewer API. + + Provides comprehensive query execution with type-aware filtering, + aggregations, full-text search, and result caching. + """ + + def __init__(self) -> None: + """Initialize the query service.""" + self.pool = get_connection_pool() + self.cache = get_query_cache() + # In-memory cache for index existence to avoid frequent sqlite_master queries + # Key: {db_path}:{table_name}:{column_name}, Value: timestamp + self._index_cache: dict[str, float] = {} + self._index_lock = threading.Lock() + + def get_column_types(self, db_path: Path, table_name: str) -> list[ColumnType]: + """ + Get column type information from table schema. + """ + try: + with self.pool.connection(db_path) as conn: + cursor = conn.cursor() + + # Validate table existence and get validated table name + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + result = cursor.fetchone() + if not result: + raise TableNotFoundError(table_name) + + # Use validated table name from sqlite_master to prevent SQL injection + validated_table_name = result[0] + cursor.execute(f"PRAGMA table_info(\"{validated_table_name}\")") + rows = cursor.fetchall() + + column_types = [] + for row in rows: + # PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk + column_types.append(ColumnType( + name=row[1], + type=row[2] or "TEXT", # Default to TEXT if type is NULL + notnull=bool(row[3]), + pk=bool(row[5]), + dflt_value=row[4] + )) + + return column_types + + except sqlite3.Error as e: + logger.error(f"Error getting column types: {e}") + raise + + def is_numeric_column(self, column_type: str) -> bool: + """Check if a column type is numeric.""" + if not column_type: + return False + type_upper = column_type.upper() + return any(numeric_type in type_upper for numeric_type in ["INT", "REAL", "NUMERIC"]) + + def convert_numeric_value(self, value: Any, column_type: str) -> float | int: + """ + Convert a value to numeric type based on column type. + + Raises: + ValueError: If value cannot be converted to the target numeric type + """ + if value is None: + return 0 + + type_upper = column_type.upper() + + # Strict validation: prevent text->0 coercion + try: + if "INT" in type_upper: + return int(float(str(value))) + else: + return float(str(value)) + except (ValueError, TypeError): + # Re-raise with clear message instead of returning 0 + raise ValueError(f"Invalid numeric value '{value}' for column type '{column_type}'") + + def ensure_index(self, db_path: Path, table_name: str, column: str) -> None: + """Ensure an index exists on a column. Optimized with in-memory cache.""" + cache_key = f"{db_path}:{table_name}:{column}" + + with self._index_lock: + # Check cache with TTL + if cache_key in self._index_cache: + if time.time() - self._index_cache[cache_key] < INDEX_CACHE_TTL: + return + + try: + with self.pool.connection(db_path) as conn: + cursor = conn.cursor() + index_name = f"idx_{table_name}_{column}".replace(" ", "_").replace("-", "_") + safe_table = f'"{table_name}"' + safe_column = f'"{column}"' + + cursor.execute( + f'CREATE INDEX IF NOT EXISTS "{index_name}" ON {safe_table}({safe_column})' + ) + conn.commit() + + with self._index_lock: + self._index_cache[cache_key] = time.time() + + except sqlite3.Error as e: + logger.warning(f"Error creating index on {table_name}.{column}: {e}") + + def ensure_fts5_table(self, db_path: Path, table_name: str, text_columns: list[str]) -> bool: + """ + Ensure FTS5 virtual table exists for full-text search. + + Safety: Skips creation if table is too large (>100k rows) to prevent + blocking the request thread for too long. + """ + if not text_columns: + return False + + try: + with self.pool.connection(db_path) as conn: + cursor = conn.cursor() + + fts5_table_name = f"{table_name}_fts5" + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (fts5_table_name,)) + if cursor.fetchone(): + return True + + # Check capabilities + cursor.execute("PRAGMA compile_options") + if "ENABLE_FTS5" not in [row[0] for row in cursor.fetchall()]: + return False + + # SAFETY CHECK: Row count limit + # Creating FTS5 index copies all data. For large tables, this is a heavy operation. + cursor.execute(f'SELECT COUNT(*) FROM "{table_name}"') + row_count = cursor.fetchone()[0] + if row_count > 100000: + logger.warning(f"Skipping FTS5 creation for large table '{table_name}' ({row_count} rows)") + return False + + safe_columns = ", ".join(f'"{col}"' for col in text_columns) + cursor.execute(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS "{fts5_table_name}" + USING fts5({safe_columns}, content="{table_name}", content_rowid="rowid") + """) + + # Populate + # If table has integer PK, use it as rowid implicitly + cursor.execute(f""" + INSERT INTO "{fts5_table_name}"(rowid, {safe_columns}) + SELECT rowid, {safe_columns} FROM "{table_name}" + """) + + conn.commit() + return True + except sqlite3.Error: + return False + + def _build_select_clause( + self, + columns: list[str] | None, + aggregations: list[AggregationSpec] | None, + group_by: list[str] | None, + column_types: dict[str, ColumnType] + ) -> tuple[str, list[str]]: + """ + Build SELECT clause and return logic for headers. + + Returns: + Tuple of (select_sql, headers_list) + """ + select_parts = [] + headers = [] + + if aggregations: + # GROUP BY columns in SELECT + if group_by: + for col in group_by: + if col in column_types: + select_parts.append(f'"{col}"') + headers.append(col) + + # Aggregation columns + for agg in aggregations: + if agg.column != "*" and agg.column not in column_types: + continue + + safe_col = f'"{agg.column}"' if agg.column != "*" else "*" + + if agg.function == "count": + expr = f"COUNT({safe_col})" + elif agg.function == "distinct_count": + expr = f"COUNT(DISTINCT {safe_col})" + elif agg.function in ["sum", "avg", "min", "max"]: + expr = f"{agg.function.upper()}({safe_col})" + else: + continue + + alias = agg.alias or f"{agg.function}_{agg.column}" + # Sanitize alias to prevent injection/bad chars + alias = alias.replace('"', '').replace("'", "") + safe_alias = alias + select_parts.append(f'{expr} AS "{safe_alias}"') + headers.append(safe_alias) + + if not select_parts: + select_parts = ["*"] + else: + # Regular columns + if columns: + valid_cols = [] + for col in columns: + if col in column_types: + valid_cols.append(f'"{col}"') + headers.append(col) + if valid_cols: + select_parts = valid_cols + else: + select_parts = ["*"] + else: + select_parts = ["*"] + headers = list(column_types.keys()) + + return ", ".join(select_parts), headers + + def _build_where_clause( + self, + db_path: Path, + table_name: str, + filters: list[FilterSpec] | None, + search_value: str | None, + column_types_list: list[ColumnType], + column_types_map: dict[str, ColumnType], + params: list[Any] + ) -> str: + """Build WHERE clause including global search and field filters.""" + where_conditions = [] + + # Global Search + if search_value: + text_columns = [ + col.name for col in column_types_list + if not self.is_numeric_column(col.type) + ] + + # Note: ensures FTS5 table is ready. This might skip if table is large. + if text_columns and self.ensure_fts5_table(db_path, table_name, text_columns): + fts5_table = f"{table_name}_fts5" + where_conditions.append( + f'rowid IN (SELECT rowid FROM "{fts5_table}" WHERE "{fts5_table}" MATCH ?)' + ) + params.append(search_value) + elif text_columns: + search_conditions = [] + for col in text_columns: + search_conditions.append(f'"{col}" LIKE ?') + params.append(f"%{search_value}%") + if search_conditions: + where_conditions.append(f"({' OR '.join(search_conditions)})") + + # Filters + if filters: + for filter_spec in filters: + condition = self._build_single_filter(filter_spec, column_types_map, params) + if condition: + where_conditions.append(condition) + + return f" WHERE {' AND '.join(where_conditions)}" if where_conditions else "" + + def _build_single_filter( + self, + filter_spec: FilterSpec, + column_types: dict[str, ColumnType], + params: list[Any] + ) -> str: + """ + Build SQL condition for a single filter. + + Raises: + InvalidFilterError: If filter parameters are unsafe (e.g. too many IN values) + """ + column = filter_spec.column + operator = filter_spec.operator.lower() + value = filter_spec.value + + if column not in column_types: + logger.warning(f"Column '{column}' not found, skipping filter") + return "" + + col_type = column_types[column] + is_numeric = self.is_numeric_column(col_type.type) + safe_column = f'"{column}"' + + if operator == "is_null": + return f"{safe_column} IS NULL" + if operator == "is_not_null": + return f"{safe_column} IS NOT NULL" + + if value is None: + return "" + + # Check variable limits for array operators + if operator in ["in", "not_in"] and isinstance(value, list): + if len(value) > 900: + raise InvalidFilterError(f"Too many values for IN operator: {len(value)}. Max is 900.") + + # Numeric handling + if is_numeric and operator in ["eq", "ne", "gt", "gte", "lt", "lte", "between", "in", "not_in"]: + if operator == "between": + if filter_spec.value2 is None: return "" + params.append(self.convert_numeric_value(value, col_type.type)) + params.append(self.convert_numeric_value(filter_spec.value2, col_type.type)) + return f"{safe_column} BETWEEN ? AND ?" + elif operator in ["in", "not_in"]: + if not isinstance(value, list): return "" + vals = [self.convert_numeric_value(v, col_type.type) for v in value] + placeholders = ",".join(["?"] * len(vals)) + params.extend(vals) + op = "IN" if operator == "in" else "NOT IN" + return f"{safe_column} {op} ({placeholders})" + else: + params.append(self.convert_numeric_value(value, col_type.type)) + else: + # Text handling + if operator in ["like", "ilike"]: + params.append(f"%{value}%") + elif operator in ["in", "not_in"]: + if not isinstance(value, list): return "" + placeholders = ",".join(["?"] * len(value)) + params.extend(value) + op = "IN" if operator == "in" else "NOT IN" + return f"{safe_column} {op} ({placeholders})" + else: + params.append(value) + + operator_map = { + "eq": "=", "ne": "!=", "gt": ">", "gte": ">=", + "lt": "<", "lte": "<=", "like": "LIKE", "ilike": "LIKE" + } + + sql_op = operator_map.get(operator) + return f"{safe_column} {sql_op} ?" if sql_op else "" + + def execute_query( + self, + db_path: Path, + table_name: str, + limit: int = 100, + offset: int = 0, + columns: list[str] | None = None, + sort_column: str | None = None, + sort_order: str = "ASC", + search_value: str | None = None, + filters: list[FilterSpec] | None = None, + aggregations: list[AggregationSpec] | None = None, + group_by: list[str] | None = None, + use_cache: bool = True + ) -> dict[str, Any]: + """Execute a comprehensive query with all features.""" + try: + table_mtime = db_path.stat().st_mtime + except OSError: + table_mtime = 0.0 + + # 1. Cache Check + cache_key = self._build_cache_key( + db_path, table_name, limit, offset, columns, sort_column, + sort_order, search_value, filters, aggregations, group_by + ) + + if use_cache: + cached = self.cache.get(cache_key, table_mtime) + if cached: + cached["cached"] = True + return cached + + # 2. Schema & Validation + # This calls get_column_types internally which uses the pool correctly now + column_types_list = self.get_column_types(db_path, table_name) + column_types_map = {col.name: col for col in column_types_list} + + # 3. Indices + if filters: + for f in filters: + if f.column in column_types_map: + self.ensure_index(db_path, table_name, f.column) + if sort_column and sort_column in column_types_map: + self.ensure_index(db_path, table_name, sort_column) + + # 4. Query Construction + select_clause, headers = self._build_select_clause(columns, aggregations, group_by, column_types_map) + + where_params: list[Any] = [] + where_clause = self._build_where_clause( + db_path, table_name, filters, search_value, + column_types_list, column_types_map, where_params + ) + + group_by_clause = "" + if group_by: + valid_groups = [f'"{col}"' for col in group_by if col in column_types_map] + if valid_groups: + group_by_clause = " GROUP BY " + ", ".join(valid_groups) + + order_by_clause = "" + if sort_column and sort_column in column_types_map: + direction = "DESC" if sort_order.upper() == "DESC" else "ASC" + order_by_clause = f' ORDER BY "{sort_column}" {direction}' + elif not aggregations and column_types_list: + order_by_clause = f' ORDER BY "{column_types_list[0].name}" ASC' + + limit_clause = f" LIMIT {int(limit)}" + offset_clause = f" OFFSET {int(offset)}" if offset > 0 else "" + + # 5. Execution - Use the connection context manager + with self.pool.connection(db_path) as conn: + cursor = conn.cursor() + + # Count Query + count_query = f'SELECT COUNT(*) FROM "{table_name}"{where_clause}' + cursor.execute(count_query, where_params) + total_count = cursor.fetchone()[0] + + # Data Query + query = f'SELECT {select_clause} FROM "{table_name}"{where_clause}{group_by_clause}{order_by_clause}{limit_clause}{offset_clause}' + + start_time = time.time() + cursor.execute(query, where_params) + rows = cursor.fetchall() + execution_time_ms = (time.time() - start_time) * 1000 + + # 6. Formatting + data = [[str(val) if val is not None else "" for val in row] for row in rows] + + response_column_types = [] + for col_name in headers: + if col_name in column_types_map: + ct = column_types_map[col_name] + response_column_types.append({ + "name": ct.name, "type": ct.type, + "notnull": ct.notnull, "pk": ct.pk, "dflt_value": ct.dflt_value + }) + else: + response_column_types.append({ + "name": col_name, "type": "REAL", "notnull": False, + "pk": False, "dflt_value": None + }) + + result = { + "headers": headers, + "data": data, + "total_count": total_count, + "column_types": response_column_types, + "query_metadata": { + "query_type": "aggregate" if aggregations else "select", + "sql": query, + "filters_applied": len(filters) if filters else 0, + "has_search": bool(search_value) + }, + "cached": False, + "execution_time_ms": execution_time_ms, + "limit": limit, + "offset": offset, + "table_name": table_name, + "database_path": str(db_path) + } + + if use_cache: + self.cache.set(cache_key, result, table_mtime) + + return result + + def _build_cache_key(self, db_path, table_name, limit, offset, columns, sort_column, + sort_order, search_value, filters, aggregations, group_by) -> str: + """Build precise cache key.""" + params = { + "db": str(db_path), "tbl": table_name, "l": limit, "o": offset, + "cols": columns, "sc": sort_column, "so": sort_order, "q": search_value, + "f": [(f.column, f.operator, f.value, f.value2) for f in (filters or [])], + "a": [(a.column, a.function, a.alias) for a in (aggregations or [])], + "gb": group_by + } + return hashlib.md5(json.dumps(params, sort_keys=True, default=str).encode()).hexdigest() + +_query_service: QueryService | None = None +_service_lock = threading.Lock() + +def get_query_service() -> QueryService: + """Get the global query service instance.""" + global _query_service + if _query_service is None: + with _service_lock: + if _query_service is None: + _query_service = QueryService() + return _query_service + diff --git a/app/services/data/schema_analyzer.py b/app/services/data/schema_analyzer.py new file mode 100644 index 0000000..2cff622 --- /dev/null +++ b/app/services/data/schema_analyzer.py @@ -0,0 +1,400 @@ +""" +Schema Analyzer. + +Comprehensive database schema introspection with sample value analysis. +Profiles tables and columns to provide input for type inference and AI analysis. +""" + +from __future__ import annotations + +import logging +import sqlite3 +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ColumnProfile: + """Detailed profile of a database column.""" + + name: str + sqlite_type: str # INTEGER, TEXT, REAL, BLOB, NULL + sample_values: list[Any] = field(default_factory=list) + null_count: int = 0 + total_count: int = 0 + unique_count: int = 0 + avg_length: float = 0.0 # For TEXT columns + min_value: Any = None # For numeric columns + max_value: Any = None + detected_patterns: list[str] = field(default_factory=list) + + @property + def null_ratio(self) -> float: + """Percentage of NULL values.""" + return self.null_count / self.total_count if self.total_count > 0 else 0.0 + + @property + def unique_ratio(self) -> float: + """Cardinality indicator (unique / total).""" + return self.unique_count / self.total_count if self.total_count > 0 else 0.0 + + @property + def is_likely_id(self) -> bool: + """Check if column is likely an identifier.""" + # High cardinality + low nulls + ID-like name pattern + return ( + self.unique_ratio > 0.9 and + self.null_ratio < 0.01 and + any(p in self.name.lower() for p in ["id", "key", "ref"]) + ) + + +@dataclass +class TableProfile: + """Complete profile of a database table.""" + + name: str + row_count: int = 0 + columns: list[ColumnProfile] = field(default_factory=list) + primary_key: str | None = None + foreign_keys: list[str] = field(default_factory=list) + + @property + def column_count(self) -> int: + return len(self.columns) + + def get_column(self, name: str) -> ColumnProfile | None: + """Get a column profile by name.""" + for col in self.columns: + if col.name == name: + return col + return None + + +class SchemaAnalyzer: + """ + Database schema introspection and profiling. + + Analyzes SQLite databases to extract: + - Table metadata (row counts, column counts) + - Column details (types, nullability, cardinality) + - Sample values for type inference + - Statistical summaries + """ + + def __init__(self, sample_size: int = 10) -> None: + """ + Initialize the schema analyzer. + + Args: + sample_size: Number of sample values to collect per column + """ + self.sample_size = sample_size + + def analyze_database(self, db_path: Path) -> list[TableProfile]: + """ + Analyze all tables in a SQLite database. + + Args: + db_path: Path to the SQLite database file + + Returns: + List of TableProfile objects for each table + """ + profiles: list[TableProfile] = [] + + try: + conn = sqlite3.connect(str(db_path)) + cursor = conn.cursor() + + # Get list of user tables + cursor.execute(""" + SELECT name FROM sqlite_master + WHERE type='table' + AND name NOT LIKE 'sqlite_%' + ORDER BY name + """) + tables = [row[0] for row in cursor.fetchall()] + + for table_name in tables: + try: + profile = self._analyze_table(cursor, table_name) + profiles.append(profile) + except Exception as e: + logger.warning(f"Error analyzing table {table_name}: {e}") + + conn.close() + + except sqlite3.Error as e: + logger.error(f"Error opening database {db_path}: {e}") + raise + + logger.info(f"Analyzed {len(profiles)} tables from {db_path}") + return profiles + + def analyze_table(self, db_path: Path, table_name: str) -> TableProfile: + """ + Analyze a single table in a SQLite database. + + Args: + db_path: Path to the SQLite database file + table_name: Name of the table to analyze + + Returns: + TableProfile for the specified table + """ + try: + conn = sqlite3.connect(str(db_path)) + cursor = conn.cursor() + + # Validate table existence and get validated table name + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + result = cursor.fetchone() + if not result: + raise ValueError(f"Invalid table name: {table_name}") + + # Use validated table name from sqlite_master to prevent SQL injection + validated_table_name = result[0] + profile = self._analyze_table(cursor, validated_table_name) + + conn.close() + return profile + + except sqlite3.Error as e: + logger.error(f"Error analyzing table {table_name}: {e}") + raise + + def get_sample_values( + self, + db_path: Path, + table_name: str, + column_name: str, + n: int | None = None + ) -> list[Any]: + """ + Get sample values from a specific column. + + Args: + db_path: Path to the SQLite database file + table_name: Name of the table + column_name: Name of the column + n: Number of samples (defaults to self.sample_size) + + Returns: + List of sample values (distinct, non-null when possible) + """ + if n is None: + n = self.sample_size + + try: + conn = sqlite3.connect(str(db_path)) + cursor = conn.cursor() + + # Validate table exists + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + if not cursor.fetchone(): + raise ValueError(f"Table not found: {table_name}") + + # Get distinct non-null samples first + safe_col = column_name.replace('"', '""') + cursor.execute(f''' + SELECT DISTINCT "{safe_col}" + FROM "{table_name}" + WHERE "{safe_col}" IS NOT NULL + LIMIT ? + ''', (n,)) + + samples = [row[0] for row in cursor.fetchall()] + conn.close() + + return samples + + except sqlite3.Error as e: + logger.error(f"Error getting samples from {table_name}.{column_name}: {e}") + raise + + # ─── Private Methods ──────────────────────────────────────────────────── + + def _analyze_table(self, cursor: sqlite3.Cursor, table_name: str) -> TableProfile: + """ + Analyze a single table using an open cursor. + + Note: table_name should already be validated from sqlite_master. + """ + + profile = TableProfile(name=table_name) + + # table_name is already validated, safe to use in queries + # Get row count + cursor.execute(f'SELECT COUNT(*) FROM "{table_name}"') + profile.row_count = cursor.fetchone()[0] + + # Get column info + cursor.execute(f'PRAGMA table_info("{table_name}")') + columns_info = cursor.fetchall() + + # Get primary key + for col_info in columns_info: + if col_info[5] == 1: # pk column in PRAGMA result + profile.primary_key = col_info[1] + break + + # Analyze each column + for col_info in columns_info: + col_name = col_info[1] + col_type = col_info[2] or "TEXT" + + col_profile = self._analyze_column( + cursor, table_name, col_name, col_type, profile.row_count + ) + profile.columns.append(col_profile) + + return profile + + def _analyze_column( + self, + cursor: sqlite3.Cursor, + table_name: str, + col_name: str, + col_type: str, + row_count: int + ) -> ColumnProfile: + """Analyze a single column.""" + + safe_col = col_name.replace('"', '""') + safe_table = table_name.replace('"', '""') + + profile = ColumnProfile( + name=col_name, + sqlite_type=col_type.upper(), + total_count=row_count, + ) + + if row_count == 0: + return profile + + # Get null count + cursor.execute(f''' + SELECT COUNT(*) FROM "{safe_table}" WHERE "{safe_col}" IS NULL + ''') + profile.null_count = cursor.fetchone()[0] + + # Get unique count (limit to avoid performance issues on large tables) + try: + cursor.execute(f''' + SELECT COUNT(DISTINCT "{safe_col}") FROM "{safe_table}" + ''') + profile.unique_count = cursor.fetchone()[0] + except sqlite3.Error: + profile.unique_count = 0 + + # Get sample values (distinct, non-null) + cursor.execute(f''' + SELECT DISTINCT "{safe_col}" + FROM "{safe_table}" + WHERE "{safe_col}" IS NOT NULL + LIMIT {self.sample_size} + ''') + profile.sample_values = [row[0] for row in cursor.fetchall()] + + # Get statistics for numeric columns + if col_type.upper() in ("INTEGER", "REAL", "NUMERIC"): + try: + cursor.execute(f''' + SELECT MIN("{safe_col}"), MAX("{safe_col}"), AVG(LENGTH(CAST("{safe_col}" AS TEXT))) + FROM "{safe_table}" + WHERE "{safe_col}" IS NOT NULL + ''') + result = cursor.fetchone() + if result: + profile.min_value = result[0] + profile.max_value = result[1] + profile.avg_length = result[2] or 0.0 + except sqlite3.Error as e: + # Per-column numeric statistics are best-effort; log at debug and continue. + logger.debug( + "Error computing numeric statistics for %s.%s: %s", + table_name, + col_name, + e, + ) + + # Get average length for text columns + elif col_type.upper() in ("TEXT", "VARCHAR", "CHAR", ""): + try: + cursor.execute(f''' + SELECT AVG(LENGTH("{safe_col}")) + FROM "{safe_table}" + WHERE "{safe_col}" IS NOT NULL + ''') + result = cursor.fetchone() + if result and result[0]: + profile.avg_length = float(result[0]) + except sqlite3.Error as e: + # Per-column text statistics are best-effort; log at debug and continue. + logger.debug( + "Error computing text statistics for %s.%s: %s", + table_name, + col_name, + e, + ) + + # Detect patterns in sample values + profile.detected_patterns = self._detect_patterns(profile.sample_values) + + return profile + + def _detect_patterns(self, values: list[Any]) -> list[str]: + """Detect common patterns in sample values.""" + patterns: list[str] = [] + + if not values: + return patterns + + str_values = [str(v) for v in values if v is not None] + if not str_values: + return patterns + + # Check for URL pattern + if all(v.startswith(("http://", "https://")) for v in str_values): + patterns.append("url") + + # Check for email pattern + if all("@" in v and "." in v for v in str_values): + patterns.append("email") + + # Check for GO term pattern + if all(v.startswith("GO:") for v in str_values): + patterns.append("go_term") + + + # Check for ISO date pattern + date_pattern = re.compile(r"^\d{4}-\d{2}-\d{2}") + if all(date_pattern.match(v) for v in str_values): + patterns.append("iso_date") + + # Check for UUID pattern + uuid_pattern = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", + re.IGNORECASE + ) + if all(uuid_pattern.match(v) for v in str_values): + patterns.append("uuid") + + # Check for sequence pattern (DNA/RNA/Protein) + seq_pattern = re.compile(r"^[ATCGUN]+$", re.IGNORECASE) + protein_pattern = re.compile(r"^[ACDEFGHIKLMNPQRSTVWY]+$", re.IGNORECASE) + if all(len(v) > 20 for v in str_values): + if all(seq_pattern.match(v) for v in str_values): + patterns.append("nucleotide_sequence") + elif all(protein_pattern.match(v) for v in str_values): + patterns.append("protein_sequence") + + return patterns diff --git a/app/services/data/schema_service.py b/app/services/data/schema_service.py new file mode 100644 index 0000000..616dab6 --- /dev/null +++ b/app/services/data/schema_service.py @@ -0,0 +1,157 @@ +""" +Schema Information Service. + +Provides table and column schema information including: +- Column names, types, constraints (NOT NULL, PRIMARY KEY) +- Default values +- Indexes +""" + +from __future__ import annotations + +import sqlite3 +import logging +import threading +from pathlib import Path +from typing import Any + +from app.services.data.connection_pool import get_connection_pool +from app.services.data.query_service import QueryService +from app.utils.sqlite import list_tables + +logger = logging.getLogger(__name__) + + +class SchemaService: + """ + Service for retrieving database schema information. + """ + + def __init__(self) -> None: + """Initialize the schema service.""" + self.pool = get_connection_pool() + self.query_service = QueryService() + + def get_table_schema( + self, + db_path: Path, + table_name: str + ) -> dict[str, Any]: + """ + Get schema information for a single table. + + Args: + db_path: Path to SQLite database + table_name: Name of the table + + Returns: + Dictionary with table schema information + """ + # Get column schema using query service (which handles its own connection) + column_types = self.query_service.get_column_types(db_path, table_name) + + columns = [] + for col_type in column_types: + columns.append({ + "name": col_type.name, + "type": col_type.type, + "notnull": col_type.notnull, + "pk": col_type.pk, + "dflt_value": col_type.dflt_value + }) + + # Get indexes using direct connection + indexes = [] + try: + with self.pool.connection(db_path) as conn: + cursor = conn.cursor() + indexes = self._get_table_indexes(cursor, table_name) + except sqlite3.Error as e: + logger.warning(f"Error getting indexes for {table_name}: {e}") + # We continue with empty indexes rather than failing the whole schema request + + return { + "table": table_name, + "columns": columns, + "indexes": indexes + } + + def get_all_tables_schema( + self, + db_path: Path + ) -> dict[str, Any]: + """ + Get schema information for all tables in the database. + + Args: + db_path: Path to SQLite database + + Returns: + Dictionary mapping table names to schema information + """ + + table_names = list_tables(db_path) + schemas = {} + + for table_name in table_names: + try: + schemas[table_name] = self.get_table_schema(db_path, table_name) + except Exception as e: + logger.warning(f"Error getting schema for {table_name}: {e}") + + return schemas + + def _get_table_indexes( + self, + cursor: sqlite3.Cursor, + table_name: str + ) -> list[dict[str, str]]: + """ + Get all indexes for a table. + + Args: + cursor: Database cursor + table_name: Name of the table + + Returns: + List of index information dictionaries + """ + indexes = [] + + try: + # Get indexes for this table + cursor.execute(""" + SELECT name, sql + FROM sqlite_master + WHERE type='index' + AND tbl_name=? + AND name NOT LIKE 'sqlite_%' + """, (table_name,)) + + for row in cursor.fetchall(): + indexes.append({ + "name": row[0], + "sql": row[1] or "" + }) + + except sqlite3.Error as e: + logger.warning(f"Error getting indexes for {table_name}: {e}") + + return indexes + + +# Global schema service instance +_schema_service: SchemaService | None = None +_schema_service_lock = threading.Lock() + + +def get_schema_service() -> SchemaService: + """Get the global schema service instance.""" + global _schema_service + + if _schema_service is None: + with _schema_service_lock: + if _schema_service is None: + _schema_service = SchemaService() + + return _schema_service diff --git a/app/services/data/statistics_service.py b/app/services/data/statistics_service.py new file mode 100644 index 0000000..683bc95 --- /dev/null +++ b/app/services/data/statistics_service.py @@ -0,0 +1,336 @@ +""" +Column Statistics Service. + +Pre-computes and caches column statistics including: +- null_count, distinct_count, min, max, mean, median, stddev +- Sample values for data exploration +""" + +from __future__ import annotations + +import sqlite3 +import logging +import time +import threading +import math +from pathlib import Path +from typing import Any +from dataclasses import dataclass + +from app.services.data.connection_pool import get_connection_pool +from app.services.data.query_service import QueryService + +logger = logging.getLogger(__name__) + + +@dataclass +class ColumnStatistics: + """Statistics for a single column.""" + + column: str + type: str + null_count: int = 0 + distinct_count: int = 0 + min: Any = None + max: Any = None + mean: float | None = None + median: float | None = None + stddev: float | None = None + sample_values: list[Any] = None + + def __post_init__(self): + """Initialize sample_values if None.""" + if self.sample_values is None: + self.sample_values = [] + + +class StatisticsCache: + """ + Cache for pre-computed column statistics. + + Invalidates when table modification time changes. + """ + + def __init__(self) -> None: + """Initialize the statistics cache.""" + self._cache: dict[str, tuple[dict[str, Any], float]] = {} + self._lock = threading.Lock() + + def get(self, cache_key: str, table_mtime: float) -> dict[str, Any] | None: + """ + Get cached statistics. + + Args: + cache_key: Cache key (db_path:table_name) + table_mtime: Table file modification time + + Returns: + Cached statistics if valid, None otherwise + """ + with self._lock: + if cache_key not in self._cache: + return None + + stats, cached_mtime = self._cache[cache_key] + + # Check if table has been modified + if cached_mtime != table_mtime: + del self._cache[cache_key] + return None + + return stats + + def set(self, cache_key: str, stats: dict[str, Any], table_mtime: float) -> None: + """ + Store statistics in cache. + + Args: + cache_key: Cache key (db_path:table_name) + stats: Statistics dictionary + table_mtime: Table file modification time + """ + with self._lock: + self._cache[cache_key] = (stats, table_mtime) + + def clear(self) -> None: + """Clear all cached statistics.""" + with self._lock: + self._cache.clear() + + +# Global statistics cache instance +_stats_cache: StatisticsCache | None = None +_stats_cache_lock = threading.Lock() + + +def get_statistics_cache() -> StatisticsCache: + """Get the global statistics cache instance.""" + global _stats_cache + + if _stats_cache is None: + with _stats_cache_lock: + if _stats_cache is None: + _stats_cache = StatisticsCache() + + return _stats_cache + + +class StatisticsService: + """ + Service for computing and caching column statistics. + """ + + def __init__(self) -> None: + """Initialize the statistics service.""" + self.pool = get_connection_pool() + self.query_service = QueryService() + self.cache = get_statistics_cache() + + def get_table_statistics( + self, + db_path: Path, + table_name: str, + use_cache: bool = True + ) -> dict[str, Any]: + """ + Get comprehensive statistics for all columns in a table. + + Args: + db_path: Path to SQLite database + table_name: Name of the table + use_cache: Whether to use cached statistics + + Returns: + Dictionary with table and column statistics + """ + # Get table modification time for cache invalidation + try: + table_mtime = db_path.stat().st_mtime + except OSError: + table_mtime = 0.0 + + cache_key = f"{db_path.absolute()}:{table_name}" + + # Check cache + if use_cache: + cached_stats = self.cache.get(cache_key, table_mtime) + if cached_stats is not None: + logger.debug(f"Cache hit for statistics: {table_name}") + return cached_stats + + # Execute stats computation + try: + with self.pool.connection(db_path) as conn: + cursor = conn.cursor() + + # Get row count + cursor.execute(f'SELECT COUNT(*) FROM "{table_name}"') + row_count = cursor.fetchone()[0] + + # Get column types (QueryService handles its own connection for this call, + # but we're in StatisticsService, so we just call it. + # Wait, query_service.get_column_types uses the pool independently. + # This is fine, but slightly inefficient (opens 2 connections). + # However, since we are inside a thread (likely), pool might give us a new connection. + # Actually, `get_column_types` is short-lived. + # We can keep using it. + column_types = self.query_service.get_column_types(db_path, table_name) + + # Compute statistics for each column + column_stats_list = [] + + for col_type in column_types: + stats = self._compute_column_statistics( + cursor, table_name, col_type, row_count + ) + column_stats_list.append(stats) + + # Build response + result = { + "table": table_name, + "row_count": row_count, + "columns": [ + { + "column": stats.column, + "type": stats.type, + "null_count": stats.null_count, + "distinct_count": stats.distinct_count, + "min": stats.min, + "max": stats.max, + "mean": stats.mean, + "median": stats.median, + "stddev": stats.stddev, + "sample_values": stats.sample_values + } + for stats in column_stats_list + ], + "last_updated": int(time.time() * 1000) # Milliseconds since epoch + } + + # Cache result + if use_cache: + self.cache.set(cache_key, result, table_mtime) + + return result + + except sqlite3.Error as e: + logger.error(f"Error computing statistics for {table_name}: {e}") + raise + + def _compute_column_statistics( + self, + cursor: sqlite3.Cursor, + table_name: str, + col_type: Any, # ColumnType from query_service + row_count: int + ) -> ColumnStatistics: + """ + Compute statistics for a single column. + + Args: + cursor: Database cursor + table_name: Name of the table + col_type: ColumnType object + row_count: Total row count + + Returns: + ColumnStatistics object + """ + column = col_type.name + sql_type = col_type.type + is_numeric = self.query_service.is_numeric_column(sql_type) + + safe_column = f'"{column}"' + + stats = ColumnStatistics(column=column, type=sql_type) + + try: + # Null count + cursor.execute(f'SELECT COUNT(*) FROM "{table_name}" WHERE {safe_column} IS NULL') + stats.null_count = cursor.fetchone()[0] + + # Distinct count + cursor.execute(f'SELECT COUNT(DISTINCT {safe_column}) FROM "{table_name}"') + stats.distinct_count = cursor.fetchone()[0] + + if is_numeric: + # Numeric statistics + try: + # Min, max, mean + cursor.execute(f''' + SELECT + MIN({safe_column}), + MAX({safe_column}), + AVG({safe_column}) + FROM "{table_name}" + WHERE {safe_column} IS NOT NULL + ''') + row = cursor.fetchone() + if row and row[0] is not None: + stats.min = float(row[0]) if "REAL" in sql_type.upper() else int(row[0]) + stats.max = float(row[1]) if "REAL" in sql_type.upper() else int(row[1]) + stats.mean = float(row[2]) if row[2] is not None else None + + # Median (approximate using ORDER BY and LIMIT) + if row_count > 0: + cursor.execute(f''' + SELECT {safe_column} + FROM "{table_name}" + WHERE {safe_column} IS NOT NULL + ORDER BY {safe_column} + LIMIT 1 OFFSET ? + ''', (row_count // 2,)) + median_row = cursor.fetchone() + if median_row and median_row[0] is not None: + stats.median = float(median_row[0]) if "REAL" in sql_type.upper() else int(median_row[0]) + + # Standard deviation (approximate) + if stats.mean is not None: + cursor.execute(f''' + SELECT AVG(({safe_column} - ?) * ({safe_column} - ?)) + FROM "{table_name}" + WHERE {safe_column} IS NOT NULL + ''', (stats.mean, stats.mean)) + variance_row = cursor.fetchone() + if variance_row and variance_row[0] is not None: + variance = float(variance_row[0]) + stats.stddev = math.sqrt(variance) if variance >= 0 else None + + except sqlite3.Error as e: + logger.warning(f"Error computing numeric statistics for {column}: {e}") + + # Sample values (always compute) + try: + cursor.execute(f''' + SELECT DISTINCT {safe_column} + FROM "{table_name}" + WHERE {safe_column} IS NOT NULL + LIMIT 5 + ''') + sample_rows = cursor.fetchall() + stats.sample_values = [row[0] for row in sample_rows if row[0] is not None] + + except sqlite3.Error as e: + logger.warning(f"Error getting sample values for {column}: {e}") + + except sqlite3.Error as e: + logger.warning(f"Error computing statistics for {column}: {e}") + + return stats + + +# Global statistics service instance +_stats_service: StatisticsService | None = None +_stats_service_lock = threading.Lock() + + +def get_statistics_service() -> StatisticsService: + """Get the global statistics service instance.""" + global _stats_service + + if _stats_service is None: + with _stats_service_lock: + if _stats_service is None: + _stats_service = StatisticsService() + + return _stats_service diff --git a/app/services/data/type_inference.py b/app/services/data/type_inference.py new file mode 100644 index 0000000..bdd97e2 --- /dev/null +++ b/app/services/data/type_inference.py @@ -0,0 +1,550 @@ +""" +Type Inference Engine. + +Rule-based pattern detection for inferring column data types and rendering +configurations. This module provides fast, deterministic type inference +without requiring AI, and serves as the foundation for hybrid inference. + +Works independently of AI providers and can serve as a fallback. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Any, Literal +from enum import Enum + + +class DataType(str, Enum): + """Column data types matching DataTables_Viewer ColumnDataType.""" + STRING = "string" + NUMBER = "number" + INTEGER = "integer" + FLOAT = "float" + BOOLEAN = "boolean" + DATE = "date" + DATETIME = "datetime" + TIMESTAMP = "timestamp" + JSON = "json" + ARRAY = "array" + SEQUENCE = "sequence" + ID = "id" + URL = "url" + EMAIL = "email" + ONTOLOGY = "ontology" + PERCENTAGE = "percentage" + FILESIZE = "filesize" + DURATION = "duration" + CURRENCY = "currency" + COLOR = "color" + IMAGE = "image" + CUSTOM = "custom" + + +@dataclass +class TransformConfig: + """Transform configuration for cell rendering.""" + type: str + options: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class InferredType: + """Result of type inference for a column.""" + data_type: DataType + display_name: str + categories: list[str] + transform: TransformConfig | None = None + width: str = "auto" + pin: Literal["left", "right"] | None = None + sortable: bool = True + filterable: bool = True + copyable: bool = False + confidence: float = 1.0 + source: Literal["rules", "ai", "hybrid"] = "rules" + + +# ============================================================================= +# PATTERN DEFINITIONS +# ============================================================================= + +# Column name patterns mapped to inference results +NAME_PATTERNS: list[tuple[re.Pattern, dict[str, Any]]] = [ + # IDs - typically pinned left + (re.compile(r"^(ID|id)$"), { + "data_type": DataType.ID, + "categories": ["core"], + "pin": "left", + "copyable": True, + "width": "100px", + }), + (re.compile(r".*_ID$|.*_id$|.*Id$"), { + "data_type": DataType.ID, + "categories": ["core"], + "copyable": True, + "width": "120px", + }), + (re.compile(r"^Database_ID$|^database_id$"), { + "data_type": DataType.ID, + "categories": ["core"], + "copyable": True, + "width": "130px", + }), + + # UniRef IDs - need chain transformer to strip prefix + (re.compile(r"^uniref_\d+$|^UniRef_\d+$|^uniref\d+$"), { + "data_type": DataType.ID, + "categories": ["external"], + "copyable": True, + "width": "140px", + "transform": TransformConfig( + type="chain", + options={ + "transforms": [ + {"type": "replace", "options": {"find": "UniRef:", "replace": ""}}, + {"type": "link", "options": { + "urlTemplate": "https://www.uniprot.org/uniref/{value}", + "target": "_blank", + "icon": "bi-link-45deg" + }} + ] + } + ), + }), + + # External database references with link transforms + (re.compile(r"^Uniprot.*|^uniprot.*|.*UniProt.*"), { + "data_type": DataType.ID, + "categories": ["external"], + "width": "100px", + "transform": TransformConfig( + type="link", + options={ + "urlTemplate": "https://www.uniprot.org/uniprotkb/{value}", + "target": "_blank", + "icon": "bi-link-45deg" + } + ), + }), + (re.compile(r"^KEGG.*|^kegg.*"), { + "data_type": DataType.ID, + "categories": ["external"], + "width": "90px", + "transform": TransformConfig( + type="link", + options={ + "urlTemplate": "https://www.genome.jp/entry/{value}", + "target": "_blank" + } + ), + }), + (re.compile(r"^GO_.*|^go_.*"), { + "data_type": DataType.ONTOLOGY, + "categories": ["functional"], + "width": "180px", + "transform": TransformConfig( + type="ontology", + options={ + "prefix": "GO", + "urlTemplate": "https://amigo.geneontology.org/amigo/term/{value}", + "style": "badge" + } + ), + }), + + # Pfam domain IDs + (re.compile(r"^pfam.*|^Pfam.*|^PF\d+"), { + "data_type": DataType.ID, + "categories": ["ontology"], + "width": "100px", + "transform": TransformConfig( + type="chain", + options={ + "transforms": [ + {"type": "replace", "options": {"find": "pfam:", "replace": ""}}, + {"type": "link", "options": { + "urlTemplate": "https://www.ebi.ac.uk/interpro/entry/pfam/{value}", + "target": "_blank", + "icon": "bi-link-45deg" + }} + ] + } + ), + }), + + # NCBI protein IDs (RefSeq) + (re.compile(r"^ncbi.*|.*_ncbi.*|^NP_.*|^WP_.*|^XP_.*"), { + "data_type": DataType.ID, + "categories": ["external"], + "copyable": True, + "width": "120px", + "transform": TransformConfig( + type="link", + options={ + "urlTemplate": "https://www.ncbi.nlm.nih.gov/protein/{value}", + "target": "_blank", + "icon": "bi-link-45deg" + } + ), + }), + + # Strand indicator (+/-) + (re.compile(r"^strand$|^Strand$|.*_strand$"), { + "data_type": DataType.STRING, + "categories": ["core"], + "width": "80px", + "transform": TransformConfig( + type="badge", + options={ + "colorMap": { + "+": {"color": "#22c55e", "bgColor": "#dcfce7"}, + "-": {"color": "#ef4444", "bgColor": "#fee2e2"}, + ".": {"color": "#94a3b8", "bgColor": "#f1f5f9"} + } + } + ), + }), + + # Sequences + (re.compile(r".*Sequence.*|.*_seq$|.*_Seq$"), { + "data_type": DataType.SEQUENCE, + "categories": ["sequence"], + "sortable": False, + "filterable": False, + "copyable": True, + "width": "150px", + "transform": TransformConfig( + type="sequence", + options={"maxLength": 20, "showCopyButton": True} + ), + }), + + # Function/product descriptions + (re.compile(r".*function.*|.*Function.*|.*product.*|.*Product.*"), { + "data_type": DataType.STRING, + "categories": ["functional"], + "width": "300px", + }), + + # Statistical measures with special formatting + (re.compile(r"^Log2FC$|.*log2.*fold.*|.*Log2.*Fold.*"), { + "data_type": DataType.FLOAT, + "categories": ["expression"], + "width": "130px", + "transform": TransformConfig( + type="heatmap", + options={ + "min": -4, "max": 4, + "colorScale": "diverging", + "showValue": True, + "decimals": 2 + } + ), + }), + (re.compile(r"^P[_-]?[Vv]alue$|^pvalue$|^p_value$"), { + "data_type": DataType.FLOAT, + "categories": ["statistics"], + "width": "100px", + "transform": TransformConfig( + type="number", + options={"notation": "scientific", "decimals": 2} + ), + }), + (re.compile(r"^FDR$|^fdr$|^q[_-]?value$"), { + "data_type": DataType.FLOAT, + "categories": ["statistics"], + "width": "100px", + "transform": TransformConfig( + type="number", + options={"notation": "scientific", "decimals": 2} + ), + }), + + # Boolean indicators + (re.compile(r"^Significant$|^is_.*|^has_.*"), { + "data_type": DataType.BOOLEAN, + "categories": ["statistics"], + "width": "90px", + "transform": TransformConfig( + type="boolean", + options={ + "trueIcon": "bi-check-circle-fill", + "falseIcon": "bi-x-circle", + "trueColor": "#22c55e", + "falseColor": "#94a3b8" + } + ), + }), + + # Temperature with unit + (re.compile(r".*Temperature.*|.*_in_C$"), { + "data_type": DataType.FLOAT, + "categories": ["experimental"], + "width": "120px", + "transform": TransformConfig( + type="number", + options={"decimals": 1, "suffix": "°C"} + ), + }), + + # Concentration fields + (re.compile(r".*Concentration.*|.*_in_mM$|.*_in_mg.*"), { + "data_type": DataType.FLOAT, + "categories": ["media"], + "width": "120px", + "transform": TransformConfig( + type="number", + options={"decimals": 2} + ), + }), + + # Name fields + (re.compile(r"^Name$|^name$|.*_Name$|.*_name$"), { + "data_type": DataType.STRING, + "categories": ["core"], + "width": "200px", + }), + + # URL fields + (re.compile(r".*_URL$|.*_url$|.*Link$|.*link$"), { + "data_type": DataType.URL, + "categories": ["external"], + "width": "150px", + }), +] + +# Value patterns for detecting types from sample data +VALUE_PATTERNS: list[tuple[re.Pattern, DataType]] = [ + # URLs + (re.compile(r"^https?://"), DataType.URL), + # Email + (re.compile(r"^[\w.+-]+@[\w-]+\.[\w.-]+$"), DataType.EMAIL), + # GO terms + (re.compile(r"^GO:\d{7}"), DataType.ONTOLOGY), + # ISO dates + (re.compile(r"^\d{4}-\d{2}-\d{2}$"), DataType.DATE), + (re.compile(r"^\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}"), DataType.DATETIME), + # Colors + (re.compile(r"^#[0-9a-fA-F]{6}$|^rgb\("), DataType.COLOR), + # DNA/RNA sequences (long strings of ATCGU only) + (re.compile(r"^[ATCGU]{20,}$", re.IGNORECASE), DataType.SEQUENCE), + # Protein sequences (amino acid codes) + (re.compile(r"^[ACDEFGHIKLMNPQRSTVWY]{20,}$", re.IGNORECASE), DataType.SEQUENCE), +] + + +# ============================================================================= +# TYPE INFERENCE ENGINE +# ============================================================================= + +class TypeInferenceEngine: + """ + Rule-based type inference engine. + + Analyzes column names and sample values to infer data types, + display configurations, and rendering transforms without AI. + """ + + def __init__(self) -> None: + self._name_patterns = NAME_PATTERNS + self._value_patterns = VALUE_PATTERNS + + def infer_from_name(self, column_name: str) -> InferredType | None: + """ + Infer column type from column name patterns. + + Args: + column_name: The name of the column + + Returns: + InferredType if a pattern matches, None otherwise + """ + for pattern, config in self._name_patterns: + if pattern.match(column_name): + return InferredType( + data_type=config.get("data_type", DataType.STRING), + display_name=self._format_display_name(column_name), + categories=config.get("categories", []), + transform=config.get("transform"), + width=config.get("width", "auto"), + pin=config.get("pin"), + sortable=config.get("sortable", True), + filterable=config.get("filterable", True), + copyable=config.get("copyable", False), + confidence=0.9, # High confidence for name pattern match + source="rules", + ) + return None + + def infer_from_values( + self, + column_name: str, + sample_values: list[Any], + sqlite_type: str = "TEXT" + ) -> InferredType: + """ + Infer column type from sample values. + + Args: + column_name: The name of the column + sample_values: List of sample values from the column + sqlite_type: The SQLite column type + + Returns: + InferredType with inferred configuration + """ + # First, try name-based inference + name_inference = self.infer_from_name(column_name) + if name_inference: + return name_inference + + # Filter out None/empty values for analysis + valid_values = [v for v in sample_values if v is not None and str(v).strip()] + + if not valid_values: + return self._default_inference(column_name, sqlite_type) + + # Check for boolean values + if self._is_boolean(valid_values): + return InferredType( + data_type=DataType.BOOLEAN, + display_name=self._format_display_name(column_name), + categories=["metadata"], + confidence=0.95, + ) + + # Check for numeric types based on SQLite type and values + if sqlite_type in ("INTEGER", "REAL") or self._is_numeric(valid_values): + return self._infer_numeric(column_name, valid_values, sqlite_type) + + # Check value patterns + str_values = [str(v) for v in valid_values] + for pattern, data_type in self._value_patterns: + matches = sum(1 for v in str_values if pattern.match(v)) + if matches / len(str_values) > 0.5: # >50% match threshold + return InferredType( + data_type=data_type, + display_name=self._format_display_name(column_name), + categories=self._default_category(data_type), + confidence=0.8, + ) + + # Default to string + return self._default_inference(column_name, sqlite_type) + + def infer( + self, + column_name: str, + sample_values: list[Any] | None = None, + sqlite_type: str = "TEXT" + ) -> InferredType: + """ + Full inference combining name and value analysis. + + Args: + column_name: The name of the column + sample_values: Optional list of sample values + sqlite_type: The SQLite column type + + Returns: + InferredType with best inference + """ + if sample_values: + return self.infer_from_values(column_name, sample_values, sqlite_type) + + name_inference = self.infer_from_name(column_name) + if name_inference: + return name_inference + + return self._default_inference(column_name, sqlite_type) + + # ─── Helper Methods ───────────────────────────────────────────────────── + + def _format_display_name(self, column_name: str) -> str: + """Convert column name to human-readable display name.""" + # Replace underscores and handle camelCase + name = re.sub(r"_", " ", column_name) + name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name) + # Title case but preserve acronyms + words = name.split() + formatted = [] + for word in words: + if word.isupper() and len(word) <= 4: # Likely acronym + formatted.append(word) + else: + formatted.append(word.capitalize()) + return " ".join(formatted) + + def _is_boolean(self, values: list[Any]) -> bool: + """Check if values represent boolean data.""" + bool_values = {"true", "false", "yes", "no", "1", "0", "t", "f", "y", "n"} + str_values = {str(v).lower() for v in values} + return str_values.issubset(bool_values) and len(str_values) <= 2 + + def _is_numeric(self, values: list[Any]) -> bool: + """Check if all values are numeric.""" + for v in values: + if v is None: + continue + try: + float(v) + except (ValueError, TypeError): + return False + return True + + def _infer_numeric( + self, + column_name: str, + values: list[Any], + sqlite_type: str + ) -> InferredType: + """Infer numeric type details.""" + # Check if all values are integers + is_integer = all( + isinstance(v, int) or (isinstance(v, float) and v.is_integer()) + for v in values if v is not None + ) + + data_type = DataType.INTEGER if (sqlite_type == "INTEGER" or is_integer) else DataType.FLOAT + + return InferredType( + data_type=data_type, + display_name=self._format_display_name(column_name), + categories=["data"], + width="100px", + transform=TransformConfig( + type="number", + options={"decimals": 0 if is_integer else 2} + ) if data_type == DataType.FLOAT else None, + confidence=0.85, + ) + + def _default_inference(self, column_name: str, sqlite_type: str) -> InferredType: + """Return default string inference.""" + # Map SQLite types to data types + type_map = { + "INTEGER": DataType.INTEGER, + "REAL": DataType.FLOAT, + "BLOB": DataType.CUSTOM, + } + + return InferredType( + data_type=type_map.get(sqlite_type, DataType.STRING), + display_name=self._format_display_name(column_name), + categories=["data"], + confidence=0.5, + ) + + def _default_category(self, data_type: DataType) -> list[str]: + """Get default categories for a data type.""" + category_map = { + DataType.ID: ["core"], + DataType.URL: ["external"], + DataType.EMAIL: ["external"], + DataType.ONTOLOGY: ["functional"], + DataType.SEQUENCE: ["sequence"], + DataType.DATE: ["metadata"], + DataType.DATETIME: ["metadata"], + } + return category_map.get(data_type, ["data"]) diff --git a/app/services/data/validation.py b/app/services/data/validation.py new file mode 100644 index 0000000..b8e6b59 --- /dev/null +++ b/app/services/data/validation.py @@ -0,0 +1,395 @@ +""" +Configuration Validation Module. + +Provides JSON schema validation for generated DataTables Viewer configurations +to ensure compatibility with the frontend viewer. +""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +try: + from jsonschema import validate, Draft7Validator + HAS_JSONSCHEMA = True +except ImportError: + HAS_JSONSCHEMA = False + # Dummy objects if needed + validate = None + Draft7Validator = None + + +# ============================================================================= +# JSON SCHEMAS +# ============================================================================= + +# Schema for individual column configuration +COLUMN_SCHEMA = { + "type": "object", + "required": ["column", "displayName"], + "properties": { + "column": {"type": "string", "minLength": 1}, + "displayName": {"type": "string", "minLength": 1}, + "dataType": { + "type": "string", + "enum": [ + "string", "number", "integer", "float", "boolean", + "date", "datetime", "timestamp", "duration", + "id", "url", "email", "phone", + "percentage", "currency", "filesize", + "sequence", "ontology", "json", "array" + ] + }, + "visible": {"type": "boolean"}, + "sortable": {"type": "boolean"}, + "filterable": {"type": "boolean"}, + "searchable": {"type": "boolean"}, + "copyable": {"type": "boolean"}, + "width": {"type": "string"}, + "align": {"type": "string", "enum": ["left", "center", "right"]}, + "pin": {"type": ["string", "null"], "enum": ["left", "right", None]}, + "categories": { + "type": "array", + "items": {"type": "string"} + }, + "transform": { + "type": ["object", "null"], + "properties": { + "type": {"type": "string"}, + "options": {"type": "object"} + } + } + }, + "additionalProperties": True # Allow future extensions +} + +# Schema for table configuration +TABLE_SCHEMA = { + "type": "object", + "required": ["displayName", "columns"], + "properties": { + "displayName": {"type": "string", "minLength": 1}, + "description": {"type": "string"}, + "icon": {"type": "string"}, + "settings": {"type": "object"}, + "categories": { + "type": "array", + "items": {"type": "object"} + }, + "columns": { + "type": "array", + "items": COLUMN_SCHEMA, + "minItems": 1 + } + } +} + +# Schema for complete DataTypeConfig +DATATYPE_CONFIG_SCHEMA = { + "type": "object", + "required": ["id", "name", "tables"], + "properties": { + "id": {"type": "string", "minLength": 1}, + "name": {"type": "string", "minLength": 1}, + "description": {"type": "string"}, + "version": {"type": "string", "pattern": r"^\d+\.\d+\.\d+$"}, + "icon": {"type": "string"}, + "color": {"type": "string"}, + "objectType": {"type": "string"}, + "defaults": { + "type": "object", + "properties": { + "pageSize": {"type": "integer", "minimum": 1, "maximum": 1000}, + "density": {"type": "string", "enum": ["compact", "default", "comfortable"]}, + "showRowNumbers": {"type": "boolean"}, + "enableSelection": {"type": "boolean"}, + "enableExport": {"type": "boolean"} + } + }, + "sharedCategories": { + "type": "array", + "items": { + "type": "object", + "required": ["id", "name"], + "properties": { + "id": {"type": "string"}, + "name": {"type": "string"}, + "icon": {"type": "string"}, + "color": {"type": "string"}, + "defaultVisible": {"type": "boolean"}, + "order": {"type": "integer"} + } + } + }, + "tables": { + "type": "object", + "additionalProperties": TABLE_SCHEMA, + "minProperties": 1 + } + } +} + +# Schema for AI-generated column response (single table) +AI_RESPONSE_SCHEMA = { + "type": "object", + "required": ["columns"], + "properties": { + "columns": { + "type": "array", + "items": COLUMN_SCHEMA, + "minItems": 1 + } + } +} + + +# ============================================================================= +# VALIDATION FUNCTIONS +# ============================================================================= + +def validate_config(config: dict[str, Any]) -> tuple[bool, str | None]: + """ + Validate a complete DataTypeConfig against the schema. + + Args: + config: The configuration dictionary to validate + + Returns: + Tuple of (is_valid, error_message) + """ + try: + if not HAS_JSONSCHEMA: + raise ImportError("jsonschema not available") + + validator = Draft7Validator(DATATYPE_CONFIG_SCHEMA) + errors = list(validator.iter_errors(config)) + + if not errors: + return True, None + + # Format first error + first_error = errors[0] + path = ".".join(str(p) for p in first_error.absolute_path) or "root" + return False, f"Validation error at '{path}': {first_error.message}" + + except ImportError: + # jsonschema not available, do basic validation + return _basic_validation(config) + except Exception as e: + logger.warning(f"Validation error: {e}") + return False, str(e) + + +def validate_table_config(table_config: dict[str, Any]) -> tuple[bool, str | None]: + """ + Validate a single table configuration. + + Args: + table_config: Table configuration dictionary + + Returns: + Tuple of (is_valid, error_message) + """ + try: + if not HAS_JSONSCHEMA: + raise ImportError("jsonschema not available") + + validate(instance=table_config, schema=TABLE_SCHEMA) + return True, None + + except ImportError: + return _basic_table_validation(table_config) + except Exception as e: + return False, str(e) + + +def validate_ai_response(response: dict[str, Any]) -> tuple[bool, str | None]: + """ + Validate AI-generated column response. + + Args: + response: AI response dictionary + + Returns: + Tuple of (is_valid, error_message) + """ + try: + if not HAS_JSONSCHEMA: + raise ImportError("jsonschema not available") + + validate(instance=response, schema=AI_RESPONSE_SCHEMA) + return True, None + + except ImportError: + # Basic validation + if not isinstance(response, dict): + return False, "Response must be a dictionary" + if "columns" not in response: + return False, "Response must have 'columns' key" + if not isinstance(response["columns"], list): + return False, "'columns' must be an array" + if len(response["columns"]) == 0: + return False, "'columns' array must not be empty" + return True, None + + except Exception as e: + return False, str(e) + + +def validate_column_config(column: dict[str, Any]) -> tuple[bool, str | None]: + """ + Validate a single column configuration. + + Args: + column: Column configuration dictionary + + Returns: + Tuple of (is_valid, error_message) + """ + if not isinstance(column, dict): + return False, "Column must be a dictionary" + + if "column" not in column: + return False, "Column must have 'column' key" + + if "displayName" not in column: + return False, "Column must have 'displayName' key" + + # Validate transform structure if present + if "transform" in column and column["transform"] is not None: + transform = column["transform"] + if not isinstance(transform, dict): + return False, "Transform must be a dictionary" + if "type" not in transform: + return False, "Transform must have 'type' key" + + return True, None + + +# ============================================================================= +# BASIC VALIDATION (fallback when jsonschema unavailable) +# ============================================================================= + +def _basic_validation(config: dict[str, Any]) -> tuple[bool, str | None]: + """Basic validation without jsonschema library.""" + if not isinstance(config, dict): + return False, "Config must be a dictionary" + + # Check required fields + for field in ["id", "name", "tables"]: + if field not in config: + return False, f"Missing required field: {field}" + + if not isinstance(config["tables"], dict): + return False, "'tables' must be a dictionary" + + if len(config["tables"]) == 0: + return False, "'tables' must not be empty" + + # Validate each table + for table_name, table_config in config["tables"].items(): + is_valid, error = _basic_table_validation(table_config) + if not is_valid: + return False, f"Table '{table_name}': {error}" + + return True, None + + +def _basic_table_validation(table_config: dict[str, Any]) -> tuple[bool, str | None]: + """Basic table validation without jsonschema library.""" + if not isinstance(table_config, dict): + return False, "Table config must be a dictionary" + + if "displayName" not in table_config: + return False, "Missing 'displayName'" + + if "columns" not in table_config: + return False, "Missing 'columns'" + + if not isinstance(table_config["columns"], list): + return False, "'columns' must be an array" + + if len(table_config["columns"]) == 0: + return False, "'columns' must not be empty" + + # Validate each column + for i, column in enumerate(table_config["columns"]): + is_valid, error = validate_column_config(column) + if not is_valid: + return False, f"Column {i}: {error}" + + return True, None + + +# ============================================================================= +# SANITIZATION +# ============================================================================= + +def sanitize_config(config: dict[str, Any]) -> dict[str, Any]: + """ + Sanitize and normalize a config, fixing common issues. + + Args: + config: Raw configuration dictionary + + Returns: + Sanitized configuration + """ + sanitized = dict(config) + + # Ensure version format + if "version" not in sanitized or not sanitized["version"]: + sanitized["version"] = "1.0.0" + + # Normalize tables + if "tables" in sanitized: + for table_name, table_config in sanitized["tables"].items(): + sanitized["tables"][table_name] = _sanitize_table(table_config) + + return sanitized + + +def _sanitize_table(table_config: dict[str, Any]) -> dict[str, Any]: + """Sanitize a table configuration.""" + sanitized = dict(table_config) + + # Ensure columns exist + if "columns" not in sanitized: + sanitized["columns"] = [] + + # Sanitize each column + sanitized["columns"] = [ + _sanitize_column(col) for col in sanitized["columns"] + ] + + return sanitized + + +def _sanitize_column(column: dict[str, Any]) -> dict[str, Any]: + """Sanitize a column configuration.""" + sanitized = dict(column) + + # Default display name to column name + if "displayName" not in sanitized and "column" in sanitized: + col_name = sanitized["column"] + # Convert snake_case to Title Case + sanitized["displayName"] = col_name.replace("_", " ").title() + + # Default data type + if "dataType" not in sanitized: + sanitized["dataType"] = "string" + + # Ensure categories is a list + if "categories" not in sanitized: + sanitized["categories"] = [] + elif not isinstance(sanitized["categories"], list): + sanitized["categories"] = [sanitized["categories"]] + + # Normalize null transform + if "transform" in sanitized and sanitized["transform"] is None: + del sanitized["transform"] + + return sanitized diff --git a/app/services/db_helper.py b/app/services/db_helper.py new file mode 100644 index 0000000..7457ef2 --- /dev/null +++ b/app/services/db_helper.py @@ -0,0 +1,144 @@ +""" +Database helper service to consolidate retrieval and validation logic. +Reduces code duplication in API routes. +""" +import logging +from pathlib import Path +from uuid import uuid4 + +from fastapi import HTTPException + +from app.config import settings +from app.utils.workspace import KBaseClient, download_pangenome_db +from app.utils.sqlite import validate_table_exists, list_tables +from app.utils.async_utils import run_sync_in_thread +from app.utils.cache import sanitize_id + +logger = logging.getLogger(__name__) + +async def get_handle_db_path( + handle_ref: str, + token: str, + kb_env: str, + cache_dir: Path +) -> Path: + """ + Get (and download if needed) a SQLite database from a handle reference. + + Args: + handle_ref: Handle reference string + token: KBase auth token + kb_env: KBase environment + cache_dir: Cache directory path + + Returns: + Path to the local SQLite database file + """ + def _download_handle_db(): + # Cache path based on handle + safe_handle = handle_ref.replace(":", "_").replace("/", "_") + db_dir = cache_dir / "handles" + db_dir.mkdir(parents=True, exist_ok=True) + db_path = db_dir / f"{safe_handle}.db" + + # Atomic download if missing + if not db_path.exists(): + client = KBaseClient(token, kb_env, cache_dir) + temp_path = db_path.with_suffix(f".{uuid4().hex}.tmp") + try: + client.download_blob_file(handle_ref, temp_path) + temp_path.rename(db_path) + except Exception: + temp_path.unlink(missing_ok=True) + raise + return db_path + + try: + return await run_sync_in_thread(_download_handle_db) + except Exception as e: + logger.error(f"Error accessing handle database {handle_ref}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to access database: {str(e)}") + + +async def get_object_db_path( + berdl_table_id: str, + token: str, + kb_env: str, + cache_dir: Path +) -> Path: + """ + Get (and download if needed) a SQLite database from a BERDL object. + + Args: + berdl_table_id: KBase workspace reference OR 'local:{uuid}' for uploaded files + token: KBase auth token + kb_env: KBase environment + cache_dir: Cache directory path + + Returns: + Path to the local SQLite database file + """ + # Handle local uploads + if berdl_table_id.startswith("local:"): + # Expect format local:UUID + handle_parts = berdl_table_id.split(":", 1) + if len(handle_parts) != 2: + raise HTTPException(status_code=400, detail="Invalid local database handle format") + + filename = getattr(sanitize_id, 'original', sanitize_id)(handle_parts[1]) + # Note: sanitize_id ensures only alphanumeric+._- chars + + # Double check against the original to ensure no unexpected chars werestripped silently that might imply malicious intent? + # Actually sanitize_id already does a good job. But let's be strict. + if filename != handle_parts[1]: + # If sanitize changed it, it had bad chars + raise HTTPException(status_code=400, detail="Invalid characters in local database handle") + + db_path = cache_dir / "uploads" / f"{filename}.db" + + if not db_path.exists(): + raise HTTPException(status_code=404, detail=f"Local database not found: {berdl_table_id}") + + return db_path + + try: + # download_pangenome_db already handles caching logic + return await run_sync_in_thread( + download_pangenome_db, + berdl_table_id, + token, + cache_dir, + kb_env + ) + except TimeoutError: + logger.error(f"Database download timed out for {berdl_table_id}") + raise HTTPException( + status_code=504, + detail="Database download timed out. Please try again later." + ) + except Exception as e: + logger.error(f"Error accessing object database {berdl_table_id}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to access database: {str(e)}") + + +async def ensure_table_accessible(db_path: Path, table_name: str) -> bool: + """ + Validate that a table exists in the database. + Raises HTTPException 404 if not found. + + Args: + db_path: Path to SQLite database + table_name: Name of table to check + + Returns: + True if exists + """ + exists = await run_sync_in_thread(validate_table_exists, db_path, table_name) + + if not exists: + available = await run_sync_in_thread(list_tables, db_path) + raise HTTPException( + status_code=404, + detail=f"Table '{table_name}' not found. Available: {available}" + ) + return True diff --git a/app/utils/__init__.py b/app/utils/__init__.py index bbf18c3..850c42e 100644 --- a/app/utils/__init__.py +++ b/app/utils/__init__.py @@ -10,7 +10,6 @@ from app.utils.workspace import ( get_berdl_table_data, - list_pangenomes_from_object, download_pangenome_db, get_object_info, @@ -26,20 +25,15 @@ cleanup_old_caches, ) from app.utils.sqlite import ( - convert_to_sqlite, - query_sqlite, - get_table_data, list_tables, get_table_columns, get_table_row_count, validate_table_exists, - ensure_indices, ) __all__ = [ # Workspace utilities "get_berdl_table_data", - "list_pangenomes_from_object", "download_pangenome_db", "get_object_info", @@ -55,12 +49,8 @@ "cleanup_old_caches", # SQLite utilities - "convert_to_sqlite", - "query_sqlite", - "get_table_data", "list_tables", "get_table_columns", "get_table_row_count", "validate_table_exists", - "ensure_indices", ] diff --git a/app/utils/async_utils.py b/app/utils/async_utils.py new file mode 100644 index 0000000..0cd0d03 --- /dev/null +++ b/app/utils/async_utils.py @@ -0,0 +1,27 @@ +""" +Async utilities for standardized execution. +""" +import asyncio +from typing import TypeVar, Any, Callable + +T = TypeVar("T") + +async def run_sync_in_thread(func: Callable[..., T], *args: Any) -> T: + """ + Run a synchronous function in a separate thread. + + Handles compatibility between Python 3.9+ (asyncio.to_thread) + and older versions (loop.run_in_executor). + + Args: + func: The synchronous function to run + *args: Arguments to pass to the function + + Returns: + The result of the function call + """ + if hasattr(asyncio, 'to_thread'): + return await asyncio.to_thread(func, *args) + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, func, *args) diff --git a/app/utils/cache.py b/app/utils/cache.py index 04d6bd8..b36e097 100644 --- a/app/utils/cache.py +++ b/app/utils/cache.py @@ -19,13 +19,35 @@ def sanitize_id(id_string: str) -> str: """ Sanitize an ID string for use as a filesystem path. + Uses a strict allow-list approach to prevent path traversal. + Only allows alphanumeric characters, underscores, hyphens, and dots. + Args: - id_string: Raw ID (may contain / : and other special chars) + id_string: Raw ID Returns: Safe string for filesystem use """ - return id_string.replace("/", "_").replace(":", "_").replace(" ", "_") + import re + # First replace common separators with underscore to maintain readability + # (e.g. "123/4" -> "123_4") + safe = id_string.replace("/", "_").replace("\\", "_").replace(":", "_").replace(" ", "_") + + # Remove any characters that aren't allowed (strict allow-list) + # Allowed: a-z, A-Z, 0-9, -, _, . + safe = re.sub(r"[^a-zA-Z0-9_.-]", "", safe) + + # Prevent empty strings + if not safe: + # Fallback for completely invalid IDs + import hashlib + return hashlib.md5(id_string.encode()).hexdigest() + + # Prevent specific directory traversal names if they somehow remain + if safe in (".", ".."): + safe = safe + "_safe" + + return safe def get_upa_cache_path( @@ -151,6 +173,25 @@ def get_cache_info(cache_path: Path) -> dict[str, Any] | None: + +def save_cache_metadata(cache_subdir: Path, metadata: dict[str, Any]) -> None: + """ + Save metadata to cache directory. + + Args: + cache_subdir: Cache subdirectory + metadata: Metadata dictionary to save + """ + metadata_path = get_metadata_path(cache_subdir) + try: + ensure_cache_dir(metadata_path) + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + logger.debug(f"Saved metadata to {metadata_path}") + except Exception as e: + logger.warning(f"Failed to save metadata to {metadata_path}: {e}") + + def load_cache_metadata(cache_subdir: Path) -> dict[str, Any] | None: """ Load cache metadata. @@ -221,11 +262,11 @@ def clear_cache(cache_dir: Path, berdl_table_id: str | None = None) -> dict[str, def cleanup_old_caches(cache_dir: Path, max_age_days: int = 7) -> dict[str, Any]: """ - Remove cache directories older than max_age_days. + Remove cache directories older than max_age_days, AND uploads older than 1 hour. Args: cache_dir: Base cache directory - max_age_days: Maximum age in days + max_age_days: Maximum age in days for standard caches Returns: Summary of cleanup operation @@ -237,8 +278,9 @@ def cleanup_old_caches(cache_dir: Path, max_age_days: int = 7) -> dict[str, Any] max_age_seconds = max_age_days * 24 * 3600 removed = [] + # Clean standard ID-based subdirectories for subdir in cache_dir.iterdir(): - if not subdir.is_dir(): + if not subdir.is_dir() or subdir.name == "uploads": continue try: @@ -249,10 +291,26 @@ def cleanup_old_caches(cache_dir: Path, max_age_days: int = 7) -> dict[str, Any] logger.info(f"Removed old cache: {subdir.name}") except Exception as e: logger.warning(f"Failed to clean {subdir}: {e}") + + # Clean uploads directory (aggressive 1 hour expiry for temp availability) + uploads_dir = cache_dir / "uploads" + uploads_removed = 0 + if uploads_dir.exists(): + upload_max_age = 3600 # 1 hour + for f in uploads_dir.glob("*.db"): + try: + mtime = f.stat().st_mtime + if now - mtime > upload_max_age: + f.unlink() + uploads_removed += 1 + logger.debug(f"Removed expired upload: {f.name}") + except Exception as e: + logger.warning(f"Failed to clean upload {f}: {e}") return { "status": "success", "removed": len(removed), + "uploads_removed": uploads_removed, "items": removed } diff --git a/app/utils/request_utils.py b/app/utils/request_utils.py new file mode 100644 index 0000000..628558c --- /dev/null +++ b/app/utils/request_utils.py @@ -0,0 +1,165 @@ +""" +Request processing utilities for TableScanner routes. +""" + +from __future__ import annotations + +import time +import logging +from typing import Any +from pathlib import Path + +from fastapi import HTTPException +from app.services.data.query_service import get_query_service, FilterSpec +from app.utils.async_utils import run_sync_in_thread +from app.exceptions import TableNotFoundError, InvalidFilterError +from app.config_constants import MAX_LIMIT + +logger = logging.getLogger(__name__) + +class TableRequestProcessor: + """ + Handles common logic for table data requests: + - Parameter extraction + - Database access (via helper/callback) + - Query execution via QueryService + - Response formatting + """ + + @staticmethod + async def process_data_request( + db_path: Path, + table_name: str, + limit: int, + offset: int, + sort_column: str | None = None, + sort_order: str = "ASC", + search_value: str | None = None, + columns: list[str] | str | None = None, + filters: dict[str, Any] | list[Any] | None = None, + aggregations: list[Any] | None = None, + group_by: list[str] | None = None, + handle_ref_or_id: str | None = None + ) -> dict[str, Any]: + """ + Process a generic table data request. + """ + # Defensive check for limit + if limit > MAX_LIMIT: + limit = MAX_LIMIT + + start_time = time.time() + + # Prepare filters + service_filters = [] + if filters: + if isinstance(filters, dict): + # Legacy dict filters + for col, val in filters.items(): + service_filters.append(FilterSpec(column=col, operator="like", value=val)) + elif isinstance(filters, list): + # Advanced filters (list of FilterRequest or dicts) + for f in filters: + if hasattr(f, "column"): # Pydantic model + service_filters.append(FilterSpec( + column=f.column, + operator=f.operator, + value=f.value, + value2=f.value2 + )) + elif isinstance(f, dict): # Dict + service_filters.append(FilterSpec( + column=f.get("column"), + operator=f.get("operator"), + value=f.get("value"), + value2=f.get("value2") + )) + + # Prepare aggregations + service_aggregations = [] + if aggregations: + from app.services.data.query_service import AggregationSpec + for agg in aggregations: + if hasattr(agg, "column"): + service_aggregations.append(AggregationSpec( + column=agg.column, + function=agg.function, + alias=agg.alias + )) + elif isinstance(agg, dict): + service_aggregations.append(AggregationSpec( + column=agg.get("column"), + function=agg.get("function"), + alias=agg.get("alias") + )) + + # Determine sort direction + direction = "ASC" + if sort_order and sort_order.lower() == "desc": + direction = "DESC" + + # Handle columns (string vs list) compatibility + columns_list = None + if columns: + if isinstance(columns, str): + if columns.lower() != "all": + columns_list = [c.strip() for c in columns.split(",") if c.strip()] + elif isinstance(columns, list): + columns_list = columns + + def _execute(): + query_service = get_query_service() + return query_service.execute_query( + db_path=db_path, + table_name=table_name, + limit=limit, + offset=offset, + columns=columns_list, + sort_column=sort_column, + sort_order=direction, + search_value=search_value, + filters=service_filters, + aggregations=service_aggregations, + group_by=group_by, + use_cache=True + ) + + try: + result = await run_sync_in_thread(_execute) + except (TableNotFoundError, InvalidFilterError): + # Allow specific exceptions to bubble up to global handlers + raise + except ValueError as e: + # Handle validation errors (e.g. invalid numeric conversion) from QueryService + raise HTTPException(status_code=422, detail=str(e)) + except Exception as e: + logger.error(f"Query execution failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + response_time_ms = (time.time() - start_time) * 1000 + + # Format response + return { + "berdl_table_id": handle_ref_or_id, + "handle_ref": handle_ref_or_id, + "table_name": table_name, + "headers": result["headers"], + "data": result["data"], + "row_count": len(result["data"]), + "total_count": result["total_count"], + "filtered_count": result["total_count"], # Matches logic in routes.py + "response_time_ms": response_time_ms, + "db_query_ms": result["execution_time_ms"], + "conversion_ms": 0.0, # Deprecated metric + "sqlite_file": str(db_path), + + # System Overhaul / Advanced Metadata + "column_types": result.get("column_types"), + "column_schema": result.get("column_types"), # Alias + "query_metadata": result.get("query_metadata"), + "cached": result.get("cached", False), + "execution_time_ms": result.get("execution_time_ms"), + "limit": limit, + "offset": offset, + "database_path": str(db_path) + } diff --git a/app/utils/sqlite.py b/app/utils/sqlite.py index f304265..041e6b2 100644 --- a/app/utils/sqlite.py +++ b/app/utils/sqlite.py @@ -1,9 +1,11 @@ +""" +Low-level SQLite utilities. +""" from __future__ import annotations import sqlite3 import logging import time from pathlib import Path -from typing import Any # Configure module logger logger = logging.getLogger(__name__) @@ -12,37 +14,20 @@ def _validate_table_name(cursor, table_name: str) -> None: """ Validate that table_name corresponds to an existing table in the database. - Prevents SQL injection by ensuring table_name is a valid identifier. """ - # Parameterized query is safe from injection cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) if not cursor.fetchone(): - # Check for case-insensitive match or just fail raise ValueError(f"Invalid table name: {table_name}") -# ============================================================================= -# TABLE LISTING & METADATA -# ============================================================================= - def list_tables(db_path: Path) -> list[str]: """ List all user tables in a SQLite database. - - Args: - db_path: Path to the SQLite database file - - Returns: - List of table names (excludes sqlite_ system tables) - - Raises: - sqlite3.Error: If database access fails """ try: conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() - # Query for user tables (exclude sqlite_ system tables) cursor.execute(""" SELECT name FROM sqlite_master WHERE type='table' @@ -52,8 +37,6 @@ def list_tables(db_path: Path) -> list[str]: tables = [row[0] for row in cursor.fetchall()] conn.close() - - logger.info(f"Found {len(tables)} tables in database: {tables}") return tables except sqlite3.Error as e: @@ -64,23 +47,20 @@ def list_tables(db_path: Path) -> list[str]: def get_table_columns(db_path: Path, table_name: str) -> list[str]: """ Get column names for a specific table. - - Args: - db_path: Path to the SQLite database file - table_name: Name of the table to query - - Returns: - List of column names """ try: conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() - # Validate table name to prevent injection - _validate_table_name(cursor, table_name) - - # Use PRAGMA to get table info - cursor.execute(f"PRAGMA table_info({table_name})") + # Validate table existence and get validated table name + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + result = cursor.fetchone() + if not result: + raise ValueError(f"Invalid table name: {table_name}") + + # Use validated table name from sqlite_master to prevent SQL injection + validated_table_name = result[0] + cursor.execute(f"PRAGMA table_info(\"{validated_table_name}\")") columns = [row[1] for row in cursor.fetchall()] conn.close() @@ -94,21 +74,20 @@ def get_table_columns(db_path: Path, table_name: str) -> list[str]: def get_table_row_count(db_path: Path, table_name: str) -> int: """ Get the total row count for a table. - - Args: - db_path: Path to the SQLite database file - table_name: Name of the table - - Returns: - Number of rows in the table """ try: conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() - _validate_table_name(cursor, table_name) - - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + # Validate table existence and get validated table name + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + result = cursor.fetchone() + if not result: + raise ValueError(f"Invalid table name: {table_name}") + + # Use validated table name from sqlite_master to prevent SQL injection + validated_table_name = result[0] + cursor.execute(f"SELECT COUNT(*) FROM \"{validated_table_name}\"") count = cursor.fetchone()[0] conn.close() @@ -119,306 +98,132 @@ def get_table_row_count(db_path: Path, table_name: str) -> int: raise + def validate_table_exists(db_path: Path, table_name: str) -> bool: """ Check if a table exists in the database. - - Args: - db_path: Path to the SQLite database file - table_name: Name of the table to check - - Returns: - True if table exists, False otherwise - """ - tables = list_tables(db_path) - return table_name in tables - - -# ============================================================================= -# INDEX OPTIMIZATION -# ============================================================================= - -def ensure_indices(db_path: Path, table_name: str) -> None: - """ - Ensure indices exist for all columns in the table to optimize filtering. - - This is an optimization step - failures are logged but not raised. - - Args: - db_path: Path to the SQLite database file - table_name: Name of the table """ try: - conn = sqlite3.connect(str(db_path)) - cursor = conn.cursor() - - _validate_table_name(cursor, table_name) - - # Get columns - cursor.execute(f"PRAGMA table_info({table_name})") - columns = [row[1] for row in cursor.fetchall()] - - # Create index for each column - for col in columns: - index_name = f"idx_{table_name}_{col}" - # Sanitize column name for SQL safety - safe_col = col.replace('"', '""') - cursor.execute( - f'CREATE INDEX IF NOT EXISTS "{index_name}" ON "{table_name}" ("{safe_col}")' - ) - - conn.commit() - conn.close() - logger.info(f"Ensured indices for table {table_name}") - - except sqlite3.Error as e: - # Don't raise, just log warning as this is an optimization step - logger.warning(f"Error creating indices for {table_name}: {e}") + tables = list_tables(db_path) + return table_name in tables + except Exception: + return False -# ============================================================================= -# DATA RETRIEVAL - SIMPLE QUERY -# ============================================================================= - -def query_sqlite(sqlite_file: Path, query_id: str) -> dict[str, Any]: - """ - Query SQLite database by ID. Legacy compatibility function. - - Args: - sqlite_file: Path to SQLite database - query_id: Query identifier - - Returns: - Query results as dictionary - """ - return { - "stub": "SQLite query results would go here", - "query_id": query_id, - "sqlite_file": str(sqlite_file) - } - - -# ============================================================================= -# DATA RETRIEVAL - FULL FEATURED -# ============================================================================= - -def get_table_data( - sqlite_file: Path, - table_name: str, - limit: int = 100, - offset: int = 0, - sort_column: str | None = None, - sort_order: str = "ASC", - search_value: str | None = None, - query_filters: dict[str, str] | None = None, - columns: str | None = "all", - order_by: list[dict[str, str]] | None = None -) -> tuple[list[str], list[Any], int, int, float, float]: +def get_table_statistics(db_path: Path, table_name: str) -> dict: """ - Get paginated and filtered data from a table. - - Supports two filtering APIs for flexibility: - 1. `filters`: List of FilterSpec-style dicts with column, op, value - 2. `query_filters`: Simple dict of column -> search_value (LIKE matching) - - Args: - sqlite_file: Path to SQLite database - table_name: Name of the table to query - limit: Maximum number of rows to return - offset: Number of rows to skip - sort_column: Single column to sort by (alternative to order_by) - sort_order: Sort direction 'asc' or 'desc' (with sort_column) - search_value: Global search term for all columns - query_filters: Dict of column-specific search terms - columns: Comma-separated list of columns to select - order_by: List of order specifications [{column, direction}] - - Returns: - Tuple of (headers, data, total_count, filtered_count, db_query_ms, conversion_ms) - - Raises: - sqlite3.Error: If database query fails - ValueError: If invalid operator is specified + Calculate statistics for all columns in a table. """ - start_time = time.time() - - # Initialize legacy filters to None since removed from signature - filters = None - try: - conn = sqlite3.connect(str(sqlite_file)) + conn = sqlite3.connect(str(db_path)) + # Use row factory to access columns by name if needed, though we use indices here conn.row_factory = sqlite3.Row cursor = conn.cursor() - # Validate table name - _validate_table_name(cursor, table_name) + # Validate table + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + if not cursor.fetchone(): + raise ValueError(f"Invalid table name: {table_name}") - # Get all column names first for validation - all_headers = get_table_columns(sqlite_file, table_name) + validated_table = table_name - if not all_headers: - logger.warning(f"Table {table_name} has no columns or doesn't exist") - return [], [], 0, 0, 0.0, 0.0 + # Get total row count + cursor.execute(f"SELECT COUNT(*) FROM \"{validated_table}\"") + row_count = cursor.fetchone()[0] - # Parse requested columns - selected_headers = all_headers - select_clause = "*" + # Get columns and types + cursor.execute(f"PRAGMA table_info(\"{validated_table}\")") + columns_info = cursor.fetchall() - if columns and columns.lower() != "all": - requested = [c.strip() for c in columns.split(',') if c.strip()] - valid = [c for c in requested if c in all_headers] - if valid: - selected_headers = valid - safe_cols = [f'"{c}"' for c in selected_headers] - select_clause = ", ".join(safe_cols) - - headers = selected_headers - - # 1. Get total count (before filtering) - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") - total_count = cursor.fetchone()[0] - - # 2. Build WHERE clause - conditions = [] - params = [] - - # 2a. Global Search (OR logic across all columns) - if search_value: - search_conditions = [] - term = f"%{search_value}%" - for col in headers: - search_conditions.append(f'"{col}" LIKE ?') - params.append(term) - - if search_conditions: - conditions.append(f"({' OR '.join(search_conditions)})") - - # 2b. Column Filters via query_filters dict (AND logic) - if query_filters: - for col, val in query_filters.items(): - if col in headers and val: - conditions.append(f'"{col}" LIKE ?') - params.append(f"%{val}%") - - # 2c. Structured filters via filters list (AND logic) - if filters: - allowed_ops = ["=", "!=", "<", ">", "<=", ">=", "LIKE", "IN"] - for filter_spec in filters: - column = filter_spec.get("column") - op = filter_spec.get("op", "LIKE") - value = filter_spec.get("value") - - if not column or column not in headers: - continue - - if op not in allowed_ops: - raise ValueError(f"Invalid operator: {op}") - - conditions.append(f'"{column}" {op} ?') - params.append(value) - - where_clause = "" - if conditions: - where_clause = " WHERE " + " AND ".join(conditions) - - # 3. Get filtered count - if where_clause: - cursor.execute(f"SELECT COUNT(*) FROM {table_name} {where_clause}", params) - filtered_count = cursor.fetchone()[0] - else: - filtered_count = total_count - - # 4. Build final query - query = f"SELECT {select_clause} FROM {table_name}{where_clause}" - - # Add ORDER BY clause - order_clauses = [] - - # Handle order_by list - if order_by: - for order_spec in order_by: - col = order_spec.get("column") - direction = order_spec.get("direction", "ASC").upper() - - if col and col in headers: - if direction not in ["ASC", "DESC"]: - direction = "ASC" - order_clauses.append(f'"{col}" {direction}') - - # Handle single sort_column (alternative API) - if sort_column and sort_column in headers: - direction = "DESC" if sort_order and sort_order.lower() == "desc" else "ASC" - order_clauses.append(f'"{sort_column}" {direction}') - - if order_clauses: - query += " ORDER BY " + ", ".join(order_clauses) - elif headers: - # Default sort for consistent pagination - query += f' ORDER BY "{headers[0]}" ASC' - - # Add LIMIT clause - if limit is not None: - query += f" LIMIT {int(limit)}" - - # Add OFFSET clause - if offset is not None: - query += f" OFFSET {int(offset)}" - - # Execute query with timing - query_start = time.time() - cursor.execute(query, params) - rows = cursor.fetchall() - db_query_ms = (time.time() - query_start) * 1000 + stats_columns = [] + + for col in columns_info: + col_name = col['name'] + col_type = col['type'] + + # Base stats query + # We use SUM(CASE WHEN ... IS NULL) instead of COUNT(col) logic sometimes to be explicit + # but COUNT(col) counts non-nulls. So Nulls = Total - COUNT(col). + cursor.execute(f""" + SELECT + COUNT("{col_name}") as non_null_count, + COUNT(DISTINCT "{col_name}") as distinct_count + FROM "{validated_table}" + """) + basic_stats = cursor.fetchone() + non_null_count = basic_stats['non_null_count'] + null_count = row_count - non_null_count + distinct_count = basic_stats['distinct_count'] + + col_stats = { + "column": col_name, + "type": col_type, + "null_count": null_count, + "distinct_count": distinct_count, + "sample_values": [] + } + + # Extended stats for numeric types + # Heuristic: simplistic check for INT, REAL, FLO, DOUB, NUM + is_numeric = any(t in col_type.upper() for t in ['INT', 'REAL', 'FLO', 'DOUB', 'NUM', 'DEC']) + + if is_numeric and non_null_count > 0: + try: + cursor.execute(f""" + SELECT + MIN("{col_name}"), + MAX("{col_name}"), + AVG("{col_name}") + FROM "{validated_table}" + WHERE "{col_name}" IS NOT NULL + """) + num_stats = cursor.fetchone() + if num_stats[0] is not None: + col_stats["min"] = num_stats[0] + col_stats["max"] = num_stats[1] + col_stats["mean"] = num_stats[2] + except Exception: + # Ignore errors in numeric aggregate (e.g. if column declared int but has strings) + pass + elif non_null_count > 0: + # For non-numeric, just get Min/Max + try: + cursor.execute(f""" + SELECT MIN("{col_name}"), MAX("{col_name}") + FROM "{validated_table}" + WHERE "{col_name}" IS NOT NULL + """) + str_stats = cursor.fetchone() + if str_stats[0] is not None: + col_stats["min"] = str_stats[0] + col_stats["max"] = str_stats[1] + except Exception: + pass + + # Get sample values (first 5 non-null distinct preferred, or just first 5) + try: + cursor.execute(f""" + SELECT DISTINCT "{col_name}" + FROM "{validated_table}" + WHERE "{col_name}" IS NOT NULL + LIMIT 5 + """) + samples = [row[0] for row in cursor.fetchall()] + col_stats["sample_values"] = samples + except Exception: + col_stats["sample_values"] = [] + + stats_columns.append(col_stats) conn.close() - - # Convert rows to string arrays with timing - conversion_start = time.time() - data = [] - for row in rows: - string_row = [ - str(value) if value is not None else "" - for value in row - ] - data.append(string_row) - conversion_ms = (time.time() - conversion_start) * 1000 - - return headers, data, total_count, filtered_count, db_query_ms, conversion_ms + + return { + "table": table_name, + "row_count": row_count, + "columns": stats_columns, + "last_updated": int(time.time() * 1000) + } except sqlite3.Error as e: - logger.error(f"Error extracting data from {table_name}: {e}") + logger.error(f"Error calculating stats for {table_name}: {e}") raise - - -# ============================================================================= -# CONVERSION (PLACEHOLDER) -# ============================================================================= - -def convert_to_sqlite(binary_file: Path, sqlite_file: Path) -> None: - """ - Convert binary file to SQLite database. - - This function handles conversion of various binary formats - to SQLite for efficient querying. - - Args: - binary_file: Path to binary file - sqlite_file: Path to output SQLite file - - Raises: - NotImplementedError: Conversion logic depends on binary format - """ - # Check if file is already a SQLite database - if binary_file.suffix == '.db': - # Just copy/link the file - import shutil - shutil.copy2(binary_file, sqlite_file) - logger.info(f"Copied SQLite database to {sqlite_file}") - return - - # TODO: Implement conversion logic based on binary file format - # The BERDLTables object stores SQLite directly, so this may not be needed - raise NotImplementedError( - f"SQLite conversion not implemented for format: {binary_file.suffix}" - ) - diff --git a/app/utils/workspace.py b/app/utils/workspace.py index b5cf86b..5cc4783 100644 --- a/app/utils/workspace.py +++ b/app/utils/workspace.py @@ -12,6 +12,19 @@ if str(LIB_PATH) not in sys.path: sys.path.insert(0, str(LIB_PATH)) +# Try conditional imports at top level +try: + from kbutillib.kb_ws_utils import KBWSUtils + from kbutillib.notebook_utils import NotebookUtils + HAS_KBUTILLIB = True +except ImportError: + HAS_KBUTILLIB = False + # Define dummy classes if needed for type hinting or logic check + KBWSUtils = object + NotebookUtils = object + +from app.config import settings + # Configure module logger logger = logging.getLogger(__name__) @@ -54,8 +67,8 @@ def __init__( def _init_client(self): """Initialize the appropriate client.""" try: - from kbutillib.kb_ws_utils import KBWSUtils - from kbutillib.notebook_utils import NotebookUtils + if not HAS_KBUTILLIB: + raise ImportError("KBUtilLib not found") # Create a proper combined class cache_dir = self.cache_dir @@ -70,6 +83,9 @@ def __init__(self): kb_version=kb_env, token=token ) + # Ensure token is saved in the token hash + if hasattr(self, 'save_token') and token: + self.save_token(token, namespace="kbase", save_file=False) self._client = NotebookUtil() self._use_kbutillib = True @@ -116,11 +132,14 @@ def download_blob_file(self, handle_ref: str, target_path: Path) -> Path: if self._use_kbutillib and self._client: try: + # Ensure KBUtilLib has the token set + if hasattr(self._client, 'save_token'): + self._client.save_token(self.token, namespace="kbase") result = self._client.download_blob_file(handle_ref, str(target_path)) if result: return Path(result) except Exception as e: - logger.warning(f"KBUtilLib download_blob_file failed: {e}. Using fallback.") + logger.warning(f"KBUtilLib download_blob_file failed: {e}. Using fallback.", exc_info=True) return Path(self._download_blob_fallback(handle_ref, str(target_path))) @@ -130,6 +149,15 @@ def download_blob_file(self, handle_ref: str, target_path: Path) -> Path: def _get_endpoints(self) -> dict[str, str]: """Get endpoints for current environment.""" + # If the requested env matches the configured env, use the configured URLs + if self.kb_env == settings.KB_ENV: + return { + "workspace": settings.WORKSPACE_URL, + "shock": settings.BLOBSTORE_URL, + "handle": f"{settings.KBASE_ENDPOINT}/handle_service", + } + + # Fallback for other environments endpoints = { "appdev": { "workspace": "https://appdev.kbase.us/services/ws", @@ -168,23 +196,137 @@ def _get_object_fallback(self, ref: str, ws: int | None = None) -> dict[str, Any } endpoints = self._get_endpoints() - response = requests.post( - endpoints["workspace"], - json=payload, - headers=headers, - timeout=60 - ) - response.raise_for_status() - result = response.json() - - if "error" in result: - raise ValueError(result["error"].get("message", "Unknown error")) + try: + response = requests.post( + endpoints["workspace"], + json=payload, + headers=headers, + timeout=30 # Reduced from 60 to fail faster + ) + response.raise_for_status() + result = response.json() + + if "error" in result: + error_msg = result["error"].get("message", "Unknown error") + error_code = result["error"].get("code", "Unknown") + logger.error(f"Workspace API error for {ref}: [{error_code}] {error_msg}") + raise ValueError(f"Workspace API error: [{error_code}] {error_msg}") + except requests.exceptions.HTTPError as e: + # Capture response body for better error messages + error_detail = f"HTTP {e.response.status_code}" + try: + error_body = e.response.json() + if "error" in error_body: + error_detail = error_body["error"].get("message", str(error_body)) + else: + error_detail = str(error_body) + except: + error_detail = e.response.text[:500] if e.response.text else str(e) + logger.error(f"Workspace API HTTP error for {ref}: {error_detail}") + raise ValueError(f"Workspace service error: {error_detail}") + except requests.exceptions.RequestException as e: + logger.error(f"Workspace API request failed for {ref}: {e}") + raise ValueError(f"Failed to connect to workspace service: {str(e)}") data_list = result.get("result", [{}])[0].get("data", []) if not data_list: raise ValueError(f"No data for: {ref}") return data_list[0] + + def get_object_with_type(self, ref: str, ws: int | None = None) -> tuple[dict[str, Any], str]: + """ + Get workspace object data along with its type. + + Args: + ref: Object reference or name + ws: Workspace ID (optional if ref is full reference) + + Returns: + Tuple of (object_data, object_type) + object_type is the full KBase type string (e.g., "KBaseFBA.GenomeDataLakeTables-2.0") + """ + # Build reference + if ws and "/" not in str(ref): + ref = f"{ws}/{ref}" + + # First get the object type using get_object_info3 + object_type = self._get_object_type(ref) + + # Then get the data using standard method + obj_data = self.get_object(ref) + + return obj_data, object_type + + def _get_object_type(self, ref: str) -> str: + """ + Get the KBase object type using Workspace.get_object_info3. + + Args: + ref: Object reference + + Returns: + Object type string (e.g., "KBaseFBA.GenomeDataLakeTables-2.0") + """ + headers = { + "Authorization": self.token, + "Content-Type": "application/json" + } + + payload = { + "method": "Workspace.get_object_info3", + "params": [{"objects": [{"ref": ref}]}], + "version": "1.1", + "id": "tablescanner-type" + } + + endpoints = self._get_endpoints() + try: + response = requests.post( + endpoints["workspace"], + json=payload, + headers=headers, + timeout=30 + ) + response.raise_for_status() + result = response.json() + + if "error" in result: + error_msg = result["error"].get("message", "Unknown error") + logger.warning(f"Error getting object type for {ref}: {error_msg}") + return "Unknown" + except requests.exceptions.HTTPError as e: + error_detail = f"HTTP {e.response.status_code}" + try: + error_body = e.response.json() + if "error" in error_body: + error_detail = error_body["error"].get("message", str(error_body)) + except: + error_detail = e.response.text[:200] if e.response.text else str(e) + logger.warning(f"Error getting object type for {ref}: {error_detail}") + return "Unknown" + except Exception as e: + logger.warning(f"Error getting object type for {ref}: {e}") + return "Unknown" + + # get_object_info3 returns: {"result": [{"infos": [[objid, name, type, ...]]}]} + infos = result.get("result", [{}])[0].get("infos", []) + if infos and infos[0] and len(infos[0]) > 2: + return infos[0][2] + + return "Unknown" + + def get_object_type_only(self, ref: str) -> str: + """ + Public method to get object type without fetching full data. + + Args: + ref: Object reference + + Returns: + Object type string + """ + return self._get_object_type(ref) def _download_blob_fallback(self, handle_ref: str, target_path: str) -> str: """Download from blobstore via direct API.""" @@ -204,7 +346,7 @@ def _download_blob_fallback(self, handle_ref: str, target_path: str) -> str: resp = requests.post( endpoints["handle"], json=handle_payload, - headers={"Authorization": self.token, "Content-Type": "application/json"}, + headers={"Authorization": f"OAuth {self.token}", "Content-Type": "application/json"}, timeout=30 ) resp.raise_for_status() @@ -275,45 +417,27 @@ def get_berdl_table_data( return obj -def list_pangenomes_from_object( +def get_object_type( berdl_table_id: str, auth_token: str, kb_env: str = "appdev" -) -> list[dict[str, Any]]: +) -> str: """ - List all pangenomes from a BERDLTables object. + Get the KBase object type for a workspace object. Args: - berdl_table_id: KBase workspace reference + berdl_table_id: KBase workspace reference (e.g., "76990/7/2") auth_token: KBase authentication token kb_env: KBase environment Returns: - List of pangenome info dictionaries with: - - pangenome_id - - pangenome_taxonomy - - handle_ref - - user_genomes - - berdl_genomes + Object type string (e.g., "KBaseGeneDataLakes.BERDLTables-1.0") """ - obj_data = get_berdl_table_data(berdl_table_id, auth_token, kb_env) - - pangenome_data = obj_data.get("pangenome_data", []) - - pangenomes = [] - for pg in pangenome_data: - pangenomes.append({ - - "pangenome_taxonomy": pg.get("pangenome_taxonomy", ""), - "user_genomes": pg.get("user_genomes", []), - "berdl_genomes": pg.get("berdl_genomes", []), - "genome_count": len(pg.get("user_genomes", [])) + len(pg.get("berdl_genomes", [])), - "handle_ref": pg.get("sqllite_tables_handle_ref", ""), - }) - - return pangenomes - - + if berdl_table_id.startswith("local:"): + return "LocalDatabase" + + client = KBaseClient(auth_token, kb_env) + return client.get_object_type_only(berdl_table_id) @@ -353,12 +477,16 @@ def download_pangenome_db( return db_path # Fetch object metadata to get handle reference - pangenomes = list_pangenomes_from_object(berdl_table_id, auth_token, kb_env) - if not pangenomes: - raise ValueError(f"No pangenomes found in {berdl_table_id}") + obj_data = get_berdl_table_data(berdl_table_id, auth_token, kb_env) + pangenome_data = obj_data.get("pangenome_data", []) + if not pangenome_data: + raise ValueError(f"No pangenomes found in {berdl_table_id}") + # Take the first (and only expected) pangenome's handle - handle_ref = pangenomes[0]["handle_ref"] + handle_ref = pangenome_data[0].get("sqllite_tables_handle_ref") + if not handle_ref: + raise ValueError(f"No handle reference found in {berdl_table_id}") # Create cache directory db_dir.mkdir(parents=True, exist_ok=True) diff --git a/docker-compose.yml b/docker-compose.yml index a7db56a..b6c0a75 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,4 +20,4 @@ services: interval: 30s timeout: 10s retries: 3 - start_period: 40s + start_period: 40s \ No newline at end of file diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 0000000..8d64c75 --- /dev/null +++ b/docs/API.md @@ -0,0 +1,123 @@ +# TableScanner API + +The **TableScanner** service provides read-only access to SQLite databases stored in KBase (via Workspace objects). It supports listing tables, inspecting schemas, and querying data with filtering, sorting, and pagination. + +## Base URL +- **Development**: `http://localhost:8000` +- **Production**: `https://kbase.us/services/berdl_table_scanner` (or similar) + +## Authentication + +**Each user must provide their own KBase authentication token.** The service does not use a shared/service-level token for production access. + +- **Header (recommended)**: `Authorization: ` or `Authorization: Bearer ` +- **Cookie**: `kbase_session=` (useful for browser-based clients like DataTables Viewer) + +> **Note for Developers**: The `KB_SERVICE_AUTH_TOKEN` environment variable is available as a legacy fallback for local testing only. It should NOT be relied upon in production. + +--- +## Performance +- **Gzip Support**: Responses >1KB are automatically compressed if the `Accept-Encoding: gzip` header is present. +- **Fast JSON**: All responses use optimized JSON serialization. +--- + +## 1. Service Status + +### `GET /` +Basic service check. +- **Response**: `{"service": "TableScanner", "version": "1.0.0", "status": "running"}` + +### `GET /health` +Detailed health check including connection pool stats. + +--- + +## 2. Object Access +Access databases via KBase Workspace Object Reference (UPA, e.g., `76990/7/2`). + +### Example curl (with auth) + +```bash +# List tables for an object (replace WS_REF with a real UPA like 76990/7/2) +curl -X GET \ + "http://localhost:8000/object/WS_REF/tables?kb_env=appdev" \ + -H "accept: application/json" \ + -H "Authorization: Bearer $KB_TOKEN" +``` + +### `GET /object/{ws_ref}/tables` +List tables for a BERDLTables object. +- **Response**: Table list with schema overviews. + +### `GET /object/{ws_ref}/tables/{table_name}/data` +Query table data. +- **Query Params**: + - `limit` (default: 100) + - `offset` (default: 0) + - `sort_column`, `sort_order` (`ASC`/`DESC`) + - `search` (Global text search) +- **Response**: Headers, data rows, total count. + +```bash +# Query table data (replace TABLE_NAME with a real table like Genes) +curl -X GET \ + "http://localhost:8000/object/WS_REF/tables/TABLE_NAME/data?limit=10&kb_env=appdev" \ + -H "accept: application/json" \ + -H "Authorization: Bearer $KB_TOKEN" +``` + +### `GET /object/{ws_ref}/tables/{table_name}/stats` +Get detailed statistics for all columns in a table. +- **Response**: Column statistics including null counts, distinct counts, min/max/mean, and samples. + + +--- + +## 3. Data Access + +### `POST /table-data` +Complex query endpoint supporting advanced filtering. +- **Body**: + ```json + { + "berdl_table_id": "...", + "table_name": "Genes", + "limit": 100, + "filters": [ + {"column": "contigs", "operator": "gt", "value": 50}, + {"column": "gene_name", "operator": "like", "value": "kinase"} + ] + } + ``` +- **Supported Operators**: `eq`, `ne`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, `not_in`, `between`, `is_null`, `is_not_null`. + +--- + +## 4. Local Database Upload + +### `POST /upload` +Upload a temporary SQLite database file to the server. Useful for testing or serving local files. + +- **Request**: Multipart form data with key `file` +- **Response**: + ```json + { + "handle": "local:uuid-string", + "filename": "my_db.db", + "size_bytes": 10240, + "message": "Database uploaded successfully" + } + ``` + +### Usage Workflow +1. **Upload File**: + ```bash + curl -X POST "http://localhost:8000/upload" \ + -H "accept: application/json" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/test.db" + ``` +2. **Use Handle**: The returned `handle` (e.g., `local:abc-123`) can be used as the `berdl_table_id` or `ws_ref` in any other endpoint. + - List tables: `GET /object/local:abc-123/tables` + - Query data: `POST /table-data` with `"berdl_table_id": "local:abc-123"` + diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 8d3cc5a..3919872 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -1,85 +1,77 @@ # TableScanner Architecture -TableScanner is a high-performance middleware service designed to provide fast, filtered, and paginated access to large tabular data stored in KBase. It solves the performance bottleneck of loading massive objects into memory by leveraging local SQLite caching and efficient indexing. +## Overview +TableScanner is a high-performance, read-only microservice designed to provide efficient access to tabular data stored in KBase (Workspace Objects or Blobstore Handles). It serves as a backend for the DataTables Viewer and other applications requiring filtered, paginated, and aggregated views of large datasets. ---- - -## High-Level Architecture +## System Architecture ```mermaid graph TD - User([User / API Client]) - TS[TableScanner Service] - KBaseWS[KBase Workspace] - KBaseBlob[KBase Blobstore] - LocalCache[(Local SQLite Cache)] - - User -->|API Requests| TS - TS -->|1. Resolve Metadata| KBaseWS - TS -->|2. Download Blob| KBaseBlob - TS -->|3. Store & Index| LocalCache - TS -->|4. SQL Query| LocalCache - LocalCache -->|5. Result| TS - TS -->|6. JSON Response| User + Client[Client Application] --> API[FastAPI Layer] + API --> Service[Query Service] + API --> DBHelper[DB Helper] + + subgraph Core Services + Service --> Pool[Connection Pool] + Pool --> SQLite[SQLite Cache] + Service --> FTS[FTS5 Search] + end + + subgraph Infrastructure + DBHelper --> WS[Workspace Client] + WS --> KBase[KBase Services] + WS --> Blob[Blobstore] + end ``` ---- - -## Caching Strategy: One DB per UPA - -TableScanner employs a strict **one-database-per-object** caching policy. Each KBase object reference (UPA, e.g., `76990/7/2`) is mapped to a unique local directory. - -- **Path Structure**: `{CACHE_DIR}/{sanitized_UPA}/tables.db` -- **Sanitization**: Special characters like `/`, `:`, and spaces are replaced with underscores to ensure filesystem compatibility. -- **Granularity**: Caching is performed at the object level. If multiple tables exist within a single SQLite blob, they are all cached together, improving subsequent access to related data. - ---- - -## Race Condition and Atomic Handling - -To ensure reliability in high-concurrency environments (multiple users requesting the same data simultaneously), TableScanner implements **Atomic File Operations**: - -### 1. Atomic Downloads -When a database needs to be downloaded, TableScanner does **not** download directly to the final path. -1. A unique temporary filename is generated using a UUID: `tables.db.{uuid}.tmp`. -2. The file is downloaded from the KBase Blobstore into this temporary file. -3. Once the download is successful and verified, a **filesystem-level atomic rename** (`os.rename`) is performed to move it to `tables.db`. -4. This ensures that if a process crashes or a network error occurs, the cache directory will not contain a partially-downloaded, corrupt database. - -### 2. Concurrent Request Handling -If two requests for the same UPA arrive at the same time: -- Both will check for the existence of `tables.db`. -- If it's missing, both may start a download to their own unique `temp` files. -- The first one to finish will atomically rename its temp file to `tables.db`. -- The second one to finish will also rename its file, overwriting the first. Since the content is identical (same UPA), the final state remains consistent and the database is never in a corrupt state during the swap. - ---- - -## Performance Optimization: Automatic Indexing - -TableScanner doesn't just store the data; it optimizes it. Upon the **first access** to any table: -- The service scans the table schema. -- It automatically generates a `idx_{table}_{column}` index for **every single column** in the table. -- This "Indexing on Demand" strategy ensures that even complex global searches or specific column filters remain sub-millisecond, regardless of the table size. - ---- - -## Data Lifecycle in Detail - -1. **Request**: User provides a KBase UPA and query parameters. -2. **Cache Verification**: Service checks if `{sanitized_UPA}/tables.db` exists and is valid. -3. **Metadata Resolution**: If not cached, `KBUtilLib` fetches the object from KBase to extract the Blobstore handle. -4. **Secure Download**: The blob is streamed to a temporary UUID file and then atomically renamed. -5. **Schema Check**: TableScanner verifies the requested table exists in the SQLite file. -6. **Index Check**: If it's the first time this table is being queried, indices are created for all columns. -7. **SQL Execution**: A standard SQL query with `LIMIT`, `OFFSET`, and `LIKE` filters is executed. -8. **Streaming Serialization**: Results are converted into a compact JSON list-of-lists and returned to the user. - ---- - -## Tech Stack and Key Components - -- **FastAPI**: Provides the high-performance async web layer. -- **SQLite**: The storage engine for tabular data, chosen for its zero-configuration and high performance with indices. -- **KBUtilLib**: Handles complex KBase Workspace and Blobstore interactions. -- **UUID-based Temp Storage**: Prevents race conditions during file I/O. +## Core Components + +### 1. API Layer (`app/routes.py`) +The entry point for all requests. It handles: +- **Object Access**: `/object/{ws_ref}/tables` +- **Data Queries**: `/table-data` (Advanced filtering) + +### 2. Query Service (`app/services/data/query_service.py`) +The heart of the application. It orchestrates query execution: +- **Type-Aware Filtering**: Automatically detects column types (text vs numeric) and applies correct SQL operators. +- **Advanced Aggregations**: Supports `GROUP BY`, `SUM`, `AVG`, `COUNT`, etc. +- **Full-Text Search**: Leverages SQLite FTS5 for fast global searching. +- **Result Caching**: Caches query results to minimize database I/O for repeated requests. + +### 3. Connection Pool (`app/services/data/connection_pool.py`) +Manages SQLite database connections efficiently: +- **Pooling**: Reuses connections to avoid open/close overhead. +- **Lifecycle**: Automatically closes idle connections after a timeout. +- **Optimization**: Configures PRAGMAs (WAL mode, memory mapping) for performance. + +### 4. Infrastructure Layer +- **DB Helper (`app/services/db_helper.py`)**: Resolves "Handle Refs" or "Workspace Refs" into local file paths, handling download and caching transparently. +- **Workspace Client (`app/utils/workspace.py`)**: Interacts with KBase services, falling back to direct HTTP queries if SDK clients are unavailable. + +## Data Flow + +1. **Request**: Client requests data (e.g., `GET /object/123/1/1/tables/Genes/data?limit=100`). +2. **Resolution**: `DB Helper` checks if the database for `123/1/1` is in the local cache. + - *Miss*: Downloads file from KBase Blobstore/Workspace. + - *Hit*: Returns path to local `.db` file. +3. **Connection**: `QueryService` requests a connection from `ConnectionPool`. +4. **Query Plan**: + - Checks schema for column types. + - Builds SQL query with parameterized filters. + - Ensures necessary indexes exist. +5. **Execution**: SQLite executes the query (using FTS or B-Tree indexes). +6. **Response**: Data is returned to the client as JSON. + +## Design Decisions + +- **Read-Only**: The service never modifies the source SQLite files. This simplifies concurrency control (WAL mode). +- **Synchronous I/O in Async App**: We use `run_sync_in_thread` to offload blocking SQLite operations to a thread pool, keeping the FastAPI event loop responsive. +- **Local Caching**: We aggressively cache database files locally to avoid the high latency of downloading multi-GB files from KBase for every request. +- **Metadata Caching**: Object types are cached locally to minimize redundant KBase API calls. +- **Concurrency**: Table listing uses parallel metadata fetching (`asyncio.gather`) to resolve "N+1" query issues. +- **Compression & High-Performance serialization**: Production-ready configuration uses Gzip and ORJSON for maximum throughput. + +## Security +- **Authentication**: All data access endpoints require a valid KBase Auth Token (`Authorization` header). +- **Authorization**: The service relies on KBase Services to validate if the token has access to the requested Workspace Object or Handle. +- **Input Validation**: Strict validation of table and column names prevents SQL injection. Parameterized queries are used for all values. diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md new file mode 100644 index 0000000..ef06c56 --- /dev/null +++ b/docs/CONTRIBUTING.md @@ -0,0 +1,81 @@ +# Contributing to TableScanner + +## Development Setup + +### Prerequisites +- Python 3.10+ +- KBase authentication token +- Access to KBase services (Workspace, Blobstore) + +### Quick Start +1. **Clone & Venv**: + ```bash + git clone + cd tablescanner + python3 -m venv venv + source venv/bin/activate + pip install -r requirements.txt + ``` + +2. **Configuration**: + Copy `.env.example` to `.env` and set `KB_SERVICE_AUTH_TOKEN`. + +3. **Run Locally**: + You can use the provided helper script: + ```bash + ./scripts/dev.sh + ``` + This script handles: + - Activating the virtual environment (`.venv`) + - Loading environment variables from `.env` + - Setting `PYTHONPATH` + - Starting the server via `fastapi dev` + + Alternatively, run manually: + ```bash + uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 + ``` + +4. **Run with Docker**: + ```bash + docker-compose up --build + ``` + +--- + +## Project Structure +- `app/`: Core application code. + - `main.py`: Entry point. + - `routes.py`: API endpoints. + - `services/`: Business logic (Data queries, schema). + - `utils/`: Helpers (SQLite, KBase Client). + - `models.py`: Pydantic data models. +- `tests/`: Test suite. +- `docs/`: Documentation. + +--- + +## Testing + +### Running Tests +We use `unittest` (compatible with `pytest`). + +```bash +# Run all tests +python -m unittest discover tests + +# Or using pytest (recommended) +pytest tests/ -v +``` + +### Writing Tests +- Place unit tests in `tests/unit/`. +- Place integration tests in `tests/integration/`. +- Use `app/services/data/query_service.py` tests as a reference for mocking SQLite. + +--- + +## Code Style +- Follow PEP 8. +- Use type hints. +- Ensure purely synchronous I/O (like `sqlite3`) is wrapped in `run_sync_in_thread`. diff --git a/docs/QUICKSTART_DEMO.md b/docs/QUICKSTART_DEMO.md deleted file mode 100644 index b06de7d..0000000 --- a/docs/QUICKSTART_DEMO.md +++ /dev/null @@ -1,50 +0,0 @@ -# Quickstart Demo - -This guide walks you through running the TableScanner demo locally. - -## Prerequisites - -- Python 3.9+ -- KBase Auth Token (for accessing workspace objects) - -## Setup - -1. **Install Dependencies** - ```bash - pip install -r requirements.txt - ``` - -2. **Start the Service** - ```bash - uv run fastapi dev app/main.py - ``` - Server will start at `http://localhost:8000`. - -## Running the Demo - -1. Open the [Viewer](http://localhost:8000/static/viewer.html) in your browser. - -2. **Configuration:** - - **Environment**: Select `AppDev` (or appropriate env). - - **Auth Token**: Enter your KBase token. - -3. **Load Data:** - - **BERDL Table ID**: Enter `76990/ADP1Test`. - - Click the **Search** icon. - -4. **Explore:** - - Since `76990/ADP1Test` contains only one pangenome, it will be **auto-selected**. - - Tables will load automatically. - - Select a table (e.g., "Genome attributes") to view data. - - Hover over cells with IDs (UniProt, KEGG, etc.) to see tooltips. - - Click IDs to visit external databases. - -## Multi-Pangenome Demo - -To test loading multiple identifiers: - -1. **BERDL Table ID**: Enter `76990/ADP1Test, 76990/ADP1Test` (simulating two sources). -2. Click **Search**. -3. The **Pangenome** dropdown will appear. -4. Options will show as: `ADP1 [76990/ADP1Test]`. -5. Select different options to toggle between datasets (if they were different). diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..5eee708 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,31 @@ +# TableScanner + +**TableScanner** is a high-performance, read-only API service for querying SQLite databases stored in [KBase](https://kbase.us). It powers the DataTables Viewer and other applications requiring fast access to tabular data. + +## Documentation + +- **[API Reference](API.md)**: Endpoints, authentication, and usage examples. +- **[Architecture](ARCHITECTURE.md)**: System design and technical overview. +- **[Contributing Guide](CONTRIBUTING.md)**: Setup, testing, and development standards. + +## Quick Start + +### Run with Docker +```bash +docker-compose up --build +``` +The API will be available at `http://localhost:8000`. + +### Run Locally +```bash +# 1. Setup environment +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +cp .env.example .env # Edit with your KBase Token + +# 2. Run using helper script +./scripts/dev.sh +``` + +The `./scripts/dev.sh` script is the recommended way to run locally as it handles environment loading and PYTHONPATH setup automatically. diff --git a/docs/USAGE_GUIDE.md b/docs/USAGE_GUIDE.md deleted file mode 100644 index 6cb87b4..0000000 --- a/docs/USAGE_GUIDE.md +++ /dev/null @@ -1,107 +0,0 @@ -# Usage Guide - -This guide covers production usage of the TableScanner service. - -## API Endpoint -The service is deployed at: -``` -https://appdev.kbase.us/services/berdl_table_scanner -``` - -## Authentication -All requests require a valid KBase authentication token passed in the `Authorization` header. - -```bash -Authorization: -``` - ---- - -## 1. Using the Hierarchical REST API (Browser-friendly) - -This style uses hierarchical paths and standard GET requests. It is ideal for web applications or simple data navigation. - -### List Available Tables -Get a list of all tables found in a KBase object. - -**Endpoint:** `GET /object/{upa}/tables` - -**Example:** -```bash -curl -H "Authorization: $KB_TOKEN" \ - "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/tables" -``` - -### Query Table Data -Retrieve paginated data from a specific table. - -**Endpoint:** `GET /object/{upa}/tables/{table_name}/data` - -**Parameters:** -- `limit`: (int) Maximum rows (default 100) -- `offset`: (int) Skip rows (default 0) -- `search`: (string) Global search term -- `sort_column`: (string) Column to sort by -- `sort_order`: (string) "ASC" or "DESC" - -**Example:** -```bash -curl -H "Authorization: $KB_TOKEN" \ - "https://appdev.kbase.us/services/berdl_table_scanner/object/76990/7/2/tables/Genes/data?limit=5" -``` - ---- - -## 2. Using the Flat POST API (Script-friendly) - -The Flat POST API is recommended for Python scripts and programmatic access. It allows sending complex query parameters in a single JSON body. - -**Endpoint:** `POST /table-data` - -### Implementation Example (Python) - -```python -import requests -import json - -url = "https://appdev.kbase.us/services/berdl_table_scanner/table-data" -headers = {"Authorization": "YOUR_KBASE_TOKEN"} - -payload = { - "berdl_table_id": "76990/7/2", - "table_name": "Metadata_Conditions", - "limit": 50, - "offset": 0, - "search_value": "glucose", - "col_filter": { - "organism": "E. coli" - }, - "sort_column": "yield", - "sort_order": "DESC" -} - -response = requests.post(url, json=payload, headers=headers) -data = response.json() - -print(f"Retrieved {len(data['data'])} rows.") -``` - ---- - -## Pro Tips - -### Multi-Source Search -The metadata endpoints support comma-separated IDs to aggregate pangenomes across multiple objects. - -```bash -GET /pangenomes?berdl_table_id=76990/7/2,76990/8/1 -``` - -### Performance -The first request for a large dataset may take a few seconds as the service downloads and indexes the database. Subsequent requests will be near-instant. - ---- - -## Web Viewer -Access the interactive viewer at: -`https://appdev.kbase.us/services/berdl_table_scanner/static/viewer.html` # TODO: implement this diff --git a/pyproject.toml b/pyproject.toml index 2e6923c..3fc8ddb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,11 +10,7 @@ dependencies = [ "minio>=7.2.20", "pydantic-settings>=2.0.0", "requests>=2.31.0", - "pandas>=2.2.0", - "PyYAML>=6.0", "tqdm>=4.64.0", - "itables>=1.5.0", - "ipywidgets>=8.0.0", ] [build-system] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d444106 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +# TableScanner Requirements +# Install with: pip install -r requirements.txt + +fastapi[standard]>=0.124.4 +uvicorn>=0.38.0 +pydantic>=2.11.0 +pydantic-settings>=2.0.0 +requests>=2.31.0 +python-multipart>=0.0.20 +tqdm>=4.64.0 +minio>=7.2.20 +orjson>=3.9.10 diff --git a/scripts/api_client.py b/scripts/api_client.py deleted file mode 100644 index 4143abc..0000000 --- a/scripts/api_client.py +++ /dev/null @@ -1,86 +0,0 @@ -import requests -import json -import os - -# Set your KBase authentication token -TOKEN = os.environ.get("KBASE_TOKEN") -if not TOKEN: - raise RuntimeError("KBASE_TOKEN environment variable is not set.") -HEADERS = {"Authorization": TOKEN} -BASE_URL = "http://127.0.0.1:8000" - -# ---------------------------------------------------------- -# STYLE 1: HIERARCHICAL REST (GET) -# Ideal for simple navigation and web viewers -# ---------------------------------------------------------- - -print("\n--- REST: List Tables ---") -# Literal path: /object/{upa}/tables -res = requests.get(f"{BASE_URL}/object/76990/7/2/tables", headers=HEADERS) -res.raise_for_status() -print(json.dumps(res.json()["tables"][:3], indent=2)) - - - -print("\n--- REST: Get Top 3 Genes ---") -# Literal path: /object/{upa}/tables/{table_name}/data -res = requests.get(f"{BASE_URL}/object/76990/7/2/tables/Genes/data", params={"limit": 3}, headers=HEADERS) -res.raise_for_status() -print(json.dumps(res.json()["data"], indent=2)) - - - -print("\n--- REST: Filtered Search (kinase) ---") -# Literal path with query parameters -params = {"limit": 3, "search": "kinase"} -res = requests.get(f"{BASE_URL}/object/76990/7/2/tables/Genes/data", params=params, headers=HEADERS) -res.raise_for_status() -print(json.dumps(res.json()["data"], indent=2)) - - -# ---------------------------------------------------------- -# STYLE 2: FLAT POST -# Ideal for complex queries and production scripts -# ---------------------------------------------------------- - -print("\n--- POST: Basic Fetch (3 rows) ---") -# Single endpoint for all data: /table-data -payload = { - "berdl_table_id": "76990/7/2", - "table_name": "Conditions", - "limit": 3 -} -res = requests.post(f"{BASE_URL}/table-data", json=payload, headers=HEADERS) -res.raise_for_status() -print(json.dumps(res.json()["data"], indent=2)) - - - -print("\n--- POST: Column-Specific Filter (Carbon_source=pyruvate) ---") -# Precise AND-logic filtering via col_filter -payload = { - "berdl_table_id": "76990/7/2", - "table_name": "Conditions", - "limit": 3, - "col_filter": {"Carbon_source": "pyruvate"} -} -res = requests.post(f"{BASE_URL}/table-data", json=payload, headers=HEADERS) -res.raise_for_status() -print(json.dumps(res.json()["data"], indent=2)) - - - -print("\n--- POST: Sorted Multi-column Query ---") -# Support for complex ordering -payload = { - "berdl_table_id": "76990/7/2", - "table_name": "Genes", - "limit": 3, - "order_by": [ - {"column": "Length", "direction": "DESC"}, - {"column": "ID", "direction": "ASC"} - ] -} -res = requests.post(f"{BASE_URL}/table-data", json=payload, headers=HEADERS) -res.raise_for_status() -print(json.dumps(res.json()["data"], indent=2)) diff --git a/static/viewer.html b/static/viewer.html index 463bb62..8732990 100644 --- a/static/viewer.html +++ b/static/viewer.html @@ -1,962 +1,1310 @@ - + - TableScanner - BERDL Table Viewer + TableScanner - Research Data Explorer - - - + - -
-
-
-

+
+ +

-
- + TableScanner
+ v2.0
-
-
- -
- -
-
Connection
-
-
- -
- - + +
+ +
-
-
-
+
+
+
diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..3a8be3e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package for TableScanner diff --git a/tests/integration/test_concurrency.py b/tests/integration/test_concurrency.py new file mode 100644 index 0000000..c23cb43 --- /dev/null +++ b/tests/integration/test_concurrency.py @@ -0,0 +1,166 @@ + +import threading +import pytest +import sqlite3 +import time +import random +from concurrent.futures import ThreadPoolExecutor, as_completed + +from app.services.data.connection_pool import get_connection_pool +from app.services.data.query_service import get_query_service, AggregationSpec + +# Use a temporary database for testing +@pytest.fixture +def test_db(tmp_path): + db_path = tmp_path / "test_concurrency.db" + conn = sqlite3.connect(str(db_path)) + cursor = conn.cursor() + cursor.execute("CREATE TABLE test_data (id INTEGER PRIMARY KEY, value INTEGER, text_col TEXT)") + + # Insert some data + data = [(i, i * 10, f"row_{i}") for i in range(100)] + cursor.executemany("INSERT INTO test_data (id, value, text_col) VALUES (?, ?, ?)", data) + conn.commit() + conn.close() + return db_path + +def test_connection_pool_concurrency(test_db): + """ + Test that the connection pool handles concurrent access correctly without + raising 'database is locked' errors or other threading issues. + """ + pool = get_connection_pool() + query_service = get_query_service() + + # Reset pool for this test to ensure clean state + # (Note: In a real app, the pool is global, but here we want to test isolation if possible. + # The pool uses path as key, so unique tmp_path helps.) + + def worker_task(worker_id): + results = [] + errors = [] + try: + # Simulate random delay to interleave requests + time.sleep(random.random() * 0.1) + + # 1. Simple Select + res = query_service.execute_query( + test_db, + "test_data", + limit=10, + offset=worker_id * 2, + use_cache=False # Disable cache to force DB hits + ) + results.append(len(res["data"])) + + # 2. Schema Info (uses pool independently) + types = query_service.get_column_types(test_db, "test_data") + results.append(len(types)) + + # 3. Aggregation (heavier query) + agg_res = query_service.execute_query( + test_db, + "test_data", + aggregations=[AggregationSpec(column="value", function="sum", alias="total_val")], + use_cache=False + ) + results.append(agg_res["data"][0][0]) + + except Exception as e: + errors.append(str(e)) + + return results, errors + + # Run 20 concurrent threads + # Max connections per pool is default 5. This forces queuing. + num_threads = 20 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = {executor.submit(worker_task, i): i for i in range(num_threads)} + + all_errors = [] + execution_counts = 0 + + for future in as_completed(futures): + res, errs = future.result() + if errs: + all_errors.extend(errs) + else: + execution_counts += 1 + + # Assertions + if all_errors: + pytest.fail(f"Concurrent execution failed with errors: {all_errors[:5]}...") + + assert execution_counts == num_threads, f"Expected {num_threads} successful executions, got {execution_counts}" + + # Verify pool cleanup or state if possible, though internals are private. + # We can check stats via public method (if we added one, checking routes.py showed get_stats) + stats = pool.get_stats() + # Should see the pool for our db_path + assert any(p["db_path"] == str(test_db) for p in stats["pools"]) + + +def test_pool_exhaustion_timeout(test_db): + """ + Test that connection acquisition times out if all connections are held. + """ + pool = get_connection_pool() + db_path = test_db + + try: + # Max connections is 5 by default constant in connection_pool.py + # We'll try to grab 6. + # But we need to use the context manager. + # It's hard to simulate holding them without nesting or threads. + + def holder_thread(event_start, event_stop): + try: + with pool.connection(db_path): + event_start.set() + # Wait until told to stop + event_stop.wait(timeout=5) + except Exception as e: + print(f"Holder thread error: {e}") + + # Start 5 threads to hold connections + threads = [] + stop_events = [] + + for _ in range(5): + start_evt = threading.Event() + stop_evt = threading.Event() + t = threading.Thread(target=holder_thread, args=(start_evt, stop_evt)) + t.start() + # Wait for it to grab connection + if not start_evt.wait(timeout=2): + pass # Might be queued if pool limit reached + + threads.append(t) + stop_events.append(stop_evt) + + # Give a moment for all to be surely active + time.sleep(0.5) + + # Now try to grab one more. It should block and eventually timeout (default 5s) + # We can set a shorter timeout if the connection() method supports it, + # but our implementation uses default. + # Let's verify it raises TimeoutError/Empty after waiting. + + try: + # We suspect this will raise or block. + # Depending on queue.get(timeout=...), default in code was 5.0s + with pool.connection(db_path): + # If we got here, maybe one of the threads didn't hold it, or max connections > 5 + pass + except Exception: + # Expecting some kind of queue Empty or timeout exception + pass + finally: + # Release threads + for evt in stop_events: + evt.set() + for t in threads: + t.join() + + except Exception as e: + pytest.fail(f"Test setup failed: {e}") diff --git a/tests/integration/test_local_upload.py b/tests/integration/test_local_upload.py new file mode 100644 index 0000000..df4b139 --- /dev/null +++ b/tests/integration/test_local_upload.py @@ -0,0 +1,154 @@ +import os +import pytest +import sqlite3 +import tempfile +from pathlib import Path +from fastapi.testclient import TestClient +from app.main import app +from app.config import settings + +client = TestClient(app) + +@pytest.fixture +def dummy_sqlite_db(): + """Create a temporary SQLite database with some data.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + conn = sqlite3.connect(tmp.name) + cursor = conn.cursor() + cursor.execute("CREATE TABLE TestTable (id INTEGER PRIMARY KEY, name TEXT)") + cursor.execute("INSERT INTO TestTable (name) VALUES ('Alpha')") + cursor.execute("INSERT INTO TestTable (name) VALUES ('Beta')") + conn.commit() + conn.close() + tmp_path = Path(tmp.name) + yield tmp_path + # Cleanup + if tmp_path.exists(): + tmp_path.unlink() + +def test_upload_and_query_flow(dummy_sqlite_db): + """ + Test the full flow: + 1. Upload DB -> Get handle + 2. List tables -> Success + 3. Query data -> Success + """ + # 1. Upload + with open(dummy_sqlite_db, "rb") as f: + response = client.post( + "/upload", + files={"file": ("my_test.db", f, "application/vnd.sqlite3")} + ) + + assert response.status_code == 200 + data = response.json() + assert "handle" in data + assert data["handle"].startswith("local:") + assert data["message"] == "Database uploaded successfully" + + handle = data["handle"] + + # 2. List Tables + # Need to mock the KBase ID check or auth if implied, BUT local handles bypass KBase download. + # The endpoint /object/{ref}/tables takes the ref. + # Note: Authorization header might still be checked by get_auth_token. + # We provide a dummy token to pass the check. + headers = {"Authorization": "Bearer dummy_token"} + + # We must patch get_object_type or it might try to call KBase for 'local:...' which is not a valid UPA. + # Let's check routes.py: list_tables_by_object calls get_object_type logic. + # Wait, routes.py:325 handles object_type by calling get_object_type. + # get_object_type might fail for local handle. I need to make sure get_object_type handles it gracefully or mock it. + + # Actually, in routes.py, I should update get_object_type logic OR just let it fail non-critically? + # routes.py:301 catches Exception and sets object_type = None. That's fine. + + response = client.get(f"/object/{handle}/tables", headers=headers) + assert response.status_code == 200, response.text + tables_data = response.json() + + assert tables_data["object_type"] == "LocalDatabase" or tables_data["object_type"] is None + names = [t["name"] for t in tables_data["tables"]] + assert "TestTable" in names + + # 3. Query Data + query_payload = { + "berdl_table_id": handle, + "table_name": "TestTable", + "limit": 10 + } + response = client.post("/table-data", json=query_payload, headers=headers) + assert response.status_code == 200 + query_data = response.json() + assert len(query_data["data"]) == 2 + assert query_data["data"][0][1] == "Alpha" + +def test_upload_security_traversal(): + """Test that we can't directory traverse with a crafted handle.""" + headers = {"Authorization": "Bearer dummy_token"} + + # Try to access a file outside uploads via path traversal + # get_object_db_path has a check for ".." + + # We'll try to use a handle that looks like traversal + bad_handle = "local:../../../../etc/passwd" + + response = client.get(f"/object/{bad_handle}/tables", headers=headers) + # converting slash to %2F might happen in client depending on how it's passed, + # but the routes.py extracts it. + # The check in db_helper.py should catch it. + + # FastAPI path parameter handling might encode it, but we can try injecting it. + # Since {ws_ref:path} captures slashes, we can test: + response = client.get("/object/local:..%2F..%2Fetc%2Fpasswd/tables", headers=headers) + + # Should get 400 or 500, but definitely not success. + # Our db_helper validation raises 400. + assert response.status_code in (400, 404, 500) + +def test_upload_invalid_file_format(): + """Test that uploading a non-SQLite file is rejected.""" + # Create a dummy text file + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + tmp.write(b"This is not a SQLite database") + tmp_path = Path(tmp.name) + + try: + with open(tmp_path, "rb") as f: + response = client.post( + "/upload", + files={"file": ("fake.db", f, "application/vnd.sqlite3")} + ) + + # Should be rejected due to header mismatch + assert response.status_code == 400 + assert "Invalid SQLite file format" in response.json()["detail"] + + finally: + if tmp_path.exists(): + tmp_path.unlink() + +def test_upload_and_get_stats(dummy_sqlite_db): + """Test getting statistics for an uploaded table.""" + # 1. Upload + with open(dummy_sqlite_db, "rb") as f: + response = client.post( + "/upload", + files={"file": ("stats_test.db", f, "application/vnd.sqlite3")} + ) + handle = response.json()["handle"] + + # 2. Get Stats + headers = {"Authorization": "Bearer dummy_token"} + response = client.get(f"/object/{handle}/tables/TestTable/stats", headers=headers) + + assert response.status_code == 200 + stats = response.json() + assert stats["table"] == "TestTable" + assert stats["row_count"] == 2 + + # Check column stats + cols = {c["column"]: c for c in stats["columns"]} + assert "name" in cols + assert cols["name"]["distinct_count"] == 2 + assert "Alpha" in cols["name"]["sample_values"] diff --git a/tests/integration/test_routes.py b/tests/integration/test_routes.py new file mode 100644 index 0000000..13c03f9 --- /dev/null +++ b/tests/integration/test_routes.py @@ -0,0 +1,36 @@ +import unittest +from fastapi.testclient import TestClient +from app.main import app + +class TestRoutes(unittest.TestCase): + def setUp(self): + self.client = TestClient(app) + + def test_health_check(self): + response = self.client.get("/health") + # 500/503 is NOT acceptable. Integration tests must ensure the application can start. + # The ConnectionPool does not require external connectivity to initialize. + self.assertEqual(response.status_code, 200) + + def test_api_docs_accessible(self): + response = self.client.get("/docs") + self.assertEqual(response.status_code, 200) + + def test_openapi_schema_structure(self): + response = self.client.get("/openapi.json") + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertIn("paths", data) + # Verify Key Endpoints exist + self.assertIn("/object/{ws_ref}/tables", data["paths"]) + self.assertIn("/table-data", data["paths"]) + + # Verify Deprecated Endpoints are GONE + self.assertNotIn("/handle/{handle_ref}/tables", data["paths"]) + self.assertNotIn("/pangenomes", data["paths"]) + self.assertNotIn("/tables", data["paths"]) + self.assertNotIn("/config/providers", data["paths"]) + self.assertNotIn("/config/resolve", data["paths"]) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration/test_routes_advanced.py b/tests/integration/test_routes_advanced.py new file mode 100644 index 0000000..44aafa6 --- /dev/null +++ b/tests/integration/test_routes_advanced.py @@ -0,0 +1,193 @@ +import unittest +import sqlite3 +from pathlib import Path +from fastapi.testclient import TestClient +from app.main import app +from app.config import settings + +def create_test_db(db_path: Path): + """Create a comprehensive test database with various types.""" + db_path.parent.mkdir(parents=True, exist_ok=True) + if db_path.exists(): + db_path.unlink() + + conn = sqlite3.connect(str(db_path)) + cursor = conn.cursor() + + # Create Genes table + cursor.execute(""" + CREATE TABLE Genes ( + gene_id TEXT PRIMARY KEY, + gene_name TEXT, + score REAL, + count INTEGER, + is_active BOOLEAN, + features TEXT, -- JSON-like + created_at TEXT + ) + """) + + data = [ + ("G1", "dnaA", 95.5, 10, 1, '{"type": "init"}', "2023-01-01"), + ("G2", "dnaN", 45.2, 5, 0, '{"type": "pol"}', "2023-01-02"), + ("G3", "gyrA", 88.0, 20, 1, '{"type": "top"}', "2023-01-03"), + ("G4", "gyrB", 87.5, 15, 1, '{"type": "top"}', "2023-01-03"), + ("G5", "recA", 12.5, 2, 0, None, "2023-01-04"), + ] + + cursor.executemany("INSERT INTO Genes VALUES (?,?,?,?,?,?,?)", data) + + # Text search table + cursor.execute("CREATE TABLE TextContents (id INTEGER PRIMARY KEY, title TEXT, body TEXT)") + cursor.execute("INSERT INTO TextContents VALUES (1, 'Hello World', 'This is a test document')") + cursor.execute("INSERT INTO TextContents VALUES (2, 'Foo Bar', 'Another document with different content')") + cursor.execute("INSERT INTO TextContents VALUES (3, 'Baz Qux', 'Hello again, world!')") + + conn.commit() + conn.close() + return db_path + +def setup_cache_with_db(cache_dir: Path, upa: str) -> Path: + """Setup a cache directory with the test DB for a specific UPA.""" + # From app/utils/cache.py logic: cache_dir / sanitized_upa / tables.db + safe_upa = upa.replace("/", "_").replace(":", "_").replace(" ", "_") + target_dir = cache_dir / safe_upa + target_dir.mkdir(parents=True, exist_ok=True) + + db_path = target_dir / "tables.db" + create_test_db(db_path) + return db_path + +class TestAdvancedFeatures(unittest.TestCase): + def setUp(self): + self.client = TestClient(app) + # Setup a real database in the configured cache directory + self.test_upa = "12345/Test/1" + self.db_path = setup_cache_with_db(Path(settings.CACHE_DIR), self.test_upa) + + def test_advanced_filtering(self): + """Test strict filtering capabilities.""" + # 1. Greater Than + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "filters": [ + {"column": "score", "operator": "gt", "value": 90} + ] + }) + self.assertEqual(response.status_code, 200, response.text) + data = response.json() + self.assertEqual(data["total_count"], 1) + self.assertEqual(data["data"][0][0], "G1") # G1 has score 95.5 + + # 2. IN operator (list) + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "filters": [ + {"column": "gene_name", "operator": "in", "value": ["dnaA", "gyrA"]} + ] + }) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["total_count"], 2) + names = sorted([r[1] for r in data["data"]]) + self.assertEqual(names, ["dnaA", "gyrA"]) + + # 3. Like (text search on specific column) + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "filters": [ + {"column": "gene_name", "operator": "like", "value": "gyr"} + ] + }) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["total_count"], 2) # gyrA, gyrB + + def test_aggregations(self): + """Test aggregation capabilities.""" + # 1. Simple Count + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "aggregations": [ + {"column": "*", "function": "count", "alias": "total"} + ] + }) + self.assertEqual(response.status_code, 200) + data = response.json() + # Expecting one row with count + self.assertEqual(data["headers"], ["total"]) + self.assertEqual(int(data["data"][0][0]), 5) + + # 2. Group By + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "group_by": ["is_active"], + "aggregations": [ + {"column": "*", "function": "count", "alias": "cnt"}, + {"column": "score", "function": "avg", "alias": "avg_score"} + ], + "sort_column": "is_active", + "sort_order": "ASC" + }) + self.assertEqual(response.status_code, 200) + data = response.json() + # 0 (inactive): G2(45.2), G5(12.5) -> avg ~28.85 + # 1 (active): G1(95.5), G3(88.0), G4(87.5) -> avg ~90.33 + self.assertEqual(len(data["data"]), 2) + self.assertEqual(data["data"][0][0], "0") # is_active=0 + self.assertEqual(data["data"][1][0], "1") # is_active=1 + + def test_sorting_and_pagination(self): + """Test sorting and pagination.""" + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "sort_column": "score", + "sort_order": "DESC", + "limit": 2, + "offset": 1 + }) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(len(data["data"]), 2) + # Scores descending: 95.5 (G1), 88.0 (G3), 87.5 (G4), 45.2 (G2), 12.5 (G5) + # Offset 1 means we skip G1. + # Should get G3 and G4. + self.assertEqual(data["data"][0][0], "G3") + self.assertEqual(data["data"][1][0], "G4") + + # Not testing global search heavily as it relies on FTS5 which might be optional/missing in some sqlite builds, + # though QueryService attempts to create it. + def test_global_search_fallback(self): + """Test global search matches text columns.""" + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "search_value": "dna*" # Use FTS5 prefix syntax + }) + self.assertEqual(response.status_code, 200) + data = response.json() + # dnaA, dnaN should match 'dna*' + self.assertTrue(len(data["data"]) >= 2, f"Expected >=2 matches for 'dna*', got {len(data['data'])}") + + def test_legacy_compatibility(self): + """Test that legacy fields still work.""" + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "columns": "gene_id, gene_name", # String format + "col_filter": {"gene_name": "dna"} # Legacy filter dict + }) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["headers"], ["gene_id", "gene_name"]) + # Should match dnaA, dnaN + self.assertEqual(data["total_count"], 2) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration/test_security.py b/tests/integration/test_security.py new file mode 100644 index 0000000..4e1b9c3 --- /dev/null +++ b/tests/integration/test_security.py @@ -0,0 +1,51 @@ +import unittest +from pathlib import Path +from fastapi.testclient import TestClient +from app.main import app +from app.config import settings +from app.utils.cache import sanitize_id + +class TestSecurity(unittest.TestCase): + def setUp(self): + self.client = TestClient(app) + self.cache_dir = Path(settings.CACHE_DIR) + + def test_sanitize_id_security(self): + """Test that ID sanitization prevents traversal.""" + # Standard ID + self.assertEqual(sanitize_id("123/456"), "123_456") + + # Path traversal attempts + self.assertNotEqual(sanitize_id("../../../etc/passwd"), "../../../etc/passwd") + # "a/../b" -> "a_.._b" (this is safe as a filename because / is removed) + self.assertEqual(sanitize_id("a/../b"), "a_.._b") + + # What about just ".." + self.assertNotEqual(sanitize_id(".."), "..") + # Ensure it was modified to be safe + self.assertTrue(sanitize_id("..").endswith("_safe")) + + def test_path_traversal_api(self): + """Test API prevents accessing files outside cache.""" + # Attempt to access a file that definitely exists outside cache but relative + # This test relies on the fact that the code uses sanitize_id internally + + malicious_id = "../../../etc/passwd" + + # This should fail because it will look for "......etcpasswd" (or similar) in cache + # and not find it, returning 404 or empty list, NOT 500 or file content + response = self.client.get(f"/object/{malicious_id}/tables") + + # Accept 404 (Not Found) or 400 (Bad Request) or 422 + # BUT should definitively NOT return 200 with file content + self.assertNotEqual(response.status_code, 200) + + def test_cors_middleware(self): + """Verify CORS headers are present (default configuration).""" + response = self.client.get("/", headers={"Origin": "http://example.com"}) + self.assertEqual(response.status_code, 200) + # Default config allows * + self.assertEqual(response.headers.get("access-control-allow-origin"), "*") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/integration/test_security_fixes.py b/tests/integration/test_security_fixes.py new file mode 100644 index 0000000..535e5bf --- /dev/null +++ b/tests/integration/test_security_fixes.py @@ -0,0 +1,173 @@ + +import unittest +import sqlite3 +import shutil +from pathlib import Path +from unittest.mock import patch, MagicMock +from fastapi.testclient import TestClient +from app.main import app +from app.config import settings +from app.services.data.query_service import get_query_service + +# Reusing DB setup logic +def create_test_db(db_path: Path): + """Create a comprehensive test database.""" + db_path.parent.mkdir(parents=True, exist_ok=True) + if db_path.exists(): + db_path.unlink() + + conn = sqlite3.connect(str(db_path)) + cursor = conn.cursor() + + # Create Genes table + cursor.execute(""" + CREATE TABLE Genes ( + gene_id TEXT PRIMARY KEY, + gene_name TEXT, + score REAL, + count INTEGER + ) + """) + + data = [ + ("G1", "dnaA", 95.5, 10), + ("G2", "dnaN", 45.2, 5), + ("G3", "gyrA", 88.0, 20), + ] + cursor.executemany("INSERT INTO Genes VALUES (?,?,?,?)", data) + + # Create a dummy large table for FTS5 test (no data needed if we mock count) + cursor.execute("CREATE TABLE LargeTable (id INTEGER PRIMARY KEY, text TEXT)") + + conn.commit() + conn.close() + return db_path + +def setup_cache_with_db(cache_dir: Path, upa: str) -> Path: + safe_upa = upa.replace("/", "_").replace(":", "_").replace(" ", "_") + target_dir = cache_dir / safe_upa + target_dir.mkdir(parents=True, exist_ok=True) + + db_path = target_dir / "tables.db" + create_test_db(db_path) + return db_path + +class TestSecurityFixes(unittest.TestCase): + def setUp(self): + self.client = TestClient(app) + self.test_upa = "99999/Security/1" + self.db_path = setup_cache_with_db(Path(settings.CACHE_DIR), self.test_upa) + + def tearDown(self): + # Clean up + safe_upa = self.test_upa.replace("/", "_") + target_dir = Path(settings.CACHE_DIR) / safe_upa + if target_dir.exists(): + shutil.rmtree(target_dir) + + def test_variable_limit_enforcement(self): + """Test that IN operator with >900 items raises 422.""" + # Create a list of 901 items + many_items = [f"item_{i}" for i in range(901)] + + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "filters": [ + {"column": "gene_name", "operator": "in", "value": many_items} + ] + }) + + self.assertEqual(response.status_code, 422) + self.assertIn("Too many values", response.json()["detail"]) + + def test_variable_limit_under_threshold(self): + """Test that IN operator with <900 items works.""" + items = ["dnaA", "dnaN"] + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "filters": [ + {"column": "gene_name", "operator": "in", "value": items} + ] + }) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["total_count"], 2) + + def test_strict_numeric_validation(self): + """Test that invalid numeric inputs return 422 instead of 0.""" + # 1. String in numeric filter + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "filters": [ + {"column": "score", "operator": "gt", "value": "high_score"} + ] + }) + self.assertEqual(response.status_code, 422) + self.assertIn("Invalid numeric value", response.json()["detail"]) + + # 2. String in integer filter + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "filters": [ + {"column": "count", "operator": "gt", "value": "not_an_int"} + ] + }) + self.assertEqual(response.status_code, 422) + self.assertIn("Invalid numeric value", response.json()["detail"]) + + @patch("app.services.data.connection_pool.ConnectionPool.get_connection") + def test_fts5_safety_logic_mocked_pool(self, mock_get_conn): + """Mocked unit test for FTS5 safety limit Logic.""" + qs = get_query_service() + + # Setup mock connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_get_conn.return_value = mock_conn + mock_conn.cursor.return_value = mock_cursor + + # Call sequence in ensure_fts5_table: + # 1. execute(check_table) -> fetchone() -> None (not exists) + # 2. execute(compile_options) -> fetchall() -> ["ENABLE_FTS5"] + # 3. execute(count) -> fetchone() -> [150000] (Too large) + + mock_cursor.fetchone.side_effect = [ + None, # 1. FTS5 table check + [150000], # 3. Row count + ] + # mock fetchall for compile options + mock_cursor.fetchall.return_value = [("ENABLE_FTS5",)] + + # Call + result = qs.ensure_fts5_table(Path("dummy.db"), "LargeTable", ["text"]) + + # Assert + self.assertFalse(result, "Should return False for tables > 100k rows") + # Ensure we didn't try to create it + # The CREATE VIRTUAL TABLE call should NOT have happened + # We can check the execute calls + execute_calls = [args[0] for args, _ in mock_cursor.execute.call_args_list] + self.assertFalse(any("CREATE VIRTUAL TABLE" in cmd for cmd in execute_calls)) + + def test_fts5_creation_small_table(self): + """Verify FTS5 IS created for small tables.""" + response = self.client.post("/table-data", json={ + "berdl_table_id": self.test_upa, + "table_name": "Genes", + "search_value": "dna" + }) + self.assertEqual(response.status_code, 200) + # Check logs or side effects? + # We can check if `Genes_fts5` table exists in the DB file. + + conn = sqlite3.connect(self.db_path) + cur = conn.cursor() + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='Genes_fts5'") + self.assertIsNotNone(cur.fetchone(), "Genes_fts5 should be created for small table") + conn.close() + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_query_service.py b/tests/unit/test_query_service.py new file mode 100644 index 0000000..c57b719 --- /dev/null +++ b/tests/unit/test_query_service.py @@ -0,0 +1,107 @@ +import unittest +import sqlite3 +import tempfile +import shutil +import logging +from pathlib import Path +from app.services.data.query_service import QueryService, FilterSpec, AggregationSpec +from app.exceptions import TableNotFoundError + +# Configure logging +logging.basicConfig(level=logging.ERROR) + +class TestQueryService(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.db_path = Path(self.temp_dir) / "test.db" + self.service = QueryService() + + # Create a test database + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER, salary REAL, status TEXT)") + data = [ + (1, "Alice", 30, 50000.0, "active"), + (2, "Bob", 25, 45000.5, "inactive"), + (3, "Charlie", 35, 70000.0, "active"), + (4, "David", 30, 52000.0, "active"), + (5, "Eve", 28, 49000.0, "inactive"), + ] + cursor.executemany("INSERT INTO users VALUES (?, ?, ?, ?, ?)", data) + conn.commit() + conn.close() + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_simple_select(self): + result = self.service.execute_query(self.db_path, "users", limit=10) + self.assertEqual(len(result["data"]), 5) + self.assertEqual(result["total_count"], 5) + self.assertEqual(result["headers"], ["id", "name", "age", "salary", "status"]) + + def test_filter_numeric(self): + filters = [FilterSpec(column="age", operator="gt", value=28)] + result = self.service.execute_query(self.db_path, "users", filters=filters) + # Should be Alice(30), Charlie(35), David(30) + self.assertEqual(len(result["data"]), 3) + self.assertEqual(result["total_count"], 3) + + def test_filter_text(self): + filters = [FilterSpec(column="status", operator="eq", value="active")] + result = self.service.execute_query(self.db_path, "users", filters=filters) + self.assertEqual(len(result["data"]), 3) + + def test_sorting(self): + # Sort by age DESC + result = self.service.execute_query(self.db_path, "users", sort_column="age", sort_order="DESC") + data = result["data"] + # Charlie(35) first + self.assertEqual(data[0][1], "Charlie") + # Bob(25) last + self.assertEqual(data[4][1], "Bob") + + def test_aggregation(self): + aggs = [ + AggregationSpec(column="salary", function="avg", alias="avg_salary"), + AggregationSpec(column="status", function="count", alias="count") + ] + result = self.service.execute_query( + self.db_path, "users", + aggregations=aggs, + group_by=["status"], + sort_column="status" + ) + + self.assertEqual(len(result["data"]), 2) + row_active = next(r for r in result["data"] if r[0] == "active") + + # Active: Alice(50k), Charlie(70k), David(52k) -> Avg 57333.33 + self.assertAlmostEqual(float(row_active[1]), 57333.33, delta=0.1) + self.assertEqual(int(row_active[2]), 3) + + def test_sql_injection_sort_ignored(self): + """Ensure sort column injection attacks are ignored (fallback to default).""" + bad_col = "age; DROP TABLE users; --" + result = self.service.execute_query(self.db_path, "users", sort_column=bad_col) + self.assertEqual(len(result["data"]), 5) + + # Verify table still exists + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("SELECT count(*) FROM users") + self.assertEqual(cursor.fetchone()[0], 5) + conn.close() + + def test_sql_injection_filter_safe(self): + """Ensure filter value injection is handled safely as literal string.""" + filters = [FilterSpec(column="name", operator="eq", value="Alice' OR '1'='1")] + result = self.service.execute_query(self.db_path, "users", filters=filters) + self.assertEqual(len(result["data"]), 0) + + def test_missing_table(self): + with self.assertRaises(TableNotFoundError): + self.service.execute_query(self.db_path, "non_existent_table") + +if __name__ == "__main__": + unittest.main()