-
Notifications
You must be signed in to change notification settings - Fork 1
feat: add dbt integration (Phase 1) #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implement core dbt module for query composition with Jinja templating, external dataset reference resolution, and project initialization. Includes CLI commands (, , ), model compilation engine, and comprehensive documentation.
🎨 Ruff Formatting & Linting ReportRun the following commands locally to fix issues: ruff format .
ruff check . --fixFormatting changes needed:--- src/amp/admin/datasets.py
+++ src/amp/admin/datasets.py
@@ -200,7 +200,9 @@
response = self._admin._request('GET', path)
return response.json()
- def describe(self, namespace: str, name: str, revision: str = 'latest') -> Dict[str, List[Dict[str, Union[str, bool]]]]:
+ def describe(
+ self, namespace: str, name: str, revision: str = 'latest'
+ ) -> Dict[str, List[Dict[str, Union[str, bool]]]]:
"""Get a structured summary of tables and columns in a dataset.
Returns a dictionary mapping table names to lists of column information,
--- src/amp/dbt/__init__.py
+++ src/amp/dbt/__init__.py
@@ -22,4 +22,3 @@
'DependencyError',
'ProjectNotFoundError',
]
-
--- src/amp/dbt/cli.py
+++ src/amp/dbt/cli.py
@@ -20,7 +20,9 @@
@app.command()
def init(
project_name: Optional[str] = typer.Argument(None, help='Project name (default: current directory name)'),
- project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'),
+ project_dir: Optional[Path] = typer.Option(
+ None, '--project-dir', help='Project directory (default: current directory)'
+ ),
):
"""Initialize a new Amp DBT project."""
if project_dir is None:
@@ -41,7 +43,7 @@
# Create parent directories if they don't exist
project_dir.mkdir(parents=True, exist_ok=True)
-
+
models_dir.mkdir(parents=True, exist_ok=True)
macros_dir.mkdir(parents=True, exist_ok=True)
tests_dir.mkdir(parents=True, exist_ok=True)
@@ -80,7 +82,7 @@
# Create example model
example_model_path = models_dir / 'example_model.sql'
if not example_model_path.exists():
- example_model = '''-- Example model
+ example_model = """-- Example model
{{ config(
dependencies={'eth': '_/eth_firehose@1.0.0'},
description='Example model showing how to use ref()'
@@ -92,7 +94,7 @@
timestamp
FROM {{ ref('eth') }}.blocks
LIMIT 10
-'''
+"""
example_model_path.write_text(example_model)
console.print(f' [green]✓[/green] Created example model: {example_model_path}')
@@ -107,7 +109,9 @@
def compile(
select: Optional[str] = typer.Option(None, '--select', '-s', help='Glob pattern to select models (e.g., stg_*)'),
show_sql: bool = typer.Option(False, '--show-sql', help='Show compiled SQL'),
- project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'),
+ project_dir: Optional[Path] = typer.Option(
+ None, '--project-dir', help='Project directory (default: current directory)'
+ ),
):
"""Compile models."""
try:
@@ -131,22 +135,21 @@
# Separate internal and external dependencies
internal_deps = [k for k, v in compiled.dependencies.items() if k == v]
external_deps = [f'{k}: {v}' for k, v in compiled.dependencies.items() if k != v]
-
+
deps_parts = []
if internal_deps:
deps_parts.append(f'Internal: {", ".join(internal_deps)}')
if external_deps:
deps_parts.append(f'External: {", ".join(external_deps)}')
-
+
deps_str = ' | '.join(deps_parts) if deps_parts else 'None'
table.add_row(name, '✓ Compiled', deps_str)
console.print(table)
-
+
# Show execution order if there are internal dependencies
has_internal = any(
- any(k == v for k, v in compiled.dependencies.items())
- for compiled in compiled_models.values()
+ any(k == v for k, v in compiled.dependencies.items()) for compiled in compiled_models.values()
)
if has_internal:
try:
@@ -176,7 +179,9 @@
@app.command()
def list(
select: Optional[str] = typer.Option(None, '--select', '-s', help='Glob pattern to filter models'),
- project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'),
+ project_dir: Optional[Path] = typer.Option(
+ None, '--project-dir', help='Project directory (default: current directory)'
+ ),
):
"""List all models in the project."""
try:
@@ -210,7 +215,9 @@
@app.command()
def run(
select: Optional[str] = typer.Option(None, '--select', '-s', help='Glob pattern to select models (e.g., stg_*)'),
- project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'),
+ project_dir: Optional[Path] = typer.Option(
+ None, '--project-dir', help='Project directory (default: current directory)'
+ ),
dry_run: bool = typer.Option(False, '--dry-run', help='Show what would be executed without running'),
):
"""Execute models in dependency order."""
@@ -280,13 +287,16 @@
except Exception as e:
console.print(f'[bold red]Unexpected error:[/bold red] {e}')
import traceback
+
traceback.print_exc()
sys.exit(1)
@app.command()
def status(
- project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'),
+ project_dir: Optional[Path] = typer.Option(
+ None, '--project-dir', help='Project directory (default: current directory)'
+ ),
all: bool = typer.Option(False, '--all', help='Show status for all models'),
):
"""Check data freshness status."""
@@ -342,13 +352,16 @@
except Exception as e:
console.print(f'[bold red]Error:[/bold red] {e}')
import traceback
+
traceback.print_exc()
sys.exit(1)
@app.command()
def monitor(
- project_dir: Optional[Path] = typer.Option(None, '--project-dir', help='Project directory (default: current directory)'),
+ project_dir: Optional[Path] = typer.Option(
+ None, '--project-dir', help='Project directory (default: current directory)'
+ ),
watch: bool = typer.Option(False, '--watch', help='Auto-refresh dashboard'),
interval: int = typer.Option(5, '--interval', help='Refresh interval in seconds (default: 5)'),
):
@@ -399,11 +412,7 @@
for model_name, state in states.items():
job_id_str = str(state.job_id) if state.job_id else '-'
block_str = str(state.latest_block) if state.latest_block else '-'
- updated_str = (
- state.last_updated.strftime('%Y-%m-%d %H:%M:%S')
- if state.last_updated
- else '-'
- )
+ updated_str = state.last_updated.strftime('%Y-%m-%d %H:%M:%S') if state.last_updated else '-'
# Status styling
if state.status == 'fresh':
@@ -447,6 +456,7 @@
except Exception as e:
console.print(f'[bold red]Error:[/bold red] {e}')
import traceback
+
traceback.print_exc()
sys.exit(1)
@@ -458,4 +468,3 @@
if __name__ == '__main__':
main()
-
--- src/amp/dbt/compiler.py
+++ src/amp/dbt/compiler.py
@@ -215,7 +215,7 @@
# Flatten all CTEs from dependencies into a single WITH clause
# Also check if final_sql already has a WITH clause and merge it
all_ctes = {} # Map of CTE name -> SQL
-
+
if internal_deps:
for dep_name, dep_sql in internal_deps.items():
# Extract CTEs and the final SELECT from dependency SQL
@@ -224,7 +224,7 @@
all_ctes.update(ctes_from_dep)
# Add this dependency as a CTE using just its SELECT part
all_ctes[dep_name] = select_part
-
+
# Check if final_sql already has a WITH clause (from the model's own SQL)
final_sql_upper = final_sql.upper().strip()
if final_sql_upper.startswith('WITH '):
@@ -234,7 +234,7 @@
all_ctes.update(model_ctes)
# Use the model's SELECT as the final SQL
final_sql = model_select
-
+
# Build final CTE section if we have any CTEs
if all_ctes:
cte_parts = []
@@ -256,41 +256,41 @@
def _extract_ctes_and_select(self, sql: str) -> tuple[Dict[str, str], str]:
"""Extract CTEs and final SELECT from SQL that may contain WITH clauses.
-
+
Returns:
Tuple of (dict of CTE name -> SQL, final SELECT statement)
"""
import re
-
+
sql_upper = sql.upper().strip()
-
+
# If no WITH clause, return empty CTEs and the SQL as SELECT
if 'WITH ' not in sql_upper:
return {}, sql.strip()
-
+
ctes = {}
-
+
# Use regex to find all CTE definitions: name AS (content)
# Pattern: word AS ( ... ) optionally followed by comma
# We need to handle nested parentheses correctly
with_match = re.search(r'^WITH\s+', sql_upper)
if not with_match:
return {}, sql.strip()
-
+
after_with_start = with_match.end()
after_with = sql[after_with_start:]
-
+
# Find where WITH clause ends (the SELECT that's not inside parentheses)
# Count parentheses to find the SELECT after all CTEs
paren_count = 0
in_string = False
string_char = None
i = 0
-
+
while i < len(after_with):
char = after_with[i]
- prev_char = after_with[i-1] if i > 0 else ''
-
+ prev_char = after_with[i - 1] if i > 0 else ''
+
# Track string literals
if char in ("'", '"') and prev_char != '\\':
if not in_string:
@@ -301,11 +301,11 @@
string_char = None
i += 1
continue
-
+
if in_string:
i += 1
continue
-
+
# Track parentheses
if char == '(':
paren_count += 1
@@ -313,23 +313,23 @@
paren_count -= 1
# When we're back to 0 or negative, check if SELECT follows
if paren_count <= 0:
- remaining = after_with[i+1:].strip()
+ remaining = after_with[i + 1 :].strip()
if remaining.upper().startswith('SELECT'):
# Found the final SELECT - extract CTEs from before this point
- with_clause_part = after_with[:i+1]
+ with_clause_part = after_with[: i + 1]
ctes = self._parse_ctes_from_with('WITH ' + with_clause_part)
select_part = remaining
return ctes, select_part
-
+
i += 1
-
+
# Fallback: use regex-based extraction
return self._simple_extract_ctes(sql)
-
+
def _simple_extract_ctes(self, sql: str) -> tuple[Dict[str, str], str]:
"""Simple fallback to extract CTEs using regex."""
import re
-
+
# Find WHERE WITH ends and SELECT begins
# Pattern: WITH ... ) SELECT (the SELECT after closing paren of last CTE)
# We want to find the SELECT that comes after the WITH clause ends
@@ -338,30 +338,31 @@
# Find the position of SELECT (group 1 start)
select_start = match.start(1)
return {}, sql[select_start:].strip()
-
+
# Alternative: look for SELECT after WITH
match2 = re.search(r'WITH\s+.*?(SELECT\s+)', sql, re.IGNORECASE | re.DOTALL)
if match2:
select_start = match2.start(1)
return {}, sql[select_start:].strip()
-
+
return {}, sql.strip()
-
+
def _parse_ctes_from_with(self, with_clause: str) -> Dict[str, str]:
"""Parse CTE definitions from a WITH clause.
-
+
Args:
with_clause: SQL starting with WITH and ending with closing paren
-
+
Returns:
Dictionary mapping CTE names to their SQL
"""
import re
+
ctes = {}
-
- # Remove the leading "WITH "
+
+ # Remove the leading "WITH "
content = with_clause[5:].strip() if with_clause.upper().startswith('WITH ') else with_clause
-
+
# Use regex to find CTE patterns: name AS (content)
# Handle nested parentheses by matching balanced parens
# Pattern: (\w+) AS \( ... \)
@@ -371,20 +372,20 @@
as_match = re.search(r'\s+AS\s+\(', content[i:], re.IGNORECASE)
if not as_match:
break
-
+
as_pos = i + as_match.start()
-
+
# Find CTE name before " AS ("
name_end = as_pos
name_start = name_end
- while name_start > 0 and (content[name_start-1].isalnum() or content[name_start-1] == '_'):
+ while name_start > 0 and (content[name_start - 1].isalnum() or content[name_start - 1] == '_'):
name_start -= 1
cte_name = content[name_start:name_end].strip()
-
+
if not cte_name:
i = as_pos + as_match.end()
continue
-
+
# Find matching closing paren
# as_match positions are relative to content[i:], so add i to get absolute position
paren_start = i + as_match.end() - 1 # Position of (
@@ -392,11 +393,11 @@
in_string = False
string_char = None
j = paren_start + 1
-
+
while j < len(content) and paren_count > 0:
char = content[j]
- prev_char = content[j-1] if j > 0 else ''
-
+ prev_char = content[j - 1] if j > 0 else ''
+
# Track string literals
if char in ("'", '"') and prev_char != '\\':
if not in_string:
@@ -407,60 +408,60 @@
string_char = None
j += 1
continue
-
+
if in_string:
j += 1
continue
-
+
if char == '(':
paren_count += 1
elif char == ')':
paren_count -= 1
-
+
j += 1
-
+
if paren_count == 0:
# Found matching closing paren
- cte_sql = content[paren_start + 1:j - 1].strip()
+ cte_sql = content[paren_start + 1 : j - 1].strip()
ctes[cte_name] = cte_sql
i = j # Continue after this CTE
else:
break
-
+
return ctes
-
+
def _extract_select_from_sql(self, sql: str) -> str:
"""Extract the final SELECT statement from SQL that may contain CTEs.
-
+
If the SQL has WITH clauses, returns just the SELECT part after all WITHs.
Otherwise returns the SQL as-is.
-
+
Args:
sql: SQL that may contain WITH clauses
-
+
Returns:
SQL with only the final SELECT statement (CTEs removed)
"""
sql_upper = sql.upper().strip()
-
+
# If no WITH clause, return as-is
if 'WITH ' not in sql_upper:
return sql.strip()
-
+
# Find the final SELECT that comes after all CTE definitions
# Pattern: WITH cte1 AS (...), cte2 AS (...) SELECT ...
# We need to find the SELECT that's not inside a CTE definition
-
+
# Simple approach: find the last closing paren that's followed by SELECT
# This indicates the end of the CTE list
paren_count = 0
in_string = False
string_char = None
-
+
for i in range(len(sql)):
char = sql[i]
- prev_char = sql[i-1] if i > 0 else ''
-
+ prev_char = sql[i - 1] if i > 0 else ''
+
# Track string literals
if char in ("'", '"') and prev_char != '\\':
if not in_string:
@@ -470,10 +471,10 @@
in_string = False
string_char = None
continue
-
+
if in_string:
continue
-
+
# Track parentheses
if char == '(':
paren_count += 1
@@ -482,21 +483,21 @@
# Check if this closes the last CTE (paren_count back to 0 or negative)
if paren_count <= 0:
# Look ahead for SELECT (skip whitespace and newlines)
- remaining = sql[i+1:].strip()
+ remaining = sql[i + 1 :].strip()
if remaining.upper().startswith('SELECT'):
return remaining
# Also check if there's a comma then more CTEs, then SELECT
if remaining.startswith(',') or remaining.startswith('\n'):
# Continue searching
continue
-
+
# Fallback: if parsing failed, try regex to find SELECT after WITH
import re
+
# Match: WITH ... SELECT (capture the SELECT part)
match = re.search(r'WITH\s+.*?\)\s+SELECT\s+', sql, re.IGNORECASE | re.DOTALL)
if match:
- return sql[match.end() - len('SELECT '):].strip()
-
+ return sql[match.end() - len('SELECT ') :].strip()
+
# Last resort: return as-is (better than failing)
return sql.strip()
-
--- src/amp/dbt/config.py
+++ src/amp/dbt/config.py
@@ -39,7 +39,7 @@
config = ModelConfig()
# Extract dependencies
- deps_match = re.search(r"dependencies\s*=\s*\{([^}]+)\}", config_str, re.DOTALL)
+ deps_match = re.search(r'dependencies\s*=\s*\{([^}]+)\}', config_str, re.DOTALL)
if deps_match:
deps_str = deps_match.group(1)
dependencies = {}
@@ -123,4 +123,3 @@
raise ConfigError(f'Invalid YAML in profiles.yml: {e}') from e
except Exception as e:
raise ConfigError(f'Failed to load profiles.yml: {e}') from e
-
--- src/amp/dbt/dependencies.py
+++ src/amp/dbt/dependencies.py
@@ -212,4 +212,3 @@
graph.add_model(model_name, internal_deps)
return graph
-
--- src/amp/dbt/exceptions.py
+++ src/amp/dbt/exceptions.py
@@ -29,4 +29,3 @@
"""Raised when dependency resolution fails."""
pass
-
--- src/amp/dbt/models.py
+++ src/amp/dbt/models.py
@@ -26,4 +26,3 @@
config: ModelConfig
dependencies: Dict[str, str] # Maps ref name to dataset reference
raw_sql: str # Original SQL before compilation
-
--- src/amp/dbt/monitor.py
+++ src/amp/dbt/monitor.py
@@ -118,4 +118,3 @@
pass
return results
-
--- src/amp/dbt/project.py
+++ src/amp/dbt/project.py
@@ -168,9 +168,7 @@
internal_deps[dep_name] = final_compiled[dep_name].sql
# Compile with CTEs
- compiled = self.compiler.compile_with_ctes(
- sql, model_name, config, internal_deps, available_models
- )
+ compiled = self.compiler.compile_with_ctes(sql, model_name, config, internal_deps, available_models)
final_compiled[model_name] = compiled
return final_compiled
@@ -194,12 +192,9 @@
try:
sql, config = self.load_model(model_path)
model_name = model_path.stem
- compiled[model_name] = self.compiler.compile(
- sql, model_name, config, {p.stem for p in model_paths}
- )
+ compiled[model_name] = self.compiler.compile(sql, model_name, config, {p.stem for p in model_paths})
except Exception:
pass # Skip models that fail to compile
graph = build_dependency_graph(compiled)
return graph.topological_sort()
-
--- src/amp/dbt/state.py
+++ src/amp/dbt/state.py
@@ -65,7 +65,7 @@
cursor = conn.cursor()
# Create model_state table
- cursor.execute('''
+ cursor.execute("""
CREATE TABLE IF NOT EXISTS model_state (
model_name TEXT PRIMARY KEY,
connection_name TEXT,
@@ -75,10 +75,10 @@
job_id INTEGER,
status TEXT
)
- ''')
+ """)
# Create job_history table
- cursor.execute('''
+ cursor.execute("""
CREATE TABLE IF NOT EXISTS job_history (
job_id INTEGER PRIMARY KEY,
model_name TEXT,
@@ -88,7 +88,7 @@
final_block INTEGER,
rows_processed INTEGER
)
- ''')
+ """)
# Create indexes
cursor.execute('CREATE INDEX IF NOT EXISTS idx_model_state_status ON model_state(status)')
@@ -111,9 +111,7 @@
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
- cursor.execute(
- 'SELECT * FROM model_state WHERE model_name = ?', (model_name,)
- )
+ cursor.execute('SELECT * FROM model_state WHERE model_name = ?', (model_name,))
row = cursor.fetchone()
conn.close()
@@ -124,12 +122,8 @@
model_name=row['model_name'],
connection_name=row['connection_name'],
latest_block=row['latest_block'],
- latest_timestamp=datetime.fromisoformat(row['latest_timestamp'])
- if row['latest_timestamp']
- else None,
- last_updated=datetime.fromisoformat(row['last_updated'])
- if row['last_updated']
- else None,
+ latest_timestamp=datetime.fromisoformat(row['latest_timestamp']) if row['latest_timestamp'] else None,
+ last_updated=datetime.fromisoformat(row['last_updated']) if row['last_updated'] else None,
job_id=row['job_id'],
status=row['status'] or 'unknown',
)
@@ -184,17 +178,15 @@
updates.append('last_updated = CURRENT_TIMESTAMP')
params.append(model_name)
- cursor.execute(
- f'UPDATE model_state SET {", ".join(updates)} WHERE model_name = ?', params
- )
+ cursor.execute(f'UPDATE model_state SET {", ".join(updates)} WHERE model_name = ?', params)
else:
# Insert new record
cursor.execute(
- '''
+ """
INSERT INTO model_state
(model_name, connection_name, latest_block, latest_timestamp, job_id, status)
VALUES (?, ?, ?, ?, ?, ?)
- ''',
+ """,
(
model_name,
connection_name,
@@ -218,11 +210,11 @@
cursor = conn.cursor()
cursor.execute(
- '''
+ """
INSERT OR REPLACE INTO job_history
(job_id, model_name, status, started_at, completed_at, final_block, rows_processed)
VALUES (?, ?, ?, ?, ?, ?, ?)
- ''',
+ """,
(
job_history.job_id,
job_history.model_name,
@@ -261,9 +253,7 @@
latest_timestamp=datetime.fromisoformat(row['latest_timestamp'])
if row['latest_timestamp']
else None,
- last_updated=datetime.fromisoformat(row['last_updated'])
- if row['last_updated']
- else None,
+ last_updated=datetime.fromisoformat(row['last_updated']) if row['last_updated'] else None,
job_id=row['job_id'],
status=row['status'] or 'unknown',
)
@@ -303,16 +293,11 @@
job_id=row['job_id'],
model_name=row['model_name'],
status=row['status'],
- started_at=datetime.fromisoformat(row['started_at'])
- if row['started_at']
- else None,
- completed_at=datetime.fromisoformat(row['completed_at'])
- if row['completed_at']
- else None,
+ started_at=datetime.fromisoformat(row['started_at']) if row['started_at'] else None,
+ completed_at=datetime.fromisoformat(row['completed_at']) if row['completed_at'] else None,
final_block=row['final_block'],
rows_processed=row['rows_processed'],
)
)
return history
-
--- src/amp/dbt/tracker.py
+++ src/amp/dbt/tracker.py
@@ -175,4 +175,3 @@
results[model_name] = self.check_freshness(model_name)
return results
-
--- src/amp/registry/datasets.py
+++ src/amp/registry/datasets.py
@@ -199,7 +199,9 @@
response = self._registry._request('GET', path)
return response.json()
- def describe(self, namespace: str, name: str, version: str = 'latest') -> Dict[str, List[Dict[str, Union[str, bool]]]]:
+ def describe(
+ self, namespace: str, name: str, version: str = 'latest'
+ ) -> Dict[str, List[Dict[str, Union[str, bool]]]]:
"""Get a structured summary of tables and columns in a dataset.
Returns a dictionary mapping table names to lists of column information,
--- tests/integration/test_dbt.py
+++ tests/integration/test_dbt.py
@@ -82,9 +82,9 @@
models_dir = tmp_path / 'models'
models_dir.mkdir()
- model_sql = '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ model_sql = """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT * FROM {{ ref('eth') }}.blocks
-'''
+"""
(models_dir / 'test_model.sql').write_text(model_sql)
project = AmpDbtProject(tmp_path)
@@ -105,10 +105,10 @@
models_dir = tmp_path / 'models'
models_dir.mkdir()
- model_sql = '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ model_sql = """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT block_num, block_hash FROM {{ ref('eth') }}.blocks
LIMIT 10
-'''
+"""
(models_dir / 'test_model.sql').write_text(model_sql)
project = AmpDbtProject(tmp_path)
@@ -126,14 +126,14 @@
# Create multiple models
(models_dir / 'model1.sql').write_text(
- '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT * FROM {{ ref('eth') }}.blocks LIMIT 10
-'''
+"""
)
(models_dir / 'model2.sql').write_text(
- '''{{ config(dependencies={'arb': '_/arb_firehose@1.0.0'}) }}
+ """{{ config(dependencies={'arb': '_/arb_firehose@1.0.0'}) }}
SELECT * FROM {{ ref('arb') }}.blocks LIMIT 10
-'''
+"""
)
project = AmpDbtProject(tmp_path)
@@ -152,16 +152,16 @@
# Create base model
(models_dir / 'base_model.sql').write_text(
- '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT block_num, block_hash FROM {{ ref('eth') }}.blocks
-'''
+"""
)
# Create dependent model
(models_dir / 'dependent_model.sql').write_text(
- '''SELECT * FROM {{ ref('base_model') }}
+ """SELECT * FROM {{ ref('base_model') }}
WHERE block_num > 1000
-'''
+"""
)
project = AmpDbtProject(tmp_path)
@@ -317,9 +317,9 @@
def test_parse_config_block_simple(self):
"""Test parsing simple config block."""
- sql = '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ sql = """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT * FROM blocks
-'''
+"""
sql_without_config, config = parse_config_block(sql)
assert 'config' not in sql_without_config
@@ -328,14 +328,14 @@
def test_parse_config_block_multiple_deps(self):
"""Test parsing config block with multiple dependencies."""
- sql = '''{{ config(
+ sql = """{{ config(
dependencies={
'eth': '_/eth_firehose@1.0.0',
'arb': '_/arb_firehose@1.0.0'
}
) }}
SELECT * FROM blocks
-'''
+"""
sql_without_config, config = parse_config_block(sql)
assert len(config.dependencies) == 2
@@ -344,9 +344,9 @@
def test_parse_config_block_with_flags(self):
"""Test parsing config block with boolean flags."""
- sql = '''{{ config(track_progress=True, register=True) }}
+ sql = """{{ config(track_progress=True, register=True) }}
SELECT * FROM blocks
-'''
+"""
sql_without_config, config = parse_config_block(sql)
assert config.track_progress is True
@@ -641,17 +641,17 @@
# Create models
(models_dir / 'base.sql').write_text(
- '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT block_num, block_hash FROM {{ ref('eth') }}.blocks
-'''
+"""
)
(models_dir / 'aggregated.sql').write_text(
- '''SELECT
+ """SELECT
block_num,
COUNT(*) as tx_count
FROM {{ ref('base') }}
GROUP BY block_num
-'''
+"""
)
# Initialize project
@@ -688,16 +688,16 @@
# Create base model
(models_dir / 'base.sql').write_text(
- '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT block_num, block_hash FROM {{ ref('eth') }}.blocks
-'''
+"""
)
# Create dependent model that should get CTE
(models_dir / 'filtered.sql').write_text(
- '''SELECT * FROM {{ ref('base') }}
+ """SELECT * FROM {{ ref('base') }}
WHERE block_num > 1000
-'''
+"""
)
project = AmpDbtProject(tmp_path)
@@ -709,4 +709,3 @@
# The compiled SQL should reference base (CTE will be added during execution)
assert 'base' in compiled['filtered'].sql or '__REF__base__' not in compiled['filtered'].sql
-
--- tests/integration/test_dbt_cli.py
+++ tests/integration/test_dbt_cli.py
@@ -20,9 +20,9 @@
def test_init_creates_project_structure(self, tmp_path):
"""Test that init creates all required directories and files."""
runner = CliRunner()
-
+
result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert (tmp_path / 'dbt_project.yml').exists()
assert (tmp_path / 'models').exists()
@@ -31,7 +31,7 @@
assert (tmp_path / 'docs').exists()
assert (tmp_path / '.gitignore').exists()
assert (tmp_path / 'models' / 'example_model.sql').exists()
-
+
# Check config content
config = yaml.safe_load((tmp_path / 'dbt_project.yml').read_text())
assert config['name'] == 'test_project'
@@ -41,15 +41,15 @@
def test_init_with_existing_config(self, tmp_path):
"""Test that init doesn't overwrite existing config."""
runner = CliRunner()
-
+
# Create existing config
config_path = tmp_path / 'dbt_project.yml'
existing_config = {'name': 'existing_project', 'version': '0.1.0'}
with open(config_path, 'w') as f:
yaml.dump(existing_config, f)
-
+
result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
# Config should still be the original
loaded_config = yaml.safe_load(config_path.read_text())
@@ -59,9 +59,9 @@
"""Test init with custom project directory."""
runner = CliRunner()
project_dir = tmp_path / 'custom_project'
-
+
result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(project_dir)])
-
+
assert result.exit_code == 0
assert project_dir.exists()
assert (project_dir / 'dbt_project.yml').exists()
@@ -71,9 +71,9 @@
runner = CliRunner()
project_dir = tmp_path / 'my_project'
project_dir.mkdir()
-
+
result = runner.invoke(app, ['init', '--project-dir', str(project_dir)])
-
+
assert result.exit_code == 0
config = yaml.safe_load((project_dir / 'dbt_project.yml').read_text())
assert config['name'] == 'my_project'
@@ -81,9 +81,9 @@
def test_init_creates_gitignore(self, tmp_path):
"""Test that init creates .gitignore file."""
runner = CliRunner()
-
+
result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
gitignore_path = tmp_path / '.gitignore'
assert gitignore_path.exists()
@@ -97,18 +97,18 @@
def test_compile_single_model(self, tmp_path):
"""Test compiling a single model."""
runner = CliRunner()
-
+
# Setup project
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'test_model.sql').write_text(
- '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT * FROM {{ ref('eth') }}.blocks
-'''
+"""
)
-
+
result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'test_model' in result.stdout
assert 'Compiled' in result.stdout or 'Compilation Results' in result.stdout
@@ -116,18 +116,15 @@
def test_compile_with_select_pattern(self, tmp_path):
"""Test compile with select pattern."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'stg_model1.sql').write_text('SELECT 1')
(models_dir / 'stg_model2.sql').write_text('SELECT 2')
(models_dir / 'final_model.sql').write_text('SELECT 3')
-
- result = runner.invoke(
- app,
- ['compile', '--select', 'stg_*', '--project-dir', str(tmp_path)]
- )
-
+
+ result = runner.invoke(app, ['compile', '--select', 'stg_*', '--project-dir', str(tmp_path)])
+
assert result.exit_code == 0
assert 'stg_model1' in result.stdout
assert 'stg_model2' in result.stdout
@@ -137,16 +134,13 @@
def test_compile_with_show_sql(self, tmp_path):
"""Test compile with --show-sql flag."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'test_model.sql').write_text('SELECT 1 as value')
-
- result = runner.invoke(
- app,
- ['compile', '--show-sql', '--project-dir', str(tmp_path)]
- )
-
+
+ result = runner.invoke(app, ['compile', '--show-sql', '--project-dir', str(tmp_path)])
+
assert result.exit_code == 0
assert 'Compiled SQL' in result.stdout
assert 'SELECT 1' in result.stdout
@@ -154,23 +148,23 @@
def test_compile_no_models_found(self, tmp_path):
"""Test compile when no models exist."""
runner = CliRunner()
-
+
result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'No models found' in result.stdout
def test_compile_with_internal_dependencies(self, tmp_path):
"""Test compile shows execution order for models with dependencies."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'base.sql').write_text('SELECT 1')
(models_dir / 'dependent.sql').write_text('SELECT * FROM {{ ref("base") }}')
-
+
result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'base' in result.stdout
assert 'dependent' in result.stdout
@@ -178,12 +172,12 @@
def test_compile_error_handling(self, tmp_path):
"""Test compile handles errors gracefully."""
runner = CliRunner()
-
+
# Create invalid project directory
invalid_dir = tmp_path / 'nonexistent' / 'nested'
-
+
result = runner.invoke(app, ['compile', '--project-dir', str(invalid_dir)])
-
+
# Should handle error gracefully
assert result.exit_code != 0 or 'Error' in result.stdout
@@ -195,14 +189,14 @@
def test_list_all_models(self, tmp_path):
"""Test listing all models."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'model1.sql').write_text('SELECT 1')
(models_dir / 'model2.sql').write_text('SELECT 2')
-
+
result = runner.invoke(app, ['list', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'model1' in result.stdout
assert 'model2' in result.stdout
@@ -210,17 +204,14 @@
def test_list_with_select_pattern(self, tmp_path):
"""Test list with select pattern."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'stg_model.sql').write_text('SELECT 1')
(models_dir / 'final_model.sql').write_text('SELECT 2')
-
- result = runner.invoke(
- app,
- ['list', '--select', 'stg_*', '--project-dir', str(tmp_path)]
- )
-
+
+ result = runner.invoke(app, ['list', '--select', 'stg_*', '--project-dir', str(tmp_path)])
+
assert result.exit_code == 0
assert 'stg_model' in result.stdout
assert 'final_model' not in result.stdout
@@ -228,22 +219,22 @@
def test_list_no_models(self, tmp_path):
"""Test list when no models exist."""
runner = CliRunner()
-
+
result = runner.invoke(app, ['list', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'No models found' in result.stdout
def test_list_shows_paths(self, tmp_path):
"""Test that list shows model paths."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'test_model.sql').write_text('SELECT 1')
-
+
result = runner.invoke(app, ['list', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'Models' in result.stdout
assert 'test_model' in result.stdout
@@ -256,17 +247,14 @@
def test_run_dry_run_mode(self, tmp_path):
"""Test run in dry-run mode."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'base.sql').write_text('SELECT 1')
(models_dir / 'dependent.sql').write_text('SELECT * FROM {{ ref("base") }}')
-
- result = runner.invoke(
- app,
- ['run', '--dry-run', '--project-dir', str(tmp_path)]
- )
-
+
+ result = runner.invoke(app, ['run', '--dry-run', '--project-dir', str(tmp_path)])
+
assert result.exit_code == 0
assert 'Dry run mode' in result.stdout or 'Execution Plan' in result.stdout
assert 'base' in result.stdout
@@ -275,17 +263,18 @@
def test_run_updates_tracker(self, tmp_path):
"""Test that run updates model tracker."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'test_model.sql').write_text('SELECT 1')
-
+
result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
-
+
# Check that state was updated
from amp.dbt.tracker import ModelTracker
+
tracker = ModelTracker(tmp_path)
states = tracker.get_all_states()
assert 'test_model' in states
@@ -293,41 +282,38 @@
def test_run_with_select_pattern(self, tmp_path):
"""Test run with select pattern."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'stg_model.sql').write_text('SELECT 1')
(models_dir / 'final_model.sql').write_text('SELECT 2')
-
- result = runner.invoke(
- app,
- ['run', '--select', 'stg_*', '--project-dir', str(tmp_path)]
- )
-
+
+ result = runner.invoke(app, ['run', '--select', 'stg_*', '--project-dir', str(tmp_path)])
+
assert result.exit_code == 0
assert 'stg_model' in result.stdout
def test_run_no_models(self, tmp_path):
"""Test run when no models exist."""
runner = CliRunner()
-
+
result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'No models found' in result.stdout
def test_run_shows_execution_order(self, tmp_path):
"""Test that run shows execution order."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'base.sql').write_text('SELECT 1')
(models_dir / 'intermediate.sql').write_text('SELECT * FROM {{ ref("base") }}')
(models_dir / 'final.sql').write_text('SELECT * FROM {{ ref("intermediate") }}')
-
+
result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'Executing' in result.stdout or 'models' in result.stdout
@@ -339,23 +325,24 @@
def test_status_no_models(self, tmp_path):
"""Test status when no models are tracked."""
runner = CliRunner()
-
+
result = runner.invoke(app, ['status', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'No models with tracked state' in result.stdout
def test_status_with_tracked_models(self, tmp_path):
"""Test status with tracked models."""
runner = CliRunner()
-
+
# Setup tracker with some state
from amp.dbt.tracker import ModelTracker
+
tracker = ModelTracker(tmp_path)
tracker.update_progress('test_model', latest_block=1000, latest_timestamp=datetime.now())
-
+
result = runner.invoke(app, ['status', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'test_model' in result.stdout
assert 'Fresh' in result.stdout or 'Stale' in result.stdout or 'Status' in result.stdout
@@ -363,28 +350,30 @@
def test_status_with_stale_data(self, tmp_path):
"""Test status shows stale data correctly."""
runner = CliRunner()
-
+
from amp.dbt.tracker import ModelTracker
+
tracker = ModelTracker(tmp_path)
old_time = datetime.now() - timedelta(hours=2)
tracker.update_progress('stale_model', latest_block=1000, latest_timestamp=old_time)
-
+
result = runner.invoke(app, ['status', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'stale_model' in result.stdout
def test_status_with_all_flag(self, tmp_path):
"""Test status with --all flag."""
runner = CliRunner()
-
+
from amp.dbt.tracker import ModelTracker
+
tracker = ModelTracker(tmp_path)
tracker.update_progress('model1', latest_block=1000, latest_timestamp=datetime.now())
tracker.update_progress('model2', latest_block=2000, latest_timestamp=datetime.now())
-
+
result = runner.invoke(app, ['status', '--all', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'model1' in result.stdout
assert 'model2' in result.stdout
@@ -397,23 +386,24 @@
def test_monitor_no_models(self, tmp_path):
"""Test monitor when no models are tracked."""
runner = CliRunner()
-
+
result = runner.invoke(app, ['monitor', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'No tracked models' in result.stdout
def test_monitor_with_models(self, tmp_path):
"""Test monitor with tracked models."""
runner = CliRunner()
-
+
# Setup tracker with some state
from amp.dbt.tracker import ModelTracker
+
tracker = ModelTracker(tmp_path)
tracker.update_progress('test_model', latest_block=1000, latest_timestamp=datetime.now(), job_id=123)
-
+
result = runner.invoke(app, ['monitor', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'Job Monitor' in result.stdout
assert 'test_model' in result.stdout
@@ -421,30 +411,32 @@
def test_monitor_with_job_mapping(self, tmp_path):
"""Test monitor loads job mappings from jobs.json."""
runner = CliRunner()
-
+
# Create .amp-dbt directory and jobs.json
amp_dbt_dir = tmp_path / '.amp-dbt'
amp_dbt_dir.mkdir(parents=True)
jobs_file = amp_dbt_dir / 'jobs.json'
jobs_file.write_text(json.dumps({'model1': 123, 'model2': 456}))
-
+
from amp.dbt.tracker import ModelTracker
+
tracker = ModelTracker(tmp_path)
tracker.update_progress('model1', latest_block=1000, latest_timestamp=datetime.now(), job_id=123)
-
+
result = runner.invoke(app, ['monitor', '--project-dir', str(tmp_path)])
-
+
assert result.exit_code == 0
assert 'Job Monitor' in result.stdout
def test_monitor_with_watch_flag(self, tmp_path):
"""Test monitor with --watch flag (should not hang in test)."""
runner = CliRunner()
-
+
from amp.dbt.tracker import ModelTracker
+
tracker = ModelTracker(tmp_path)
tracker.update_progress('test_model', latest_block=1000, latest_timestamp=datetime.now())
-
+
# Use a very short interval and expect it to exit quickly
# In watch mode, it would loop, but we'll test it exits gracefully
result = runner.invoke(
@@ -452,7 +444,7 @@
['monitor', '--watch', '--interval', '1', '--project-dir', str(tmp_path)],
# This will timeout quickly in watch mode, but we can test the initial output
)
-
+
# Should show monitor output
assert 'Job Monitor' in result.stdout or 'test_model' in result.stdout
@@ -464,36 +456,36 @@
def test_compile_with_invalid_project_dir(self, tmp_path):
"""Test compile handles invalid project directory."""
runner = CliRunner()
-
+
invalid_dir = tmp_path / 'nonexistent' / 'nested'
-
+
result = runner.invoke(app, ['compile', '--project-dir', str(invalid_dir)])
-
+
# Should handle error gracefully
assert result.exit_code != 0 or 'Error' in result.stdout
def test_run_with_circular_dependency(self, tmp_path):
"""Test run handles circular dependencies."""
runner = CliRunner()
-
+
models_dir = tmp_path / 'models'
models_dir.mkdir(parents=True)
(models_dir / 'model1.sql').write_text('SELECT * FROM {{ ref("model2") }}')
(models_dir / 'model2.sql').write_text('SELECT * FROM {{ ref("model1") }}')
-
+
result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)])
-
+
# Should detect circular dependency
assert result.exit_code != 0 or 'Circular' in result.stdout or 'Dependency Error' in result.stdout
def test_list_with_invalid_project(self, tmp_path):
"""Test list handles invalid project."""
runner = CliRunner()
-
+
invalid_dir = tmp_path / 'nonexistent'
-
+
result = runner.invoke(app, ['list', '--project-dir', str(invalid_dir)])
-
+
# Should handle error gracefully
assert result.exit_code != 0 or 'Error' in result.stdout
@@ -505,19 +497,19 @@
def test_full_cli_workflow(self, tmp_path):
"""Test complete CLI workflow: init -> compile -> run -> status."""
runner = CliRunner()
-
+
# Step 1: Initialize project
result = runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)])
assert result.exit_code == 0
-
+
# Step 2: Compile models
result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)])
assert result.exit_code == 0
-
+
# Step 3: Run models
result = runner.invoke(app, ['run', '--project-dir', str(tmp_path)])
assert result.exit_code == 0
-
+
# Step 4: Check status
result = runner.invoke(app, ['status', '--project-dir', str(tmp_path)])
assert result.exit_code == 0
@@ -526,25 +518,24 @@
def test_cli_with_custom_models(self, tmp_path):
"""Test CLI with custom model files."""
runner = CliRunner()
-
+
# Initialize project
runner.invoke(app, ['init', 'test_project', '--project-dir', str(tmp_path)])
-
+
# Add custom model
models_dir = tmp_path / 'models'
(models_dir / 'custom_model.sql').write_text(
- '''{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
+ """{{ config(dependencies={'eth': '_/eth_firehose@1.0.0'}) }}
SELECT block_num FROM {{ ref('eth') }}.blocks LIMIT 10
-'''
+"""
)
-
+
# List models
result = runner.invoke(app, ['list', '--project-dir', str(tmp_path)])
assert result.exit_code == 0
assert 'custom_model' in result.stdout
-
+
# Compile models
result = runner.invoke(app, ['compile', '--project-dir', str(tmp_path)])
assert result.exit_code == 0
assert 'custom_model' in result.stdout
-
15 files would be reformatted, 94 files already formattedLinting issues: |
craigtutterow
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is dope. A few requests/questions:
- Could you write some integration tests for the dbt and CLI commands?
- What's the reason for recreating dbt instead of using their existing python packages?
- What's the plan for scheduled jobs, monitoring schedules, etc.? (doesn't have to be part of this PR, just curious what you are thinking)
|
For your second question: What's the reason for recreating dbt instead of using their existing python packages? |
Implement core dbt module for query composition with Jinja templating, external dataset reference resolution, and project initialization. Includes CLI commands (, , ), model compilation engine, and comprehensive documentation.