From 9c207659c8cb0f62ab4fdcb0a854291a18f4986e Mon Sep 17 00:00:00 2001 From: tariq-hasan Date: Sun, 15 Feb 2026 17:01:41 -0500 Subject: [PATCH 1/4] chore(spark): add kubeflow-spark-api dependency Signed-off-by: tariq-hasan --- hack/Dockerfile.spark-e2e-runner | 2 +- pyproject.toml | 2 ++ uv.lock | 12 ++++++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/hack/Dockerfile.spark-e2e-runner b/hack/Dockerfile.spark-e2e-runner index 65a3569b0..0613a8468 100644 --- a/hack/Dockerfile.spark-e2e-runner +++ b/hack/Dockerfile.spark-e2e-runner @@ -9,7 +9,7 @@ COPY pyproject.toml README.md LICENSE ./ COPY kubeflow/ kubeflow/ COPY examples/ examples/ -RUN pip install --no-cache-dir .[spark] +RUN pip install --no-cache-dir --pre .[spark] ENV SPARK_TEST_NAMESPACE=spark-test ENV PYTHONUNBUFFERED=1 diff --git a/pyproject.toml b/pyproject.toml index 8f0aadcdd..965189bd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "pydantic>=2.10.0", "kubeflow-trainer-api>=2.0.0", "kubeflow-katib-api>=0.19.0", + "kubeflow-spark-api>=2.3.0", ] [project.optional-dependencies] @@ -64,6 +65,7 @@ dev = [ "coverage>=7.0", "kubeflow_trainer_api@git+https://github.com/kubeflow/trainer.git@master#subdirectory=api/python_api", "kubeflow_katib_api@git+https://github.com/kubeflow/katib.git@master#subdirectory=api/python_api", + "kubeflow_spark_api@git+https://github.com/kubeflow/spark-operator.git@master#subdirectory=api/python_api", "ruff>=0.12.2", "pre-commit>=4.2.0", "PyGithub>=2.7.0", diff --git a/uv.lock b/uv.lock index 632af8922..960b8b657 100644 --- a/uv.lock +++ b/uv.lock @@ -883,6 +883,7 @@ name = "kubeflow" source = { editable = "." } dependencies = [ { name = "kubeflow-katib-api" }, + { name = "kubeflow-spark-api" }, { name = "kubeflow-trainer-api" }, { name = "kubernetes" }, { name = "pydantic" }, @@ -908,6 +909,7 @@ dev = [ { name = "git-cliff" }, { name = "kubeflow", extra = ["hub", "spark"] }, { name = "kubeflow-katib-api" }, + { name = "kubeflow-spark-api" }, { name = "kubeflow-trainer-api" }, { name = "pre-commit" }, { name = "pygithub" }, @@ -932,6 +934,7 @@ docs = [ requires-dist = [ { name = "docker", marker = "extra == 'docker'", specifier = ">=6.1.3" }, { name = "kubeflow-katib-api", specifier = ">=0.19.0" }, + { name = "kubeflow-spark-api", specifier = ">=2.3.0" }, { name = "kubeflow-trainer-api", specifier = ">=2.0.0" }, { name = "kubernetes", specifier = ">=27.2.0" }, { name = "model-registry", marker = "extra == 'hub'", specifier = ">=0.3.6" }, @@ -947,6 +950,7 @@ dev = [ { name = "git-cliff", specifier = ">=2.11.0" }, { name = "kubeflow", extras = ["hub", "spark"] }, { name = "kubeflow-katib-api", git = "https://github.com/kubeflow/katib.git?subdirectory=api%2Fpython_api&rev=master" }, + { name = "kubeflow-spark-api", git = "https://github.com/kubeflow/spark-operator.git?subdirectory=api%2Fpython_api&rev=master" }, { name = "kubeflow-trainer-api", git = "https://github.com/kubeflow/trainer.git?subdirectory=api%2Fpython_api&rev=master" }, { name = "pre-commit", specifier = ">=4.2.0" }, { name = "pygithub", specifier = ">=2.7.0" }, @@ -973,6 +977,14 @@ dependencies = [ { name = "pydantic" }, ] +[[package]] +name = "kubeflow-spark-api" +version = "2.3.0" +source = { git = "https://github.com/kubeflow/spark-operator.git?subdirectory=api%2Fpython_api&rev=master#2ed985ba8ae8956f0aa1da5e9ef760694412d355" } +dependencies = [ + { name = "pydantic" }, +] + [[package]] name = "kubeflow-trainer-api" version = "2.1.0" From 0fabaa0601cb1716ac02313bec16716ebdfa6388 Mon Sep 17 00:00:00 2001 From: tariq-hasan Date: Sun, 15 Feb 2026 17:02:11 -0500 Subject: [PATCH 2/4] chore(spark): migrate options to typed Pydantic models Signed-off-by: tariq-hasan --- kubeflow/spark/types/options.py | 145 ++++++++++++++--------- kubeflow/spark/types/options_test.py | 170 ++++++++++++++------------- 2 files changed, 179 insertions(+), 136 deletions(-) diff --git a/kubeflow/spark/types/options.py b/kubeflow/spark/types/options.py index dd6022e8c..2cdf9ab63 100644 --- a/kubeflow/spark/types/options.py +++ b/kubeflow/spark/types/options.py @@ -23,6 +23,8 @@ from dataclasses import dataclass from typing import Any +from kubeflow_spark_api import models + from kubeflow.spark.backends.base import RuntimeBackend @@ -46,11 +48,13 @@ class Labels: labels: dict[str, str] - def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: - """Apply labels to the CRD specification. + def __call__( + self, spark_connect: models.SparkV1alpha1SparkConnect, backend: RuntimeBackend + ) -> None: + """Apply labels to the SparkConnect model. Args: - crd: CRD specification dictionary to modify. + spark_connect: SparkConnect model to modify. backend: Backend instance for validation. Raises: @@ -64,9 +68,9 @@ def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: f"Supported backends: KubernetesBackend" ) - metadata = crd.setdefault("metadata", {}) - labels = metadata.setdefault("labels", {}) - labels.update(self.labels) + if spark_connect.metadata.labels is None: + spark_connect.metadata.labels = {} + spark_connect.metadata.labels.update(self.labels) @dataclass @@ -94,11 +98,13 @@ class Annotations: annotations: dict[str, str] - def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: - """Apply annotations to the CRD specification. + def __call__( + self, spark_connect: models.SparkV1alpha1SparkConnect, backend: RuntimeBackend + ) -> None: + """Apply annotations to the SparkConnect model. Args: - crd: CRD specification dictionary to modify. + spark_connect: SparkConnect model to modify. backend: Backend instance for validation. Raises: @@ -112,9 +118,9 @@ def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: f"Supported backends: KubernetesBackend" ) - metadata = crd.setdefault("metadata", {}) - annotations = metadata.setdefault("annotations", {}) - annotations.update(self.annotations) + if spark_connect.metadata.annotations is None: + spark_connect.metadata.annotations = {} + spark_connect.metadata.annotations.update(self.annotations) @dataclass @@ -155,11 +161,13 @@ class PodTemplateOverride: role: str # "driver" or "executor" template: dict[str, Any] - def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: - """Apply pod template override to the CRD specification. + def __call__( + self, spark_connect: models.SparkV1alpha1SparkConnect, backend: RuntimeBackend + ) -> None: + """Apply pod template override to the SparkConnect model. Args: - crd: CRD specification dictionary to modify. + spark_connect: SparkConnect model to modify. backend: Backend instance for validation. Raises: @@ -174,18 +182,32 @@ def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: ) if self.role == "driver": - spec_key = "server" + role_spec = spark_connect.spec.server elif self.role == "executor": - spec_key = "executor" + role_spec = spark_connect.spec.executor else: raise ValueError(f"Invalid role '{self.role}'. Must be 'driver' or 'executor'.") - spec = crd.setdefault("spec", {}) - role_spec = spec.setdefault(spec_key, {}) - template = role_spec.setdefault("template", {}) + # Get or create template + if role_spec.template is None: + role_spec.template = models.IoK8sApiCoreV1PodTemplateSpec() + + # Convert existing template to dict, merge, and convert back + existing_dict = role_spec.template.to_dict() if role_spec.template else {} + self._deep_merge(existing_dict, self.template) + + # Ensure spec.containers exists (required by PodSpec validation) + if ( + "spec" in existing_dict + and existing_dict["spec"] is not None + and ( + "containers" not in existing_dict["spec"] + or existing_dict["spec"]["containers"] is None + ) + ): + existing_dict["spec"]["containers"] = [] - # Deep merge template - self._deep_merge(template, self.template) + role_spec.template = models.IoK8sApiCoreV1PodTemplateSpec.from_dict(existing_dict) @staticmethod def _deep_merge(target: dict[str, Any], source: dict[str, Any]) -> None: @@ -219,11 +241,13 @@ class NodeSelector: selectors: dict[str, str] - def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: - """Apply node selector constraints to the CRD specification. + def __call__( + self, spark_connect: models.SparkV1alpha1SparkConnect, backend: RuntimeBackend + ) -> None: + """Apply node selector constraints to the SparkConnect model. Args: - crd: CRD specification dictionary to modify. + spark_connect: SparkConnect model to modify. backend: Backend instance for validation. Raises: @@ -237,14 +261,16 @@ def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: f"Supported backends: KubernetesBackend" ) - spec = crd.setdefault("spec", {}) - - for role in ["server", "executor"]: - role_spec = spec.setdefault(role, {}) - template = role_spec.setdefault("template", {}) - pod_spec = template.setdefault("spec", {}) - node_selector = pod_spec.setdefault("nodeSelector", {}) - node_selector.update(self.selectors) + # Apply to both server and executor + for role_spec in [spark_connect.spec.server, spark_connect.spec.executor]: + if role_spec.template is None: + role_spec.template = models.IoK8sApiCoreV1PodTemplateSpec() + if role_spec.template.spec is None: + # PodSpec requires containers field (can be empty list) + role_spec.template.spec = models.IoK8sApiCoreV1PodSpec(containers=[]) + if role_spec.template.spec.node_selector is None: + role_spec.template.spec.node_selector = {} + role_spec.template.spec.node_selector.update(self.selectors) @dataclass @@ -280,11 +306,13 @@ class Toleration: value: str = "" effect: str = "NoSchedule" - def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: - """Apply toleration to the CRD specification. + def __call__( + self, spark_connect: models.SparkV1alpha1SparkConnect, backend: RuntimeBackend + ) -> None: + """Apply toleration to the SparkConnect model. Args: - crd: CRD specification dictionary to modify. + spark_connect: SparkConnect model to modify. backend: Backend instance for validation. Raises: @@ -298,22 +326,24 @@ def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: f"Supported backends: KubernetesBackend" ) - toleration = { - "key": self.key, - "operator": self.operator, - "effect": self.effect, - } - if self.value: - toleration["value"] = self.value - - spec = crd.setdefault("spec", {}) - - for role in ["server", "executor"]: - role_spec = spec.setdefault(role, {}) - template = role_spec.setdefault("template", {}) - pod_spec = template.setdefault("spec", {}) - tolerations = pod_spec.setdefault("tolerations", []) - tolerations.append(toleration) + # Create toleration model + toleration = models.IoK8sApiCoreV1Toleration( + key=self.key, + operator=self.operator, + effect=self.effect, + value=self.value if self.value else None, + ) + + # Apply to both server and executor + for role_spec in [spark_connect.spec.server, spark_connect.spec.executor]: + if role_spec.template is None: + role_spec.template = models.IoK8sApiCoreV1PodTemplateSpec() + if role_spec.template.spec is None: + # PodSpec requires containers field (can be empty list) + role_spec.template.spec = models.IoK8sApiCoreV1PodSpec(containers=[]) + if role_spec.template.spec.tolerations is None: + role_spec.template.spec.tolerations = [] + role_spec.template.spec.tolerations.append(toleration) @dataclass @@ -355,14 +385,16 @@ class Name: name: str - def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: - """Apply custom name to CRD metadata. + def __call__( + self, spark_connect: models.SparkV1alpha1SparkConnect, backend: RuntimeBackend + ) -> None: + """Apply custom name to SparkConnect metadata. Note: This method exists for interface consistency but is not typically called, as the name is extracted earlier in the backend flow. Args: - crd: CRD specification dictionary to modify. + spark_connect: SparkConnect model to modify. backend: Backend instance for validation. Raises: @@ -376,5 +408,4 @@ def __call__(self, crd: dict[str, Any], backend: RuntimeBackend) -> None: f"Supported backends: KubernetesBackend" ) - metadata = crd.setdefault("metadata", {}) - metadata["name"] = self.name + spark_connect.metadata.name = self.name diff --git a/kubeflow/spark/types/options_test.py b/kubeflow/spark/types/options_test.py index 1ca8dbb2b..03790d8be 100644 --- a/kubeflow/spark/types/options_test.py +++ b/kubeflow/spark/types/options_test.py @@ -16,8 +16,10 @@ from unittest.mock import MagicMock +from kubeflow_spark_api import models import pytest +from kubeflow.spark.backends.kubernetes import constants from kubeflow.spark.backends.kubernetes.backend import KubernetesBackend from kubeflow.spark.types.options import ( Annotations, @@ -46,88 +48,112 @@ def mock_non_k8s_backend(): return backend +@pytest.fixture +def spark_connect_model(): + """Create a minimal SparkConnect model for testing.""" + return models.SparkV1alpha1SparkConnect( + api_version=f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}", + kind=constants.SPARK_CONNECT_KIND, + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name="test-session", + namespace="default", + ), + spec=models.SparkV1alpha1SparkConnectSpec( + spark_version=constants.DEFAULT_SPARK_VERSION, + image=constants.DEFAULT_SPARK_IMAGE, + server=models.SparkV1alpha1ServerSpec( + cores=constants.DEFAULT_DRIVER_CPU, + memory=constants.DEFAULT_DRIVER_MEMORY, + ), + executor=models.SparkV1alpha1ExecutorSpec( + instances=2, + cores=constants.DEFAULT_EXECUTOR_CPU, + memory=constants.DEFAULT_EXECUTOR_MEMORY, + ), + ), + ) + + class TestLabels: """Tests for Labels option.""" - def test_labels_apply_to_crd(self, mock_k8s_backend): + def test_labels_apply_to_crd(self, mock_k8s_backend, spark_connect_model): """Labels option adds labels to CRD metadata.""" option = Labels({"app": "spark", "team": "data-eng"}) - crd = {} - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) - assert crd["metadata"]["labels"]["app"] == "spark" - assert crd["metadata"]["labels"]["team"] == "data-eng" + assert spark_connect_model.metadata.labels["app"] == "spark" + assert spark_connect_model.metadata.labels["team"] == "data-eng" - def test_labels_merge_with_existing(self, mock_k8s_backend): + def test_labels_merge_with_existing(self, mock_k8s_backend, spark_connect_model): """Labels option merges with existing labels.""" + spark_connect_model.metadata.labels = {"existing": "label"} option = Labels({"new-label": "value"}) - crd = {"metadata": {"labels": {"existing": "label"}}} - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) - assert crd["metadata"]["labels"]["existing"] == "label" - assert crd["metadata"]["labels"]["new-label"] == "value" + assert spark_connect_model.metadata.labels["existing"] == "label" + assert spark_connect_model.metadata.labels["new-label"] == "value" - def test_labels_incompatible_backend(self, mock_non_k8s_backend): + def test_labels_incompatible_backend(self, mock_non_k8s_backend, spark_connect_model): """Labels option raises error for incompatible backend.""" option = Labels({"app": "spark"}) - crd = {} with pytest.raises(ValueError, match="not compatible"): - option(crd, mock_non_k8s_backend) + option(spark_connect_model, mock_non_k8s_backend) class TestAnnotations: """Tests for Annotations option.""" - def test_annotations_apply_to_crd(self, mock_k8s_backend): + def test_annotations_apply_to_crd(self, mock_k8s_backend, spark_connect_model): """Annotations option adds annotations to CRD metadata.""" option = Annotations({"description": "ETL pipeline", "owner": "data-team"}) - crd = {} - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) - assert crd["metadata"]["annotations"]["description"] == "ETL pipeline" - assert crd["metadata"]["annotations"]["owner"] == "data-team" + assert spark_connect_model.metadata.annotations["description"] == "ETL pipeline" + assert spark_connect_model.metadata.annotations["owner"] == "data-team" - def test_annotations_incompatible_backend(self, mock_non_k8s_backend): + def test_annotations_incompatible_backend(self, mock_non_k8s_backend, spark_connect_model): """Annotations option raises error for incompatible backend.""" option = Annotations({"description": "test"}) - crd = {} with pytest.raises(ValueError, match="not compatible"): - option(crd, mock_non_k8s_backend) + option(spark_connect_model, mock_non_k8s_backend) class TestNodeSelector: """Tests for NodeSelector option.""" - def test_node_selector_applies_to_both_roles(self, mock_k8s_backend): + def test_node_selector_applies_to_both_roles(self, mock_k8s_backend, spark_connect_model): """NodeSelector option adds selectors to both driver and executor.""" option = NodeSelector({"node-type": "spark", "gpu": "true"}) - crd = {} - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) - assert crd["spec"]["server"]["template"]["spec"]["nodeSelector"]["node-type"] == "spark" - assert crd["spec"]["server"]["template"]["spec"]["nodeSelector"]["gpu"] == "true" - assert crd["spec"]["executor"]["template"]["spec"]["nodeSelector"]["node-type"] == "spark" - assert crd["spec"]["executor"]["template"]["spec"]["nodeSelector"]["gpu"] == "true" + # Check server (driver) + server_node_selector = spark_connect_model.spec.server.template.spec.node_selector + assert server_node_selector["node-type"] == "spark" + assert server_node_selector["gpu"] == "true" + # Check executor + executor_node_selector = spark_connect_model.spec.executor.template.spec.node_selector + assert executor_node_selector["node-type"] == "spark" + assert executor_node_selector["gpu"] == "true" - def test_node_selector_incompatible_backend(self, mock_non_k8s_backend): + def test_node_selector_incompatible_backend(self, mock_non_k8s_backend, spark_connect_model): """NodeSelector option raises error for incompatible backend.""" option = NodeSelector({"node-type": "spark"}) - crd = {} with pytest.raises(ValueError, match="not compatible"): - option(crd, mock_non_k8s_backend) + option(spark_connect_model, mock_non_k8s_backend) class TestToleration: """Tests for Toleration option.""" - def test_toleration_with_value(self, mock_k8s_backend): + def test_toleration_with_value(self, mock_k8s_backend, spark_connect_model): """Toleration option with value.""" option = Toleration( key="spark-workload", @@ -135,48 +161,45 @@ def test_toleration_with_value(self, mock_k8s_backend): value="true", effect="NoSchedule", ) - crd = {} - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) - tolerations = crd["spec"]["server"]["template"]["spec"]["tolerations"] + tolerations = spark_connect_model.spec.server.template.spec.tolerations assert len(tolerations) == 1 - assert tolerations[0]["key"] == "spark-workload" - assert tolerations[0]["operator"] == "Equal" - assert tolerations[0]["value"] == "true" - assert tolerations[0]["effect"] == "NoSchedule" + assert tolerations[0].key == "spark-workload" + assert tolerations[0].operator == "Equal" + assert tolerations[0].value == "true" + assert tolerations[0].effect == "NoSchedule" - def test_toleration_without_value(self, mock_k8s_backend): + def test_toleration_without_value(self, mock_k8s_backend, spark_connect_model): """Toleration option without value (operator=Exists).""" option = Toleration( key="dedicated", operator="Exists", effect="NoSchedule", ) - crd = {} - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) - tolerations = crd["spec"]["server"]["template"]["spec"]["tolerations"] + tolerations = spark_connect_model.spec.server.template.spec.tolerations assert len(tolerations) == 1 - assert tolerations[0]["key"] == "dedicated" - assert tolerations[0]["operator"] == "Exists" - assert "value" not in tolerations[0] # Value not included when empty - assert tolerations[0]["effect"] == "NoSchedule" + assert tolerations[0].key == "dedicated" + assert tolerations[0].operator == "Exists" + assert tolerations[0].value is None # Value is None when empty + assert tolerations[0].effect == "NoSchedule" - def test_toleration_incompatible_backend(self, mock_non_k8s_backend): + def test_toleration_incompatible_backend(self, mock_non_k8s_backend, spark_connect_model): """Toleration option raises error for incompatible backend.""" option = Toleration(key="test", operator="Exists") - crd = {} with pytest.raises(ValueError, match="not compatible"): - option(crd, mock_non_k8s_backend) + option(spark_connect_model, mock_non_k8s_backend) class TestPodTemplateOverride: """Tests for PodTemplateOverride option.""" - def test_pod_template_driver(self, mock_k8s_backend): + def test_pod_template_driver(self, mock_k8s_backend, spark_connect_model): """PodTemplateOverride applies to driver.""" option = PodTemplateOverride( role="driver", @@ -189,14 +212,15 @@ def test_pod_template_driver(self, mock_k8s_backend): } }, ) - crd = {} - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) + # Convert to dict to verify merged template + crd = spark_connect_model.to_dict() assert crd["spec"]["server"]["template"]["spec"]["securityContext"]["runAsUser"] == 1000 assert crd["spec"]["server"]["template"]["spec"]["securityContext"]["fsGroup"] == 1000 - def test_pod_template_executor(self, mock_k8s_backend): + def test_pod_template_executor(self, mock_k8s_backend, spark_connect_model): """PodTemplateOverride applies to executor.""" option = PodTemplateOverride( role="executor", @@ -208,27 +232,26 @@ def test_pod_template_executor(self, mock_k8s_backend): } }, ) - crd = {} - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) + # Convert to dict to verify merged template + crd = spark_connect_model.to_dict() assert crd["spec"]["executor"]["template"]["spec"]["securityContext"]["runAsUser"] == 1000 - def test_pod_template_invalid_role(self, mock_k8s_backend): + def test_pod_template_invalid_role(self, mock_k8s_backend, spark_connect_model): """PodTemplateOverride raises error for invalid role.""" option = PodTemplateOverride(role="invalid", template={"spec": {}}) - crd = {} with pytest.raises(ValueError, match="Invalid role"): - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) - def test_pod_template_incompatible_backend(self, mock_non_k8s_backend): + def test_pod_template_incompatible_backend(self, mock_non_k8s_backend, spark_connect_model): """PodTemplateOverride option raises error for incompatible backend.""" option = PodTemplateOverride(role="driver", template={"spec": {}}) - crd = {} with pytest.raises(ValueError, match="not compatible"): - option(crd, mock_non_k8s_backend) + option(spark_connect_model, mock_non_k8s_backend) class TestNameOption: @@ -239,28 +262,17 @@ def test_name_option_basic(self): option = Name("my-custom-session") assert option.name == "my-custom-session" - def test_name_option_apply_to_crd(self, mock_k8s_backend): + def test_name_option_apply_to_crd(self, mock_k8s_backend, spark_connect_model): """Apply Name option to CRD.""" - option = Name("test-session") - crd = {"metadata": {"name": "old-name", "namespace": "default"}} - - option(crd, mock_k8s_backend) - - assert crd["metadata"]["name"] == "test-session" - - def test_name_option_creates_metadata(self, mock_k8s_backend): - """Name option creates metadata if missing.""" - option = Name("test-session") - crd = {} + option = Name("new-session-name") - option(crd, mock_k8s_backend) + option(spark_connect_model, mock_k8s_backend) - assert crd["metadata"]["name"] == "test-session" + assert spark_connect_model.metadata.name == "new-session-name" - def test_name_option_incompatible_backend(self, mock_non_k8s_backend): + def test_name_option_incompatible_backend(self, mock_non_k8s_backend, spark_connect_model): """Name option raises error for incompatible backend.""" option = Name("test-session") - crd = {} with pytest.raises(ValueError, match="not compatible"): - option(crd, mock_non_k8s_backend) + option(spark_connect_model, mock_non_k8s_backend) From e1cc4ac71e9ddd7a851fb5e9b869da49d060d328 Mon Sep 17 00:00:00 2001 From: tariq-hasan Date: Sun, 15 Feb 2026 17:02:33 -0500 Subject: [PATCH 3/4] chore(spark): migrate utils to typed Pydantic models Signed-off-by: tariq-hasan --- kubeflow/spark/backends/kubernetes/utils.py | 247 +++++++++++------- .../spark/backends/kubernetes/utils_test.py | 175 ++++++++----- 2 files changed, 264 insertions(+), 158 deletions(-) diff --git a/kubeflow/spark/backends/kubernetes/utils.py b/kubeflow/spark/backends/kubernetes/utils.py index 9d7be81ab..bffb1274a 100644 --- a/kubeflow/spark/backends/kubernetes/utils.py +++ b/kubeflow/spark/backends/kubernetes/utils.py @@ -14,20 +14,16 @@ """Utility functions for Kubernetes Spark backend.""" -import contextlib -from datetime import datetime import re from typing import Any, Optional from urllib.parse import urlparse import uuid +from kubeflow_spark_api import models + from kubeflow.spark.backends.kubernetes import constants from kubeflow.spark.types.types import Driver, Executor, SparkConnectInfo, SparkConnectState -# Type alias for backend to avoid circular imports -if False: # TYPE_CHECKING equivalent without import - pass - def generate_session_name() -> str: """Generate a unique session name. @@ -88,6 +84,95 @@ def build_service_url(info: SparkConnectInfo) -> str: return f"sc://{service}.{info.namespace}.svc.cluster.local:{constants.SPARK_CONNECT_PORT}" +def get_server_spec_from_driver( + driver: Optional[Driver] = None, +) -> models.SparkV1alpha1ServerSpec: + """Convert SDK Driver to API ServerSpec. + + Args: + driver: SDK Driver configuration. + + Returns: + API ServerSpec model. + """ + cores = constants.DEFAULT_DRIVER_CPU + memory = _memory_kubernetes_to_spark(constants.DEFAULT_DRIVER_MEMORY) + template = None + + if driver: + if driver.resources: + if "cpu" in driver.resources: + cores = int(driver.resources["cpu"]) + if "memory" in driver.resources: + memory = _memory_kubernetes_to_spark(driver.resources["memory"]) + + if driver.service_account: + # PodSpec requires containers field (can be empty list) + template = models.IoK8sApiCoreV1PodTemplateSpec( + spec=models.IoK8sApiCoreV1PodSpec( + containers=[], + service_account_name=driver.service_account, + ) + ) + + return models.SparkV1alpha1ServerSpec( + cores=cores, + memory=memory, + template=template, + ) + + +def get_executor_spec_from_executor( + executor: Optional[Executor] = None, + num_executors: Optional[int] = None, + resources_per_executor: Optional[dict[str, str]] = None, +) -> models.SparkV1alpha1ExecutorSpec: + """Convert SDK Executor to API ExecutorSpec. + + Precedence rules: + - Instances: executor.num_instances > num_executors > default + - Resources: executor.resources_per_executor > resources_per_executor + + Args: + executor: SDK Executor configuration. + num_executors: Simple mode number of executors. + resources_per_executor: Simple mode resource requirements. + + Returns: + API ExecutorSpec model. + """ + # Determine number of instances + if executor and executor.num_instances is not None: + instances = executor.num_instances + elif num_executors is not None: + instances = num_executors + else: + instances = constants.DEFAULT_NUM_EXECUTORS + + # Determine resource dict + resource_dict = None + if executor and executor.resources_per_executor: + resource_dict = executor.resources_per_executor + elif resources_per_executor: + resource_dict = resources_per_executor + + # Extract cores and memory + cores = constants.DEFAULT_EXECUTOR_CPU + memory = _memory_kubernetes_to_spark(constants.DEFAULT_EXECUTOR_MEMORY) + + if resource_dict: + if "cpu" in resource_dict: + cores = int(resource_dict["cpu"]) + if "memory" in resource_dict: + memory = _memory_kubernetes_to_spark(resource_dict["memory"]) + + return models.SparkV1alpha1ExecutorSpec( + instances=instances, + cores=cores, + memory=memory, + ) + + def build_spark_connect_crd( name: str, namespace: str, @@ -99,8 +184,8 @@ def build_spark_connect_crd( executor: Optional[Executor] = None, options: Optional[list] = None, backend: Optional[Any] = None, -) -> dict[str, Any]: - """Build SparkConnect CRD manifest (KEP-107 compliant). +) -> models.SparkV1alpha1SparkConnect: + """Build SparkConnect CRD using typed API models (KEP-107 compliant). Precedence rules: - Executor instances: executor.num_instances > num_executors > default @@ -121,72 +206,18 @@ def build_spark_connect_crd( backend: Backend instance for option validation. Returns: - SparkConnect CRD as dictionary. + SparkConnect CRD as typed Pydantic model. """ spark_version = spark_version or constants.DEFAULT_SPARK_VERSION - # Precedence: executor.num_instances > num_executors > default - executor_spec: dict[str, Any] = {} - if executor and executor.num_instances is not None: - executor_spec["instances"] = executor.num_instances - elif num_executors is not None: - executor_spec["instances"] = num_executors - else: - executor_spec["instances"] = constants.DEFAULT_NUM_EXECUTORS + # Build server spec using conversion function + server_spec = get_server_spec_from_driver(driver) - # Precedence: executor.resources_per_executor > resources_per_executor - resource_dict = None - if executor and executor.resources_per_executor: - resource_dict = executor.resources_per_executor - elif resources_per_executor: - resource_dict = resources_per_executor + # Build executor spec using conversion function + executor_spec = get_executor_spec_from_executor(executor, num_executors, resources_per_executor) - if resource_dict: - if "cpu" in resource_dict: - executor_spec["cores"] = int(resource_dict["cpu"]) - if "memory" in resource_dict: - executor_spec["memory"] = _memory_kubernetes_to_spark(resource_dict["memory"]) - if "cores" not in executor_spec: - executor_spec["cores"] = constants.DEFAULT_EXECUTOR_CPU - if "memory" not in executor_spec: - executor_spec["memory"] = _memory_kubernetes_to_spark(constants.DEFAULT_EXECUTOR_MEMORY) - - server_spec: dict[str, Any] = {} - if driver: - if driver.resources: - if "cpu" in driver.resources: - server_spec["cores"] = int(driver.resources["cpu"]) - if "memory" in driver.resources: - server_spec["memory"] = _memory_kubernetes_to_spark(driver.resources["memory"]) - - if driver.service_account: - if "template" not in server_spec: - server_spec["template"] = {"spec": {}} - server_spec["template"]["spec"]["serviceAccountName"] = driver.service_account - if "cores" not in server_spec: - server_spec["cores"] = constants.DEFAULT_DRIVER_CPU - if "memory" not in server_spec: - server_spec["memory"] = _memory_kubernetes_to_spark(constants.DEFAULT_DRIVER_MEMORY) - - crd: dict[str, Any] = { - "apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}", - "kind": constants.SPARK_CONNECT_KIND, - "metadata": { - "name": name, - "namespace": namespace, - }, - "spec": { - "sparkVersion": spark_version, - "server": server_spec, - "executor": executor_spec, - }, - } - - # Precedence: driver.image > default - if driver and driver.image: - crd["spec"]["image"] = driver.image - else: - crd["spec"]["image"] = constants.DEFAULT_SPARK_IMAGE + # Determine image (driver.image > default) + image = driver.image if driver and driver.image else constants.DEFAULT_SPARK_IMAGE # Use direct JAR URL to avoid Ivy cache (container may not have writable ~/.ivy2) connect_jar_url = ( @@ -206,49 +237,69 @@ def build_spark_connect_crd( for k, v in spark_conf.items(): if k != "spark.jars": base_conf[k] = v - crd["spec"]["sparkConf"] = base_conf + + # Build the typed SparkConnect model + spark_connect = models.SparkV1alpha1SparkConnect( + api_version=f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}", + kind=constants.SPARK_CONNECT_KIND, + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name=name, + namespace=namespace, + ), + spec=models.SparkV1alpha1SparkConnectSpec( + spark_version=spark_version, + image=image, + server=server_spec, + executor=executor_spec, + spark_conf=base_conf, + ), + ) # Apply options - extensibility without API changes (callable pattern) if options and backend is not None: for option in options: if callable(option): - option(crd, backend) + option(spark_connect, backend) - return crd + return spark_connect -def parse_spark_connect_status(crd_response: dict[str, Any]) -> SparkConnectInfo: - """Parse SparkConnect CRD response into SparkConnectInfo. +def get_spark_connect_info_from_cr( + spark_connect_cr: models.SparkV1alpha1SparkConnect, +) -> SparkConnectInfo: + """Convert API SparkConnect model to SDK SparkConnectInfo. Args: - crd_response: Raw CRD response from Kubernetes API. + spark_connect_cr: API SparkConnect model. Returns: - SparkConnectInfo with parsed status. + SDK SparkConnectInfo dataclass. + + Raises: + ValueError: If the CR is invalid. """ - metadata = crd_response.get("metadata", {}) - status = crd_response.get("status", {}) - server_status = status.get("server", {}) - - state_str = status.get("state", "") - try: - state = SparkConnectState(state_str) if state_str else SparkConnectState.PROVISIONING - except ValueError: - state = SparkConnectState.PROVISIONING - - # Parse creation timestamp - creation_timestamp = None - creation_ts = metadata.get("creationTimestamp") - if creation_ts: - with contextlib.suppress(ValueError, AttributeError): - creation_timestamp = datetime.fromisoformat(creation_ts.replace("Z", "+00:00")) + if not (spark_connect_cr.metadata and spark_connect_cr.metadata.name): + raise ValueError(f"SparkConnect CR is invalid: {spark_connect_cr}") + + # Parse state + state = SparkConnectState.PROVISIONING + if spark_connect_cr.status and spark_connect_cr.status.state: + try: + state = SparkConnectState(spark_connect_cr.status.state) + except ValueError: + state = SparkConnectState.PROVISIONING + + # Extract server status + server_status = None + if spark_connect_cr.status and spark_connect_cr.status.server: + server_status = spark_connect_cr.status.server return SparkConnectInfo( - name=metadata.get("name", ""), - namespace=metadata.get("namespace", ""), + name=spark_connect_cr.metadata.name, + namespace=spark_connect_cr.metadata.namespace or "", state=state, - pod_name=server_status.get("podName"), - pod_ip=server_status.get("podIp"), - service_name=server_status.get("serviceName"), - creation_timestamp=creation_timestamp, + pod_name=server_status.pod_name if server_status else None, + pod_ip=server_status.pod_ip if server_status else None, + service_name=server_status.service_name if server_status else None, + creation_timestamp=spark_connect_cr.metadata.creation_timestamp, ) diff --git a/kubeflow/spark/backends/kubernetes/utils_test.py b/kubeflow/spark/backends/kubernetes/utils_test.py index 4c97418ba..81ebdb5c4 100644 --- a/kubeflow/spark/backends/kubernetes/utils_test.py +++ b/kubeflow/spark/backends/kubernetes/utils_test.py @@ -14,6 +14,7 @@ """Unit tests for Kubernetes Spark backend utilities.""" +from kubeflow_spark_api import models import pytest from kubeflow.spark.backends.kubernetes import constants @@ -22,7 +23,7 @@ build_service_url, build_spark_connect_crd, generate_session_name, - parse_spark_connect_status, + get_spark_connect_info_from_cr, validate_spark_connect_url, ) from kubeflow.spark.types.types import Driver, Executor, SparkConnectInfo, SparkConnectState @@ -111,7 +112,8 @@ class TestBuildSparkConnectCrd: def test_minimal_crd(self): """U01: Build SparkConnect CRD with minimal config.""" - crd = build_spark_connect_crd(name="test-session", namespace="default") + spark_connect = build_spark_connect_crd(name="test-session", namespace="default") + crd = spark_connect.to_dict() assert ( crd["apiVersion"] @@ -130,31 +132,34 @@ def test_minimal_crd(self): def test_with_num_executors(self): """U02: Build CRD with num_executors.""" - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", num_executors=3, ) + crd = spark_connect.to_dict() assert crd["spec"]["executor"]["instances"] == 3 def test_with_resources(self): """U03: Build CRD with resources_per_executor.""" - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", resources_per_executor={"cpu": "2", "memory": "4Gi"}, ) + crd = spark_connect.to_dict() assert crd["spec"]["executor"]["cores"] == 2 assert crd["spec"]["executor"]["memory"] == "4g" def test_with_spark_conf(self): """U04: Build CRD with spark_conf.""" spark_conf = {"spark.sql.adaptive.enabled": "true"} - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", spark_conf=spark_conf, ) + crd = spark_connect.to_dict() assert crd["spec"]["sparkConf"]["spark.jars"].endswith( f"spark-connect_2.12-{constants.DEFAULT_SPARK_VERSION}.jar" ) @@ -162,42 +167,46 @@ def test_with_spark_conf(self): def test_spark_conf_overrides_binding_address(self): """User spark_conf can override default grpc binding address.""" - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", spark_conf={"spark.connect.grpc.binding.address": "127.0.0.1"}, ) + crd = spark_connect.to_dict() assert crd["spec"]["sparkConf"]["spark.connect.grpc.binding.address"] == "127.0.0.1" def test_with_driver_image(self): """U05: Build CRD with custom image via Driver.""" driver = Driver(image="custom-spark:v1") - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", driver=driver, ) + crd = spark_connect.to_dict() assert crd["spec"]["image"] == "custom-spark:v1" def test_with_driver_config(self): """U06: Build CRD with Driver config (KEP-107 resources dict).""" driver = Driver(resources={"cpu": "2", "memory": "2Gi"}) - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", driver=driver, ) + crd = spark_connect.to_dict() assert crd["spec"]["server"]["cores"] == 2 assert crd["spec"]["server"]["memory"] == "2g" def test_with_service_account(self): """U07: Build CRD with service account.""" driver = Driver(service_account="spark-sa") - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", driver=driver, ) + crd = spark_connect.to_dict() assert crd["spec"]["server"]["template"]["spec"]["serviceAccountName"] == "spark-sa" def test_with_executor_config(self): @@ -206,22 +215,24 @@ def test_with_executor_config(self): num_instances=5, resources_per_executor={"cpu": "4", "memory": "8Gi"}, ) - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", executor=executor, ) + crd = spark_connect.to_dict() assert crd["spec"]["executor"]["instances"] == 5 assert crd["spec"]["executor"]["cores"] == 4 assert crd["spec"]["executor"]["memory"] == "8g" def test_app_name(self): """Build CRD with spark.app.name via spark_conf.""" - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", spark_conf={"spark.app.name": "my-spark-app"}, ) + crd = spark_connect.to_dict() assert crd["spec"]["sparkConf"]["spark.jars"].endswith( f"spark-connect_2.12-{constants.DEFAULT_SPARK_VERSION}.jar" ) @@ -230,12 +241,13 @@ def test_app_name(self): def test_precedence_executor_instances(self): """Test precedence: executor.num_instances > num_executors.""" executor = Executor(num_instances=10) - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", num_executors=5, executor=executor, ) + crd = spark_connect.to_dict() # Executor object should override simple parameter assert crd["spec"]["executor"]["instances"] == 10 @@ -244,24 +256,26 @@ def test_precedence_executor_resources(self): executor = Executor( resources_per_executor={"cpu": "8", "memory": "16Gi"}, ) - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", resources_per_executor={"cpu": "4", "memory": "8Gi"}, executor=executor, ) + crd = spark_connect.to_dict() # Executor object should override simple parameter assert crd["spec"]["executor"]["cores"] == 8 assert crd["spec"]["executor"]["memory"] == "16g" def test_kep107_level2_simple(self): """Test KEP-107 Level 2 (simple mode) example.""" - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", num_executors=5, resources_per_executor={"cpu": "5", "memory": "10Gi"}, ) + crd = spark_connect.to_dict() assert crd["spec"]["executor"]["instances"] == 5 assert crd["spec"]["executor"]["cores"] == 5 assert crd["spec"]["executor"]["memory"] == "10g" @@ -276,12 +290,13 @@ def test_kep107_level3_advanced(self): num_instances=20, resources_per_executor={"cpu": "8", "memory": "32Gi"}, ) - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name="test-session", namespace="default", driver=driver, executor=executor, ) + crd = spark_connect.to_dict() assert crd["spec"]["server"]["cores"] == 4 assert crd["spec"]["server"]["memory"] == "8g" assert ( @@ -292,27 +307,37 @@ def test_kep107_level3_advanced(self): assert crd["spec"]["executor"]["memory"] == "32g" -class TestParseSparkConnectStatus: - """Tests for parse_spark_connect_status function.""" +class TestGetSparkConnectInfoFromCr: + """Tests for get_spark_connect_info_from_cr function.""" - def test_parse_ready_status(self): + @pytest.fixture + def minimal_spec(self): + """Create minimal spec required for SparkConnect model.""" + return models.SparkV1alpha1SparkConnectSpec( + sparkVersion=constants.DEFAULT_SPARK_VERSION, + server=models.SparkV1alpha1ServerSpec(), + executor=models.SparkV1alpha1ExecutorSpec(), + ) + + def test_parse_ready_status(self, minimal_spec): """U08: Parse CRD with Ready state.""" - crd_response = { - "metadata": { - "name": "my-session", - "namespace": "default", - "creationTimestamp": "2025-01-12T10:30:00Z", - }, - "status": { - "state": "Ready", - "server": { - "podName": "my-session-server-0", - "podIp": "10.0.0.5", - "serviceName": "my-session-svc", - }, - }, - } - info = parse_spark_connect_status(crd_response) + spark_connect_cr = models.SparkV1alpha1SparkConnect( + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name="my-session", + namespace="default", + creationTimestamp="2025-01-12T10:30:00Z", + ), + spec=minimal_spec, + status=models.SparkV1alpha1SparkConnectStatus( + state="Ready", + server=models.SparkV1alpha1SparkConnectServerStatus( + podName="my-session-server-0", + podIp="10.0.0.5", + serviceName="my-session-svc", + ), + ), + ) + info = get_spark_connect_info_from_cr(spark_connect_cr) assert info.name == "my-session" assert info.namespace == "default" @@ -322,47 +347,77 @@ def test_parse_ready_status(self): assert info.service_name == "my-session-svc" assert info.creation_timestamp is not None - def test_parse_provisioning_status(self): + def test_parse_provisioning_status(self, minimal_spec): """U09: Parse CRD with Provisioning state.""" - crd_response = { - "metadata": {"name": "new-session", "namespace": "spark"}, - "status": {"state": "Provisioning"}, - } - info = parse_spark_connect_status(crd_response) + spark_connect_cr = models.SparkV1alpha1SparkConnect( + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name="new-session", + namespace="spark", + ), + spec=minimal_spec, + status=models.SparkV1alpha1SparkConnectStatus(state="Provisioning"), + ) + info = get_spark_connect_info_from_cr(spark_connect_cr) assert info.name == "new-session" assert info.namespace == "spark" assert info.state == SparkConnectState.PROVISIONING - def test_parse_failed_status(self): + def test_parse_failed_status(self, minimal_spec): """U10: Parse CRD with Failed state.""" - crd_response = { - "metadata": {"name": "failed-session", "namespace": "default"}, - "status": {"state": "Failed"}, - } - info = parse_spark_connect_status(crd_response) + spark_connect_cr = models.SparkV1alpha1SparkConnect( + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name="failed-session", + namespace="default", + ), + spec=minimal_spec, + status=models.SparkV1alpha1SparkConnectStatus(state="Failed"), + ) + info = get_spark_connect_info_from_cr(spark_connect_cr) assert info.state == SparkConnectState.FAILED - def test_parse_running_status(self): + def test_parse_running_status(self, minimal_spec): """Parse CRD with Running state (operator may set this when server is up).""" - crd_response = { - "metadata": {"name": "run-session", "namespace": "default"}, - "status": { - "state": "Running", - "server": {"podName": "run-session-server", "serviceName": "run-session-svc"}, - }, - } - info = parse_spark_connect_status(crd_response) + spark_connect_cr = models.SparkV1alpha1SparkConnect( + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name="run-session", + namespace="default", + ), + spec=minimal_spec, + status=models.SparkV1alpha1SparkConnectStatus( + state="Running", + server=models.SparkV1alpha1SparkConnectServerStatus( + podName="run-session-server", + serviceName="run-session-svc", + ), + ), + ) + info = get_spark_connect_info_from_cr(spark_connect_cr) assert info.state == SparkConnectState.RUNNING assert info.service_name == "run-session-svc" - def test_parse_empty_status(self): + def test_parse_empty_status(self, minimal_spec): """Parse CRD with empty status.""" - crd_response = { - "metadata": {"name": "new-session", "namespace": "default"}, - } - info = parse_spark_connect_status(crd_response) + spark_connect_cr = models.SparkV1alpha1SparkConnect( + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name="new-session", + namespace="default", + ), + spec=minimal_spec, + ) + info = get_spark_connect_info_from_cr(spark_connect_cr) assert info.state == SparkConnectState.PROVISIONING assert info.pod_name is None + + def test_invalid_cr_missing_name_raises_error(self, minimal_spec): + """Test that CR without name in metadata raises ValueError.""" + spark_connect_cr = models.SparkV1alpha1SparkConnect( + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + namespace="default", + ), + spec=minimal_spec, + ) + with pytest.raises(ValueError, match="SparkConnect CR is invalid"): + get_spark_connect_info_from_cr(spark_connect_cr) From 255b2ad2e3f953b3aa78deebd4b20a137eb0667c Mon Sep 17 00:00:00 2001 From: tariq-hasan Date: Sun, 15 Feb 2026 17:02:50 -0500 Subject: [PATCH 4/4] chore(spark): migrate backend to typed Pydantic models Signed-off-by: tariq-hasan --- kubeflow/spark/backends/kubernetes/backend.py | 17 ++-- .../spark/backends/kubernetes/backend_test.py | 82 ++++++++++++++++++- 2 files changed, 89 insertions(+), 10 deletions(-) diff --git a/kubeflow/spark/backends/kubernetes/backend.py b/kubeflow/spark/backends/kubernetes/backend.py index 9505c858b..e5ace35cd 100644 --- a/kubeflow/spark/backends/kubernetes/backend.py +++ b/kubeflow/spark/backends/kubernetes/backend.py @@ -27,6 +27,7 @@ import time from typing import Optional +from kubeflow_spark_api import models from kubernetes import client, config from pyspark.sql import SparkSession @@ -38,7 +39,7 @@ build_service_url, build_spark_connect_crd, generate_session_name, - parse_spark_connect_status, + get_spark_connect_info_from_cr, ) from kubeflow.spark.types.options import Name from kubeflow.spark.types.types import Driver, Executor, SparkConnectInfo, SparkConnectState @@ -138,7 +139,7 @@ def _create_session( # Extract Name option if present, or auto-generate name, filtered_options = self._extract_name_option(options) - crd = build_spark_connect_crd( + spark_connect = build_spark_connect_crd( name=name, namespace=self.namespace, num_executors=num_executors, @@ -158,7 +159,7 @@ def _create_session( version=constants.SPARK_CONNECT_VERSION, namespace=self.namespace, plural=constants.SPARK_CONNECT_PLURAL, - body=crd, + body=spark_connect.to_dict(), async_req=True, ) response = thread.get(common_constants.DEFAULT_TIMEOUT) @@ -171,7 +172,8 @@ def _create_session( f"Failed to create {constants.SPARK_CONNECT_KIND}: {self.namespace}/{name}" ) from e - return parse_spark_connect_status(response) + spark_connect_cr = models.SparkV1alpha1SparkConnect.from_dict(response) + return get_spark_connect_info_from_cr(spark_connect_cr) def get_session(self, name: str) -> SparkConnectInfo: """Get information about a SparkConnect session.""" @@ -185,7 +187,9 @@ def get_session(self, name: str) -> SparkConnectInfo: async_req=True, ) response = thread.get(common_constants.DEFAULT_TIMEOUT) - return parse_spark_connect_status(response) + + spark_connect_cr = models.SparkV1alpha1SparkConnect.from_dict(response) + return get_spark_connect_info_from_cr(spark_connect_cr) except multiprocessing.TimeoutError as e: raise TimeoutError( f"Timeout to get {constants.SPARK_CONNECT_KIND}: {self.namespace}/{name}" @@ -223,7 +227,8 @@ def list_sessions(self) -> list[SparkConnectInfo]: f"Failed to list {constants.SPARK_CONNECT_KIND}s in namespace: {self.namespace}" ) from e - return [parse_spark_connect_status(item) for item in response.get("items", [])] + spark_connect_list = models.SparkV1alpha1SparkConnectList.from_dict(response) + return [get_spark_connect_info_from_cr(sc) for sc in spark_connect_list.items] def delete_session(self, name: str) -> None: """Delete a SparkConnect session.""" diff --git a/kubeflow/spark/backends/kubernetes/backend_test.py b/kubeflow/spark/backends/kubernetes/backend_test.py index 5c556c27b..864ffbd12 100644 --- a/kubeflow/spark/backends/kubernetes/backend_test.py +++ b/kubeflow/spark/backends/kubernetes/backend_test.py @@ -23,6 +23,7 @@ import pytest from kubeflow.common.types import KubernetesBackendConfig +from kubeflow.spark.backends.kubernetes import constants from kubeflow.spark.backends.kubernetes.backend import KubernetesBackend from kubeflow.spark.test.common import ( DEFAULT_NAMESPACE, @@ -63,9 +64,30 @@ def create_mock_thread_with_error(response=None, raise_timeout=False, raise_erro def mock_get_response(name: str) -> dict: - """Return mock CRD response based on session name.""" + """Return mock CRD response based on session name. + + Note: Responses must include all fields required by the Pydantic model's from_dict(). + """ + base_response = { + "apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}", + "kind": constants.SPARK_CONNECT_KIND, + "spec": { + "sparkVersion": constants.DEFAULT_SPARK_VERSION, + "image": constants.DEFAULT_SPARK_IMAGE, + "server": { + "cores": constants.DEFAULT_DRIVER_CPU, + "memory": constants.DEFAULT_DRIVER_MEMORY, + }, + "executor": { + "instances": 2, + "cores": constants.DEFAULT_EXECUTOR_CPU, + "memory": constants.DEFAULT_EXECUTOR_MEMORY, + }, + }, + } if name == SPARK_CONNECT_READY: return { + **base_response, "metadata": {"name": name, "namespace": DEFAULT_NAMESPACE}, "status": { "state": "Ready", @@ -74,11 +96,13 @@ def mock_get_response(name: str) -> dict: } elif name == SPARK_CONNECT_PROVISIONING: return { + **base_response, "metadata": {"name": name, "namespace": DEFAULT_NAMESPACE}, "status": {"state": "Provisioning"}, } elif name == SPARK_CONNECT_FAILED: return { + **base_response, "metadata": {"name": name, "namespace": DEFAULT_NAMESPACE}, "status": {"state": "Failed"}, } @@ -86,26 +110,76 @@ def mock_get_response(name: str) -> dict: def mock_list_response(*args, **kwargs) -> dict: - """Return mock list response.""" + """Return mock list response. + + Note: List responses must include all fields required by SparkConnectList.from_dict(). + """ + base_spec = { + "sparkVersion": constants.DEFAULT_SPARK_VERSION, + "image": constants.DEFAULT_SPARK_IMAGE, + "server": { + "cores": constants.DEFAULT_DRIVER_CPU, + "memory": constants.DEFAULT_DRIVER_MEMORY, + }, + "executor": { + "instances": 2, + "cores": constants.DEFAULT_EXECUTOR_CPU, + "memory": constants.DEFAULT_EXECUTOR_MEMORY, + }, + } return { + "apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}", + "kind": "SparkConnectList", "items": [ { + "apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}", + "kind": constants.SPARK_CONNECT_KIND, "metadata": {"name": "session-1", "namespace": DEFAULT_NAMESPACE}, + "spec": base_spec, "status": {"state": "Ready"}, }, { + "apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}", + "kind": constants.SPARK_CONNECT_KIND, "metadata": {"name": "session-2", "namespace": DEFAULT_NAMESPACE}, + "spec": base_spec, "status": {"state": "Provisioning"}, }, - ] + ], } def mock_create_response(*args, **kwargs) -> dict: - """Return mock create response.""" + """Return mock create response. + + Note: Create responses must include all fields required by SparkConnect.from_dict(). + """ body = kwargs.get("body", {}) return { + "apiVersion": body.get( + "apiVersion", f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}" + ), + "kind": body.get( + "kind", + constants.SPARK_CONNECT_KIND, + ), "metadata": body.get("metadata", {}), + "spec": body.get( + "spec", + { + "sparkVersion": constants.DEFAULT_SPARK_VERSION, + "image": constants.DEFAULT_SPARK_IMAGE, + "server": { + "cores": constants.DEFAULT_DRIVER_CPU, + "memory": constants.DEFAULT_DRIVER_MEMORY, + }, + "executor": { + "instances": 2, + "cores": constants.DEFAULT_EXECUTOR_CPU, + "memory": constants.DEFAULT_EXECUTOR_MEMORY, + }, + }, + ), "status": {"state": "Provisioning"}, }