Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hack/Dockerfile.spark-e2e-runner
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ COPY pyproject.toml README.md LICENSE ./
COPY kubeflow/ kubeflow/
COPY examples/ examples/

RUN pip install --no-cache-dir .[spark]
RUN pip install --no-cache-dir --pre .[spark]

ENV SPARK_TEST_NAMESPACE=spark-test
ENV PYTHONUNBUFFERED=1
Expand Down
17 changes: 11 additions & 6 deletions kubeflow/spark/backends/kubernetes/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import time
from typing import Optional

from kubeflow_spark_api import models
from kubernetes import client, config
from pyspark.sql import SparkSession

Expand All @@ -38,7 +39,7 @@
build_service_url,
build_spark_connect_crd,
generate_session_name,
parse_spark_connect_status,
get_spark_connect_info_from_cr,
)
from kubeflow.spark.types.options import Name
from kubeflow.spark.types.types import Driver, Executor, SparkConnectInfo, SparkConnectState
Expand Down Expand Up @@ -138,7 +139,7 @@ def _create_session(
# Extract Name option if present, or auto-generate
name, filtered_options = self._extract_name_option(options)

crd = build_spark_connect_crd(
spark_connect = build_spark_connect_crd(
name=name,
namespace=self.namespace,
num_executors=num_executors,
Expand All @@ -158,7 +159,7 @@ def _create_session(
version=constants.SPARK_CONNECT_VERSION,
namespace=self.namespace,
plural=constants.SPARK_CONNECT_PLURAL,
body=crd,
body=spark_connect.to_dict(),
async_req=True,
)
response = thread.get(common_constants.DEFAULT_TIMEOUT)
Expand All @@ -171,7 +172,8 @@ def _create_session(
f"Failed to create {constants.SPARK_CONNECT_KIND}: {self.namespace}/{name}"
) from e

return parse_spark_connect_status(response)
spark_connect_cr = models.SparkV1alpha1SparkConnect.from_dict(response)
return get_spark_connect_info_from_cr(spark_connect_cr)

def get_session(self, name: str) -> SparkConnectInfo:
"""Get information about a SparkConnect session."""
Expand All @@ -185,7 +187,9 @@ def get_session(self, name: str) -> SparkConnectInfo:
async_req=True,
)
response = thread.get(common_constants.DEFAULT_TIMEOUT)
return parse_spark_connect_status(response)

spark_connect_cr = models.SparkV1alpha1SparkConnect.from_dict(response)
return get_spark_connect_info_from_cr(spark_connect_cr)
except multiprocessing.TimeoutError as e:
raise TimeoutError(
f"Timeout to get {constants.SPARK_CONNECT_KIND}: {self.namespace}/{name}"
Expand Down Expand Up @@ -223,7 +227,8 @@ def list_sessions(self) -> list[SparkConnectInfo]:
f"Failed to list {constants.SPARK_CONNECT_KIND}s in namespace: {self.namespace}"
) from e

return [parse_spark_connect_status(item) for item in response.get("items", [])]
spark_connect_list = models.SparkV1alpha1SparkConnectList.from_dict(response)
return [get_spark_connect_info_from_cr(sc) for sc in spark_connect_list.items]

def delete_session(self, name: str) -> None:
"""Delete a SparkConnect session."""
Expand Down
82 changes: 78 additions & 4 deletions kubeflow/spark/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest

from kubeflow.common.types import KubernetesBackendConfig
from kubeflow.spark.backends.kubernetes import constants
from kubeflow.spark.backends.kubernetes.backend import KubernetesBackend
from kubeflow.spark.test.common import (
DEFAULT_NAMESPACE,
Expand Down Expand Up @@ -63,9 +64,30 @@ def create_mock_thread_with_error(response=None, raise_timeout=False, raise_erro


def mock_get_response(name: str) -> dict:
"""Return mock CRD response based on session name."""
"""Return mock CRD response based on session name.

Note: Responses must include all fields required by the Pydantic model's from_dict().
"""
base_response = {
"apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}",
"kind": constants.SPARK_CONNECT_KIND,
"spec": {
"sparkVersion": constants.DEFAULT_SPARK_VERSION,
"image": constants.DEFAULT_SPARK_IMAGE,
"server": {
"cores": constants.DEFAULT_DRIVER_CPU,
"memory": constants.DEFAULT_DRIVER_MEMORY,
},
"executor": {
"instances": 2,
"cores": constants.DEFAULT_EXECUTOR_CPU,
"memory": constants.DEFAULT_EXECUTOR_MEMORY,
},
},
}
if name == SPARK_CONNECT_READY:
return {
**base_response,
"metadata": {"name": name, "namespace": DEFAULT_NAMESPACE},
"status": {
"state": "Ready",
Expand All @@ -74,38 +96,90 @@ def mock_get_response(name: str) -> dict:
}
elif name == SPARK_CONNECT_PROVISIONING:
return {
**base_response,
"metadata": {"name": name, "namespace": DEFAULT_NAMESPACE},
"status": {"state": "Provisioning"},
}
elif name == SPARK_CONNECT_FAILED:
return {
**base_response,
"metadata": {"name": name, "namespace": DEFAULT_NAMESPACE},
"status": {"state": "Failed"},
}
raise ApiException(status=404, reason="Not Found")


def mock_list_response(*args, **kwargs) -> dict:
"""Return mock list response."""
"""Return mock list response.

Note: List responses must include all fields required by SparkConnectList.from_dict().
"""
base_spec = {
"sparkVersion": constants.DEFAULT_SPARK_VERSION,
"image": constants.DEFAULT_SPARK_IMAGE,
"server": {
"cores": constants.DEFAULT_DRIVER_CPU,
"memory": constants.DEFAULT_DRIVER_MEMORY,
},
"executor": {
"instances": 2,
"cores": constants.DEFAULT_EXECUTOR_CPU,
"memory": constants.DEFAULT_EXECUTOR_MEMORY,
},
}
return {
"apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}",
"kind": "SparkConnectList",
"items": [
{
"apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}",
"kind": constants.SPARK_CONNECT_KIND,
"metadata": {"name": "session-1", "namespace": DEFAULT_NAMESPACE},
"spec": base_spec,
"status": {"state": "Ready"},
},
{
"apiVersion": f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}",
"kind": constants.SPARK_CONNECT_KIND,
"metadata": {"name": "session-2", "namespace": DEFAULT_NAMESPACE},
"spec": base_spec,
"status": {"state": "Provisioning"},
},
]
],
}


def mock_create_response(*args, **kwargs) -> dict:
"""Return mock create response."""
"""Return mock create response.

Note: Create responses must include all fields required by SparkConnect.from_dict().
"""
body = kwargs.get("body", {})
return {
"apiVersion": body.get(
"apiVersion", f"{constants.SPARK_CONNECT_GROUP}/{constants.SPARK_CONNECT_VERSION}"
),
"kind": body.get(
"kind",
constants.SPARK_CONNECT_KIND,
),
"metadata": body.get("metadata", {}),
"spec": body.get(
"spec",
{
"sparkVersion": constants.DEFAULT_SPARK_VERSION,
"image": constants.DEFAULT_SPARK_IMAGE,
"server": {
"cores": constants.DEFAULT_DRIVER_CPU,
"memory": constants.DEFAULT_DRIVER_MEMORY,
},
"executor": {
"instances": 2,
"cores": constants.DEFAULT_EXECUTOR_CPU,
"memory": constants.DEFAULT_EXECUTOR_MEMORY,
},
},
),
"status": {"state": "Provisioning"},
}

Expand Down
Loading