From ec7d9ac8bee273d82d9cd186f4706c50e3d451ac Mon Sep 17 00:00:00 2001 From: Vivian Peng Date: Fri, 12 Dec 2025 07:33:07 +0900 Subject: [PATCH 1/3] feat: add dbt integration (Phase 1) Implement core dbt module for query composition with Jinja templating, external dataset reference resolution, and project initialization. Includes CLI commands (, , ), model compilation engine, and comprehensive documentation. --- .amp-dbt/state.db | Bin 0 -> 28672 bytes AMP_DBT_DESIGN.md | 784 ++++++++++++++++++++++++++++++++++++ USER_WALKTHROUGH.md | 126 ++++++ apps/execute_query.py | 18 +- models/example_model.sql | 12 + models/staging/stg_test.sql | 6 + pyproject.toml | 7 + src/amp/dbt/README.md | 113 ++++++ src/amp/dbt/__init__.py | 25 ++ src/amp/dbt/cli.py | 461 +++++++++++++++++++++ src/amp/dbt/compiler.py | 502 +++++++++++++++++++++++ src/amp/dbt/config.py | 126 ++++++ src/amp/dbt/dependencies.py | 215 ++++++++++ src/amp/dbt/exceptions.py | 32 ++ src/amp/dbt/models.py | 29 ++ src/amp/dbt/monitor.py | 121 ++++++ src/amp/dbt/project.py | 205 ++++++++++ src/amp/dbt/state.py | 318 +++++++++++++++ src/amp/dbt/tracker.py | 178 ++++++++ 19 files changed, 3277 insertions(+), 1 deletion(-) create mode 100644 .amp-dbt/state.db create mode 100644 AMP_DBT_DESIGN.md create mode 100644 USER_WALKTHROUGH.md create mode 100644 models/example_model.sql create mode 100644 models/staging/stg_test.sql create mode 100644 src/amp/dbt/README.md create mode 100644 src/amp/dbt/__init__.py create mode 100644 src/amp/dbt/cli.py create mode 100644 src/amp/dbt/compiler.py create mode 100644 src/amp/dbt/config.py create mode 100644 src/amp/dbt/dependencies.py create mode 100644 src/amp/dbt/exceptions.py create mode 100644 src/amp/dbt/models.py create mode 100644 src/amp/dbt/monitor.py create mode 100644 src/amp/dbt/project.py create mode 100644 src/amp/dbt/state.py create mode 100644 src/amp/dbt/tracker.py diff --git a/.amp-dbt/state.db b/.amp-dbt/state.db new file mode 100644 index 0000000000000000000000000000000000000000..ce55f584bbe06c9597321f4d1a118baaa64957bf GIT binary patch literal 28672 zcmeI%?`zXQ7{Kvl>kpZo+Y6y@5BjDB5ky3MXWgO{+pcyQY_CGvWGpjHYLnu8>CiX6 z_U{q=uY4t5(pIlc+u_^rJxF?&Jh{vB$u)4~^j+JFU2z#sMs_ToYKNMxYcGY+G_5GN zvfLIWX%`j^`Kv4SP1U0I_}jz%gI`+d{txXz@8D=VE=hIOz)I zlKQNZ#&xk=eJ)2TuN`R=IhUSRvy zzz>J-7dM~nY!ZHqtnnlqx>4l*O>!x(#%-XdqF-yNsA8^Eca}tz6Kh0Ql(v%&!ys^n zu@?sEOxi%_%V`wFcV#<|ioKBxwnyV_%)TxAn~ojXvGXrZ92>8jXKhm)o%Q=h*R-}g zsO5>1wO8iCGWfHdQngjnwe)`(eemUK-m<50n6@p&)_S5ecIAuaUh=ZpPj-jNj)DLJ z2q1s}0tg_000IagfB*t{71&D>WBs4k$4g!aAb {{ var('last_processed_block', 0) }} +{% endif %} +``` + +**State Management:** +- Store `last_processed_block` per model in state database +- Auto-detect from destination table if available +- Support `--full-refresh` flag to reset + +--- + +## 5. Monitoring & Tracking + +### 5.1 Job Monitoring + +**Integration with Amp Jobs API:** +```python +# Uses existing JobsClient +class JobMonitor: + def monitor_job(self, job_id: int, model_name: str): + job = self.client.jobs.get(job_id) + + # Track progress + if job.status == 'Running': + latest_block = self._get_latest_block(model_name) + self.tracker.update_progress(model_name, latest_block) + + return job +``` + +**Job States:** +- `Pending` → Job queued +- `Running` → Actively processing +- `Completed` → Successfully finished +- `Failed` → Error occurred +- `Stopped` → Manually stopped + +### 5.2 Data Progress Tracking + +**State Database Schema:** +```sql +-- .amp-dbt/state.db +CREATE TABLE model_state ( + model_name TEXT PRIMARY KEY, + connection_name TEXT, + latest_block INTEGER, + latest_timestamp TIMESTAMP, + last_updated TIMESTAMP, + job_id INTEGER, + status TEXT -- 'fresh', 'stale', 'error' +); + +CREATE TABLE job_history ( + job_id INTEGER PRIMARY KEY, + model_name TEXT, + status TEXT, + started_at TIMESTAMP, + completed_at TIMESTAMP, + final_block INTEGER, + rows_processed INTEGER +); +``` + +**Tracking Methods:** + +1. **Query Destination Table** (if loading to database): + ```sql + SELECT MAX(block_num) FROM stg_erc20_transfers; + ``` + +2. **Job Metadata** (from Amp server): + ```python + job = client.jobs.get(job_id) + # Extract latest block from job metadata + ``` + +3. **Stream State** (for streaming queries): + ```python + # Use existing stream state tracking + checkpoint = get_checkpoint(connection_name, table_name) + latest_block = checkpoint.end_block + ``` + +### 5.3 Freshness Monitoring + +**Freshness Check:** +```python +class FreshnessMonitor: + def check_freshness(self, model_name: str) -> FreshnessResult: + latest = self.tracker.get_latest_timestamp(model_name) + if not latest: + return FreshnessResult(stale=True, reason="No data") + + age = datetime.now() - latest + threshold = self.config.get_alert_threshold(model_name) + + return FreshnessResult( + stale=age > threshold, + age=age, + latest_block=self.tracker.get_latest_block(model_name) + ) +``` + +**Alert Thresholds:** +- Per-model configuration +- Default: 30 minutes +- Configurable in `dbt_project.yml` + +--- + +## 6. CLI Design + +### 6.1 Core Commands + +```bash +# Compilation +amp-dbt compile # Compile all models +amp-dbt compile --select stg_* # Compile specific models +amp-dbt compile --show-sql # Show compiled SQL + +# Execution +amp-dbt run # Execute all models +amp-dbt run --select stg_* # Run specific models +amp-dbt run --register # Register as Amp datasets +amp-dbt run --register --deploy # Register and deploy +amp-dbt run --incremental # Use incremental strategy +amp-dbt run --full-refresh # Reset incremental state + +# Testing +amp-dbt test # Run all tests +amp-dbt test --select stg_* # Test specific models + +# Monitoring +amp-dbt monitor # Interactive dashboard +amp-dbt monitor --watch # Auto-refresh +amp-dbt status # Data freshness check +amp-dbt status --all # All models freshness + +# Job Management +amp-dbt jobs list # List all jobs +amp-dbt jobs status # Job details +amp-dbt jobs logs # View job logs +amp-dbt jobs stop # Stop running job +amp-dbt jobs resume # Resume failed job + +# Utilities +amp-dbt list # List all models +amp-dbt docs generate # Generate documentation +amp-dbt docs serve # Serve docs locally +amp-dbt clean # Clean compiled cache +``` + +### 6.2 Command Examples + +**Compile and Check:** +```bash +$ amp-dbt compile --select stg_erc20_transfers + +Compiling stg_erc20_transfers... +✅ Compiled successfully + +Dependencies: + - eth: _/eth_firehose@1.0.0 + +Compiled SQL (first 500 chars): +SELECT l.block_num, l.block_hash, ... +``` + +**Run with Monitoring:** +```bash +$ amp-dbt run --select stg_erc20_transfers --monitor + +Compiling stg_erc20_transfers... +✅ Compiled successfully + +Registering dataset... +✅ Registered: _/stg_erc20_transfers@1.0.0 + +Deploying... +✅ Job started: 12345 + +Monitoring job 12345... +[████████░░] 45% | Block: 18,500,000 | ETA: 5m +``` + +**Status Check:** +```bash +$ amp-dbt status + +┌─────────────────────────────────────────────────────────┐ +│ Model │ Status │ Latest Block │ Age │ +├─────────────────────────────────────────────────────────┤ +│ stg_erc20_transfers│ ✅ Fresh │ 18,500,000 │ 5 min │ +│ stg_erc721_transfers│ ⚠️ Stale │ 18,400,000 │ 2 hours │ +│ token_analytics │ ❌ Error │ - │ - │ +└─────────────────────────────────────────────────────────┘ +``` + +**Monitor Dashboard:** +```bash +$ amp-dbt monitor + +┌─────────────────────────────────────────────────────────┐ +│ Amp DBT Job Monitor (Refreshing every 5s) │ +├─────────────────────────────────────────────────────────┤ +│ Model │ Job ID │ Status │ Progress │ Block │ +├─────────────────────────────────────────────────────────┤ +│ stg_erc20_transfers│ 12345 │ Running │ 45% │ 18.5M │ +│ stg_erc721_transfers│ 12346 │ Completed │ 100% │ 18.5M │ +│ token_analytics │ 12347 │ Failed │ 0% │ - │ +└─────────────────────────────────────────────────────────┘ + +Press 'q' to quit, 'r' to refresh +``` + +--- + +## 7. Implementation Phases + +### Phase 1: Core Compilation (MVP) +**Goal:** Get basic query compilation working + +**Features:** +- ✅ Project initialization (`amp-dbt init`) +- ✅ Model loading and parsing +- ✅ Jinja templating support +- ✅ `ref()` resolution (external datasets only) +- ✅ Basic config parsing +- ✅ `amp-dbt compile` command + +**Deliverables:** +- Compiler engine +- Project structure +- Basic CLI + +### Phase 2: Dependency Resolution +**Goal:** Full dependency graph support + +**Features:** +- ✅ Internal model `ref()` resolution (CTE inlining) +- ✅ Dependency graph building +- ✅ Topological sort for execution order +- ✅ Circular dependency detection +- ✅ `amp-dbt run` command + +**Deliverables:** +- Dependency resolver +- Execution orchestrator +- Error handling + +### Phase 3: Monitoring & Tracking +**Goal:** Job monitoring and data tracking + +**Features:** +- ✅ Job status tracking +- ✅ State database (SQLite) +- ✅ Latest block/timestamp tracking +- ✅ Freshness monitoring +- ✅ `amp-dbt monitor` dashboard +- ✅ `amp-dbt status` command + +**Deliverables:** +- Monitoring system +- State tracker +- Dashboard UI + +### Phase 4: Advanced Features +**Goal:** Production-ready features + +**Features:** +- ✅ Macros system +- ✅ Testing framework +- ✅ Documentation generation +- ✅ Incremental query support +- ✅ Dataset registration automation +- ✅ Alerts/notifications +- ✅ Performance metrics + +**Deliverables:** +- Complete feature set +- Production-ready tool + +--- + +## 8. Technical Specifications + +### 8.1 Core Classes + +```python +# amp_dbt/core.py +class AmpDbtProject: + """Main project class""" + def compile_model(self, model_name: str) -> CompiledModel + def compile_all(self) -> Dict[str, CompiledModel] + def build_dag(self) -> Dict[str, List[str]] + def resolve_dependencies(self, model: str) -> Dict[str, str] + +class Compiler: + """Query compilation engine""" + def compile(self, sql: str, context: dict) -> str + def resolve_ref(self, ref_name: str) -> str + def apply_macros(self, sql: str) -> str + +class Executor: + """Query execution""" + def execute(self, model: CompiledModel) -> ExecutionResult + def register_dataset(self, model: CompiledModel) -> int + def deploy_dataset(self, namespace: str, name: str, version: str) -> int + +class JobMonitor: + """Job monitoring""" + def monitor_job(self, job_id: int, model_name: str) -> JobInfo + def monitor_all(self) -> List[JobInfo] + +class ModelTracker: + """Data progress tracking""" + def get_latest_block(self, model_name: str) -> Optional[int] + def update_progress(self, model_name: str, block_num: int) + def check_freshness(self, model_name: str) -> FreshnessResult +``` + +### 8.2 State Management + +**State Database:** +- SQLite database in `.amp-dbt/state.db` +- Tables: `model_state`, `job_history`, `compiled_cache` + +**Jobs Mapping:** +- JSON file: `.amp-dbt/jobs.json` +- Maps model names to active job IDs + +**Compiled Cache:** +- Directory: `.amp-dbt/compiled/` +- Stores compiled SQL for faster recompilation + +### 8.3 Error Handling + +**Error Types:** +- `CompilationError`: SQL compilation failed +- `DependencyError`: Missing or circular dependency +- `ExecutionError`: Query execution failed +- `JobError`: Job monitoring error +- `ConfigError`: Invalid configuration + +**Error Messages:** +- Clear, actionable error messages +- Suggestions for fixes +- Link to documentation + +--- + +## 9. Example Workflows + +### 9.1 New Project Setup + +```bash +# 1. Initialize project +amp-dbt init my-project +cd my-project + +# 2. Configure profiles.yml +# Edit profiles.yml with Amp server URLs + +# 3. Create first model +# models/staging/stg_erc20_transfers.sql +{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT * FROM {{ ref('eth') }}.logs WHERE ... + +# 4. Compile and test +amp-dbt compile +amp-dbt test + +# 5. Run +amp-dbt run --select stg_erc20_transfers +``` + +### 9.2 Daily Operations + +```bash +# Morning: Check status +amp-dbt status + +# Run incremental updates +amp-dbt run --incremental + +# Monitor progress +amp-dbt monitor --watch + +# Evening: Check for stale data +amp-dbt alerts +``` + +### 9.3 Debugging + +```bash +# See compiled SQL +amp-dbt compile --select stg_erc20_transfers --show-sql + +# Check dependencies +amp-dbt list --select stg_erc20_transfers + +# View job logs +amp-dbt jobs logs + +# Test specific model +amp-dbt test --select stg_erc20_transfers +``` + +--- + +## 10. Success Metrics + +### 10.1 Developer Experience +- ✅ Time to create new model: < 5 minutes +- ✅ Compilation time: < 1 second per model +- ✅ Clear error messages: 100% actionable +- ✅ Documentation coverage: 100% + +### 10.2 Reliability +- ✅ Compilation success rate: > 99% +- ✅ Job monitoring accuracy: 100% +- ✅ Freshness detection accuracy: > 95% + +### 10.3 Performance +- ✅ State database queries: < 10ms +- ✅ Dashboard refresh: < 1 second +- ✅ CLI command response: < 2 seconds + +--- + +## 11. Future Enhancements + +### 11.1 Advanced Features +- Query performance optimization hints +- Automatic query rewriting for optimization +- Multi-environment support (dev/staging/prod) +- CI/CD integration +- Slack/email notifications + +### 11.2 Integration +- dbt Cloud integration +- Airflow/Dagster integration +- Data quality frameworks (Great Expectations) +- BI tool connectors + +--- + +## 12. Appendix + +### 12.1 File Format Examples + +**Model File:** +```sql +-- models/staging/stg_erc20_transfers.sql +{{ config( + dependencies={'eth': '_/eth_firehose@1.0.0'}, + track_progress=true, + track_column='block_num', + description='Decoded ERC20 Transfer events' +) }} + +SELECT ... +FROM {{ ref('eth') }}.logs +WHERE ... +``` + +**Macro File:** +```sql +-- macros/evm_decode.sql +{% macro evm_decode(topic1, topic2, topic3, data, signature) %} + evm_decode({{ topic1 }}, {{ topic2 }}, {{ topic3 }}, {{ data }}, '{{ signature }}') +{% endmacro %} +``` + +**Test File:** +```sql +-- tests/assert_not_null.sql +SELECT COUNT(*) as null_count +FROM {{ ref('stg_erc20_transfers') }} +WHERE token_address IS NULL +-- Fails if null_count > 0 +``` + +### 12.2 Configuration Reference + +**dbt_project.yml:** +```yaml +name: 'my_project' +version: '1.0.0' + +models: + staging: + +dependencies: + eth: '_/eth_firehose@1.0.0' + +track_progress: true + +track_column: 'block_num' + + marts: + +register: true + +deploy: false + +monitoring: + alert_threshold_minutes: 30 + check_interval_seconds: 60 +``` + +--- + +## Summary + +Amp DBT provides: +1. **Query Composition**: DBT-like SQL organization with dependency management +2. **Monitoring**: Real-time job tracking and data freshness monitoring +3. **Developer Experience**: Simple CLI, clear errors, fast iteration +4. **Production Ready**: State tracking, error handling, comprehensive tooling + +The design balances familiarity (DBT patterns) with practicality (Amp constraints) to create a tool that makes working with Amp queries easy and reliable. + diff --git a/USER_WALKTHROUGH.md b/USER_WALKTHROUGH.md new file mode 100644 index 0000000..500fe2a --- /dev/null +++ b/USER_WALKTHROUGH.md @@ -0,0 +1,126 @@ +# Complete User Walkthrough: Amp DBT from Start to Finish + +## Step-by-Step Guide + +### Step 1: Initialize Project + +```bash +# Create directory +mkdir -p /tmp/my-amp-project + +# Initialize (from project root) +cd /Users/vivianpeng/Work/amp-python +uv run python -m amp.dbt.cli init --project-dir /tmp/my-amp-project +``` + +**Creates:** models/, macros/, tests/, docs/, dbt_project.yml + +--- + +### Step 2: Create Models + +#### Model 1: Staging (External Dependency) + +```bash +cd /tmp/my-amp-project +mkdir -p models/staging + +cat > models/staging/stg_erc20.sql << 'EOF' +{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT block_num, tx_hash FROM {{ ref('eth') }}.logs LIMIT 100 +EOF +``` + +#### Model 2: Intermediate (Internal Dependency) + +```bash +mkdir -p models/intermediate + +cat > models/intermediate/int_stats.sql << 'EOF' +SELECT COUNT(*) as count FROM {{ ref('stg_erc20') }} +EOF +``` + +#### Model 3: Marts (Internal Dependency) + +```bash +mkdir -p models/marts + +cat > models/marts/analytics.sql << 'EOF' +SELECT * FROM {{ ref('int_stats') }} ORDER BY count DESC LIMIT 10 +EOF +``` + +--- + +### Step 3: Test/Compile + +```bash +cd /Users/vivianpeng/Work/amp-python + +# List models +uv run python -m amp.dbt.cli list --project-dir /tmp/my-amp-project + +# Compile +uv run python -m amp.dbt.cli compile --project-dir /tmp/my-amp-project + +# See compiled SQL +uv run python -m amp.dbt.cli compile --show-sql --project-dir /tmp/my-amp-project +``` + +**Shows:** Dependencies, execution order, compiled SQL with CTEs + +--- + +### Step 4: Run + +```bash +# Dry run (see plan) +uv run python -m amp.dbt.cli run --dry-run --project-dir /tmp/my-amp-project + +# Actually run +uv run python -m amp.dbt.cli run --project-dir /tmp/my-amp-project +``` + +**Creates:** State tracking in .amp-dbt/state.db + +--- + +### Step 5: Monitor + +```bash +# Check status +uv run python -m amp.dbt.cli status --project-dir /tmp/my-amp-project + +# Monitor dashboard +uv run python -m amp.dbt.cli monitor --project-dir /tmp/my-amp-project + +# Auto-refresh +uv run python -m amp.dbt.cli monitor --watch --project-dir /tmp/my-amp-project +``` + +--- + +## Quick Command Reference + +All commands run from: `/Users/vivianpeng/Work/amp-python` + +```bash +# Initialize +uv run python -m amp.dbt.cli init --project-dir /tmp/my-amp-project + +# List +uv run python -m amp.dbt.cli list --project-dir /tmp/my-amp-project + +# Compile +uv run python -m amp.dbt.cli compile --project-dir /tmp/my-amp-project +uv run python -m amp.dbt.cli compile --show-sql --project-dir /tmp/my-amp-project + +# Run +uv run python -m amp.dbt.cli run --dry-run --project-dir /tmp/my-amp-project +uv run python -m amp.dbt.cli run --project-dir /tmp/my-amp-project + +# Monitor +uv run python -m amp.dbt.cli status --project-dir /tmp/my-amp-project +uv run python -m amp.dbt.cli monitor --project-dir /tmp/my-amp-project +``` diff --git a/apps/execute_query.py b/apps/execute_query.py index d73d459..f23ed90 100644 --- a/apps/execute_query.py +++ b/apps/execute_query.py @@ -2,7 +2,23 @@ from amp.client import Client -client = Client('grpc://127.0.0.1:80') +# Replace with your remote server URL +# Format: grpc://hostname:port or grpc+tls://hostname:port for TLS +SERVER_URL = "grpc://34.27.238.174:80" + +# Option 1: No authentication (if server doesn't require it) +# client = Client(url=SERVER_URL) + +# Option 2: Use explicit auth token +# client = Client(url=SERVER_URL, auth_token='your-token-here') + +# Option 3: Use environment variable AMP_AUTH_TOKEN +# export AMP_AUTH_TOKEN="your-token-here" +# client = Client(url=SERVER_URL) + +# Option 4: Use auto-refreshing auth from shared auth file (recommended) +# Uses ~/.amp/cache/amp_cli_auth (shared with TypeScript CLI) +client = Client(url=SERVER_URL, auth=True) df = client.get_sql('select * from eth_firehose.logs limit 1', read_all=True).to_pandas() print(df) diff --git a/models/example_model.sql b/models/example_model.sql new file mode 100644 index 0000000..76f35c3 --- /dev/null +++ b/models/example_model.sql @@ -0,0 +1,12 @@ +-- Example model +{{ config( + dependencies={'eth': '_/eth_firehose@1.0.0'}, + description='Example model showing how to use ref()' +) }} + +SELECT + block_num, + block_hash, + timestamp +FROM {{ ref('eth') }}.blocks +LIMIT 10 diff --git a/models/staging/stg_test.sql b/models/staging/stg_test.sql new file mode 100644 index 0000000..255546d --- /dev/null +++ b/models/staging/stg_test.sql @@ -0,0 +1,6 @@ +{{ config( + dependencies={'eth': '_/eth_firehose@1.0.0'}, + track_progress=True, + track_column='block_num' +) }} +SELECT block_num, tx_hash FROM {{ ref('eth') }}.logs LIMIT 10 diff --git a/pyproject.toml b/pyproject.toml index 258bc24..8ed64a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,15 @@ dependencies = [ # Admin API client support "httpx>=0.27.0", "pydantic>=2.0,<2.12", # Constrained for PyIceberg compatibility + # DBT support + "jinja2>=3.1.0", + "pyyaml>=6.0", + "rich>=13.0.0", ] +[project.scripts] +amp-dbt = "amp.dbt.cli:main" + [dependency-groups] dev = [ "altair>=5.5.0", # Data visualization for notebooks diff --git a/src/amp/dbt/README.md b/src/amp/dbt/README.md new file mode 100644 index 0000000..0580cb4 --- /dev/null +++ b/src/amp/dbt/README.md @@ -0,0 +1,113 @@ +# Amp DBT - Phase 1 Implementation + +## Overview + +Phase 1 implements the core compilation engine for Amp DBT, providing basic query composition with Jinja templating and external dataset reference resolution. + +## Features Implemented + +### ✅ Project Initialization +- `amp-dbt init` command creates a new DBT project structure +- Generates `dbt_project.yml`, directory structure, and example model + +### ✅ Model Loading and Parsing +- Loads SQL model files from `models/` directory +- Parses `{{ config() }}` blocks from model SQL +- Extracts configuration (dependencies, track_progress, etc.) + +### ✅ Jinja Templating Support +- Full Jinja2 template rendering +- Custom `ref()` function for dependency resolution +- Support for variables and macros (basic) + +### ✅ ref() Resolution (External Datasets Only) +- Resolves `{{ ref('eth') }}` to dataset references like `_/eth_firehose@1.0.0` +- Validates dependencies are defined in config +- Replaces ref() calls in compiled SQL + +### ✅ Basic Config Parsing +- Parses `{{ config() }}` blocks from model SQL +- Loads `dbt_project.yml` (optional) +- Supports dependencies, track_progress, register, deploy flags + +### ✅ CLI Commands +- `amp-dbt init` - Initialize new project +- `amp-dbt compile` - Compile models +- `amp-dbt list` - List all models + +## Usage + +### Initialize a Project + +```bash +amp-dbt init my-project +cd my-project +``` + +### Create a Model + +Create `models/staging/stg_erc20_transfers.sql`: + +```sql +{{ config( + dependencies={'eth': '_/eth_firehose@1.0.0'}, + track_progress=true, + track_column='block_num', + description='Decoded ERC20 Transfer events' +) }} + +SELECT + l.block_num, + l.block_hash, + l.timestamp, + l.tx_hash, + l.address as token_address +FROM {{ ref('eth') }}.logs l +WHERE + l.topic0 = evm_topic('Transfer(address indexed from, address indexed to, uint256 value)') + AND l.topic3 IS NULL +``` + +### Compile Models + +```bash +# Compile all models +amp-dbt compile + +# Compile specific models +amp-dbt compile --select stg_* + +# Show compiled SQL +amp-dbt compile --show-sql +``` + +## Project Structure + +``` +my-project/ +├── dbt_project.yml # Project configuration +├── models/ # SQL model files +│ ├── staging/ +│ │ └── stg_erc20_transfers.sql +│ └── marts/ +│ └── token_analytics.sql +├── macros/ # Reusable SQL macros (future) +├── tests/ # Data quality tests (future) +└── docs/ # Documentation (future) +``` + +## Limitations (Phase 1) + +- ❌ Internal model references (model-to-model dependencies) not supported +- ❌ Macros system not fully implemented +- ❌ No execution (`amp-dbt run`) - compilation only +- ❌ No monitoring or tracking +- ❌ No testing framework + +## Next Steps (Phase 2) + +- Internal model dependency resolution (CTE inlining) +- Dependency graph building +- Topological sort for execution order +- `amp-dbt run` command for execution + diff --git a/src/amp/dbt/__init__.py b/src/amp/dbt/__init__.py new file mode 100644 index 0000000..0c56f0c --- /dev/null +++ b/src/amp/dbt/__init__.py @@ -0,0 +1,25 @@ +"""Amp DBT - Query composition and orchestration framework for Amp.""" + +from amp.dbt.project import AmpDbtProject +from amp.dbt.compiler import Compiler +from amp.dbt.models import CompiledModel, ModelConfig +from amp.dbt.exceptions import ( + AmpDbtError, + CompilationError, + ConfigError, + DependencyError, + ProjectNotFoundError, +) + +__all__ = [ + 'AmpDbtProject', + 'Compiler', + 'CompiledModel', + 'ModelConfig', + 'AmpDbtError', + 'CompilationError', + 'ConfigError', + 'DependencyError', + 'ProjectNotFoundError', +] + diff --git a/src/amp/dbt/cli.py b/src/amp/dbt/cli.py new file mode 100644 index 0000000..fcc9930 --- /dev/null +++ b/src/amp/dbt/cli.py @@ -0,0 +1,461 @@ +"""CLI for Amp DBT.""" + +import sys +from pathlib import Path +from typing import Optional + +import typer +from rich.console import Console +from rich.table import Table + +from amp.dbt.exceptions import AmpDbtError, DependencyError, ProjectNotFoundError +from amp.dbt.monitor import JobMonitor +from amp.dbt.project import AmpDbtProject +from amp.dbt.tracker import FreshnessMonitor, ModelTracker + +app = typer.Typer(name='amp-dbt', help='Amp DBT - Query composition and orchestration framework') +console = Console() + + +@app.command() +def init( + project_name: Optional[str] = typer.Argument(None, help='Project name (default: current directory name)'), + project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'), +): + """Initialize a new Amp DBT project.""" + if project_dir is None: + project_dir = Path.cwd() + else: + project_dir = Path(project_dir).resolve() + + if project_name is None: + project_name = project_dir.name + + console.print(f'[bold green]Initializing Amp DBT project:[/bold green] {project_name}') + + # Create directory structure + models_dir = project_dir / 'models' + macros_dir = project_dir / 'macros' + tests_dir = project_dir / 'tests' + docs_dir = project_dir / 'docs' + + # Create parent directories if they don't exist + project_dir.mkdir(parents=True, exist_ok=True) + + models_dir.mkdir(parents=True, exist_ok=True) + macros_dir.mkdir(parents=True, exist_ok=True) + tests_dir.mkdir(parents=True, exist_ok=True) + docs_dir.mkdir(parents=True, exist_ok=True) + + # Create dbt_project.yml + project_config = { + 'name': project_name, + 'version': '1.0.0', + 'models': {}, + 'monitoring': { + 'alert_threshold_minutes': 30, + 'check_interval_seconds': 60, + }, + } + + import yaml + + config_path = project_dir / 'dbt_project.yml' + if not config_path.exists(): + with open(config_path, 'w') as f: + yaml.dump(project_config, f, default_flow_style=False, sort_keys=False) + console.print(f' [green]✓[/green] Created {config_path}') + else: + console.print(f' [yellow]⚠[/yellow] {config_path} already exists, skipping') + + # Create .gitignore + gitignore_path = project_dir / '.gitignore' + gitignore_content = """# Amp DBT +.amp-dbt/ +""" + if not gitignore_path.exists(): + gitignore_path.write_text(gitignore_content) + console.print(f' [green]✓[/green] Created {gitignore_path}') + + # Create example model + example_model_path = models_dir / 'example_model.sql' + if not example_model_path.exists(): + example_model = '''-- Example model +{{ config( + dependencies={'eth': '_/eth_firehose@1.0.0'}, + description='Example model showing how to use ref()' +) }} + +SELECT + block_num, + block_hash, + timestamp +FROM {{ ref('eth') }}.blocks +LIMIT 10 +''' + example_model_path.write_text(example_model) + console.print(f' [green]✓[/green] Created example model: {example_model_path}') + + console.print(f'\n[bold green]✓ Project initialized successfully![/bold green]') + console.print(f'\nNext steps:') + console.print(f' 1. Edit models in {models_dir}/') + console.print(f' 2. Run [bold]amp-dbt compile[/bold] to compile models') + console.print(f' 3. Run [bold]amp-dbt run[/bold] to execute models') + + +@app.command() +def compile( + select: Optional[str] = typer.Option(None, '--select', '-s', help='Glob pattern to select models (e.g., stg_*)'), + show_sql: bool = typer.Option(False, '--show-sql', help='Show compiled SQL'), + project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'), +): + """Compile models.""" + try: + project = AmpDbtProject(project_dir) + + console.print('[bold]Compiling models...[/bold]\n') + + compiled_models = project.compile_all(select) + + if not compiled_models: + console.print('[yellow]No models found to compile[/yellow]') + return + + # Create results table + table = Table(title='Compilation Results') + table.add_column('Model', style='cyan') + table.add_column('Status', style='green') + table.add_column('Dependencies', style='yellow') + + for name, compiled in compiled_models.items(): + # Separate internal and external dependencies + internal_deps = [k for k, v in compiled.dependencies.items() if k == v] + external_deps = [f'{k}: {v}' for k, v in compiled.dependencies.items() if k != v] + + deps_parts = [] + if internal_deps: + deps_parts.append(f'Internal: {", ".join(internal_deps)}') + if external_deps: + deps_parts.append(f'External: {", ".join(external_deps)}') + + deps_str = ' | '.join(deps_parts) if deps_parts else 'None' + table.add_row(name, '✓ Compiled', deps_str) + + console.print(table) + + # Show execution order if there are internal dependencies + has_internal = any( + any(k == v for k, v in compiled.dependencies.items()) + for compiled in compiled_models.values() + ) + if has_internal: + try: + execution_order = project.get_execution_order(select) + console.print(f'\n[bold]Execution order:[/bold] {" → ".join(execution_order)}') + except Exception: + pass # Skip if we can't determine order + + if show_sql: + console.print('\n[bold]Compiled SQL:[/bold]\n') + for name, compiled in compiled_models.items(): + console.print(f'[bold cyan]{name}:[/bold cyan]') + console.print(f'[dim]{compiled.sql[:500]}{"..." if len(compiled.sql) > 500 else ""}[/dim]\n') + + except ProjectNotFoundError as e: + console.print(f'[bold red]Error:[/bold red] {e}') + console.print('Run [bold]amp-dbt init[/bold] to initialize a project') + sys.exit(1) + except AmpDbtError as e: + console.print(f'[bold red]Error:[/bold red] {e}') + sys.exit(1) + except Exception as e: + console.print(f'[bold red]Unexpected error:[/bold red] {e}') + sys.exit(1) + + +@app.command() +def list( + select: Optional[str] = typer.Option(None, '--select', '-s', help='Glob pattern to filter models'), + project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'), +): + """List all models in the project.""" + try: + project = AmpDbtProject(project_dir) + + models = project.find_models(select) + + if not models: + console.print('[yellow]No models found[/yellow]') + return + + table = Table(title='Models') + table.add_column('Model', style='cyan') + table.add_column('Path', style='dim') + + for model_path in models: + relative_path = model_path.relative_to(project.project_root) + table.add_row(model_path.stem, str(relative_path)) + + console.print(table) + + except ProjectNotFoundError as e: + console.print(f'[bold red]Error:[/bold red] {e}') + console.print('Run [bold]amp-dbt init[/bold] to initialize a project') + sys.exit(1) + except Exception as e: + console.print(f'[bold red]Error:[/bold red] {e}') + sys.exit(1) + + +@app.command() +def run( + select: Optional[str] = typer.Option(None, '--select', '-s', help='Glob pattern to select models (e.g., stg_*)'), + project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'), + dry_run: bool = typer.Option(False, '--dry-run', help='Show what would be executed without running'), +): + """Execute models in dependency order.""" + try: + project = AmpDbtProject(project_dir) + + console.print('[bold]Compiling models...[/bold]\n') + + # Compile all models (with dependency resolution) + compiled_models = project.compile_all(select) + + if not compiled_models: + console.print('[yellow]No models found to execute[/yellow]') + return + + # Get execution order + execution_order = project.get_execution_order(select) + + if dry_run: + console.print('[bold yellow]Dry run mode - showing execution plan[/bold yellow]\n') + table = Table(title='Execution Plan') + table.add_column('Order', style='cyan') + table.add_column('Model', style='green') + table.add_column('Dependencies', style='yellow') + + for idx, model_name in enumerate(execution_order, 1): + compiled = compiled_models[model_name] + deps = [d for d in compiled.dependencies.keys() if d in compiled_models] + deps_str = ', '.join(deps) if deps else 'None' + table.add_row(str(idx), model_name, deps_str) + + console.print(table) + console.print(f'\n[dim]Would execute {len(execution_order)} models[/dim]') + return + + # Execute models in order + console.print(f'[bold]Executing {len(execution_order)} models...[/bold]\n') + + for idx, model_name in enumerate(execution_order, 1): + compiled = compiled_models[model_name] + console.print(f'[{idx}/{len(execution_order)}] [cyan]{model_name}[/cyan]') + + # In Phase 2, we just show the compiled SQL + # Actual execution will be added in later phases + console.print(f' [dim]Compiled SQL ({len(compiled.sql)} chars)[/dim]') + console.print(f' [green]✓[/green] Ready to execute\n') + + console.print('[bold green]✓ All models compiled successfully![/bold green]') + console.print('\n[dim]Note: Actual query execution will be implemented in Phase 3[/dim]') + + # Update state tracking (Phase 3) + tracker = ModelTracker(project.project_root) + for model_name in execution_order: + # Mark as ready (actual execution would update with real job_id and blocks) + tracker.update_progress(model_name, status='ready') + + except DependencyError as e: + console.print(f'[bold red]Dependency Error:[/bold red] {e}') + sys.exit(1) + except ProjectNotFoundError as e: + console.print(f'[bold red]Error:[/bold red] {e}') + console.print('Run [bold]amp-dbt init[/bold] to initialize a project') + sys.exit(1) + except AmpDbtError as e: + console.print(f'[bold red]Error:[/bold red] {e}') + sys.exit(1) + except Exception as e: + console.print(f'[bold red]Unexpected error:[/bold red] {e}') + import traceback + traceback.print_exc() + sys.exit(1) + + +@app.command() +def status( + project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'), + all: bool = typer.Option(False, '--all', help='Show status for all models'), +): + """Check data freshness status.""" + try: + project = AmpDbtProject(project_dir) + tracker = ModelTracker(project.project_root) + freshness_monitor = FreshnessMonitor(tracker) + + if all: + # Check all models + results = freshness_monitor.check_all_freshness() + else: + # Check models that have state + states = tracker.get_all_states() + results = {} + for model_name in states.keys(): + results[model_name] = freshness_monitor.check_freshness(model_name) + + if not results: + console.print('[yellow]No models with tracked state found[/yellow]') + console.print('Run [bold]amp-dbt run[/bold] to start tracking models') + return + + # Create status table + table = Table(title='Data Freshness Status') + table.add_column('Model', style='cyan') + table.add_column('Status', style='green') + table.add_column('Latest Block', style='yellow') + table.add_column('Age', style='dim') + + for model_name, result in results.items(): + if result.stale: + status_icon = '⚠️ Stale' + status_style = 'red' + elif result.latest_timestamp: + status_icon = '✅ Fresh' + status_style = 'green' + else: + status_icon = '❌ Error' + status_style = 'red' + + block_str = str(result.latest_block) if result.latest_block else '-' + age_str = str(result.age).split('.')[0] if result.age else '-' + + table.add_row(model_name, f'[{status_style}]{status_icon}[/{status_style}]', block_str, age_str) + + console.print(table) + + except ProjectNotFoundError as e: + console.print(f'[bold red]Error:[/bold red] {e}') + console.print('Run [bold]amp-dbt init[/bold] to initialize a project') + sys.exit(1) + except Exception as e: + console.print(f'[bold red]Error:[/bold red] {e}') + import traceback + traceback.print_exc() + sys.exit(1) + + +@app.command() +def monitor( + project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'), + watch: bool = typer.Option(False, '--watch', help='Auto-refresh dashboard'), + interval: int = typer.Option(5, '--interval', help='Refresh interval in seconds (default: 5)'), +): + """Interactive job monitoring dashboard.""" + try: + project = AmpDbtProject(project_dir) + tracker = ModelTracker(project.project_root) + + # Load job mappings from .amp-dbt/jobs.json if it exists + jobs_file = project.project_root / '.amp-dbt' / 'jobs.json' + model_job_map = {} + + if jobs_file.exists(): + import json + + with open(jobs_file, 'r') as f: + model_job_map = json.load(f) + + console.print('[bold]Job Monitor[/bold]\n') + + if watch: + console.print(f'[dim]Refreshing every {interval} seconds. Press Ctrl+C to stop.[/dim]\n') + + import time + + try: + while True: + # Clear screen (works on most terminals) + if watch: + console.print('\033[2J\033[H', end='') + + # Get all model states + states = tracker.get_all_states() + + if not states: + console.print('[yellow]No tracked models found[/yellow]') + console.print('Run [bold]amp-dbt run[/bold] to start tracking models') + break + + # Create monitor table + table = Table(title='Amp DBT Job Monitor') + table.add_column('Model', style='cyan') + table.add_column('Job ID', style='yellow') + table.add_column('Status', style='green') + table.add_column('Latest Block', style='dim') + table.add_column('Last Updated', style='dim') + + for model_name, state in states.items(): + job_id_str = str(state.job_id) if state.job_id else '-' + block_str = str(state.latest_block) if state.latest_block else '-' + updated_str = ( + state.last_updated.strftime('%Y-%m-%d %H:%M:%S') + if state.last_updated + else '-' + ) + + # Status styling + if state.status == 'fresh': + status_icon = '✅ Fresh' + status_style = 'green' + elif state.status == 'stale': + status_icon = '⚠️ Stale' + status_style = 'yellow' + elif state.status == 'error': + status_icon = '❌ Error' + status_style = 'red' + elif state.status == 'running': + status_icon = '🔄 Running' + status_style = 'cyan' + else: + status_icon = '❓ Unknown' + status_style = 'dim' + + table.add_row( + model_name, + job_id_str, + f'[{status_style}]{status_icon}[/{status_style}]', + block_str, + updated_str, + ) + + console.print(table) + + if not watch: + break + + time.sleep(interval) + + except KeyboardInterrupt: + console.print('\n[yellow]Monitoring stopped[/yellow]') + + except ProjectNotFoundError as e: + console.print(f'[bold red]Error:[/bold red] {e}') + console.print('Run [bold]amp-dbt init[/bold] to initialize a project') + sys.exit(1) + except Exception as e: + console.print(f'[bold red]Error:[/bold red] {e}') + import traceback + traceback.print_exc() + sys.exit(1) + + +def main(): + """Entry point for CLI.""" + app() + + +if __name__ == '__main__': + main() + diff --git a/src/amp/dbt/compiler.py b/src/amp/dbt/compiler.py new file mode 100644 index 0000000..122d0a6 --- /dev/null +++ b/src/amp/dbt/compiler.py @@ -0,0 +1,502 @@ +"""SQL compilation engine for Amp DBT.""" + +import re +from pathlib import Path +from typing import Dict, Optional, Set + +from jinja2 import Environment, FileSystemLoader, Template, select_autoescape + +from amp.dbt.exceptions import CompilationError, DependencyError +from amp.dbt.models import CompiledModel, ModelConfig + + +class Compiler: + """Query compilation engine with Jinja templating support.""" + + def __init__(self, project_root: Path, macros_dir: Optional[Path] = None): + """Initialize compiler. + + Args: + project_root: Root directory of the DBT project + macros_dir: Optional directory containing macro files + """ + self.project_root = project_root + self.macros_dir = macros_dir or project_root / 'macros' + + # Set up Jinja environment + self.env = Environment( + loader=FileSystemLoader([str(project_root), str(self.macros_dir)]), + autoescape=select_autoescape(['html', 'xml']), + trim_blocks=True, + lstrip_blocks=True, + ) + + # Add custom functions to Jinja context + self.env.globals['ref'] = self._ref_function + + def compile( + self, + sql: str, + model_name: str, + config: ModelConfig, + available_models: Optional[Set[str]] = None, + ) -> CompiledModel: + """Compile a model SQL with Jinja templating. + + Args: + sql: Raw SQL with Jinja templates + model_name: Name of the model + config: Model configuration + available_models: Optional set of available internal model names + + Returns: + CompiledModel with compiled SQL + + Raises: + CompilationError: If compilation fails + """ + try: + # Parse config block if present + template = self.env.from_string(sql) + + # Build context for Jinja rendering + context = { + 'config': config, + 'model_name': model_name, + } + + # Render template + compiled_sql = template.render(**context) + + # Resolve dependencies (now supports both internal and external) + dependencies = self._resolve_dependencies(compiled_sql, config, available_models) + + # Replace ref() placeholders with actual references + final_sql = self._replace_refs(compiled_sql, dependencies) + + return CompiledModel( + name=model_name, + sql=final_sql, + config=config, + dependencies=dependencies, + raw_sql=sql, + ) + + except Exception as e: + raise CompilationError(f'Failed to compile model {model_name}: {e}') from e + + def _ref_function(self, name: str, *args, **kwargs) -> str: + """Jinja function for resolving ref() calls. + + Args: + name: Reference name (e.g., 'eth' or model name) + *args: Additional arguments (not used for external datasets) + **kwargs: Additional keyword arguments (not used) + + Returns: + SQL fragment for the reference (dataset reference format) + """ + # This will be resolved later based on config.dependencies + # For now, return a placeholder that we'll replace + # The placeholder includes the ref name so we can resolve it later + return f'__REF__{name}__' + + def _resolve_dependencies( + self, sql: str, config: ModelConfig, available_models: Optional[Set[str]] = None + ) -> Dict[str, str]: + """Resolve ref() calls in compiled SQL. + + Args: + sql: Compiled SQL with ref() placeholders + config: Model configuration with dependencies + available_models: Optional set of available internal model names + + Returns: + Dictionary mapping ref names to either dataset references (external) or model names (internal) + """ + dependencies = {} + available_models = available_models or set() + + # Find all ref() calls in the SQL + ref_pattern = r'__REF__(\w+)__' + matches = re.findall(ref_pattern, sql) + + for ref_name in matches: + if ref_name in available_models: + # Internal model reference - store as model name + dependencies[ref_name] = ref_name + elif ref_name in config.dependencies: + # External dataset reference + dataset_ref = config.dependencies[ref_name] + dependencies[ref_name] = dataset_ref + else: + # Unknown reference + raise DependencyError( + f'Unknown reference "{ref_name}". ' + f'It must be either an internal model or defined in config.dependencies.' + ) + + return dependencies + + def _replace_refs( + self, sql: str, dependencies: Dict[str, str], model_sql_map: Optional[Dict[str, str]] = None + ) -> str: + """Replace ref() placeholders with actual references. + + For external dependencies: replaces with dataset reference + For internal dependencies: replaces with CTE name (if model_sql_map provided) or model name (first pass) + + Args: + sql: SQL with ref() placeholders + dependencies: Dictionary mapping ref names to references (dataset refs or model names) + model_sql_map: Optional dictionary mapping model names to their compiled SQL (for CTE inlining) + + Returns: + SQL with ref() calls replaced + """ + result = sql + model_sql_map = model_sql_map or {} + + for ref_name, ref_value in dependencies.items(): + placeholder = f'__REF__{ref_name}__' + + if ref_name == ref_value: + # Internal model reference + if ref_name in model_sql_map: + # Second pass: CTE will be inlined, just replace with model name + result = result.replace(placeholder, ref_name) + else: + # First pass: just replace with model name (CTE will be added later) + # This allows dependency discovery without requiring CTE SQL + result = result.replace(placeholder, ref_name) + else: + # External dataset reference + # Use alias format (e.g., 'arb_firehose') instead of full reference + # The full reference will be tracked in dependencies for with_dependency() + result = result.replace(placeholder, ref_name) + + return result + + def compile_with_ctes( + self, + sql: str, + model_name: str, + config: ModelConfig, + internal_deps: Dict[str, str], + available_models: Optional[Set[str]] = None, + ) -> CompiledModel: + """Compile a model with CTE inlining for internal dependencies. + + Args: + sql: Raw SQL with Jinja templates + model_name: Name of the model + config: Model configuration + internal_deps: Dictionary mapping internal model names to their compiled SQL + available_models: Optional set of available internal model names + + Returns: + CompiledModel with compiled SQL including CTEs + + Raises: + CompilationError: If compilation fails + """ + try: + template = self.env.from_string(sql) + context = {'config': config, 'model_name': model_name} + compiled_sql = template.render(**context) + + # Resolve dependencies + dependencies = self._resolve_dependencies(compiled_sql, config, available_models) + + # Replace ref() placeholders first (before extracting CTEs) + final_sql = self._replace_refs(compiled_sql, dependencies, internal_deps) + + # Build CTE section for internal dependencies + # Flatten all CTEs from dependencies into a single WITH clause + # Also check if final_sql already has a WITH clause and merge it + all_ctes = {} # Map of CTE name -> SQL + + if internal_deps: + for dep_name, dep_sql in internal_deps.items(): + # Extract CTEs and the final SELECT from dependency SQL + ctes_from_dep, select_part = self._extract_ctes_and_select(dep_sql) + # Add any CTEs from this dependency + all_ctes.update(ctes_from_dep) + # Add this dependency as a CTE using just its SELECT part + all_ctes[dep_name] = select_part + + # Check if final_sql already has a WITH clause (from the model's own SQL) + final_sql_upper = final_sql.upper().strip() + if final_sql_upper.startswith('WITH '): + # Extract CTEs from the model's own SQL + model_ctes, model_select = self._extract_ctes_and_select(final_sql) + # Merge with dependency CTEs + all_ctes.update(model_ctes) + # Use the model's SELECT as the final SQL + final_sql = model_select + + # Build final CTE section if we have any CTEs + if all_ctes: + cte_parts = [] + for cte_name, cte_sql in all_ctes.items(): + cte_parts.append(f'{cte_name} AS (\n{cte_sql}\n)') + cte_section = 'WITH ' + ',\n'.join(cte_parts) + '\n' + final_sql = cte_section + final_sql + + return CompiledModel( + name=model_name, + sql=final_sql, + config=config, + dependencies=dependencies, + raw_sql=sql, + ) + + except Exception as e: + raise CompilationError(f'Failed to compile model {model_name} with CTEs: {e}') from e + + def _extract_ctes_and_select(self, sql: str) -> tuple[Dict[str, str], str]: + """Extract CTEs and final SELECT from SQL that may contain WITH clauses. + + Returns: + Tuple of (dict of CTE name -> SQL, final SELECT statement) + """ + import re + + sql_upper = sql.upper().strip() + + # If no WITH clause, return empty CTEs and the SQL as SELECT + if 'WITH ' not in sql_upper: + return {}, sql.strip() + + ctes = {} + + # Use regex to find all CTE definitions: name AS (content) + # Pattern: word AS ( ... ) optionally followed by comma + # We need to handle nested parentheses correctly + with_match = re.search(r'^WITH\s+', sql_upper) + if not with_match: + return {}, sql.strip() + + after_with_start = with_match.end() + after_with = sql[after_with_start:] + + # Find where WITH clause ends (the SELECT that's not inside parentheses) + # Count parentheses to find the SELECT after all CTEs + paren_count = 0 + in_string = False + string_char = None + i = 0 + + while i < len(after_with): + char = after_with[i] + prev_char = after_with[i-1] if i > 0 else '' + + # Track string literals + if char in ("'", '"') and prev_char != '\\': + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None + i += 1 + continue + + if in_string: + i += 1 + continue + + # Track parentheses + if char == '(': + paren_count += 1 + elif char == ')': + paren_count -= 1 + # When we're back to 0 or negative, check if SELECT follows + if paren_count <= 0: + remaining = after_with[i+1:].strip() + if remaining.upper().startswith('SELECT'): + # Found the final SELECT - extract CTEs from before this point + with_clause_part = after_with[:i+1] + ctes = self._parse_ctes_from_with('WITH ' + with_clause_part) + select_part = remaining + return ctes, select_part + + i += 1 + + # Fallback: use regex-based extraction + return self._simple_extract_ctes(sql) + + def _simple_extract_ctes(self, sql: str) -> tuple[Dict[str, str], str]: + """Simple fallback to extract CTEs using regex.""" + import re + + # Find WHERE WITH ends and SELECT begins + # Pattern: WITH ... ) SELECT (the SELECT after closing paren of last CTE) + # We want to find the SELECT that comes after the WITH clause ends + match = re.search(r'\)\s+(SELECT\s+)', sql, re.IGNORECASE | re.DOTALL) + if match: + # Find the position of SELECT (group 1 start) + select_start = match.start(1) + return {}, sql[select_start:].strip() + + # Alternative: look for SELECT after WITH + match2 = re.search(r'WITH\s+.*?(SELECT\s+)', sql, re.IGNORECASE | re.DOTALL) + if match2: + select_start = match2.start(1) + return {}, sql[select_start:].strip() + + return {}, sql.strip() + + def _parse_ctes_from_with(self, with_clause: str) -> Dict[str, str]: + """Parse CTE definitions from a WITH clause. + + Args: + with_clause: SQL starting with WITH and ending with closing paren + + Returns: + Dictionary mapping CTE names to their SQL + """ + import re + ctes = {} + + # Remove the leading "WITH " + content = with_clause[5:].strip() if with_clause.upper().startswith('WITH ') else with_clause + + # Use regex to find CTE patterns: name AS (content) + # Handle nested parentheses by matching balanced parens + # Pattern: (\w+) AS \( ... \) + i = 0 + while i < len(content): + # Find next " AS (" + as_match = re.search(r'\s+AS\s+\(', content[i:], re.IGNORECASE) + if not as_match: + break + + as_pos = i + as_match.start() + + # Find CTE name before " AS (" + name_end = as_pos + name_start = name_end + while name_start > 0 and (content[name_start-1].isalnum() or content[name_start-1] == '_'): + name_start -= 1 + cte_name = content[name_start:name_end].strip() + + if not cte_name: + i = as_pos + as_match.end() + continue + + # Find matching closing paren + # as_match positions are relative to content[i:], so add i to get absolute position + paren_start = i + as_match.end() - 1 # Position of ( + paren_count = 1 + in_string = False + string_char = None + j = paren_start + 1 + + while j < len(content) and paren_count > 0: + char = content[j] + prev_char = content[j-1] if j > 0 else '' + + # Track string literals + if char in ("'", '"') and prev_char != '\\': + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None + j += 1 + continue + + if in_string: + j += 1 + continue + + if char == '(': + paren_count += 1 + elif char == ')': + paren_count -= 1 + + j += 1 + + if paren_count == 0: + # Found matching closing paren + cte_sql = content[paren_start + 1:j - 1].strip() + ctes[cte_name] = cte_sql + i = j # Continue after this CTE + else: + break + + return ctes + + def _extract_select_from_sql(self, sql: str) -> str: + """Extract the final SELECT statement from SQL that may contain CTEs. + + If the SQL has WITH clauses, returns just the SELECT part after all WITHs. + Otherwise returns the SQL as-is. + + Args: + sql: SQL that may contain WITH clauses + + Returns: + SQL with only the final SELECT statement (CTEs removed) + """ + sql_upper = sql.upper().strip() + + # If no WITH clause, return as-is + if 'WITH ' not in sql_upper: + return sql.strip() + + # Find the final SELECT that comes after all CTE definitions + # Pattern: WITH cte1 AS (...), cte2 AS (...) SELECT ... + # We need to find the SELECT that's not inside a CTE definition + + # Simple approach: find the last closing paren that's followed by SELECT + # This indicates the end of the CTE list + paren_count = 0 + in_string = False + string_char = None + + for i in range(len(sql)): + char = sql[i] + prev_char = sql[i-1] if i > 0 else '' + + # Track string literals + if char in ("'", '"') and prev_char != '\\': + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None + continue + + if in_string: + continue + + # Track parentheses + if char == '(': + paren_count += 1 + elif char == ')': + paren_count -= 1 + # Check if this closes the last CTE (paren_count back to 0 or negative) + if paren_count <= 0: + # Look ahead for SELECT (skip whitespace and newlines) + remaining = sql[i+1:].strip() + if remaining.upper().startswith('SELECT'): + return remaining + # Also check if there's a comma then more CTEs, then SELECT + if remaining.startswith(',') or remaining.startswith('\n'): + # Continue searching + continue + + # Fallback: if parsing failed, try regex to find SELECT after WITH + import re + # Match: WITH ... SELECT (capture the SELECT part) + match = re.search(r'WITH\s+.*?\)\s+SELECT\s+', sql, re.IGNORECASE | re.DOTALL) + if match: + return sql[match.end() - len('SELECT '):].strip() + + # Last resort: return as-is (better than failing) + return sql.strip() + diff --git a/src/amp/dbt/config.py b/src/amp/dbt/config.py new file mode 100644 index 0000000..25e9136 --- /dev/null +++ b/src/amp/dbt/config.py @@ -0,0 +1,126 @@ +"""Configuration parsing for Amp DBT.""" + +import re +from pathlib import Path +from typing import Dict, Optional + +import yaml + +from amp.dbt.exceptions import ConfigError +from amp.dbt.models import ModelConfig + + +def parse_config_block(sql: str) -> tuple[str, ModelConfig]: + """Parse config block from model SQL. + + Args: + sql: Model SQL with optional config block + + Returns: + Tuple of (sql_without_config, ModelConfig) + + Raises: + ConfigError: If config block is invalid + """ + # Pattern to match {{ config(...) }} block + # Need to handle nested braces in the config dict + config_pattern = r'\{\{\s*config\s*\((.*?)\)\s*\}\}' + + match = re.search(config_pattern, sql, re.DOTALL) + if not match: + # No config block, return default config + return sql, ModelConfig() + + config_str = match.group(1) + sql_without_config = sql[: match.start()] + sql[match.end() :] + + # Parse config as Python dict-like syntax + # Simple parser for key=value pairs + config = ModelConfig() + + # Extract dependencies + deps_match = re.search(r"dependencies\s*=\s*\{([^}]+)\}", config_str, re.DOTALL) + if deps_match: + deps_str = deps_match.group(1) + dependencies = {} + # Parse key: value pairs + for dep_match in re.finditer(r"['\"]?(\w+)['\"]?\s*:\s*['\"]?([^'\"]+)['\"]?", deps_str): + key = dep_match.group(1) + value = dep_match.group(2) + dependencies[key] = value + config.dependencies = dependencies + + # Extract boolean flags + for flag in ['track_progress', 'register', 'deploy']: + if re.search(rf'{flag}\s*=\s*True', config_str, re.IGNORECASE): + setattr(config, flag, True) + + # Extract string values + track_col_match = re.search(r"track_column\s*=\s*['\"](\w+)['\"]", config_str) + if track_col_match: + config.track_column = track_col_match.group(1) + + desc_match = re.search(r"description\s*=\s*['\"]([^'\"]+)['\"]", config_str) + if desc_match: + config.description = desc_match.group(1) + + incr_match = re.search(r"incremental_strategy\s*=\s*['\"](\w+)['\"]", config_str) + if incr_match: + config.incremental_strategy = incr_match.group(1) + + return sql_without_config, config + + +def load_project_config(project_root: Path) -> Dict: + """Load dbt_project.yml configuration. + + Args: + project_root: Root directory of the DBT project + + Returns: + Dictionary with project configuration + + Raises: + ConfigError: If config file is invalid or missing + """ + config_path = project_root / 'dbt_project.yml' + if not config_path.exists(): + return {} + + try: + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + return config or {} + except yaml.YAMLError as e: + raise ConfigError(f'Invalid YAML in dbt_project.yml: {e}') from e + except Exception as e: + raise ConfigError(f'Failed to load dbt_project.yml: {e}') from e + + +def load_profiles(profiles_path: Optional[Path] = None) -> Dict: + """Load profiles.yml configuration. + + Args: + profiles_path: Optional path to profiles.yml (default: ~/.amp-dbt/profiles.yml) + + Returns: + Dictionary with profiles configuration + + Raises: + ConfigError: If profiles file is invalid + """ + if profiles_path is None: + profiles_path = Path.home() / '.amp-dbt' / 'profiles.yml' + + if not profiles_path.exists(): + return {} + + try: + with open(profiles_path, 'r') as f: + profiles = yaml.safe_load(f) + return profiles or {} + except yaml.YAMLError as e: + raise ConfigError(f'Invalid YAML in profiles.yml: {e}') from e + except Exception as e: + raise ConfigError(f'Failed to load profiles.yml: {e}') from e + diff --git a/src/amp/dbt/dependencies.py b/src/amp/dbt/dependencies.py new file mode 100644 index 0000000..b66c156 --- /dev/null +++ b/src/amp/dbt/dependencies.py @@ -0,0 +1,215 @@ +"""Dependency resolution and graph building for Amp DBT.""" + +from collections import defaultdict, deque +from typing import Dict, List, Optional, Set, Tuple + +from amp.dbt.exceptions import DependencyError +from amp.dbt.models import CompiledModel + + +class DependencyGraph: + """Represents the dependency graph of models.""" + + def __init__(self): + """Initialize empty dependency graph.""" + # Maps model name -> set of dependencies (model names) + self._dependencies: Dict[str, Set[str]] = defaultdict(set) + # Maps model name -> set of dependents (models that depend on this one) + self._dependents: Dict[str, Set[str]] = defaultdict(set) + # All model names + self._models: Set[str] = set() + + def add_model(self, model_name: str, dependencies: Set[str]): + """Add a model and its dependencies to the graph. + + Args: + model_name: Name of the model + dependencies: Set of model names this model depends on + """ + self._models.add(model_name) + self._dependencies[model_name] = dependencies.copy() + + # Update reverse dependencies + for dep in dependencies: + self._dependents[dep].add(model_name) + + def get_dependencies(self, model_name: str) -> Set[str]: + """Get dependencies of a model. + + Args: + model_name: Name of the model + + Returns: + Set of model names this model depends on + """ + return self._dependencies.get(model_name, set()).copy() + + def get_dependents(self, model_name: str) -> Set[str]: + """Get models that depend on this model. + + Args: + model_name: Name of the model + + Returns: + Set of model names that depend on this model + """ + return self._dependents.get(model_name, set()).copy() + + def get_all_models(self) -> Set[str]: + """Get all models in the graph. + + Returns: + Set of all model names + """ + return self._models.copy() + + def detect_cycles(self) -> List[List[str]]: + """Detect cycles in the dependency graph. + + Returns: + List of cycles, where each cycle is a list of model names + """ + cycles = [] + visited = set() + rec_stack = set() + + def dfs(node: str, path: List[str]) -> None: + """Depth-first search to detect cycles.""" + if node in rec_stack: + # Found a cycle + cycle_start = path.index(node) + cycle = path[cycle_start:] + [node] + cycles.append(cycle) + return + + if node in visited: + return + + visited.add(node) + rec_stack.add(node) + path.append(node) + + # Visit all dependencies + for dep in self._dependencies.get(node, set()): + if dep in self._models: # Only check internal dependencies + dfs(dep, path) + + rec_stack.remove(node) + path.pop() + + for model in self._models: + if model not in visited: + dfs(model, []) + + return cycles + + def topological_sort(self) -> List[str]: + """Perform topological sort to get execution order. + + Returns: + List of model names in execution order (dependencies first) + + Raises: + DependencyError: If cycles are detected + """ + cycles = self.detect_cycles() + if cycles: + cycle_str = ' -> '.join(cycles[0]) + raise DependencyError(f'Circular dependency detected: {cycle_str}') + + # Kahn's algorithm for topological sort + in_degree = defaultdict(int) + for model in self._models: + in_degree[model] = 0 + + # Count incoming edges (only for internal dependencies) + for model in self._models: + for dep in self._dependencies.get(model, set()): + if dep in self._models: # Only internal dependencies + in_degree[model] += 1 + + # Queue of nodes with no incoming edges + queue = deque([model for model in self._models if in_degree[model] == 0]) + result = [] + + while queue: + node = queue.popleft() + result.append(node) + + # Remove edges from this node + for dependent in self._dependents.get(node, set()): + if dependent in self._models: # Only internal dependencies + in_degree[dependent] -= 1 + if in_degree[dependent] == 0: + queue.append(dependent) + + # Check if all nodes were processed + if len(result) != len(self._models): + # This shouldn't happen if cycles are detected, but check anyway + remaining = self._models - set(result) + raise DependencyError(f'Could not resolve dependencies for: {remaining}') + + return result + + def get_execution_order(self, select: Optional[Set[str]] = None) -> List[str]: + """Get execution order for selected models (including dependencies). + + Args: + select: Optional set of model names to execute. If None, all models. + + Returns: + List of model names in execution order + """ + if select is None: + return self.topological_sort() + + # Build subgraph with selected models and their dependencies + to_execute = select.copy() + queue = deque(select) + + while queue: + model = queue.popleft() + deps = self._dependencies.get(model, set()) + for dep in deps: + if dep in self._models: # Only internal dependencies + if dep not in to_execute: + to_execute.add(dep) + queue.append(dep) + + # Create subgraph + subgraph = DependencyGraph() + for model in to_execute: + deps = self._dependencies.get(model, set()) + internal_deps = {d for d in deps if d in self._models} + subgraph.add_model(model, internal_deps) + + return subgraph.topological_sort() + + +def build_dependency_graph(compiled_models: Dict[str, CompiledModel]) -> DependencyGraph: + """Build dependency graph from compiled models. + + Args: + compiled_models: Dictionary mapping model names to CompiledModel + + Returns: + DependencyGraph with all dependencies + """ + graph = DependencyGraph() + + for model_name, compiled in compiled_models.items(): + # Get internal dependencies (models referenced via ref()) + # External dependencies are in compiled.dependencies + internal_deps = set() + + # Find ref() calls in raw SQL that reference other models + # We need to check which refs are internal vs external + for ref_name in compiled.dependencies.keys(): + # If ref_name is in compiled_models, it's an internal dependency + if ref_name in compiled_models: + internal_deps.add(ref_name) + + graph.add_model(model_name, internal_deps) + + return graph + diff --git a/src/amp/dbt/exceptions.py b/src/amp/dbt/exceptions.py new file mode 100644 index 0000000..c0b7bff --- /dev/null +++ b/src/amp/dbt/exceptions.py @@ -0,0 +1,32 @@ +"""Exception classes for Amp DBT.""" + + +class AmpDbtError(Exception): + """Base exception for all Amp DBT errors.""" + + pass + + +class ProjectNotFoundError(AmpDbtError): + """Raised when a project directory is not found or invalid.""" + + pass + + +class ConfigError(AmpDbtError): + """Raised when configuration is invalid or missing.""" + + pass + + +class CompilationError(AmpDbtError): + """Raised when SQL compilation fails.""" + + pass + + +class DependencyError(AmpDbtError): + """Raised when dependency resolution fails.""" + + pass + diff --git a/src/amp/dbt/models.py b/src/amp/dbt/models.py new file mode 100644 index 0000000..dc227a0 --- /dev/null +++ b/src/amp/dbt/models.py @@ -0,0 +1,29 @@ +"""Data models for Amp DBT.""" + +from dataclasses import dataclass, field +from typing import Dict, Optional + + +@dataclass +class ModelConfig: + """Configuration for a single model.""" + + dependencies: Dict[str, str] = field(default_factory=dict) + track_progress: bool = False + track_column: Optional[str] = None + register: bool = False + deploy: bool = False + description: Optional[str] = None + incremental_strategy: Optional[str] = None + + +@dataclass +class CompiledModel: + """Result of compiling a model.""" + + name: str + sql: str + config: ModelConfig + dependencies: Dict[str, str] # Maps ref name to dataset reference + raw_sql: str # Original SQL before compilation + diff --git a/src/amp/dbt/monitor.py b/src/amp/dbt/monitor.py new file mode 100644 index 0000000..7fb81e0 --- /dev/null +++ b/src/amp/dbt/monitor.py @@ -0,0 +1,121 @@ +"""Job monitoring for Amp DBT.""" + +from datetime import datetime +from typing import Dict, List, Optional + +from amp.admin.jobs import JobsClient +from amp.admin.models import JobInfo +from amp.dbt.state import JobHistory, ModelState, StateDatabase +from amp.dbt.tracker import ModelTracker + + +class JobMonitor: + """Monitors job status and updates model state.""" + + def __init__(self, jobs_client: JobsClient, tracker: ModelTracker): + """Initialize job monitor. + + Args: + jobs_client: JobsClient instance for querying job status + tracker: ModelTracker instance for updating state + """ + self.jobs_client = jobs_client + self.tracker = tracker + + def monitor_job(self, job_id: int, model_name: str) -> JobInfo: + """Monitor a job and update model state. + + Args: + job_id: Job ID to monitor + model_name: Name of the model associated with this job + + Returns: + JobInfo with current job status + """ + job = self.jobs_client.get(job_id) + + # Update model state based on job status + if job.status == 'Running': + # Job is running - mark as in progress + self.tracker.update_progress(model_name, job_id=job_id, status='running') + elif job.status == 'Completed': + # Job completed - update state + # Note: In a real implementation, we'd extract latest_block from job metadata + self.tracker.update_progress(model_name, job_id=job_id, status='fresh') + self._record_job_completion(job_id, model_name, job.status) + elif job.status == 'Failed': + # Job failed - mark as error + self.tracker.update_progress(model_name, job_id=job_id, status='error') + self._record_job_completion(job_id, model_name, job.status) + elif job.status == 'Stopped': + # Job stopped - mark appropriately + self.tracker.update_progress(model_name, job_id=job_id, status='stale') + self._record_job_completion(job_id, model_name, job.status) + + return job + + def _record_job_completion(self, job_id: int, model_name: str, status: str): + """Record job completion in history. + + Args: + job_id: Job ID + model_name: Model name + status: Final job status + """ + job = self.jobs_client.get(job_id) + + # Extract job metadata (this would need to be enhanced based on actual JobInfo structure) + job_history = JobHistory( + job_id=job_id, + model_name=model_name, + status=status, + completed_at=datetime.now(), + # Note: In real implementation, extract started_at, final_block, rows_processed + # from job.descriptor or job metadata + ) + + self.tracker.state_db.add_job_history(job_history) + + def monitor_all(self, model_job_map: Dict[str, int]) -> Dict[str, JobInfo]: + """Monitor all jobs for given models. + + Args: + model_job_map: Dictionary mapping model names to job IDs + + Returns: + Dictionary mapping model names to JobInfo + """ + results = {} + + for model_name, job_id in model_job_map.items(): + try: + job_info = self.monitor_job(job_id, model_name) + results[model_name] = job_info + except Exception: + # Skip models with errors + pass + + return results + + def get_active_jobs(self, model_job_map: Dict[str, int]) -> Dict[str, JobInfo]: + """Get all active (running/pending) jobs. + + Args: + model_job_map: Dictionary mapping model names to job IDs + + Returns: + Dictionary mapping model names to JobInfo for active jobs + """ + active_statuses = {'Running', 'Pending', 'Scheduled'} + results = {} + + for model_name, job_id in model_job_map.items(): + try: + job = self.jobs_client.get(job_id) + if job.status in active_statuses: + results[model_name] = job + except Exception: + pass + + return results + diff --git a/src/amp/dbt/project.py b/src/amp/dbt/project.py new file mode 100644 index 0000000..ffb295a --- /dev/null +++ b/src/amp/dbt/project.py @@ -0,0 +1,205 @@ +"""Project management for Amp DBT.""" + +from pathlib import Path +from typing import Dict, List, Optional + +from amp.dbt.compiler import Compiler +from amp.dbt.config import load_project_config, parse_config_block +from amp.dbt.dependencies import DependencyGraph, build_dependency_graph +from amp.dbt.exceptions import DependencyError, ProjectNotFoundError +from amp.dbt.models import CompiledModel, ModelConfig + + +class AmpDbtProject: + """Main project class for Amp DBT.""" + + def __init__(self, project_root: Optional[Path] = None): + """Initialize project. + + Args: + project_root: Root directory of the DBT project (default: current directory) + + Raises: + ProjectNotFoundError: If project directory is invalid + """ + if project_root is None: + project_root = Path.cwd() + + project_root = Path(project_root).resolve() + + # Check for dbt_project.yml (optional for Phase 1) + self.project_root = project_root + self.models_dir = project_root / 'models' + self.macros_dir = project_root / 'macros' + + # Initialize compiler + self.compiler = Compiler(self.project_root, self.macros_dir) + + # Load project config + self.config = load_project_config(self.project_root) + + def find_models(self, select: Optional[str] = None) -> List[Path]: + """Find all model files in the project. + + Args: + select: Optional glob pattern to filter models (e.g., 'stg_*') + + Returns: + List of model file paths + """ + if not self.models_dir.exists(): + return [] + + pattern = select if select else '*.sql' + models = list(self.models_dir.rglob(pattern)) + return sorted(models) + + def load_model(self, model_path: Path) -> tuple[str, ModelConfig]: + """Load a model file and parse its config. + + Args: + model_path: Path to model SQL file + + Returns: + Tuple of (sql_content, ModelConfig) + + Raises: + ProjectNotFoundError: If model file doesn't exist + """ + if not model_path.exists(): + raise ProjectNotFoundError(f'Model file not found: {model_path}') + + sql = model_path.read_text() + sql_without_config, config = parse_config_block(sql) + + return sql_without_config, config + + def compile_model(self, model_path: Path) -> CompiledModel: + """Compile a single model. + + Args: + model_path: Path to model SQL file + + Returns: + CompiledModel with compiled SQL + + Raises: + ProjectNotFoundError: If model file doesn't exist + """ + # Load model and parse config (config block is removed from SQL) + sql, config = self.load_model(model_path) + + # Get model name from file path (without extension) + model_name = model_path.stem + + # Merge with project-level config if available + if 'models' in self.config: + # Apply project-level config overrides (future enhancement) + pass + + return self.compiler.compile(sql, model_name, config) + + def compile_all(self, select: Optional[str] = None) -> Dict[str, CompiledModel]: + """Compile all models in the project with dependency resolution. + + Args: + select: Optional glob pattern to filter models + + Returns: + Dictionary mapping model names to CompiledModel + + Raises: + DependencyError: If circular dependencies are detected + """ + # Find models matching select pattern + selected_model_paths = self.find_models(select) + if not selected_model_paths: + return {} + + # To resolve dependencies properly, we need to compile ALL models first + # to discover the dependency graph, then filter to selected + dependencies + all_model_paths = self.find_models(None) # Get all models + + # First pass: compile all models without CTEs to get dependencies + # IMPORTANT: Add all model names to available_models FIRST so references can be resolved + # even if dependencies haven't been compiled yet (we're just discovering dependencies here) + available_models = {p.stem for p in all_model_paths} + initial_compiled = {} + + for model_path in all_model_paths: + try: + sql, config = self.load_model(model_path) + model_name = model_path.stem + + # Compile with all models available (for reference resolution) + compiled = self.compiler.compile(sql, model_name, config, available_models) + initial_compiled[model_name] = compiled + except Exception as e: + # Log error but continue with other models + print(f'Warning: Failed to compile {model_path.name}: {e}') + + # Build dependency graph from all models + graph = build_dependency_graph(initial_compiled) + + # Detect cycles + cycles = graph.detect_cycles() + if cycles: + cycle_str = ' -> '.join(cycles[0]) + raise DependencyError(f'Circular dependency detected: {cycle_str}') + + # If select is specified, get execution order for selected models + dependencies + if select is not None: + selected_model_names = {p.stem for p in selected_model_paths} + execution_order = graph.get_execution_order(selected_model_names) + else: + # No select - compile everything + execution_order = graph.topological_sort() + + # Second pass: compile with CTE inlining in dependency order + final_compiled = {} + for model_name in execution_order: + model_path = next(p for p in all_model_paths if p.stem == model_name) + sql, config = self.load_model(model_path) + + # Get internal dependencies for this model + internal_deps = {} + for dep_name in graph.get_dependencies(model_name): + if dep_name in final_compiled: + internal_deps[dep_name] = final_compiled[dep_name].sql + + # Compile with CTEs + compiled = self.compiler.compile_with_ctes( + sql, model_name, config, internal_deps, available_models + ) + final_compiled[model_name] = compiled + + return final_compiled + + def get_execution_order(self, select: Optional[str] = None) -> List[str]: + """Get execution order for models. + + Args: + select: Optional glob pattern to filter models + + Returns: + List of model names in execution order + """ + model_paths = self.find_models(select) + if not model_paths: + return [] + + # Compile to get dependencies + compiled = {} + for model_path in model_paths: + try: + sql, config = self.load_model(model_path) + model_name = model_path.stem + compiled[model_name] = self.compiler.compile( + sql, model_name, config, {p.stem for p in model_paths} + ) + except Exception: + pass # Skip models that fail to compile + + graph = build_dependency_graph(compiled) + return graph.topological_sort() + diff --git a/src/amp/dbt/state.py b/src/amp/dbt/state.py new file mode 100644 index 0000000..9183d75 --- /dev/null +++ b/src/amp/dbt/state.py @@ -0,0 +1,318 @@ +"""State management for Amp DBT monitoring and tracking.""" + +import sqlite3 +from dataclasses import dataclass +from datetime import datetime, timedelta +from pathlib import Path +from typing import Dict, List, Optional + +from amp.dbt.exceptions import ConfigError + + +@dataclass +class ModelState: + """State information for a model.""" + + model_name: str + connection_name: Optional[str] = None + latest_block: Optional[int] = None + latest_timestamp: Optional[datetime] = None + last_updated: Optional[datetime] = None + job_id: Optional[int] = None + status: str = 'unknown' # 'fresh', 'stale', 'error', 'unknown' + + +@dataclass +class JobHistory: + """History record for a job.""" + + job_id: int + model_name: str + status: str + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + final_block: Optional[int] = None + rows_processed: Optional[int] = None + + +@dataclass +class FreshnessResult: + """Result of a freshness check.""" + + stale: bool + age: Optional[timedelta] = None + latest_block: Optional[int] = None + latest_timestamp: Optional[datetime] = None + reason: Optional[str] = None + + +class StateDatabase: + """SQLite database for storing model state and job history.""" + + def __init__(self, db_path: Path): + """Initialize state database. + + Args: + db_path: Path to SQLite database file + """ + self.db_path = db_path + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_schema() + + def _init_schema(self): + """Initialize database schema.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + # Create model_state table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS model_state ( + model_name TEXT PRIMARY KEY, + connection_name TEXT, + latest_block INTEGER, + latest_timestamp TIMESTAMP, + last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + job_id INTEGER, + status TEXT + ) + ''') + + # Create job_history table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS job_history ( + job_id INTEGER PRIMARY KEY, + model_name TEXT, + status TEXT, + started_at TIMESTAMP, + completed_at TIMESTAMP, + final_block INTEGER, + rows_processed INTEGER + ) + ''') + + # Create indexes + cursor.execute('CREATE INDEX IF NOT EXISTS idx_model_state_status ON model_state(status)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_job_history_model ON job_history(model_name)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_job_history_status ON job_history(status)') + + conn.commit() + conn.close() + + def get_model_state(self, model_name: str) -> Optional[ModelState]: + """Get state for a model. + + Args: + model_name: Name of the model + + Returns: + ModelState if found, None otherwise + """ + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + cursor.execute( + 'SELECT * FROM model_state WHERE model_name = ?', (model_name,) + ) + row = cursor.fetchone() + conn.close() + + if not row: + return None + + return ModelState( + model_name=row['model_name'], + connection_name=row['connection_name'], + latest_block=row['latest_block'], + latest_timestamp=datetime.fromisoformat(row['latest_timestamp']) + if row['latest_timestamp'] + else None, + last_updated=datetime.fromisoformat(row['last_updated']) + if row['last_updated'] + else None, + job_id=row['job_id'], + status=row['status'] or 'unknown', + ) + + def update_model_state( + self, + model_name: str, + latest_block: Optional[int] = None, + latest_timestamp: Optional[datetime] = None, + job_id: Optional[int] = None, + status: Optional[str] = None, + connection_name: Optional[str] = None, + ): + """Update state for a model. + + Args: + model_name: Name of the model + latest_block: Latest block number processed + latest_timestamp: Latest timestamp processed + job_id: Current job ID + status: Status ('fresh', 'stale', 'error') + connection_name: Connection name + """ + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + # Check if record exists + cursor.execute('SELECT model_name FROM model_state WHERE model_name = ?', (model_name,)) + exists = cursor.fetchone() is not None + + if exists: + # Update existing record + updates = [] + params = [] + + if latest_block is not None: + updates.append('latest_block = ?') + params.append(latest_block) + if latest_timestamp is not None: + updates.append('latest_timestamp = ?') + params.append(latest_timestamp.isoformat()) + if job_id is not None: + updates.append('job_id = ?') + params.append(job_id) + if status is not None: + updates.append('status = ?') + params.append(status) + if connection_name is not None: + updates.append('connection_name = ?') + params.append(connection_name) + + updates.append('last_updated = CURRENT_TIMESTAMP') + params.append(model_name) + + cursor.execute( + f'UPDATE model_state SET {", ".join(updates)} WHERE model_name = ?', params + ) + else: + # Insert new record + cursor.execute( + ''' + INSERT INTO model_state + (model_name, connection_name, latest_block, latest_timestamp, job_id, status) + VALUES (?, ?, ?, ?, ?, ?) + ''', + ( + model_name, + connection_name, + latest_block, + latest_timestamp.isoformat() if latest_timestamp else None, + job_id, + status or 'unknown', + ), + ) + + conn.commit() + conn.close() + + def add_job_history(self, job_history: JobHistory): + """Add a job history record. + + Args: + job_history: JobHistory record to add + """ + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute( + ''' + INSERT OR REPLACE INTO job_history + (job_id, model_name, status, started_at, completed_at, final_block, rows_processed) + VALUES (?, ?, ?, ?, ?, ?, ?) + ''', + ( + job_history.job_id, + job_history.model_name, + job_history.status, + job_history.started_at.isoformat() if job_history.started_at else None, + job_history.completed_at.isoformat() if job_history.completed_at else None, + job_history.final_block, + job_history.rows_processed, + ), + ) + + conn.commit() + conn.close() + + def get_all_model_states(self) -> List[ModelState]: + """Get all model states. + + Returns: + List of ModelState objects + """ + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + cursor.execute('SELECT * FROM model_state ORDER BY model_name') + rows = cursor.fetchall() + conn.close() + + states = [] + for row in rows: + states.append( + ModelState( + model_name=row['model_name'], + connection_name=row['connection_name'], + latest_block=row['latest_block'], + latest_timestamp=datetime.fromisoformat(row['latest_timestamp']) + if row['latest_timestamp'] + else None, + last_updated=datetime.fromisoformat(row['last_updated']) + if row['last_updated'] + else None, + job_id=row['job_id'], + status=row['status'] or 'unknown', + ) + ) + + return states + + def get_job_history(self, model_name: Optional[str] = None, limit: int = 50) -> List[JobHistory]: + """Get job history. + + Args: + model_name: Optional filter by model name + limit: Maximum number of records to return + + Returns: + List of JobHistory objects + """ + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + if model_name: + cursor.execute( + 'SELECT * FROM job_history WHERE model_name = ? ORDER BY started_at DESC LIMIT ?', + (model_name, limit), + ) + else: + cursor.execute('SELECT * FROM job_history ORDER BY started_at DESC LIMIT ?', (limit,)) + + rows = cursor.fetchall() + conn.close() + + history = [] + for row in rows: + history.append( + JobHistory( + job_id=row['job_id'], + model_name=row['model_name'], + status=row['status'], + started_at=datetime.fromisoformat(row['started_at']) + if row['started_at'] + else None, + completed_at=datetime.fromisoformat(row['completed_at']) + if row['completed_at'] + else None, + final_block=row['final_block'], + rows_processed=row['rows_processed'], + ) + ) + + return history + diff --git a/src/amp/dbt/tracker.py b/src/amp/dbt/tracker.py new file mode 100644 index 0000000..6e67fee --- /dev/null +++ b/src/amp/dbt/tracker.py @@ -0,0 +1,178 @@ +"""Model tracking and freshness monitoring for Amp DBT.""" + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Dict, Optional + +from amp.dbt.config import load_project_config +from amp.dbt.exceptions import ConfigError +from amp.dbt.state import FreshnessResult, ModelState, StateDatabase + + +class ModelTracker: + """Tracks data progress for models.""" + + def __init__(self, project_root: Path): + """Initialize model tracker. + + Args: + project_root: Root directory of the DBT project + """ + self.project_root = project_root + self.state_dir = project_root / '.amp-dbt' + self.state_db = StateDatabase(self.state_dir / 'state.db') + self.config = load_project_config(project_root) + + def get_latest_block(self, model_name: str) -> Optional[int]: + """Get latest block number for a model. + + Args: + model_name: Name of the model + + Returns: + Latest block number, or None if not tracked + """ + state = self.state_db.get_model_state(model_name) + return state.latest_block if state else None + + def get_latest_timestamp(self, model_name: str) -> Optional[datetime]: + """Get latest timestamp for a model. + + Args: + model_name: Name of the model + + Returns: + Latest timestamp, or None if not tracked + """ + state = self.state_db.get_model_state(model_name) + return state.latest_timestamp if state else None + + def update_progress( + self, + model_name: str, + latest_block: Optional[int] = None, + latest_timestamp: Optional[datetime] = None, + job_id: Optional[int] = None, + status: Optional[str] = None, + ): + """Update progress for a model. + + Args: + model_name: Name of the model + latest_block: Latest block number + latest_timestamp: Latest timestamp + job_id: Current job ID + status: Status ('fresh', 'stale', 'error') + """ + self.state_db.update_model_state( + model_name=model_name, + latest_block=latest_block, + latest_timestamp=latest_timestamp or datetime.now(), + job_id=job_id, + status=status, + ) + + def get_model_state(self, model_name: str) -> Optional[ModelState]: + """Get full state for a model. + + Args: + model_name: Name of the model + + Returns: + ModelState if found, None otherwise + """ + return self.state_db.get_model_state(model_name) + + def get_all_states(self) -> Dict[str, ModelState]: + """Get all model states. + + Returns: + Dictionary mapping model names to ModelState + """ + states = self.state_db.get_all_model_states() + return {state.model_name: state for state in states} + + +class FreshnessMonitor: + """Monitors data freshness for models.""" + + def __init__(self, tracker: ModelTracker): + """Initialize freshness monitor. + + Args: + tracker: ModelTracker instance + """ + self.tracker = tracker + self.config = tracker.config + + def get_alert_threshold(self, model_name: str) -> timedelta: + """Get alert threshold for a model. + + Args: + model_name: Name of the model + + Returns: + Alert threshold as timedelta + """ + # Check model-specific config + if 'models' in self.config: + # Try to find model-specific threshold + # For now, use default + pass + + # Get default from monitoring config + default_minutes = 30 + if 'monitoring' in self.config: + default_minutes = self.config['monitoring'].get('alert_threshold_minutes', 30) + + return timedelta(minutes=default_minutes) + + def check_freshness(self, model_name: str) -> FreshnessResult: + """Check freshness of a model. + + Args: + model_name: Name of the model + + Returns: + FreshnessResult with freshness information + """ + state = self.tracker.get_model_state(model_name) + + if not state or not state.latest_timestamp: + return FreshnessResult( + stale=True, reason='No data tracked', latest_block=state.latest_block if state else None + ) + + now = datetime.now() + age = now - state.latest_timestamp + threshold = self.get_alert_threshold(model_name) + + is_stale = age > threshold + + # Update status in database + new_status = 'stale' if is_stale else 'fresh' + if state.status != new_status: + self.tracker.state_db.update_model_state(model_name, status=new_status) + + return FreshnessResult( + stale=is_stale, + age=age, + latest_block=state.latest_block, + latest_timestamp=state.latest_timestamp, + reason='Data is stale' if is_stale else None, + ) + + def check_all_freshness(self) -> Dict[str, FreshnessResult]: + """Check freshness of all models. + + Returns: + Dictionary mapping model names to FreshnessResult + """ + states = self.tracker.get_all_states() + results = {} + + for model_name in states.keys(): + results[model_name] = self.check_freshness(model_name) + + return results + From 21c796e1d6f83ebe30b156f6bf015535312e99ea Mon Sep 17 00:00:00 2001 From: Vivian Peng Date: Fri, 12 Dec 2025 07:36:30 +0900 Subject: [PATCH 2/3] Revised the commit --- apps/execute_query.py | 24 -------- models/example_model.sql | 12 ---- models/staging/stg_test.sql | 6 -- src/amp/dbt/README.md | 113 ------------------------------------ 4 files changed, 155 deletions(-) delete mode 100644 apps/execute_query.py delete mode 100644 models/example_model.sql delete mode 100644 models/staging/stg_test.sql delete mode 100644 src/amp/dbt/README.md diff --git a/apps/execute_query.py b/apps/execute_query.py deleted file mode 100644 index f23ed90..0000000 --- a/apps/execute_query.py +++ /dev/null @@ -1,24 +0,0 @@ -from rich import print - -from amp.client import Client - -# Replace with your remote server URL -# Format: grpc://hostname:port or grpc+tls://hostname:port for TLS -SERVER_URL = "grpc://34.27.238.174:80" - -# Option 1: No authentication (if server doesn't require it) -# client = Client(url=SERVER_URL) - -# Option 2: Use explicit auth token -# client = Client(url=SERVER_URL, auth_token='your-token-here') - -# Option 3: Use environment variable AMP_AUTH_TOKEN -# export AMP_AUTH_TOKEN="your-token-here" -# client = Client(url=SERVER_URL) - -# Option 4: Use auto-refreshing auth from shared auth file (recommended) -# Uses ~/.amp/cache/amp_cli_auth (shared with TypeScript CLI) -client = Client(url=SERVER_URL, auth=True) - -df = client.get_sql('select * from eth_firehose.logs limit 1', read_all=True).to_pandas() -print(df) diff --git a/models/example_model.sql b/models/example_model.sql deleted file mode 100644 index 76f35c3..0000000 --- a/models/example_model.sql +++ /dev/null @@ -1,12 +0,0 @@ --- Example model -{{ config( - dependencies={'eth': '_/eth_firehose@1.0.0'}, - description='Example model showing how to use ref()' -) }} - -SELECT - block_num, - block_hash, - timestamp -FROM {{ ref('eth') }}.blocks -LIMIT 10 diff --git a/models/staging/stg_test.sql b/models/staging/stg_test.sql deleted file mode 100644 index 255546d..0000000 --- a/models/staging/stg_test.sql +++ /dev/null @@ -1,6 +0,0 @@ -{{ config( - dependencies={'eth': '_/eth_firehose@1.0.0'}, - track_progress=True, - track_column='block_num' -) }} -SELECT block_num, tx_hash FROM {{ ref('eth') }}.logs LIMIT 10 diff --git a/src/amp/dbt/README.md b/src/amp/dbt/README.md deleted file mode 100644 index 0580cb4..0000000 --- a/src/amp/dbt/README.md +++ /dev/null @@ -1,113 +0,0 @@ -# Amp DBT - Phase 1 Implementation - -## Overview - -Phase 1 implements the core compilation engine for Amp DBT, providing basic query composition with Jinja templating and external dataset reference resolution. - -## Features Implemented - -### ✅ Project Initialization -- `amp-dbt init` command creates a new DBT project structure -- Generates `dbt_project.yml`, directory structure, and example model - -### ✅ Model Loading and Parsing -- Loads SQL model files from `models/` directory -- Parses `{{ config() }}` blocks from model SQL -- Extracts configuration (dependencies, track_progress, etc.) - -### ✅ Jinja Templating Support -- Full Jinja2 template rendering -- Custom `ref()` function for dependency resolution -- Support for variables and macros (basic) - -### ✅ ref() Resolution (External Datasets Only) -- Resolves `{{ ref('eth') }}` to dataset references like `_/eth_firehose@1.0.0` -- Validates dependencies are defined in config -- Replaces ref() calls in compiled SQL - -### ✅ Basic Config Parsing -- Parses `{{ config() }}` blocks from model SQL -- Loads `dbt_project.yml` (optional) -- Supports dependencies, track_progress, register, deploy flags - -### ✅ CLI Commands -- `amp-dbt init` - Initialize new project -- `amp-dbt compile` - Compile models -- `amp-dbt list` - List all models - -## Usage - -### Initialize a Project - -```bash -amp-dbt init my-project -cd my-project -``` - -### Create a Model - -Create `models/staging/stg_erc20_transfers.sql`: - -```sql -{{ config( - dependencies={'eth': '_/eth_firehose@1.0.0'}, - track_progress=true, - track_column='block_num', - description='Decoded ERC20 Transfer events' -) }} - -SELECT - l.block_num, - l.block_hash, - l.timestamp, - l.tx_hash, - l.address as token_address -FROM {{ ref('eth') }}.logs l -WHERE - l.topic0 = evm_topic('Transfer(address indexed from, address indexed to, uint256 value)') - AND l.topic3 IS NULL -``` - -### Compile Models - -```bash -# Compile all models -amp-dbt compile - -# Compile specific models -amp-dbt compile --select stg_* - -# Show compiled SQL -amp-dbt compile --show-sql -``` - -## Project Structure - -``` -my-project/ -├── dbt_project.yml # Project configuration -├── models/ # SQL model files -│ ├── staging/ -│ │ └── stg_erc20_transfers.sql -│ └── marts/ -│ └── token_analytics.sql -├── macros/ # Reusable SQL macros (future) -├── tests/ # Data quality tests (future) -└── docs/ # Documentation (future) -``` - -## Limitations (Phase 1) - -- ❌ Internal model references (model-to-model dependencies) not supported -- ❌ Macros system not fully implemented -- ❌ No execution (`amp-dbt run`) - compilation only -- ❌ No monitoring or tracking -- ❌ No testing framework - -## Next Steps (Phase 2) - -- Internal model dependency resolution (CTE inlining) -- Dependency graph building -- Topological sort for execution order -- `amp-dbt run` command for execution - From 15efdf17d599adec063c1f0150e488c8e814c5ec Mon Sep 17 00:00:00 2001 From: Vivian Peng Date: Thu, 18 Dec 2025 10:06:01 +0900 Subject: [PATCH 3/3] Added integration tests --- src/amp/admin/datasets.py | 4 +- src/amp/registry/datasets.py | 4 +- src/amp/streaming/state.py | 4 +- src/amp/utils/manifest_inspector.py | 4 +- tests/integration/test_dbt.py | 712 ++++++++++++++++++++++++++++ tests/integration/test_dbt_cli.py | 550 +++++++++++++++++++++ 6 files changed, 1270 insertions(+), 8 deletions(-) create mode 100644 tests/integration/test_dbt.py create mode 100644 tests/integration/test_dbt_cli.py diff --git a/src/amp/admin/datasets.py b/src/amp/admin/datasets.py index b47f107..1eb9bae 100644 --- a/src/amp/admin/datasets.py +++ b/src/amp/admin/datasets.py @@ -4,7 +4,7 @@ including registration, deployment, versioning, and manifest operations. """ -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union from amp.utils.manifest_inspector import describe_manifest, print_schema @@ -200,7 +200,7 @@ def get_manifest(self, namespace: str, name: str, revision: str) -> dict: response = self._admin._request('GET', path) return response.json() - def describe(self, namespace: str, name: str, revision: str = 'latest') -> Dict[str, list[Dict[str, str | bool]]]: + def describe(self, namespace: str, name: str, revision: str = 'latest') -> Dict[str, List[Dict[str, Union[str, bool]]]]: """Get a structured summary of tables and columns in a dataset. Returns a dictionary mapping table names to lists of column information, diff --git a/src/amp/registry/datasets.py b/src/amp/registry/datasets.py index 9bf94ef..e54e349 100644 --- a/src/amp/registry/datasets.py +++ b/src/amp/registry/datasets.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from amp.utils.manifest_inspector import describe_manifest, print_schema @@ -199,7 +199,7 @@ def get_manifest(self, namespace: str, name: str, version: str) -> dict: response = self._registry._request('GET', path) return response.json() - def describe(self, namespace: str, name: str, version: str = 'latest') -> Dict[str, list[Dict[str, str | bool]]]: + def describe(self, namespace: str, name: str, version: str = 'latest') -> Dict[str, List[Dict[str, Union[str, bool]]]]: """Get a structured summary of tables and columns in a dataset. Returns a dictionary mapping table names to lists of column information, diff --git a/src/amp/streaming/state.py b/src/amp/streaming/state.py index d0e7936..21e5e6c 100644 --- a/src/amp/streaming/state.py +++ b/src/amp/streaming/state.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import UTC, datetime -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union from amp.streaming.types import BlockRange, ResumeWatermark @@ -235,7 +235,7 @@ def __init__(self): def _get_key( self, connection_name: str, table_name: str, network: Optional[str] = None - ) -> Tuple[str, str, str] | List[Tuple[str, str, str]]: + ) -> Union[Tuple[str, str, str], List[Tuple[str, str, str]]]: """Get storage key(s) for the given parameters.""" if network: return (connection_name, table_name, network) diff --git a/src/amp/utils/manifest_inspector.py b/src/amp/utils/manifest_inspector.py index 0828fee..b362c23 100644 --- a/src/amp/utils/manifest_inspector.py +++ b/src/amp/utils/manifest_inspector.py @@ -4,10 +4,10 @@ from manifest files in a human-readable format. """ -from typing import Any, Dict +from typing import Any, Dict, List, Union -def describe_manifest(manifest: dict) -> Dict[str, list[Dict[str, str | bool]]]: +def describe_manifest(manifest: dict) -> Dict[str, List[Dict[str, Union[str, bool]]]]: """Extract structured schema information from a manifest. Args: diff --git a/tests/integration/test_dbt.py b/tests/integration/test_dbt.py new file mode 100644 index 0000000..5a2fe31 --- /dev/null +++ b/tests/integration/test_dbt.py @@ -0,0 +1,712 @@ +"""Integration tests for Amp DBT module.""" + +import tempfile +from datetime import datetime, timedelta +from pathlib import Path + +import pytest +import yaml + +from amp.dbt.compiler import Compiler +from amp.dbt.config import load_project_config, parse_config_block +from amp.dbt.dependencies import DependencyGraph, build_dependency_graph +from amp.dbt.exceptions import ( + CompilationError, + ConfigError, + DependencyError, + ProjectNotFoundError, +) +from amp.dbt.models import CompiledModel, ModelConfig +from amp.dbt.project import AmpDbtProject +from amp.dbt.state import FreshnessResult, ModelState, StateDatabase +from amp.dbt.tracker import FreshnessMonitor, ModelTracker + + +@pytest.mark.integration +class TestDbtProject: + """Integration tests for AmpDbtProject.""" + + def test_project_initialization(self, tmp_path): + """Test project initialization with valid directory.""" + project = AmpDbtProject(tmp_path) + assert project.project_root == tmp_path + assert project.models_dir == tmp_path / 'models' + assert project.macros_dir == tmp_path / 'macros' + + def test_project_initialization_defaults_to_cwd(self, monkeypatch): + """Test project initialization defaults to current directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + monkeypatch.chdir(tmpdir) + project = AmpDbtProject() + assert project.project_root == Path(tmpdir).resolve() + + def test_find_models_empty_directory(self, tmp_path): + """Test finding models in empty directory.""" + project = AmpDbtProject(tmp_path) + models = project.find_models() + assert models == [] + + def test_find_models_with_files(self, tmp_path): + """Test finding models in directory with SQL files.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + # Create some model files + (models_dir / 'model1.sql').write_text('SELECT 1') + (models_dir / 'model2.sql').write_text('SELECT 2') + (models_dir / 'not_a_model.txt').write_text('not sql') + + project = AmpDbtProject(tmp_path) + models = project.find_models() + assert len(models) == 2 + assert all(m.suffix == '.sql' for m in models) + assert any('model1' in str(m) for m in models) + assert any('model2' in str(m) for m in models) + + def test_find_models_with_select_pattern(self, tmp_path): + """Test finding models with select pattern.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + (models_dir / 'stg_model1.sql').write_text('SELECT 1') + (models_dir / 'stg_model2.sql').write_text('SELECT 2') + (models_dir / 'final_model.sql').write_text('SELECT 3') + + project = AmpDbtProject(tmp_path) + models = project.find_models('stg_*') + assert len(models) == 2 + assert all('stg_' in m.stem for m in models) + + def test_load_model(self, tmp_path): + """Test loading a model file.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + model_sql = '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT * FROM {{ ref('eth') }}.blocks +''' + (models_dir / 'test_model.sql').write_text(model_sql) + + project = AmpDbtProject(tmp_path) + sql, config = project.load_model(models_dir / 'test_model.sql') + + assert 'config' not in sql + assert 'SELECT * FROM' in sql + assert config.dependencies == {'eth': '_/eth_firehose@1.0.0'} + + def test_load_model_not_found(self, tmp_path): + """Test loading non-existent model raises error.""" + project = AmpDbtProject(tmp_path) + with pytest.raises(ProjectNotFoundError): + project.load_model(tmp_path / 'models' / 'missing.sql') + + def test_compile_single_model(self, tmp_path): + """Test compiling a single model.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + model_sql = '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT block_num, block_hash FROM {{ ref('eth') }}.blocks +LIMIT 10 +''' + (models_dir / 'test_model.sql').write_text(model_sql) + + project = AmpDbtProject(tmp_path) + compiled = project.compile_model(models_dir / 'test_model.sql') + + assert compiled.name == 'test_model' + assert 'SELECT block_num, block_hash FROM' in compiled.sql + assert compiled.config.dependencies == {'eth': '_/eth_firehose@1.0.0'} + assert 'eth' in compiled.dependencies + + def test_compile_all_models(self, tmp_path): + """Test compiling all models in project.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + # Create multiple models + (models_dir / 'model1.sql').write_text( + '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT * FROM {{ ref('eth') }}.blocks LIMIT 10 +''' + ) + (models_dir / 'model2.sql').write_text( + '''{{ config(dependencies={'arb': '_/arb_firehose@1.0.0'}) }} +SELECT * FROM {{ ref('arb') }}.blocks LIMIT 10 +''' + ) + + project = AmpDbtProject(tmp_path) + compiled = project.compile_all() + + assert len(compiled) == 2 + assert 'model1' in compiled + assert 'model2' in compiled + assert compiled['model1'].dependencies['eth'] == '_/eth_firehose@1.0.0' + assert compiled['model2'].dependencies['arb'] == '_/arb_firehose@1.0.0' + + def test_compile_models_with_internal_dependencies(self, tmp_path): + """Test compiling models with internal dependencies.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + # Create base model + (models_dir / 'base_model.sql').write_text( + '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT block_num, block_hash FROM {{ ref('eth') }}.blocks +''' + ) + + # Create dependent model + (models_dir / 'dependent_model.sql').write_text( + '''SELECT * FROM {{ ref('base_model') }} +WHERE block_num > 1000 +''' + ) + + project = AmpDbtProject(tmp_path) + compiled = project.compile_all() + + assert len(compiled) == 2 + assert 'base_model' in compiled + assert 'dependent_model' in compiled + + # Check dependencies + assert 'eth' in compiled['base_model'].dependencies + assert 'base_model' in compiled['dependent_model'].dependencies + + def test_execution_order(self, tmp_path): + """Test getting execution order for models.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + # Create models with dependencies + (models_dir / 'base.sql').write_text('SELECT 1 as id') + (models_dir / 'intermediate.sql').write_text('SELECT * FROM {{ ref("base") }}') + (models_dir / 'final.sql').write_text('SELECT * FROM {{ ref("intermediate") }}') + + project = AmpDbtProject(tmp_path) + order = project.get_execution_order() + + assert len(order) == 3 + assert order[0] == 'base' + assert order[1] == 'intermediate' + assert order[2] == 'final' + + def test_execution_order_with_select(self, tmp_path): + """Test execution order with select pattern.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + (models_dir / 'base.sql').write_text('SELECT 1 as id') + (models_dir / 'intermediate.sql').write_text('SELECT * FROM {{ ref("base") }}') + (models_dir / 'final.sql').write_text('SELECT * FROM {{ ref("intermediate") }}') + + project = AmpDbtProject(tmp_path) + order = project.get_execution_order('final') + + # Should include final and its dependencies + assert 'final' in order + assert 'intermediate' in order + assert 'base' in order + assert order.index('base') < order.index('intermediate') + assert order.index('intermediate') < order.index('final') + + def test_circular_dependency_detection(self, tmp_path): + """Test that circular dependencies are detected.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + (models_dir / 'model1.sql').write_text('SELECT * FROM {{ ref("model2") }}') + (models_dir / 'model2.sql').write_text('SELECT * FROM {{ ref("model1") }}') + + project = AmpDbtProject(tmp_path) + with pytest.raises(DependencyError, match='Circular dependency'): + project.compile_all() + + +@pytest.mark.integration +class TestDbtCompiler: + """Integration tests for Compiler.""" + + def test_compile_simple_model(self, tmp_path): + """Test compiling a simple model.""" + compiler = Compiler(tmp_path) + sql = 'SELECT 1 as value' + config = ModelConfig() + + compiled = compiler.compile(sql, 'test_model', config) + + assert compiled.name == 'test_model' + assert compiled.sql == sql + assert compiled.config == config + + def test_compile_with_jinja(self, tmp_path): + """Test compiling with Jinja templating.""" + compiler = Compiler(tmp_path) + sql = 'SELECT {{ model_name }} as model' + config = ModelConfig() + + compiled = compiler.compile(sql, 'test_model', config) + + assert 'test_model' in compiled.sql + + def test_compile_with_ref_external(self, tmp_path): + """Test compiling with external ref().""" + compiler = Compiler(tmp_path) + sql = 'SELECT * FROM {{ ref("eth") }}.blocks' + config = ModelConfig(dependencies={'eth': '_/eth_firehose@1.0.0'}) + + compiled = compiler.compile(sql, 'test_model', config, available_models=set()) + + assert 'eth' in compiled.dependencies + assert compiled.dependencies['eth'] == '_/eth_firehose@1.0.0' + assert '__REF__eth__' not in compiled.sql + + def test_compile_with_ref_internal(self, tmp_path): + """Test compiling with internal ref().""" + compiler = Compiler(tmp_path) + sql = 'SELECT * FROM {{ ref("base_model") }}' + config = ModelConfig() + available_models = {'base_model'} + + compiled = compiler.compile(sql, 'test_model', config, available_models) + + assert 'base_model' in compiled.dependencies + assert compiled.dependencies['base_model'] == 'base_model' + + def test_compile_with_ref_unknown(self, tmp_path): + """Test that unknown ref() raises error.""" + compiler = Compiler(tmp_path) + sql = 'SELECT * FROM {{ ref("unknown") }}' + config = ModelConfig() + + with pytest.raises(DependencyError, match='Unknown reference'): + compiler.compile(sql, 'test_model', config, available_models=set()) + + def test_compile_with_ctes(self, tmp_path): + """Test compiling with CTE inlining.""" + compiler = Compiler(tmp_path) + sql = 'SELECT * FROM {{ ref("base") }}' + config = ModelConfig() + internal_deps = {'base': 'SELECT 1 as id'} + + compiled = compiler.compile_with_ctes(sql, 'test_model', config, internal_deps) + + assert 'WITH' in compiled.sql.upper() + assert 'base' in compiled.sql + + def test_compile_with_nested_ctes(self, tmp_path): + """Test compiling with nested CTEs.""" + compiler = Compiler(tmp_path) + sql = 'SELECT * FROM {{ ref("intermediate") }}' + config = ModelConfig() + internal_deps = { + 'intermediate': 'WITH base AS (SELECT 1 as id) SELECT * FROM base', + } + + compiled = compiler.compile_with_ctes(sql, 'test_model', config, internal_deps) + + assert 'WITH' in compiled.sql.upper() + assert 'intermediate' in compiled.sql + + +@pytest.mark.integration +class TestDbtConfig: + """Integration tests for configuration parsing.""" + + def test_parse_config_block_simple(self): + """Test parsing simple config block.""" + sql = '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT * FROM blocks +''' + sql_without_config, config = parse_config_block(sql) + + assert 'config' not in sql_without_config + assert 'SELECT * FROM blocks' in sql_without_config + assert config.dependencies == {'eth': '_/eth_firehose@1.0.0'} + + def test_parse_config_block_multiple_deps(self): + """Test parsing config block with multiple dependencies.""" + sql = '''{{ config( + dependencies={ + 'eth': '_/eth_firehose@1.0.0', + 'arb': '_/arb_firehose@1.0.0' + } +) }} +SELECT * FROM blocks +''' + sql_without_config, config = parse_config_block(sql) + + assert len(config.dependencies) == 2 + assert config.dependencies['eth'] == '_/eth_firehose@1.0.0' + assert config.dependencies['arb'] == '_/arb_firehose@1.0.0' + + def test_parse_config_block_with_flags(self): + """Test parsing config block with boolean flags.""" + sql = '''{{ config(track_progress=True, register=True) }} +SELECT * FROM blocks +''' + sql_without_config, config = parse_config_block(sql) + + assert config.track_progress is True + assert config.register is True + + def test_parse_config_block_no_config(self): + """Test parsing SQL without config block.""" + sql = 'SELECT * FROM blocks' + sql_without_config, config = parse_config_block(sql) + + assert sql_without_config == sql + assert config.dependencies == {} + + def test_load_project_config_exists(self, tmp_path): + """Test loading existing project config.""" + config_data = { + 'name': 'test_project', + 'version': '1.0.0', + 'monitoring': {'alert_threshold_minutes': 60}, + } + config_path = tmp_path / 'dbt_project.yml' + with open(config_path, 'w') as f: + yaml.dump(config_data, f) + + config = load_project_config(tmp_path) + + assert config['name'] == 'test_project' + assert config['monitoring']['alert_threshold_minutes'] == 60 + + def test_load_project_config_missing(self, tmp_path): + """Test loading missing project config returns empty dict.""" + config = load_project_config(tmp_path) + assert config == {} + + def test_load_project_config_invalid_yaml(self, tmp_path): + """Test loading invalid YAML raises error.""" + config_path = tmp_path / 'dbt_project.yml' + config_path.write_text('invalid: yaml: content: [unclosed') + + with pytest.raises(ConfigError): + load_project_config(tmp_path) + + +@pytest.mark.integration +class TestDbtDependencies: + """Integration tests for dependency resolution.""" + + def test_build_dependency_graph(self): + """Test building dependency graph from compiled models.""" + compiled_models = { + 'base': CompiledModel( + name='base', + sql='SELECT 1', + config=ModelConfig(), + dependencies={}, + raw_sql='SELECT 1', + ), + 'dependent': CompiledModel( + name='dependent', + sql='SELECT * FROM base', + config=ModelConfig(), + dependencies={'base': 'base'}, + raw_sql='SELECT * FROM {{ ref("base") }}', + ), + } + + graph = build_dependency_graph(compiled_models) + + assert 'base' in graph.get_all_models() + assert 'dependent' in graph.get_all_models() + assert 'base' in graph.get_dependencies('dependent') + assert 'dependent' in graph.get_dependents('base') + + def test_dependency_graph_topological_sort(self): + """Test topological sort of dependency graph.""" + graph = DependencyGraph() + graph.add_model('base', set()) + graph.add_model('intermediate', {'base'}) + graph.add_model('final', {'intermediate'}) + + order = graph.topological_sort() + + assert order == ['base', 'intermediate', 'final'] + + def test_dependency_graph_cycle_detection(self): + """Test cycle detection in dependency graph.""" + graph = DependencyGraph() + graph.add_model('model1', {'model2'}) + graph.add_model('model2', {'model1'}) + + cycles = graph.detect_cycles() + + assert len(cycles) > 0 + assert 'model1' in cycles[0] + assert 'model2' in cycles[0] + + def test_dependency_graph_get_execution_order(self): + """Test getting execution order for selected models.""" + graph = DependencyGraph() + graph.add_model('base', set()) + graph.add_model('intermediate', {'base'}) + graph.add_model('final', {'intermediate'}) + graph.add_model('other', set()) + + order = graph.get_execution_order({'final'}) + + assert 'base' in order + assert 'intermediate' in order + assert 'final' in order + assert 'other' not in order + assert order.index('base') < order.index('intermediate') + assert order.index('intermediate') < order.index('final') + + +@pytest.mark.integration +class TestDbtState: + """Integration tests for state management.""" + + def test_state_database_initialization(self, tmp_path): + """Test state database initialization.""" + db_path = tmp_path / 'state.db' + db = StateDatabase(db_path) + + assert db_path.exists() + + def test_state_database_update_and_get(self, tmp_path): + """Test updating and getting model state.""" + db_path = tmp_path / 'state.db' + db = StateDatabase(db_path) + + now = datetime.now() + db.update_model_state( + 'test_model', + latest_block=1000, + latest_timestamp=now, + job_id=123, + status='fresh', + ) + + state = db.get_model_state('test_model') + + assert state is not None + assert state.model_name == 'test_model' + assert state.latest_block == 1000 + assert state.job_id == 123 + assert state.status == 'fresh' + + def test_state_database_get_all_states(self, tmp_path): + """Test getting all model states.""" + db_path = tmp_path / 'state.db' + db = StateDatabase(db_path) + + db.update_model_state('model1', latest_block=1000, status='fresh') + db.update_model_state('model2', latest_block=2000, status='stale') + + states = db.get_all_model_states() + + assert len(states) == 2 + model_names = {s.model_name for s in states} + assert 'model1' in model_names + assert 'model2' in model_names + + +@pytest.mark.integration +class TestDbtTracker: + """Integration tests for model tracking.""" + + def test_model_tracker_initialization(self, tmp_path): + """Test model tracker initialization.""" + tracker = ModelTracker(tmp_path) + + assert tracker.project_root == tmp_path + assert tracker.state_dir == tmp_path / '.amp-dbt' + + def test_model_tracker_update_progress(self, tmp_path): + """Test updating model progress.""" + tracker = ModelTracker(tmp_path) + + now = datetime.now() + tracker.update_progress('test_model', latest_block=1000, latest_timestamp=now, job_id=123, status='fresh') + + state = tracker.get_model_state('test_model') + + assert state is not None + assert state.latest_block == 1000 + assert state.job_id == 123 + assert state.status == 'fresh' + + def test_model_tracker_get_latest_block(self, tmp_path): + """Test getting latest block for a model.""" + tracker = ModelTracker(tmp_path) + + tracker.update_progress('test_model', latest_block=5000) + + latest_block = tracker.get_latest_block('test_model') + assert latest_block == 5000 + + def test_model_tracker_get_all_states(self, tmp_path): + """Test getting all model states.""" + tracker = ModelTracker(tmp_path) + + tracker.update_progress('model1', latest_block=1000, status='fresh') + tracker.update_progress('model2', latest_block=2000, status='stale') + + states = tracker.get_all_states() + + assert len(states) == 2 + assert 'model1' in states + assert 'model2' in states + + +@pytest.mark.integration +class TestDbtFreshnessMonitor: + """Integration tests for freshness monitoring.""" + + def test_freshness_monitor_check_fresh(self, tmp_path): + """Test checking freshness for fresh data.""" + tracker = ModelTracker(tmp_path) + monitor = FreshnessMonitor(tracker) + + # Update with recent timestamp + recent_time = datetime.now() - timedelta(minutes=10) + tracker.update_progress('test_model', latest_block=1000, latest_timestamp=recent_time, status='fresh') + + result = monitor.check_freshness('test_model') + + assert result.stale is False + assert result.latest_block == 1000 + assert result.latest_timestamp == recent_time + + def test_freshness_monitor_check_stale(self, tmp_path): + """Test checking freshness for stale data.""" + tracker = ModelTracker(tmp_path) + monitor = FreshnessMonitor(tracker) + + # Update with old timestamp + old_time = datetime.now() - timedelta(hours=2) + tracker.update_progress('test_model', latest_block=1000, latest_timestamp=old_time, status='fresh') + + result = monitor.check_freshness('test_model') + + assert result.stale is True + assert result.reason == 'Data is stale' + + def test_freshness_monitor_check_no_data(self, tmp_path): + """Test checking freshness when no data is tracked.""" + tracker = ModelTracker(tmp_path) + monitor = FreshnessMonitor(tracker) + + result = monitor.check_freshness('missing_model') + + assert result.stale is True + assert result.reason == 'No data tracked' + + def test_freshness_monitor_check_all(self, tmp_path): + """Test checking freshness for all models.""" + tracker = ModelTracker(tmp_path) + monitor = FreshnessMonitor(tracker) + + recent_time = datetime.now() - timedelta(minutes=10) + old_time = datetime.now() - timedelta(hours=2) + + tracker.update_progress('fresh_model', latest_block=1000, latest_timestamp=recent_time) + tracker.update_progress('stale_model', latest_block=2000, latest_timestamp=old_time) + + results = monitor.check_all_freshness() + + assert len(results) == 2 + assert results['fresh_model'].stale is False + assert results['stale_model'].stale is True + + +@pytest.mark.integration +class TestDbtEndToEnd: + """End-to-end integration tests for DBT workflows.""" + + def test_full_workflow(self, tmp_path): + """Test complete workflow: init, compile, track, monitor.""" + # Create project structure + models_dir = tmp_path / 'models' + models_dir.mkdir() + + # Create project config + config_data = { + 'name': 'test_project', + 'version': '1.0.0', + 'monitoring': {'alert_threshold_minutes': 30}, + } + config_path = tmp_path / 'dbt_project.yml' + with open(config_path, 'w') as f: + yaml.dump(config_data, f) + + # Create models + (models_dir / 'base.sql').write_text( + '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT block_num, block_hash FROM {{ ref('eth') }}.blocks +''' + ) + (models_dir / 'aggregated.sql').write_text( + '''SELECT + block_num, + COUNT(*) as tx_count +FROM {{ ref('base') }} +GROUP BY block_num +''' + ) + + # Initialize project + project = AmpDbtProject(tmp_path) + + # Compile models + compiled = project.compile_all() + assert len(compiled) == 2 + assert 'base' in compiled + assert 'aggregated' in compiled + + # Get execution order + order = project.get_execution_order() + assert order == ['base', 'aggregated'] + + # Track progress + tracker = ModelTracker(tmp_path) + now = datetime.now() + tracker.update_progress('base', latest_block=1000, latest_timestamp=now, job_id=1, status='fresh') + tracker.update_progress('aggregated', latest_block=1000, latest_timestamp=now, job_id=2, status='fresh') + + # Check freshness + monitor = FreshnessMonitor(tracker) + results = monitor.check_all_freshness() + + assert len(results) == 2 + assert results['base'].stale is False + assert results['aggregated'].stale is False + + def test_workflow_with_cte_inlining(self, tmp_path): + """Test workflow with CTE inlining for internal dependencies.""" + models_dir = tmp_path / 'models' + models_dir.mkdir() + + # Create base model + (models_dir / 'base.sql').write_text( + '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT block_num, block_hash FROM {{ ref('eth') }}.blocks +''' + ) + + # Create dependent model that should get CTE + (models_dir / 'filtered.sql').write_text( + '''SELECT * FROM {{ ref('base') }} +WHERE block_num > 1000 +''' + ) + + project = AmpDbtProject(tmp_path) + compiled = project.compile_all() + + # Check that filtered model has base as dependency + assert 'base' in compiled['filtered'].dependencies + assert compiled['filtered'].dependencies['base'] == 'base' + + # The compiled SQL should reference base (CTE will be added during execution) + assert 'base' in compiled['filtered'].sql or '__REF__base__' not in compiled['filtered'].sql + diff --git a/tests/integration/test_dbt_cli.py b/tests/integration/test_dbt_cli.py new file mode 100644 index 0000000..b992ae7 --- /dev/null +++ b/tests/integration/test_dbt_cli.py @@ -0,0 +1,550 @@ +"""Integration tests for Amp DBT CLI commands.""" + +import json +import tempfile +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import patch + +import pytest +import yaml +from typer.testing import CliRunner + +from amp.dbt.cli import app + + +@pytest.mark.integration +class TestDbtCliInit: + """Integration tests for `amp-dbt init` command.""" + + def test_init_creates_project_structure(self, tmp_path): + """Test that init creates all required directories and files.""" + runner = CliRunner() + + result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert (tmp_path / 'dbt_project.yml').exists() + assert (tmp_path / 'models').exists() + assert (tmp_path / 'macros').exists() + assert (tmp_path / 'tests').exists() + assert (tmp_path / 'docs').exists() + assert (tmp_path / '.gitignore').exists() + assert (tmp_path / 'models' / 'example_model.sql').exists() + + # Check config content + config = yaml.safe_load((tmp_path / 'dbt_project.yml').read_text()) + assert config['name'] == 'test_project' + assert config['version'] == '1.0.0' + assert 'monitoring' in config + + def test_init_with_existing_config(self, tmp_path): + """Test that init doesn't overwrite existing config.""" + runner = CliRunner() + + # Create existing config + config_path = tmp_path / 'dbt_project.yml' + existing_config = {'name': 'existing_project', 'version': '0.1.0'} + with open(config_path, 'w') as f: + yaml.dump(existing_config, f) + + result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + # Config should still be the original + loaded_config = yaml.safe_load(config_path.read_text()) + assert loaded_config['name'] == 'existing_project' + + def test_init_with_project_dir(self, tmp_path): + """Test init with custom project directory.""" + runner = CliRunner() + project_dir = tmp_path / 'custom_project' + + result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(project_dir)]) + + assert result.exit_code == 0 + assert project_dir.exists() + assert (project_dir / 'dbt_project.yml').exists() + + def test_init_defaults_project_name(self, tmp_path): + """Test that init uses directory name when project name not provided.""" + runner = CliRunner() + project_dir = tmp_path / 'my_project' + project_dir.mkdir() + + result = runner.invoke(app, ['init', '--project-dir', str(project_dir)]) + + assert result.exit_code == 0 + config = yaml.safe_load((project_dir / 'dbt_project.yml').read_text()) + assert config['name'] == 'my_project' + + def test_init_creates_gitignore(self, tmp_path): + """Test that init creates .gitignore file.""" + runner = CliRunner() + + result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + gitignore_path = tmp_path / '.gitignore' + assert gitignore_path.exists() + assert '.amp-dbt/' in gitignore_path.read_text() + + +@pytest.mark.integration +class TestDbtCliCompile: + """Integration tests for `amp-dbt compile` command.""" + + def test_compile_single_model(self, tmp_path): + """Test compiling a single model.""" + runner = CliRunner() + + # Setup project + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'test_model.sql').write_text( + '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT * FROM {{ ref('eth') }}.blocks +''' + ) + + result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'test_model' in result.stdout + assert 'Compiled' in result.stdout or 'Compilation Results' in result.stdout + + def test_compile_with_select_pattern(self, tmp_path): + """Test compile with select pattern.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'stg_model1.sql').write_text('SELECT 1') + (models_dir / 'stg_model2.sql').write_text('SELECT 2') + (models_dir / 'final_model.sql').write_text('SELECT 3') + + result = runner.invoke( + app, + ['compile', '--select', 'stg_*', '--project-dir', str(tmp_path)] + ) + + assert result.exit_code == 0 + assert 'stg_model1' in result.stdout + assert 'stg_model2' in result.stdout + # final_model should not appear since it doesn't match pattern + # (unless it's in execution order output) + + def test_compile_with_show_sql(self, tmp_path): + """Test compile with --show-sql flag.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'test_model.sql').write_text('SELECT 1 as value') + + result = runner.invoke( + app, + ['compile', '--show-sql', '--project-dir', str(tmp_path)] + ) + + assert result.exit_code == 0 + assert 'Compiled SQL' in result.stdout + assert 'SELECT 1' in result.stdout + + def test_compile_no_models_found(self, tmp_path): + """Test compile when no models exist.""" + runner = CliRunner() + + result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'No models found' in result.stdout + + def test_compile_with_internal_dependencies(self, tmp_path): + """Test compile shows execution order for models with dependencies.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'base.sql').write_text('SELECT 1') + (models_dir / 'dependent.sql').write_text('SELECT * FROM {{ ref("base") }}') + + result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'base' in result.stdout + assert 'dependent' in result.stdout + + def test_compile_error_handling(self, tmp_path): + """Test compile handles errors gracefully.""" + runner = CliRunner() + + # Create invalid project directory + invalid_dir = tmp_path / 'nonexistent' / 'nested' + + result = runner.invoke(app, ['compile', '--project-dir', str(invalid_dir)]) + + # Should handle error gracefully + assert result.exit_code != 0 or 'Error' in result.stdout + + +@pytest.mark.integration +class TestDbtCliList: + """Integration tests for `amp-dbt list` command.""" + + def test_list_all_models(self, tmp_path): + """Test listing all models.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'model1.sql').write_text('SELECT 1') + (models_dir / 'model2.sql').write_text('SELECT 2') + + result = runner.invoke(app, ['list', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'model1' in result.stdout + assert 'model2' in result.stdout + + def test_list_with_select_pattern(self, tmp_path): + """Test list with select pattern.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'stg_model.sql').write_text('SELECT 1') + (models_dir / 'final_model.sql').write_text('SELECT 2') + + result = runner.invoke( + app, + ['list', '--select', 'stg_*', '--project-dir', str(tmp_path)] + ) + + assert result.exit_code == 0 + assert 'stg_model' in result.stdout + assert 'final_model' not in result.stdout + + def test_list_no_models(self, tmp_path): + """Test list when no models exist.""" + runner = CliRunner() + + result = runner.invoke(app, ['list', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'No models found' in result.stdout + + def test_list_shows_paths(self, tmp_path): + """Test that list shows model paths.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'test_model.sql').write_text('SELECT 1') + + result = runner.invoke(app, ['list', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'Models' in result.stdout + assert 'test_model' in result.stdout + + +@pytest.mark.integration +class TestDbtCliRun: + """Integration tests for `amp-dbt run` command.""" + + def test_run_dry_run_mode(self, tmp_path): + """Test run in dry-run mode.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'base.sql').write_text('SELECT 1') + (models_dir / 'dependent.sql').write_text('SELECT * FROM {{ ref("base") }}') + + result = runner.invoke( + app, + ['run', '--dry-run', '--project-dir', str(tmp_path)] + ) + + assert result.exit_code == 0 + assert 'Dry run mode' in result.stdout or 'Execution Plan' in result.stdout + assert 'base' in result.stdout + assert 'dependent' in result.stdout + + def test_run_updates_tracker(self, tmp_path): + """Test that run updates model tracker.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'test_model.sql').write_text('SELECT 1') + + result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + + # Check that state was updated + from amp.dbt.tracker import ModelTracker + tracker = ModelTracker(tmp_path) + states = tracker.get_all_states() + assert 'test_model' in states + + def test_run_with_select_pattern(self, tmp_path): + """Test run with select pattern.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'stg_model.sql').write_text('SELECT 1') + (models_dir / 'final_model.sql').write_text('SELECT 2') + + result = runner.invoke( + app, + ['run', '--select', 'stg_*', '--project-dir', str(tmp_path)] + ) + + assert result.exit_code == 0 + assert 'stg_model' in result.stdout + + def test_run_no_models(self, tmp_path): + """Test run when no models exist.""" + runner = CliRunner() + + result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'No models found' in result.stdout + + def test_run_shows_execution_order(self, tmp_path): + """Test that run shows execution order.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'base.sql').write_text('SELECT 1') + (models_dir / 'intermediate.sql').write_text('SELECT * FROM {{ ref("base") }}') + (models_dir / 'final.sql').write_text('SELECT * FROM {{ ref("intermediate") }}') + + result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'Executing' in result.stdout or 'models' in result.stdout + + +@pytest.mark.integration +class TestDbtCliStatus: + """Integration tests for `amp-dbt status` command.""" + + def test_status_no_models(self, tmp_path): + """Test status when no models are tracked.""" + runner = CliRunner() + + result = runner.invoke(app, ['status', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'No models with tracked state' in result.stdout + + def test_status_with_tracked_models(self, tmp_path): + """Test status with tracked models.""" + runner = CliRunner() + + # Setup tracker with some state + from amp.dbt.tracker import ModelTracker + tracker = ModelTracker(tmp_path) + tracker.update_progress('test_model', latest_block=1000, latest_timestamp=datetime.now()) + + result = runner.invoke(app, ['status', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'test_model' in result.stdout + assert 'Fresh' in result.stdout or 'Stale' in result.stdout or 'Status' in result.stdout + + def test_status_with_stale_data(self, tmp_path): + """Test status shows stale data correctly.""" + runner = CliRunner() + + from amp.dbt.tracker import ModelTracker + tracker = ModelTracker(tmp_path) + old_time = datetime.now() - timedelta(hours=2) + tracker.update_progress('stale_model', latest_block=1000, latest_timestamp=old_time) + + result = runner.invoke(app, ['status', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'stale_model' in result.stdout + + def test_status_with_all_flag(self, tmp_path): + """Test status with --all flag.""" + runner = CliRunner() + + from amp.dbt.tracker import ModelTracker + tracker = ModelTracker(tmp_path) + tracker.update_progress('model1', latest_block=1000, latest_timestamp=datetime.now()) + tracker.update_progress('model2', latest_block=2000, latest_timestamp=datetime.now()) + + result = runner.invoke(app, ['status', '--all', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'model1' in result.stdout + assert 'model2' in result.stdout + + +@pytest.mark.integration +class TestDbtCliMonitor: + """Integration tests for `amp-dbt monitor` command.""" + + def test_monitor_no_models(self, tmp_path): + """Test monitor when no models are tracked.""" + runner = CliRunner() + + result = runner.invoke(app, ['monitor', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'No tracked models' in result.stdout + + def test_monitor_with_models(self, tmp_path): + """Test monitor with tracked models.""" + runner = CliRunner() + + # Setup tracker with some state + from amp.dbt.tracker import ModelTracker + tracker = ModelTracker(tmp_path) + tracker.update_progress('test_model', latest_block=1000, latest_timestamp=datetime.now(), job_id=123) + + result = runner.invoke(app, ['monitor', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'Job Monitor' in result.stdout + assert 'test_model' in result.stdout + + def test_monitor_with_job_mapping(self, tmp_path): + """Test monitor loads job mappings from jobs.json.""" + runner = CliRunner() + + # Create .amp-dbt directory and jobs.json + amp_dbt_dir = tmp_path / '.amp-dbt' + amp_dbt_dir.mkdir(parents=True) + jobs_file = amp_dbt_dir / 'jobs.json' + jobs_file.write_text(json.dumps({'model1': 123, 'model2': 456})) + + from amp.dbt.tracker import ModelTracker + tracker = ModelTracker(tmp_path) + tracker.update_progress('model1', latest_block=1000, latest_timestamp=datetime.now(), job_id=123) + + result = runner.invoke(app, ['monitor', '--project-dir', str(tmp_path)]) + + assert result.exit_code == 0 + assert 'Job Monitor' in result.stdout + + def test_monitor_with_watch_flag(self, tmp_path): + """Test monitor with --watch flag (should not hang in test).""" + runner = CliRunner() + + from amp.dbt.tracker import ModelTracker + tracker = ModelTracker(tmp_path) + tracker.update_progress('test_model', latest_block=1000, latest_timestamp=datetime.now()) + + # Use a very short interval and expect it to exit quickly + # In watch mode, it would loop, but we'll test it exits gracefully + result = runner.invoke( + app, + ['monitor', '--watch', '--interval', '1', '--project-dir', str(tmp_path)], + # This will timeout quickly in watch mode, but we can test the initial output + ) + + # Should show monitor output + assert 'Job Monitor' in result.stdout or 'test_model' in result.stdout + + +@pytest.mark.integration +class TestDbtCliErrorHandling: + """Integration tests for CLI error handling.""" + + def test_compile_with_invalid_project_dir(self, tmp_path): + """Test compile handles invalid project directory.""" + runner = CliRunner() + + invalid_dir = tmp_path / 'nonexistent' / 'nested' + + result = runner.invoke(app, ['compile', '--project-dir', str(invalid_dir)]) + + # Should handle error gracefully + assert result.exit_code != 0 or 'Error' in result.stdout + + def test_run_with_circular_dependency(self, tmp_path): + """Test run handles circular dependencies.""" + runner = CliRunner() + + models_dir = tmp_path / 'models' + models_dir.mkdir(parents=True) + (models_dir / 'model1.sql').write_text('SELECT * FROM {{ ref("model2") }}') + (models_dir / 'model2.sql').write_text('SELECT * FROM {{ ref("model1") }}') + + result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)]) + + # Should detect circular dependency + assert result.exit_code != 0 or 'Circular' in result.stdout or 'Dependency Error' in result.stdout + + def test_list_with_invalid_project(self, tmp_path): + """Test list handles invalid project.""" + runner = CliRunner() + + invalid_dir = tmp_path / 'nonexistent' + + result = runner.invoke(app, ['list', '--project-dir', str(invalid_dir)]) + + # Should handle error gracefully + assert result.exit_code != 0 or 'Error' in result.stdout + + +@pytest.mark.integration +class TestDbtCliIntegration: + """End-to-end integration tests for CLI workflows.""" + + def test_full_cli_workflow(self, tmp_path): + """Test complete CLI workflow: init -> compile -> run -> status.""" + runner = CliRunner() + + # Step 1: Initialize project + result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)]) + assert result.exit_code == 0 + + # Step 2: Compile models + result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)]) + assert result.exit_code == 0 + + # Step 3: Run models + result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)]) + assert result.exit_code == 0 + + # Step 4: Check status + result = runner.invoke(app, ['status', '--project-dir', str(tmp_path)]) + assert result.exit_code == 0 + assert 'example_model' in result.stdout or 'No models' in result.stdout + + def test_cli_with_custom_models(self, tmp_path): + """Test CLI with custom model files.""" + runner = CliRunner() + + # Initialize project + runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)]) + + # Add custom model + models_dir = tmp_path / 'models' + (models_dir / 'custom_model.sql').write_text( + '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }} +SELECT block_num FROM {{ ref('eth') }}.blocks LIMIT 10 +''' + ) + + # List models + result = runner.invoke(app, ['list', '--project-dir', str(tmp_path)]) + assert result.exit_code == 0 + assert 'custom_model' in result.stdout + + # Compile models + result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)]) + assert result.exit_code == 0 + assert 'custom_model' in result.stdout +