Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
**/.DS_Store

# Temp
preparation/

# Binder
.bash_logout
.bashrc
Expand Down
103 changes: 83 additions & 20 deletions integration/databricks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,51 +57,114 @@ databricks auth profiles

You should see your workspace listed.


### 5. Configure the Databricks profile (optional)

The above steps created `DEFAULT` profile for Databricks authentication. The ingestion
module also defaults to `DEFAULT` profile. The authentication should work smoothly for
a single profile.

If you have multiple profiles (e.g for different Databricks hosts), you can set
`DATABRICKS_CONFIG_PROFILE` environment variable in `.mise.local.toml` (gitignored) to
pin a specific profile to be used for this project:

```toml
[env]
DATABRICKS_CONFIG_PROFILE = "Code17"
```

In this example, `Code17` profile will be used instead of `DEFAULT` one.


## Usage

### Python API (Recommended)
### Complete Workflow (Ingestion + Preparation)

Use the modules directly in notebooks or scripts:
For a complete end-to-end workflow that loads raw data and prepares it for getML:

```bash
# From repository root
uv run --group databricks python -m integration.databricks.prepare_jaffle_shop_data_for_databricks
```

Or in Python:

```python
from integration.databricks.data import ingestion
from databricks.connect import DatabricksSession
from integration.databricks.data import ingestion, preparation

# Load raw data from GCS to Databricks
# Create Spark session
spark = DatabricksSession.builder.serverless().getOrCreate()

# Step 1: Load raw data from GCS
loaded_tables = ingestion.load_from_gcs(
spark=spark,
bucket="https://static.getml.com/datasets/jaffle_shop/",
destination_schema="jaffle_shop"
destination_catalog="workspace",
destination_schema="raw",
)
print(f"Loaded {len(loaded_tables)} tables")

# Step 2: Prepare weekly sales forecasting data
population_table = preparation.create_weekly_sales_by_store_with_target(
spark,
source_catalog="workspace",
source_schema="raw",
target_catalog="workspace",
target_schema="prepared",
)

print(f"Population table ready: {population_table}")
```

### Load Specific Tables
### Python API: Ingestion Only

Load raw data from GCS to Databricks:

```python
from integration.databricks.data import ingestion

# Load only the tables you need
# Load all jaffle_shop tables
loaded_tables = ingestion.load_from_gcs(
bucket="https://static.getml.com/datasets/jaffle_shop/",
destination_schema="raw"
)
print(f"Loaded {len(loaded_tables)} tables")

# Or load specific tables
ingestion.load_from_gcs(
destination_schema="RAW",
destination_schema="raw",
tables=["raw_customers", "raw_orders", "raw_items", "raw_products"]
)
```

### Configure the Databricks profile (optional)
### Python API: Preparation Only

The above steps created `DEFAULT` profile for Databricks authentication. The ingestion
module also defaults to `DEFAULT` profile. The authentication should work smoothly for
a single profile.
Create weekly sales forecasting population table from existing raw data:

