diff --git a/samples/ml/ml_jobs/README.md b/samples/ml/ml_jobs/README.md index aebc81ec..9f586207 100644 --- a/samples/ml/ml_jobs/README.md +++ b/samples/ml/ml_jobs/README.md @@ -139,6 +139,62 @@ job4 = submit_from_stage( `job1`, `job2` and `job3` are job handles, see [Function Dispatch](#function-dispatch) for usage examples. +### Job definition + +A job definition captures the reusable parts of an ML Job—payload location, compute pool, and other configuration—while keeping +arguments separate. This lets you create multiple jobs from the same payload with different arguments, without re-uploading the +payload. Defining a job is very similar to creating a job. + +```python +from snowflake.ml.jobs import remote + +compute_pool = "MY_COMPUTE_POOL" +@remote(compute_pool, stage_name="payload_stage") +def hello_world(name: str = "world"): + from datetime import datetime + + print(f"{datetime.now()} Hello {name}!") + +# this is a definition handle +definition = hello_world + +job1 = hello_world() + +job2 = hello_world(name="ML Job Definition") +``` + +```python +from snowflake.ml.jobs import MLJobDefinition + +job_definition = MLJobDefinition.register( + "/path/to/repo/my_script.py", + # If you register a source directory, provide the entrypoint file: + # entrypoint="/path/to/repo/my_script.py", + compute_pool=self.compute_pool, + stage_name="payload_stage", + session=self.session, +) +# Arguments follow the same format used in file dispatch +job1 = job_definition("arg1", "--arg2_key", "arg2_value") + +job2 = job_definition("arg3", "--arg4_key", "arg4_value") + +``` + +### Task Integration + +ML Job definitions integrate directly with the Task SDK. Use a definition as the task definition when creating a DAG task. +For a detailed example, see `e2e_task_graph/README.md`. + +```python +@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2) +def train_model(input_data: DataSource) -> Optional[str]: + ... + +train_model_task = DAGTask("TRAIN_MODEL", definition=train_model) +``` + + ### Supporting Additional Payloads in Submissions When submitting a file, directory, or from a stage, additional payloads are supported for use during job execution. diff --git a/samples/ml/ml_jobs/e2e_task_graph/README.md b/samples/ml/ml_jobs/e2e_task_graph/README.md index a784e3b0..a20ba989 100644 --- a/samples/ml/ml_jobs/e2e_task_graph/README.md +++ b/samples/ml/ml_jobs/e2e_task_graph/README.md @@ -119,8 +119,7 @@ for downstream consumption. Run the ML pipeline locally without task graph orchestration: ```bash -python src/pipeline_local.py -python src/pipeline_local.py --no-register # Skip model registration for faster experimentation +python src/pipeline_local.py --no-register # Skip model registration for faster ``` You can monitor the corresponding ML Job for model training via the [Job UI in Snowsight](../README.md#job-ui-in-snowsight). @@ -187,16 +186,17 @@ This visual interface makes it easy to: - **Branching Logic**: Using `DAGTaskBranch` for conditional execution paths - **Finalizer Tasks**: Ensuring cleanup always runs regardless of success/failure -### Model Training on SPCS using ML Jobs - -The `train_model` function uses the `@remote` decorator to run multi-node training on Snowpark Container Services: +### Model Training on SPCS Using ML Jobs +The `train_model` function is decorated with `@remote` to execute multi-node training on Snowpark Container Services (SPCS): ```python @remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2) -def train_model(session: Session, input_data: DataSource) -> XGBClassifier: +def train_model() -> None: # Training logic runs on distributed compute ``` +When running as a DAG task, the dataset information is retrieved from the previous task (PREPARE_DATA) via `TaskContext`. The model is trained and evaluated, and the results (model path and metrics) are saved and passed to the next task. The Task SDK lets you use that ML Job definition directly when creating a DAG task. For additional ML Job definition examples, see `../README.md`. + ### Conditional Model Promotion The task graph includes branching logic that only promotes models meeting quality thresholds: diff --git a/samples/ml/ml_jobs/e2e_task_graph/images/task-graph-overview.png b/samples/ml/ml_jobs/e2e_task_graph/images/task-graph-overview.png index 1f5d6955..7bdad6ec 100644 Binary files a/samples/ml/ml_jobs/e2e_task_graph/images/task-graph-overview.png and b/samples/ml/ml_jobs/e2e_task_graph/images/task-graph-overview.png differ diff --git a/samples/ml/ml_jobs/e2e_task_graph/src/modeling.py b/samples/ml/ml_jobs/e2e_task_graph/src/modeling.py index 293366d4..17992af9 100644 --- a/samples/ml/ml_jobs/e2e_task_graph/src/modeling.py +++ b/samples/ml/ml_jobs/e2e_task_graph/src/modeling.py @@ -1,13 +1,13 @@ import os import logging from datetime import datetime, timedelta, timezone -from typing import Any, Dict, Optional, Union +from typing import Optional import cloudpickle as cp +from xgboost import Booster, DMatrix import data import ops from constants import ( - COMPUTE_POOL, DAG_STAGE, DB_NAME, JOB_STAGE, @@ -18,7 +18,6 @@ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score from snowflake.ml.data import DataConnector, DatasetInfo, DataSource from snowflake.ml.dataset import Dataset, load_dataset -from snowflake.ml.jobs import remote from snowflake.ml.model import ModelVersion from snowflake.snowpark import Session from snowflake.snowpark.exceptions import SnowparkSQLException @@ -144,10 +143,7 @@ def prepare_datasets( return (ds, train_ds, test_ds) -# NOTE: Remove `target_instances=2` to run training on a single node -# See https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/distributed-ml-jobs -@remote(COMPUTE_POOL, stage_name=JOB_STAGE, target_instances=2) -def train_model(session: Session, input_data: DataSource) -> XGBClassifier: +def train_model(session: Session, input_data: Optional[DataSource] = None) -> XGBClassifier: """ Train a model on the training dataset. @@ -162,18 +158,15 @@ def train_model(session: Session, input_data: DataSource) -> XGBClassifier: Returns: XGBClassifier: Trained XGBoost classifier model """ - input_data_df = DataConnector.from_sources(session, [input_data]).to_pandas() - + input_data_dc = DataConnector.from_sources(session, [input_data]) + assert isinstance(input_data, DatasetInfo), "Input data must be a DatasetInfo" exclude_cols = input_data.exclude_cols label_col = exclude_cols[0] - X_train = input_data_df.drop(exclude_cols, axis=1) - y_train = input_data_df[label_col].squeeze() - model_params = dict( - max_depth=50, n_estimators=3, + max_depth=50, learning_rate=0.75, objective="binary:logistic", booster="gbtree", @@ -186,18 +179,21 @@ def train_model(session: Session, input_data: DataSource) -> XGBClassifier: XGBEstimator, XGBScalingConfig, ) + all_cols = input_data_dc.to_pandas(limit=1).columns.tolist() + input_cols = [c for c in all_cols if c not in exclude_cols] estimator = XGBEstimator( params=model_params, scaling_config=XGBScalingConfig(), ) + model = estimator.fit(input_data_dc,input_cols = input_cols, label_col = label_col) + return model else: - # Single node training - can use standard XGBClassifier - estimator = XGBClassifier(**model_params) - - estimator.fit(X_train, y_train) - - # Convert distributed estimator to standard XGBClassifier if needed - return getattr(estimator, '_sklearn_estimator', estimator) + df = input_data_dc.to_pandas() + X_train = df.drop(exclude_cols, axis=1) + y_train = df[label_col].squeeze() + estimator = XGBClassifier(**model_params) + model = estimator.fit(X_train, y_train) + return model def evaluate_model( @@ -205,7 +201,7 @@ def evaluate_model( model: XGBClassifier, input_data: DataSource, *, - prefix: str = None, + prefix: Optional[str] = None, ) -> dict: """ Evaluate a model on the training and test datasets. @@ -232,7 +228,12 @@ def evaluate_model( X_test = input_data_df.drop(exclude_cols, axis=1) expected = input_data_df[label_col].squeeze() - actual = model.predict(X_test) + # inside evaluate_model + if isinstance(model, Booster): + dmatrix = DMatrix(X_test) + actual = (model.predict(dmatrix) > 0.5).astype(int) + else: + actual = model.predict(X_test) metric_types = [ f1_score, diff --git a/samples/ml/ml_jobs/e2e_task_graph/src/ops.py b/samples/ml/ml_jobs/e2e_task_graph/src/ops.py index 324a47fe..b8ec62f1 100644 --- a/samples/ml/ml_jobs/e2e_task_graph/src/ops.py +++ b/samples/ml/ml_jobs/e2e_task_graph/src/ops.py @@ -166,3 +166,7 @@ def promote_model( # Set model as default base_model = registry.get_model(model.model_name) base_model.default = model + +def get_model(session: Session, model_name: str, version_name: str) -> ModelVersion: + registry = get_model_registry(session) + return registry.get_model(model_name).version(version_name) \ No newline at end of file diff --git a/samples/ml/ml_jobs/e2e_task_graph/src/pipeline_dag.py b/samples/ml/ml_jobs/e2e_task_graph/src/pipeline_dag.py index 892d8e10..290d0fe2 100644 --- a/samples/ml/ml_jobs/e2e_task_graph/src/pipeline_dag.py +++ b/samples/ml/ml_jobs/e2e_task_graph/src/pipeline_dag.py @@ -2,8 +2,8 @@ import json import os import time -from dataclasses import asdict, dataclass -from datetime import datetime, timedelta +from dataclasses import asdict +from datetime import timedelta from typing import Any, Optional import cloudpickle as cp @@ -12,18 +12,19 @@ from snowflake.core.task.dagv1 import DAG, DAGOperation, DAGTask, DAGTaskBranch from snowflake.ml.data import DatasetInfo from snowflake.ml.dataset import load_dataset -from snowflake.ml.jobs import MLJob from snowflake.snowpark import Session +from snowflake.ml.jobs import remote +import modeling +import data +import ops +from dataclasses import dataclass +from datetime import datetime import cli_utils -import data -import modeling -from constants import (DAG_STAGE, DATA_TABLE_NAME, DB_NAME, SCHEMA_NAME, +from constants import (COMPUTE_POOL, DAG_STAGE, DATA_TABLE_NAME, DB_NAME, JOB_STAGE, SCHEMA_NAME, WAREHOUSE) ARTIFACT_DIR = "run_artifacts" - - def _ensure_environment(session: Session): """ Ensure the environment is properly set up for DAG execution. @@ -36,14 +37,11 @@ def _ensure_environment(session: Session): session (Session): Snowflake session object """ modeling.ensure_environment(session) + cp.register_pickle_by_value(modeling) # Ensure the raw data table exists _ = data.get_raw_data(session, DATA_TABLE_NAME, create_if_not_exists=True) - # Register local modules for inclusion in ML Job payloads - cp.register_pickle_by_value(modeling) - - def _wait_for_run_to_complete(session: Session, dag: DAG) -> str: """ Wait for a DAG run to complete and return the final status. @@ -108,7 +106,6 @@ def _wait_for_run_to_complete(session: Session, dag: DAG) -> str: return dag_result - @dataclass(frozen=True) class RunConfig: run_id: str @@ -117,6 +114,7 @@ class RunConfig: metric_name: str metric_threshold: float + @property def artifact_dir(self) -> str: return os.path.join(DAG_STAGE, ARTIFACT_DIR, self.run_id) @@ -148,8 +146,7 @@ def from_task_context(cls, ctx: TaskContext, **kwargs: Any) -> "RunConfig": @classmethod def from_session(cls, session: Session) -> "RunConfig": ctx = TaskContext(session) - return cls.from_task_context(ctx) - + return cls.from_task_context(ctx) def prepare_datasets(session: Session) -> str: """ @@ -179,41 +176,27 @@ def prepare_datasets(session: Session) -> str: } return json.dumps(dataset_info) +@remote(COMPUTE_POOL, stage_name=JOB_STAGE, database=DB_NAME, schema=SCHEMA_NAME, target_instances=2) +def train_model() -> None: + ''' + ML Job to train a model on the training dataset and evaluate it. The model is saved to the stage and the metrics are returned. -def train_model(session: Session) -> str: - """ - DAG task to train a machine learning model. - - This function is executed as part of the DAG workflow to train a model using the prepared datasets. - It retrieves dataset information from the previous task, trains the model, evaluates it on both - training and test sets, and saves the model to a stage for later use. - - Args: - session (Session): Snowflake session object + ''' + session = Session.builder.getOrCreate() - Returns: - str: JSON string containing model path and evaluation metrics - """ ctx = TaskContext(session) config = RunConfig.from_task_context(ctx) - - # Load the datasets serialized = json.loads(ctx.get_predecessor_return_value("PREPARE_DATA")) - dataset_info = { + datasets = { key: DatasetInfo(**obj_dict) for key, obj_dict in serialized.items() } - # Train the model - model = modeling.train_model(session, dataset_info["train"]) - if isinstance(model, MLJob): - model = model.result() - - # Evaluate the model + model = modeling.train_model(session, datasets["train"]) train_metrics = modeling.evaluate_model( - session, model, dataset_info["train"], prefix="train" + session, model, datasets["train"], prefix="train" ) test_metrics = modeling.evaluate_model( - session, model, dataset_info["test"], prefix="test" + session, model, datasets["test"], prefix="test" ) metrics = {**train_metrics, **test_metrics} @@ -228,8 +211,9 @@ def train_model(session: Session) -> str: "model_path": os.path.join(config.artifact_dir, put_result.target), "metrics": metrics, } - return json.dumps(result_dict) - + # set the return value to the task context as a JSON string + ctx.set_return_value(json.dumps(result_dict)) + def check_model_quality(session: Session) -> str: """ @@ -250,7 +234,6 @@ def check_model_quality(session: Session) -> str: metrics = json.loads(ctx.get_predecessor_return_value("TRAIN_MODEL"))["metrics"] - # If model is good, promote model threshold = config.metric_threshold if metrics[config.metric_name] >= threshold: return "promote_model" @@ -314,7 +297,6 @@ def cleanup(session: Session) -> None: ctx = TaskContext(session) config = RunConfig.from_task_context(ctx) - session.sql(f"REMOVE {config.artifact_dir}").collect() modeling.clean_up(session, config.dataset_name, config.model_name) @@ -341,7 +323,7 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) - schedule=schedule, use_func_return_value=True, stage_location=DAG_STAGE, - packages=["snowflake-snowpark-python", "snowflake-ml-python<1.9.0", "xgboost"], # NOTE: Temporarily pinning to <1.9.0 due to compatibility issues + packages=["snowflake-snowpark-python", "snowflake-ml-python", "xgboost"], config={ "dataset_name": "mortgage_dataset", "model_name": "mortgage_model", @@ -352,7 +334,8 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) - ) as dag: # Need to wrap first function in a DAGTask to make >> operator work properly prepare_data = DAGTask("prepare_data", definition=prepare_datasets) - evaluate_model = DAGTaskBranch( + train_model_task = DAGTask("TRAIN_MODEL", definition=train_model) + check_model_quality_task = DAGTaskBranch( "check_model_quality", definition=check_model_quality ) promote_model_task = DAGTask("promote_model", definition=promote_model) @@ -372,7 +355,7 @@ def create_dag(name: str, schedule: Optional[timedelta] = None, **config: Any) - cleanup_task = DAGTask("cleanup_task", definition=cleanup, is_finalizer=True) # Build the DAG - prepare_data >> train_model >> evaluate_model >> [promote_model_task, alert_task] + prepare_data >> train_model_task >> check_model_quality_task >> [promote_model_task, alert_task] return dag diff --git a/samples/ml/ml_jobs/e2e_task_graph/src/pipeline_local.py b/samples/ml/ml_jobs/e2e_task_graph/src/pipeline_local.py index 72c62795..b4f80fc0 100644 --- a/samples/ml/ml_jobs/e2e_task_graph/src/pipeline_local.py +++ b/samples/ml/ml_jobs/e2e_task_graph/src/pipeline_local.py @@ -1,12 +1,19 @@ import logging from datetime import datetime +from typing import Any +from snowflake.ml.data import DataSource +from snowflake.ml.jobs import remote from snowflake.snowpark import Session - +import cloudpickle as cp import modeling -from constants import DATA_TABLE_NAME +from constants import COMPUTE_POOL, DATA_TABLE_NAME, DB_NAME, JOB_STAGE, SCHEMA_NAME logging.getLogger().setLevel(logging.ERROR) +@remote(COMPUTE_POOL, stage_name=JOB_STAGE, database=DB_NAME, schema=SCHEMA_NAME, target_instances=2) +def train_model(input_data: DataSource) -> Any: + session = Session.builder.getOrCreate() + return modeling.train_model(session, input_data) def run_pipeline( session: Session, @@ -41,7 +48,7 @@ def run_pipeline( ) print("Training model...") - model_obj = modeling.train_model(session, train_ds.read.data_sources[0]).result() + model_obj = train_model(train_ds.read.data_sources[0]).result() print("Evaluating model...") train_metrics = modeling.evaluate_model( @@ -119,6 +126,7 @@ def run_pipeline( session_builder = session_builder.config("connection_name", args.connection) session = session_builder.getOrCreate() modeling.ensure_environment(session) + cp.register_pickle_by_value(modeling) run_pipeline( session, @@ -127,4 +135,4 @@ def run_pipeline( model_name=args.model_name or args.source_table, force_refresh=args.force_refresh, no_register=args.no_register, - ) + ) \ No newline at end of file