diff --git a/.gitignore b/.gitignore index 37b234b..155d5a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,5 @@ **/.DS_Store -# Temp -preparation/ - # Binder .bash_logout .bashrc diff --git a/integration/databricks/README.md b/integration/databricks/README.md index eb84b5b..c735cc8 100644 --- a/integration/databricks/README.md +++ b/integration/databricks/README.md @@ -57,51 +57,114 @@ databricks auth profiles You should see your workspace listed. + +### 5. Configure the Databricks profile (optional) + +The above steps created `DEFAULT` profile for Databricks authentication. The ingestion +module also defaults to `DEFAULT` profile. The authentication should work smoothly for +a single profile. + +If you have multiple profiles (e.g for different Databricks hosts), you can set +`DATABRICKS_CONFIG_PROFILE` environment variable in `.mise.local.toml` (gitignored) to +pin a specific profile to be used for this project: + +```toml +[env] +DATABRICKS_CONFIG_PROFILE = "Code17" +``` + +In this example, `Code17` profile will be used instead of `DEFAULT` one. + + ## Usage -### Python API (Recommended) +### Complete Workflow (Ingestion + Preparation) -Use the modules directly in notebooks or scripts: +For a complete end-to-end workflow that loads raw data and prepares it for getML: + +```bash +# From repository root +uv run --group databricks python -m integration.databricks.prepare_jaffle_shop_data_for_databricks +``` + +Or in Python: ```python -from integration.databricks.data import ingestion +from databricks.connect import DatabricksSession +from integration.databricks.data import ingestion, preparation -# Load raw data from GCS to Databricks +# Create Spark session +spark = DatabricksSession.builder.serverless().getOrCreate() + +# Step 1: Load raw data from GCS loaded_tables = ingestion.load_from_gcs( + spark=spark, bucket="https://static.getml.com/datasets/jaffle_shop/", - destination_schema="jaffle_shop" + destination_catalog="workspace", + destination_schema="raw", ) -print(f"Loaded {len(loaded_tables)} tables") + +# Step 2: Prepare weekly sales forecasting data +population_table = preparation.create_weekly_sales_by_store_with_target( + spark, + source_catalog="workspace", + source_schema="raw", + target_catalog="workspace", + target_schema="prepared", +) + +print(f"Population table ready: {population_table}") ``` -### Load Specific Tables +### Python API: Ingestion Only + +Load raw data from GCS to Databricks: ```python from integration.databricks.data import ingestion -# Load only the tables you need +# Load all jaffle_shop tables +loaded_tables = ingestion.load_from_gcs( + bucket="https://static.getml.com/datasets/jaffle_shop/", + destination_schema="raw" +) +print(f"Loaded {len(loaded_tables)} tables") + +# Or load specific tables ingestion.load_from_gcs( - destination_schema="RAW", + destination_schema="raw", tables=["raw_customers", "raw_orders", "raw_items", "raw_products"] ) ``` -### Configure the Databricks profile (optional) +### Python API: Preparation Only -The above steps created `DEFAULT` profile for Databricks authentication. The ingestion -module also defaults to `DEFAULT` profile. The authentication should work smoothly for -a single profile. +Create weekly sales forecasting population table from existing raw data: -If you have multiple profiles (e.g for different Databricks hosts), you can set -`DATABRICKS_CONFIG_PROFILE` environment variable in `.mise.local.toml` (gitignored) to -pin a specific profile to be used for this project: +```python +from databricks.connect import DatabricksSession +from integration.databricks.data import preparation + +spark = DatabricksSession.builder.serverless().getOrCreate() + +# Create population table with target variable +population_table = preparation.create_weekly_sales_by_store_with_target( + spark, + source_catalog="workspace", + source_schema="raw", + target_catalog="workspace", + target_schema="prepared", + table_name="weekly_sales_by_store_with_target", +) -```toml -[env] -DATABRICKS_CONFIG_PROFILE = "Code17" +# Use the prepared data +df = spark.table(population_table) +df.show() ``` -In this example, `Code17` profile will be used instead of `DEFAULT` one. +This creates: +- `weekly_stores` table: Store-week combinations with reference dates (Monday week starts) +- Population view with `next_week_sales` target variable for forecasting ## Troubleshooting diff --git a/integration/databricks/data/__init__.py b/integration/databricks/data/__init__.py index e69de29..f91a552 100644 --- a/integration/databricks/data/__init__.py +++ b/integration/databricks/data/__init__.py @@ -0,0 +1,18 @@ +"""Databricks data ingestion and preparation for getML. + +This module provides utilities for loading data from GCS into Databricks +and preparing weekly sales forecasting datasets. +""" + +from integration.databricks.data import ingestion, preparation +from integration.databricks.data.preparation import ( + DEFAULT_POPULATION_TABLE_NAME, + create_weekly_sales_by_store_with_target, +) + +__all__ = [ + "DEFAULT_POPULATION_TABLE_NAME", + "create_weekly_sales_by_store_with_target", + "ingestion", + "preparation", +] diff --git a/integration/databricks/data/ingestion.py b/integration/databricks/data/ingestion.py index 4d1257e..47dbffe 100644 --- a/integration/databricks/data/ingestion.py +++ b/integration/databricks/data/ingestion.py @@ -26,13 +26,11 @@ import logging import os -from collections.abc import Sequence from io import BytesIO -from typing import Annotated, ClassVar, Final +from collections.abc import Sequence +from typing import Final import requests -from pydantic import AfterValidator, BaseModel, ConfigDict, Field -from sqlglot import exp # ruff: noqa: E402 @@ -54,6 +52,12 @@ def _suppress_vendor_logging() -> None: from databricks.sdk.service.catalog import VolumeType from pyspark.sql import SparkSession +from integration.databricks.data.models import ( + IngestionTableConfig, + SchemaLocation, + TableLocation, +) + logger = logging.getLogger(__name__) @@ -75,51 +79,6 @@ def _suppress_vendor_logging() -> None: ) -def _quote_identifier(raw_identifier: str, dialect: str = "databricks") -> str: - """Quote SQL identifier using sqlglot. - - Args: - raw_identifier: Identifier to quote. - dialect: SQL dialect (default: databricks). - - Returns: - Properly quoted identifier for the target dialect. - """ - return exp.to_identifier(raw_identifier).sql(dialect=dialect) # pyright: ignore[reportUnknownMemberType] - - -SqlIdentifier = Annotated[str, AfterValidator(_quote_identifier)] - - -class SchemaLocation(BaseModel): - """Location identifier for a Databricks schema.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True) - - catalog: SqlIdentifier - schema_: SqlIdentifier = Field(alias="schema") - - @property - def qualified_name(self) -> str: - """Return fully qualified schema name.""" - return f"{self.catalog}.{self.schema_}" - - -class TableConfig(BaseModel): - """Configuration for a Delta table to be created.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True) - - source_url: str - table_name: SqlIdentifier - location: SchemaLocation - - @property - def full_table_name(self) -> str: - """Return fully qualified table name.""" - return f"{self.location.catalog}.{self.location.schema_}.{self.table_name}" - - def _stream_from_url_to_volume( workspace: WorkspaceClient, url: str, volume_path: str ) -> None: @@ -197,7 +156,7 @@ def _create_spark_session(profile: str | None = None) -> SparkSession: def _write_to_delta( spark: SparkSession, source_path: str, - config: TableConfig, + config: IngestionTableConfig, ) -> None: """Write a parquet file from Volume to Delta Lake as a managed table.""" logger.info(f"Writing to Delta table: {config.full_table_name} from {source_path}") @@ -224,13 +183,12 @@ def _build_table_configs( bucket: str, table_names: Sequence[str], location: SchemaLocation, -) -> list[TableConfig]: - """Build TableConfig objects for all tables to be loaded.""" +) -> list[IngestionTableConfig]: + """Build IngestionTableConfig objects for all tables to be loaded.""" return [ - TableConfig( + IngestionTableConfig( source_url=f"{bucket.rstrip('/')}/{name}.parquet", - table_name=name, - location=location, + destination=TableLocation(table_name=name, location=location), ) for name in table_names ] @@ -248,7 +206,7 @@ def _cleanup_volume_file(workspace: WorkspaceClient, volume_path: str) -> None: def _process_single_table( workspace: WorkspaceClient, spark: SparkSession, - config: TableConfig, + config: IngestionTableConfig, volume_path: str, ) -> bool: """Process a single table: download, write to Delta, and cleanup. @@ -266,7 +224,7 @@ def _process_single_table( return False except Exception as e: - logger.error(f"Failed to process {config.table_name}: {e}") + logger.error(f"Failed to process {config.destination.table_name}: {e}") return False finally: @@ -275,8 +233,8 @@ def _process_single_table( def load_from_gcs( bucket: str = DEFAULT_BUCKET, - destination_schema: str = DEFAULT_SCHEMA, destination_catalog: str = DEFAULT_CATALOG, + destination_schema: str = DEFAULT_SCHEMA, tables: Sequence[str] | None = None, spark: SparkSession | None = None, profile: str = DEFAULT_PROFILE, @@ -290,8 +248,8 @@ def load_from_gcs( Args: bucket: GCS bucket URL (default: jaffle_shop dataset). - destination_schema: Target schema name in Databricks. destination_catalog: Target catalog name in Databricks. + destination_schema: Target schema name in Databricks. tables: List of table names to load. If None, loads all jaffle_shop tables. spark: Optional existing SparkSession. If None, creates a new one. profile: Databricks CLI profile name (optional). @@ -331,11 +289,11 @@ def load_from_gcs( for config in table_configs: volume_path = ( f"/Volumes/{location.catalog}/{location.schema_}/{STAGING_VOLUME}" - f"/{config.table_name}.parquet" + f"/{config.destination.table_name}.parquet" ) if _process_single_table(workspace, spark, config, volume_path): - loaded_tables.append(config.table_name) + loaded_tables.append(config.destination.table_name) logger.info( f"Successfully loaded {len(loaded_tables)}/{len(table_names)} tables" @@ -348,8 +306,8 @@ def load_from_gcs( def list_tables( - schema: str = DEFAULT_SCHEMA, catalog: str = DEFAULT_CATALOG, + schema: str = DEFAULT_SCHEMA, spark: SparkSession | None = None, profile: str | None = None, ) -> list[str]: diff --git a/integration/databricks/data/models.py b/integration/databricks/data/models.py new file mode 100644 index 0000000..934997f --- /dev/null +++ b/integration/databricks/data/models.py @@ -0,0 +1,70 @@ +"""Data models for Databricks integration. + +This module provides validated data models for working with Databricks schemas and tables. +All models use Pydantic for validation and are frozen (immutable) by default. +""" + +from __future__ import annotations + +from typing import Annotated, ClassVar + +from pydantic import AfterValidator, BaseModel, ConfigDict, Field +from sqlglot import exp + + +def _quote_identifier(raw_identifier: str, dialect: str = "databricks") -> str: + """Quote SQL identifier using sqlglot. + + Args: + raw_identifier: Identifier to quote. + dialect: SQL dialect (default: databricks). + + Returns: + Properly quoted identifier for the target dialect. + """ + return exp.to_identifier(raw_identifier).sql(dialect=dialect) # pyright: ignore[reportUnknownMemberType] + + +SqlIdentifier = Annotated[str, AfterValidator(_quote_identifier)] + + +class SchemaLocation(BaseModel): + """Location identifier for a Databricks schema.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True) + + catalog: SqlIdentifier + schema_: SqlIdentifier = Field(alias="schema") + + @property + def qualified_name(self) -> str: + """Return fully qualified schema name.""" + return f"{self.catalog}.{self.schema_}" + + +class TableLocation(BaseModel): + """Location identifier for a Delta table in Databricks.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True) + + table_name: SqlIdentifier + location: SchemaLocation + + @property + def qualified_name(self) -> str: + """Return fully qualified table name (catalog.schema.table).""" + return f"{self.location.qualified_name}.{self.table_name}" + + +class IngestionTableConfig(BaseModel): + """Configuration for ingesting a table from URL to Delta Lake.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True) + + source_url: str + destination: TableLocation + + @property + def full_table_name(self) -> str: + """Return fully qualified table name.""" + return self.destination.qualified_name diff --git a/integration/databricks/data/preparation.py b/integration/databricks/data/preparation.py new file mode 100644 index 0000000..86353ee --- /dev/null +++ b/integration/databricks/data/preparation.py @@ -0,0 +1,168 @@ +"""Prepare weekly sales forecasting data for getML - by store. + +This module creates: +- weekly_stores table: Store-week combinations with reference_date (Monday week start) +- Population view with target (next week's sales) +""" + + +# pyright: reportAny=none +# pyright: reportUnknownMemberType=none + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Final + +from integration.databricks.data.models import ( + SchemaLocation, + TableLocation, +) + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + +logger = logging.getLogger(__name__) +_SQL_DIR = Path(__file__).parent / "sql" + + +REQUIRED_SOURCE_TABLES: Final[list[str]] = ["raw_stores", "raw_orders"] +DEFAULT_POPULATION_TABLE_NAME: Final[str] = "weekly_sales_by_store_with_target" + + +def create_weekly_sales_by_store_with_target( + spark: SparkSession, + source_catalog: str = "workspace", + source_schema: str = "raw", + target_catalog: str = "workspace", + target_schema: str = "prepared", + table_name: str = DEFAULT_POPULATION_TABLE_NAME, +) -> str: + """Create weekly sales forecasting data for getML, grouped by store. + + Creates: + - weekly_stores: Table with store-week combinations (reference_date = Monday) + - Population view with target column (configurable name via table_name) + - Target: Sum of order_total for the 7-day window starting at reference_date + + Args: + spark: Active Spark session. + source_catalog: Source catalog name. + source_schema: Source schema name. + target_catalog: Target catalog name. + target_schema: Target schema name. + table_name: Name of the population view to create. + + Returns: + Fully qualified table name (e.g., "workspace.prepared.weekly_sales_by_store_with_target"). + """ + source_location = SchemaLocation(catalog=source_catalog, schema=source_schema) + target_location = SchemaLocation(catalog=target_catalog, schema=target_schema) + population_table_location = TableLocation( + table_name=table_name, location=target_location + ) + + _validate_source_tables(spark, source_location) + _ensure_target_schema(spark, target_location.qualified_name) + + _create_weekly_stores_table( + spark, source_location.qualified_name, target_location.qualified_name + ) + _create_target_view( + spark, + source_location.qualified_name, + target_location.qualified_name, + population_table_location.table_name, + ) + + logger.info(f""" +Objects created in '{target_location.qualified_name}' schema: +- weekly_stores (Table) +- {population_table_location.table_name} (View) +""") + + return population_table_location.table_name + + +def _validate_source_tables( + spark: SparkSession, + schema_location: SchemaLocation, + required_tables: list[str] = REQUIRED_SOURCE_TABLES, +) -> None: + """Validate that source schema contains required tables.""" + logger.info( + f"1. Validating '{schema_location.qualified_name}' schema and required tables..." + ) + + for table_name in required_tables: + table_location = TableLocation( + table_name=table_name, + location=schema_location, + ) + _ = spark.sql( + "SELECT 1 FROM IDENTIFIER(:table_qualified_name) LIMIT 1", + args={"table_qualified_name": table_location.qualified_name}, + ).collect() + + logger.info( + f"✓ '{schema_location.qualified_name}' schema validated. It contains " + + f"required tables: {required_tables}." + ) + + +def _ensure_target_schema(spark: SparkSession, target_schema: str) -> None: + """Create target schema if needed.""" + logger.info(f"Creating '{target_schema}' schema if not exists...") + + sql: str = "CREATE SCHEMA IF NOT EXISTS IDENTIFIER(:full_schema_name)" + _ = spark.sql(sql, args={"full_schema_name": target_schema}).collect() + + logger.info(f"✓ '{target_schema}' schema ready") + + +def _create_weekly_stores_table( + spark: SparkSession, + source_schema: str, + target_schema: str, +) -> None: + """Create weekly_stores table with store-week combinations.""" + logger.info("\n2. Creating weekly_stores table (store-week combinations)...") + logger.info(" reference_date is Monday (week start) from date_trunc('week', ...)") + + _ = spark.sql( + "DROP TABLE IF EXISTS IDENTIFIER(:table_qualified_name)", + args={"table_qualified_name": f"{target_schema}.weekly_stores"}, + ).collect() + _ = spark.sql( + (_SQL_DIR / "create_stores_per_week.sql").read_text(), + args={ + "weekly_stores_table": f"{target_schema}.weekly_stores", + "stores_table": f"{source_schema}.raw_stores", + "orders_table": f"{source_schema}.raw_orders", + }, + ).collect() + + +def _create_target_view( + spark: SparkSession, + source_schema: str, + target_schema: str, + table_name: str, +) -> None: + """Create view with target variable (next week's sales).""" + logger.info("\n3. Creating target view: next week's total sales per store...") + + # Use python string formatting instead of Spark SQL parameters because + # CREATE VIEW does not support parameter markers for identifiers. + sql = ( + (_SQL_DIR / "create_weekly_total_sales_per_store.sql") + .read_text() + .format( + orders_table=f"{source_schema}.raw_orders", + weekly_stores_table=f"{target_schema}.weekly_stores", + population_table=f"{target_schema}.{table_name}", + ) + ) + + _ = spark.sql(sql).collect() diff --git a/integration/databricks/data/sql/create_stores_per_week.sql b/integration/databricks/data/sql/create_stores_per_week.sql new file mode 100644 index 0000000..cddb89d --- /dev/null +++ b/integration/databricks/data/sql/create_stores_per_week.sql @@ -0,0 +1,67 @@ +-- Create store-week combinations for weekly sales forecasting +-- +-- This table creates the base data for getML: one row per store per week. +-- reference_date is the Monday (week start) derived from date_trunc('week', ordered_at). +-- +-- Filtering logic: +-- - Week must be >= store's opened_at date (store existed) +-- - Week must be < last_order_week (exclude incomplete final week) +-- +-- Boolean flags for data quality filtering: +-- - is_full_week_after_opening: Store had a full week of operation before this week +-- - has_order_activity: Store has order data spanning this week +-- - has_min_history: At least 7 days since store opened +CREATE OR REPLACE TABLE IDENTIFIER(:weekly_stores_table) AS +WITH store_activity AS ( + SELECT + s.id as store_id, + s.name as store_name, + CAST(s.opened_at AS TIMESTAMP) as opened_at, + date_add(date_trunc('week', CAST(s.opened_at AS TIMESTAMP)), 7) as first_full_week, + MIN(CAST(o.ordered_at AS TIMESTAMP)) as first_order_date, + MAX(CAST(o.ordered_at AS TIMESTAMP)) as last_order_date, + date_trunc('week', MIN(CAST(o.ordered_at AS TIMESTAMP))) as first_order_week, + date_trunc('week', MAX(CAST(o.ordered_at AS TIMESTAMP))) as last_order_week + FROM IDENTIFIER(:stores_table) s + LEFT JOIN IDENTIFIER(:orders_table) o ON o.store_id = s.id + GROUP BY s.id, s.name, s.opened_at +), + +all_weeks AS ( + SELECT DISTINCT + date_trunc('week', CAST(ordered_at AS TIMESTAMP)) as reference_date + FROM IDENTIFIER(:orders_table) + WHERE ordered_at IS NOT NULL +), + +store_weeks AS ( + SELECT + sa.store_id, + sa.store_name, + w.reference_date, + sa.opened_at, + sa.first_full_week, + sa.first_order_week, + sa.last_order_week + FROM store_activity sa + CROSS JOIN all_weeks w + WHERE w.reference_date >= sa.opened_at + AND w.reference_date < sa.last_order_week +) + +SELECT + ROW_NUMBER() OVER (ORDER BY reference_date, store_id) as snapshot_id, + store_id, + store_name, + reference_date, + YEAR(reference_date) as year, + MONTH(reference_date) as month, + WEEKOFYEAR(reference_date) as week_number, + DATEDIFF(reference_date, opened_at) as days_since_open, + CASE WHEN reference_date >= first_full_week THEN true ELSE false END as is_full_week_after_opening, + CASE WHEN first_order_week IS NOT NULL + AND reference_date >= first_order_week + AND reference_date < last_order_week THEN true ELSE false END as has_order_activity, + CASE WHEN DATEDIFF(reference_date, opened_at) >= 7 THEN true ELSE false END as has_min_history +FROM store_weeks +ORDER BY reference_date, store_id diff --git a/integration/databricks/data/sql/create_weekly_total_sales_per_store.sql b/integration/databricks/data/sql/create_weekly_total_sales_per_store.sql new file mode 100644 index 0000000..d22a764 --- /dev/null +++ b/integration/databricks/data/sql/create_weekly_total_sales_per_store.sql @@ -0,0 +1,41 @@ +-- Create view with target: next week's sales per store +-- +-- This view joins weekly_stores with pre-aggregated order totals. +-- Target is the sum of order_total for the 7-day window starting at reference_date. +-- +-- Window: [reference_date, reference_date + 7 days) +-- - reference_date is Monday 00:00:00 (week start) +-- - Target covers Monday through Sunday of that week +-- +-- Optimized: Single aggregation pass instead of correlated subqueries +-- Note: CREATE VIEW does not support parameter markers for identifiers. +-- Provide valid table names for safety. +CREATE OR REPLACE VIEW {population_table} AS +WITH weekly_order_totals AS ( + SELECT + store_id, + date_trunc('week', CAST(ordered_at AS TIMESTAMP)) as week_start, + SUM(order_total) / 100.0 as week_sales, + COUNT(*) as week_orders + FROM {orders_table} + WHERE ordered_at IS NOT NULL + GROUP BY store_id, date_trunc('week', CAST(ordered_at AS TIMESTAMP)) +) +SELECT + ws.snapshot_id, + ws.store_id, + ws.store_name, + ws.reference_date, + ws.year, + ws.month, + ws.week_number, + ws.days_since_open, + ws.is_full_week_after_opening, + ws.has_order_activity, + ws.has_min_history, + COALESCE(wot.week_sales, 0) as next_week_sales, + COALESCE(wot.week_orders, 0) as next_week_orders +FROM {weekly_stores_table} ws +LEFT JOIN weekly_order_totals wot + ON wot.store_id = ws.store_id + AND wot.week_start = ws.reference_date diff --git a/integration/databricks/prepare_jaffle_shop_data_for_databricks.py b/integration/databricks/prepare_jaffle_shop_data_for_databricks.py new file mode 100755 index 0000000..84ef487 --- /dev/null +++ b/integration/databricks/prepare_jaffle_shop_data_for_databricks.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +"""Prepare Jaffle Shop data in Databricks for getML Feature Store integration. + +Loads raw data from GCS and creates the weekly sales forecasting population table. +Infrastructure (catalog, schemas) is created as needed. +""" + +from __future__ import annotations + +import logging + +from databricks.connect import DatabricksSession + +from .data import ingestion, preparation + +logger: logging.Logger = logging.getLogger(__name__) + + +def main() -> None: + """Load and prepare Jaffle Shop data for getML. + + Note: + Set basicConfig.level to logging.DEBUG for more verbose output. + """ + logging.basicConfig( + level=logging.INFO, + format="%(message)s", + ) + + logger.info("Starting Jaffle Shop data preparation for Databricks...") + + spark = DatabricksSession.builder.serverless().getOrCreate() + + logger.info("\nStep 1: Loading raw data from GCS...") + _ = ingestion.load_from_gcs( + spark=spark, + bucket="https://static.getml.com/datasets/jaffle_shop/", + destination_catalog="workspace", + destination_schema="raw", + ) + + logger.info("\nStep 2: Preparing weekly sales forecasting data...") + population_table = preparation.create_weekly_sales_by_store_with_target( + spark, + source_catalog="workspace", + source_schema="raw", + target_catalog="workspace", + target_schema="prepared", + table_name="weekly_sales_by_store_with_target", + ) + + logger.info(f""" +Data preparation complete. Use '{population_table}' table in your getML notebook: df = spark.table("{population_table}") +""") + + +if __name__ == "__main__": + main()