If you have multiple profiles (e.g for different Databricks hosts), you can set
`DATABRICKS_CONFIG_PROFILE` environment variable in `.mise.local.toml` (gitignored) to
pin a specific profile to be used for this project:
```python
from databricks.connect import DatabricksSession
from integration.databricks.data import preparation

spark = DatabricksSession.builder.serverless().getOrCreate()

# Create population table with target variable
population_table = preparation.create_weekly_sales_by_store_with_target(
spark,
source_catalog="workspace",
source_schema="raw",
target_catalog="workspace",
target_schema="prepared",
table_name="weekly_sales_by_store_with_target",
)

```toml
[env]
DATABRICKS_CONFIG_PROFILE = "Code17"
# Use the prepared data
df = spark.table(population_table)
df.show()
```

In this example, `Code17` profile will be used instead of `DEFAULT` one.
This creates:
- `weekly_stores` table: Store-week combinations with reference dates (Monday week starts)
- Population view with `next_week_sales` target variable for forecasting

## Troubleshooting

Expand Down
18 changes: 18 additions & 0 deletions integration/databricks/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Databricks data ingestion and preparation for getML.

This module provides utilities for loading data from GCS into Databricks
and preparing weekly sales forecasting datasets.
"""

from integration.databricks.data import ingestion, preparation
from integration.databricks.data.preparation import (
DEFAULT_POPULATION_TABLE_NAME,
create_weekly_sales_by_store_with_target,
)

__all__ = [
"DEFAULT_POPULATION_TABLE_NAME",
"create_weekly_sales_by_store_with_target",
"ingestion",
"preparation",
]
82 changes: 20 additions & 62 deletions integration/databricks/data/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@

import logging
import os
from collections.abc import Sequence
from io import BytesIO
from typing import Annotated, ClassVar, Final
from collections.abc import Sequence
from typing import Final

import requests
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
from sqlglot import exp

# ruff: noqa: E402

Expand All @@ -54,6 +52,12 @@ def _suppress_vendor_logging() -> None:
from databricks.sdk.service.catalog import VolumeType
from pyspark.sql import SparkSession

from integration.databricks.data.models import (
IngestionTableConfig,
SchemaLocation,
TableLocation,
)

logger = logging.getLogger(__name__)


Expand All @@ -75,51 +79,6 @@ def _suppress_vendor_logging() -> None:
)


def _quote_identifier(raw_identifier: str, dialect: str = "databricks") -> str:
"""Quote SQL identifier using sqlglot.

Args:
raw_identifier: Identifier to quote.
dialect: SQL dialect (default: databricks).

Returns:
Properly quoted identifier for the target dialect.
"""
return exp.to_identifier(raw_identifier).sql(dialect=dialect) # pyright: ignore[reportUnknownMemberType]


SqlIdentifier = Annotated[str, AfterValidator(_quote_identifier)]


class SchemaLocation(BaseModel):
"""Location identifier for a Databricks schema."""

model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True)

catalog: SqlIdentifier
schema_: SqlIdentifier = Field(alias="schema")

@property
def qualified_name(self) -> str:
"""Return fully qualified schema name."""
return f"{self.catalog}.{self.schema_}"


class TableConfig(BaseModel):
"""Configuration for a Delta table to be created."""

model_config: ClassVar[ConfigDict] = ConfigDict(frozen=True)

source_url: str
table_name: SqlIdentifier
location: SchemaLocation

@property
def full_table_name(self) -> str:
"""Return fully qualified table name."""
return f"{self.location.catalog}.{self.location.schema_}.{self.table_name}"


def _stream_from_url_to_volume(
workspace: WorkspaceClient, url: str, volume_path: str
) -> None:
Expand Down Expand Up @@ -197,7 +156,7 @@ def _create_spark_session(profile: str | None = None) -> SparkSession:
def _write_to_delta(
spark: SparkSession,
source_path: str,
config: TableConfig,
config: IngestionTableConfig,
) -> None:
"""Write a parquet file from Volume to Delta Lake as a managed table."""
logger.info(f"Writing to Delta table: {config.full_table_name} from {source_path}")
Expand All @@ -224,13 +183,12 @@ def _build_table_configs(
bucket: str,
table_names: Sequence[str],
location: SchemaLocation,
) -> list[TableConfig]:
"""Build TableConfig objects for all tables to be loaded."""
) -> list[IngestionTableConfig]:
"""Build IngestionTableConfig objects for all tables to be loaded."""
return [
TableConfig(
IngestionTableConfig(
source_url=f"{bucket.rstrip('/')}/{name}.parquet",
table_name=name,
location=location,
destination=TableLocation(table_name=name, location=location),
)
for name in table_names
]
Expand All @@ -248,7 +206,7 @@ def _cleanup_volume_file(workspace: WorkspaceClient, volume_path: str) -> None:
def _process_single_table(
workspace: WorkspaceClient,
spark: SparkSession,
config: TableConfig,
config: IngestionTableConfig,
volume_path: str,
) -> bool:
"""Process a single table: download, write to Delta, and cleanup.
Expand All @@ -266,7 +224,7 @@ def _process_single_table(
return False

except Exception as e:
logger.error(f"Failed to process {config.table_name}: {e}")
logger.error(f"Failed to process {config.destination.table_name}: {e}")
return False

finally:
Expand All @@ -275,8 +233,8 @@ def _process_single_table(

def load_from_gcs(
bucket: str = DEFAULT_BUCKET,
destination_schema: str = DEFAULT_SCHEMA,
destination_catalog: str = DEFAULT_CATALOG,
destination_schema: str = DEFAULT_SCHEMA,
tables: Sequence[str] | None = None,
spark: SparkSession | None = None,
profile: str = DEFAULT_PROFILE,
Expand All @@ -290,8 +248,8 @@ def load_from_gcs(

Args:
bucket: GCS bucket URL (default: jaffle_shop dataset).
destination_schema: Target schema name in Databricks.
destination_catalog: Target catalog name in Databricks.
destination_schema: Target schema name in Databricks.
tables: List of table names to load. If None, loads all jaffle_shop tables.
spark: Optional existing SparkSession. If None, creates a new one.
profile: Databricks CLI profile name (optional).
Expand Down Expand Up @@ -331,11 +289,11 @@ def load_from_gcs(
for config in table_configs:
volume_path = (
f"/Volumes/{location.catalog}/{location.schema_}/{STAGING_VOLUME}"
f"/{config.table_name}.parquet"
f"/{config.destination.table_name}.parquet"
)

if _process_single_table(workspace, spark, config, volume_path):
loaded_tables.append(config.table_name)
loaded_tables.append(config.destination.table_name)

logger.info(
f"Successfully loaded {len(loaded_tables)}/{len(table_names)} tables"
Expand All @@ -348,8 +306,8 @@ def load_from_gcs(


def list_tables(
schema: str = DEFAULT_SCHEMA,
catalog: str = DEFAULT_CATALOG,
schema: str = DEFAULT_SCHEMA,
spark: SparkSession | None = None,
profile: str | None = None,
) -> list[str]:
Expand Down
Loading