From 70ff95adbe0acf7044bac04a2c454f5d16551c3b Mon Sep 17 00:00:00 2001 From: digvijay-y Date: Fri, 13 Feb 2026 00:13:52 +0530 Subject: [PATCH 1/4] Refactored unit test Signed-off-by: digvijay-y --- kubeflow/spark/api/spark_client_test.py | 309 ++++++++++++++++-------- 1 file changed, 211 insertions(+), 98 deletions(-) diff --git a/kubeflow/spark/api/spark_client_test.py b/kubeflow/spark/api/spark_client_test.py index 21ff32919..765b1979f 100644 --- a/kubeflow/spark/api/spark_client_test.py +++ b/kubeflow/spark/api/spark_client_test.py @@ -14,6 +14,8 @@ """Unit tests for SparkClient API.""" +from dataclasses import dataclass +from typing import Any, Optional from unittest.mock import Mock, patch import pytest @@ -24,6 +26,23 @@ from kubeflow.spark.types.types import SparkConnectInfo, SparkConnectState +@dataclass +class TestCase: + """Test case structure for parametrized SparkClient tests.""" + + name: str + expected_status: str + config: dict[str, Any] + expected_output: Optional[Any] = None + expected_error: Optional[str] = None + # Prevent pytest from collecting this dataclass as a test + __test__ = False + + +SUCCESS = "SUCCESS" +EXCEPTION = "EXCEPTION" + + @pytest.fixture def mock_backend(): """Create mock backend for SparkClient tests.""" @@ -37,9 +56,16 @@ def mock_backend(): backend.list_sessions.return_value = [ SparkConnectInfo(name="s1", namespace="default", state=SparkConnectState.READY), ] - backend.get_session.return_value = SparkConnectInfo( - name="test", namespace="default", state=SparkConnectState.READY - ) + + # Configure mock to handle both existing and non-existent sessions + def mock_get_session(session_name): + if session_name == "nonexistent": + raise ValueError("Session not found") + return SparkConnectInfo( + name=session_name, namespace="default", state=SparkConnectState.READY + ) + + backend.get_session.side_effect = mock_get_session backend.create_session.return_value = SparkConnectInfo( name="new-session", namespace="default", state=SparkConnectState.PROVISIONING ) @@ -47,6 +73,7 @@ def mock_backend(): backend._create_session.return_value = ready_info backend._wait_for_session_ready.return_value = ready_info backend.get_connect_url.return_value = ("sc://localhost:15002", None) + backend.get_session_logs.return_value = iter(["log1", "log2"]) return backend @@ -62,110 +89,196 @@ def spark_client(mock_backend): yield client -class TestSparkClientInit: - """Tests for SparkClient initialization.""" - - def test_default_backend(self): - """C01: Init with default creates KubernetesBackendConfig.""" - with patch("kubeflow.spark.api.spark_client.KubernetesBackend"): - client = SparkClient() - assert client.backend is not None - - def test_custom_namespace(self): - """C02: Init with custom namespace.""" - with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock: - SparkClient(backend_config=KubernetesBackendConfig(namespace="spark")) - mock.assert_called_once() - - def test_invalid_backend(self): - """C03: Init with invalid backend raises ValueError.""" +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="default backend initialization", + expected_status=SUCCESS, + config={}, + expected_output="backend_created", + ), + TestCase( + name="custom namespace initialization", + expected_status=SUCCESS, + config={"namespace": "spark"}, + expected_output="backend_created", + ), + TestCase( + name="invalid backend config", + expected_status=EXCEPTION, + config={"backend_config": "invalid"}, + expected_error="ValueError", + ), + ], +) +def test_spark_client_initialization(test_case: TestCase): + """Test SparkClient initialization scenarios.""" + print(f"Running test: {test_case.name}") + + if test_case.expected_status == SUCCESS: + if "namespace" in test_case.config: + with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock: + SparkClient( + backend_config=KubernetesBackendConfig(namespace=test_case.config["namespace"]) + ) + mock.assert_called_once() + else: + with patch("kubeflow.spark.api.spark_client.KubernetesBackend"): + client = SparkClient() + assert client.backend is not None + print(f"✓ {test_case.name} succeeded") + else: with pytest.raises(ValueError): - SparkClient(backend_config="invalid") - - -class TestSparkClientConnect: - """Tests for connect method.""" - - def test_connect_with_url(self, spark_client): - """C04: Connect with URL returns SparkSession.""" - mock_session = Mock() - mock_builder = Mock() - mock_builder.remote.return_value = mock_builder - mock_builder.getOrCreate.return_value = mock_session - - mock_spark = Mock() - mock_spark.builder = mock_builder - - with ( - patch.dict("sys.modules", {"pyspark": Mock(), "pyspark.sql": mock_spark}), - patch("kubeflow.spark.api.spark_client.SparkSession", mock_spark), - ): - pass - - # Test URL validation works - from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url - - assert validate_spark_connect_url("sc://localhost:15002") is True - - def test_connect_with_url_invalid(self, spark_client): - """C04b: Connect with invalid URL raises ValueError.""" + SparkClient(backend_config=test_case.config["backend_config"]) + print(f"✓ {test_case.name} failed as expected") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="connect with valid URL validation", + expected_status=SUCCESS, + config={"url": "sc://localhost:15002"}, + expected_output=True, + ), + TestCase( + name="connect with invalid URL validation", + expected_status=EXCEPTION, + config={"url": "http://localhost:15002"}, + expected_error="ValueError", + ), + TestCase( + name="connect create session verification", + expected_status=SUCCESS, + config={}, + expected_output="backend_not_called", + ), + ], +) +def test_spark_client_connect(test_case: TestCase, spark_client): + """Test SparkClient connect method scenarios.""" + print(f"Running test: {test_case.name}") + + if test_case.expected_status == SUCCESS: + if "url" in test_case.config: + from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url + + result = validate_spark_connect_url(test_case.config["url"]) + assert result == test_case.expected_output + else: + # Verify backend methods are not called initially + spark_client.backend.create_session.assert_not_called() + spark_client.backend.wait_for_session_ready.assert_not_called() + print(f"✓ {test_case.name} succeeded") + else: from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url with pytest.raises(ValueError): - validate_spark_connect_url("http://localhost:15002") - - def test_connect_create_session(self, spark_client, mock_backend): - """C06: Connect without URL creates new session - verifies backend calls.""" - # Since pyspark is not installed, we verify the backend is called correctly - mock_backend.create_session.assert_not_called() - mock_backend.wait_for_session_ready.assert_not_called() - - -class TestSparkClientSessionManagement: - """Tests for session management methods.""" - - def test_list_sessions(self, spark_client, mock_backend): - """C14: list_sessions delegates to backend.""" - result = spark_client.list_sessions() - mock_backend.list_sessions.assert_called_once() - assert len(result) == 1 - - def test_get_session(self, spark_client, mock_backend): - """C15: get_session delegates to backend.""" - result = spark_client.get_session("test") - mock_backend.get_session.assert_called_once_with("test") - assert result.name == "test" - - def test_delete_session(self, spark_client, mock_backend): - """C16: delete_session delegates to backend.""" - spark_client.delete_session("test") - mock_backend.delete_session.assert_called_once_with("test") - - def test_get_session_logs(self, spark_client, mock_backend): - """C17: get_session_logs delegates to backend.""" - mock_backend.get_session_logs.return_value = iter(["log1", "log2"]) - list(spark_client.get_session_logs("test")) - mock_backend.get_session_logs.assert_called_once_with("test", follow=False) - - -class TestSparkClientConnectWithNameOption: - """Tests for connect method with Name option.""" - - def test_connect_with_name_option(self, spark_client, mock_backend): - """C18: Connect passes options to backend including Name option.""" - mock_session = Mock() - mock_backend.create_and_connect.return_value = mock_session - options = [Name("custom-session")] + validate_spark_connect_url(test_case.config["url"]) + print(f"✓ {test_case.name} failed as expected") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="list sessions successfully", expected_status=SUCCESS, config={}, expected_output=1 + ), + TestCase( + name="get existing session", + expected_status=SUCCESS, + config={"session_name": "test"}, + expected_output="test", + ), + TestCase( + name="get non-existent session", + expected_status=EXCEPTION, + config={"session_name": "nonexistent"}, + expected_error="Session not found", + ), + TestCase( + name="delete session", + expected_status=SUCCESS, + config={"session_name": "test", "operation": "delete"}, + expected_output="deleted", + ), + TestCase( + name="get session logs", + expected_status=SUCCESS, + config={"session_name": "test", "operation": "logs"}, + expected_output=2, # Expected number of log entries + ), + ], +) +def test_spark_client_session_management(test_case: TestCase, spark_client, mock_backend): + """Test SparkClient session management operations.""" + print(f"Running test: {test_case.name}") + + if test_case.expected_status == SUCCESS: + if "operation" in test_case.config: + if test_case.config["operation"] == "delete": + spark_client.delete_session(test_case.config["session_name"]) + mock_backend.delete_session.assert_called_once_with( + test_case.config["session_name"] + ) + elif test_case.config["operation"] == "logs": + result = list(spark_client.get_session_logs(test_case.config["session_name"])) + assert len(result) == test_case.expected_output + mock_backend.get_session_logs.assert_called_once_with( + test_case.config["session_name"], follow=False + ) + elif "session_name" in test_case.config: + result = spark_client.get_session(test_case.config["session_name"]) + assert result.name == test_case.expected_output + mock_backend.get_session.assert_called_with(test_case.config["session_name"]) + else: + result = spark_client.list_sessions() + assert len(result) == test_case.expected_output + mock_backend.list_sessions.assert_called_once() + print(f"✓ {test_case.name} succeeded") + else: + with pytest.raises(Exception) as exc_info: + spark_client.get_session(test_case.config["session_name"]) + assert test_case.expected_error in str(exc_info.value) + print(f"✓ {test_case.name} failed as expected") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="connect with name option", + expected_status=SUCCESS, + config={"options": [Name("custom-session")]}, + expected_output="options_passed", + ), + TestCase( + name="connect without options auto-generates", + expected_status=SUCCESS, + config={}, + expected_output="no_options", + ), + ], +) +def test_spark_client_connect_with_options(test_case: TestCase, spark_client, mock_backend): + """Test SparkClient connect method with Name option scenarios.""" + print(f"Running test: {test_case.name}") + + mock_session = Mock() + mock_backend.create_and_connect.return_value = mock_session + + if "options" in test_case.config: + options = test_case.config["options"] spark_client.connect(options=options) mock_backend.create_and_connect.assert_called_once() call_args = mock_backend.create_and_connect.call_args assert call_args.kwargs["options"] == options - - def test_connect_without_options_auto_generates(self, spark_client, mock_backend): - """C19: Connect without options auto-generates name via backend.""" - mock_session = Mock() - mock_backend.create_and_connect.return_value = mock_session + else: spark_client.connect() mock_backend.create_and_connect.assert_called_once() call_args = mock_backend.create_and_connect.call_args assert call_args.kwargs["options"] is None + + print(f"✓ {test_case.name} succeeded") From 78cff4517a022c0372596f60a95ac1e1c29a5d3c Mon Sep 17 00:00:00 2001 From: Jon Burdo Date: Thu, 12 Feb 2026 11:41:15 -0500 Subject: [PATCH 2/4] chore: bump minimum model-registry version to 0.3.6 (#289) Signed-off-by: Jon Burdo Signed-off-by: digvijay-y --- pyproject.toml | 2 +- uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5559a2246..8f0aadcdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ spark = [ "pyspark[connect]==3.4.1", ] hub = [ - "model-registry>=0.3.0", + "model-registry>=0.3.6", ] [dependency-groups] diff --git a/uv.lock b/uv.lock index f6f258ac4..632af8922 100644 --- a/uv.lock +++ b/uv.lock @@ -934,7 +934,7 @@ requires-dist = [ { name = "kubeflow-katib-api", specifier = ">=0.19.0" }, { name = "kubeflow-trainer-api", specifier = ">=2.0.0" }, { name = "kubernetes", specifier = ">=27.2.0" }, - { name = "model-registry", marker = "extra == 'hub'", specifier = ">=0.3.0" }, + { name = "model-registry", marker = "extra == 'hub'", specifier = ">=0.3.6" }, { name = "podman", marker = "extra == 'podman'", specifier = ">=5.6.0" }, { name = "pydantic", specifier = ">=2.10.0" }, { name = "pyspark", extras = ["connect"], marker = "extra == 'spark'", specifier = "==3.4.1" }, @@ -1121,7 +1121,7 @@ wheels = [ [[package]] name = "model-registry" -version = "0.3.5" +version = "0.3.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -1131,9 +1131,9 @@ dependencies = [ { name = "python-dateutil" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8f/7c/957482253e6360238ae6a28a1f15a6390bf8210fcb68db57c47c8f48e7ae/model_registry-0.3.5.tar.gz", hash = "sha256:fb933aea7cb55693ee7b69d560a9e29ff1d924a2ac6940ceb24b8beeb68ea4de", size = 95789, upload-time = "2026-01-09T20:32:10.835Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/e8/790640c40b9c4a6ba792df97596d6989e4e2e52e3fba1c22f36369b2a73c/model_registry-0.3.6.tar.gz", hash = "sha256:79d3bf95a65a0b01d84e6ff5a9ba0d0920a69042bc2ea45cb573cdc7c571fcff", size = 105137, upload-time = "2026-02-10T10:37:16.487Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/4d/856972d0b2d0cd466868fcd88fe9b1f0c559c4e0d967247066e22ee3f48d/model_registry-0.3.5-py3-none-any.whl", hash = "sha256:5cf09c9012f860a5308008511ec9bc0cc1643f9038d62acf4ccdc3ff4f8f6eba", size = 204406, upload-time = "2026-01-09T20:32:09.389Z" }, + { url = "https://files.pythonhosted.org/packages/3e/49/f025399a9cff506ee7821f98e2fe7dc94ab475002493a16a82ecd7594f09/model_registry-0.3.6-py3-none-any.whl", hash = "sha256:b7b5ad080cb982db4442107d818cd47d50cea4361552cdd030b247bc039150b6", size = 216727, upload-time = "2026-02-10T10:37:14.536Z" }, ] [[package]] From 5d8280e35108a484e9c91cce5411985990e0c14c Mon Sep 17 00:00:00 2001 From: digvijay-y Date: Fri, 13 Feb 2026 00:45:17 +0530 Subject: [PATCH 3/4] Changes made Signed-off-by: digvijay-y --- kubeflow/spark/api/spark_client_test.py | 119 +++++++++++++++--------- 1 file changed, 76 insertions(+), 43 deletions(-) diff --git a/kubeflow/spark/api/spark_client_test.py b/kubeflow/spark/api/spark_client_test.py index 765b1979f..00bb05fa3 100644 --- a/kubeflow/spark/api/spark_client_test.py +++ b/kubeflow/spark/api/spark_client_test.py @@ -96,13 +96,11 @@ def spark_client(mock_backend): name="default backend initialization", expected_status=SUCCESS, config={}, - expected_output="backend_created", ), TestCase( name="custom namespace initialization", expected_status=SUCCESS, config={"namespace": "spark"}, - expected_output="backend_created", ), TestCase( name="invalid backend config", @@ -114,24 +112,35 @@ def spark_client(mock_backend): ) def test_spark_client_initialization(test_case: TestCase): """Test SparkClient initialization scenarios.""" - print(f"Running test: {test_case.name}") - if test_case.expected_status == SUCCESS: + try: if "namespace" in test_case.config: with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock: SparkClient( backend_config=KubernetesBackendConfig(namespace=test_case.config["namespace"]) ) mock.assert_called_once() + elif "backend_config" in test_case.config: + SparkClient(backend_config=test_case.config["backend_config"]) else: with patch("kubeflow.spark.api.spark_client.KubernetesBackend"): client = SparkClient() assert client.backend is not None - print(f"✓ {test_case.name} succeeded") - else: - with pytest.raises(ValueError): - SparkClient(backend_config=test_case.config["backend_config"]) - print(f"✓ {test_case.name} failed as expected") + + # If we reach here but expected an exception, fail + assert test_case.expected_status == SUCCESS, ( + f"Expected exception but none was raised for {test_case.name}" + ) + except Exception as e: + # If we got an exception but expected success, fail + assert test_case.expected_status == EXCEPTION, ( + f"Unexpected exception in {test_case.name}: {e}" + ) + # Validate the exception type/message if specified + if test_case.expected_error: + assert test_case.expected_error in str(e), ( + f"Expected error '{test_case.expected_error}' but got '{str(e)}'" + ) @pytest.mark.parametrize( @@ -153,15 +162,13 @@ def test_spark_client_initialization(test_case: TestCase): name="connect create session verification", expected_status=SUCCESS, config={}, - expected_output="backend_not_called", ), ], ) def test_spark_client_connect(test_case: TestCase, spark_client): """Test SparkClient connect method scenarios.""" - print(f"Running test: {test_case.name}") - if test_case.expected_status == SUCCESS: + try: if "url" in test_case.config: from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url @@ -171,13 +178,21 @@ def test_spark_client_connect(test_case: TestCase, spark_client): # Verify backend methods are not called initially spark_client.backend.create_session.assert_not_called() spark_client.backend.wait_for_session_ready.assert_not_called() - print(f"✓ {test_case.name} succeeded") - else: - from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url - with pytest.raises(ValueError): - validate_spark_connect_url(test_case.config["url"]) - print(f"✓ {test_case.name} failed as expected") + # If we reach here but expected an exception, fail + assert test_case.expected_status == SUCCESS, ( + f"Expected exception but none was raised for {test_case.name}" + ) + except Exception as e: + # If we got an exception but expected success, fail + assert test_case.expected_status == EXCEPTION, ( + f"Unexpected exception in {test_case.name}: {e}" + ) + # Validate the exception type/message if specified + if test_case.expected_error: + assert test_case.expected_error in str(e), ( + f"Expected error '{test_case.expected_error}' but got '{str(e)}'" + ) @pytest.mark.parametrize( @@ -202,7 +217,6 @@ def test_spark_client_connect(test_case: TestCase, spark_client): name="delete session", expected_status=SUCCESS, config={"session_name": "test", "operation": "delete"}, - expected_output="deleted", ), TestCase( name="get session logs", @@ -214,9 +228,8 @@ def test_spark_client_connect(test_case: TestCase, spark_client): ) def test_spark_client_session_management(test_case: TestCase, spark_client, mock_backend): """Test SparkClient session management operations.""" - print(f"Running test: {test_case.name}") - if test_case.expected_status == SUCCESS: + try: if "operation" in test_case.config: if test_case.config["operation"] == "delete": spark_client.delete_session(test_case.config["session_name"]) @@ -237,12 +250,21 @@ def test_spark_client_session_management(test_case: TestCase, spark_client, mock result = spark_client.list_sessions() assert len(result) == test_case.expected_output mock_backend.list_sessions.assert_called_once() - print(f"✓ {test_case.name} succeeded") - else: - with pytest.raises(Exception) as exc_info: - spark_client.get_session(test_case.config["session_name"]) - assert test_case.expected_error in str(exc_info.value) - print(f"✓ {test_case.name} failed as expected") + + # If we reach here but expected an exception, fail + assert test_case.expected_status == SUCCESS, ( + f"Expected exception but none was raised for {test_case.name}" + ) + except Exception as e: + # If we got an exception but expected success, fail + assert test_case.expected_status == EXCEPTION, ( + f"Unexpected exception in {test_case.name}: {e}" + ) + # Validate the exception type/message if specified + if test_case.expected_error: + assert test_case.expected_error in str(e), ( + f"Expected error '{test_case.expected_error}' but got '{str(e)}'" + ) @pytest.mark.parametrize( @@ -252,33 +274,44 @@ def test_spark_client_session_management(test_case: TestCase, spark_client, mock name="connect with name option", expected_status=SUCCESS, config={"options": [Name("custom-session")]}, - expected_output="options_passed", ), TestCase( name="connect without options auto-generates", expected_status=SUCCESS, config={}, - expected_output="no_options", ), ], ) def test_spark_client_connect_with_options(test_case: TestCase, spark_client, mock_backend): """Test SparkClient connect method with Name option scenarios.""" - print(f"Running test: {test_case.name}") mock_session = Mock() mock_backend.create_and_connect.return_value = mock_session - if "options" in test_case.config: - options = test_case.config["options"] - spark_client.connect(options=options) - mock_backend.create_and_connect.assert_called_once() - call_args = mock_backend.create_and_connect.call_args - assert call_args.kwargs["options"] == options - else: - spark_client.connect() - mock_backend.create_and_connect.assert_called_once() - call_args = mock_backend.create_and_connect.call_args - assert call_args.kwargs["options"] is None - - print(f"✓ {test_case.name} succeeded") + try: + if "options" in test_case.config: + options = test_case.config["options"] + spark_client.connect(options=options) + mock_backend.create_and_connect.assert_called_once() + call_args = mock_backend.create_and_connect.call_args + assert call_args.kwargs["options"] == options + else: + spark_client.connect() + mock_backend.create_and_connect.assert_called_once() + call_args = mock_backend.create_and_connect.call_args + assert call_args.kwargs["options"] is None + + # If we reach here but expected an exception, fail + assert test_case.expected_status == SUCCESS, ( + f"Expected exception but none was raised for {test_case.name}" + ) + except Exception as e: + # If we got an exception but expected success, fail + assert test_case.expected_status == EXCEPTION, ( + f"Unexpected exception in {test_case.name}: {e}" + ) + # Validate the exception type/message if specified + if test_case.expected_error: + assert test_case.expected_error in str(e), ( + f"Expected error '{test_case.expected_error}' but got '{str(e)}'" + ) From efdb7091d2817c45a0a283b064cacdd1dd1aeaf8 Mon Sep 17 00:00:00 2001 From: digvijay-y Date: Fri, 13 Feb 2026 12:42:01 +0530 Subject: [PATCH 4/4] Version Signed-off-by: digvijay-y --- kubeflow/spark/api/spark_client_test.py | 86 +++++++++++-------------- pyproject.toml | 2 +- uv.lock | 2 +- 3 files changed, 38 insertions(+), 52 deletions(-) diff --git a/kubeflow/spark/api/spark_client_test.py b/kubeflow/spark/api/spark_client_test.py index 00bb05fa3..d18c04595 100644 --- a/kubeflow/spark/api/spark_client_test.py +++ b/kubeflow/spark/api/spark_client_test.py @@ -14,35 +14,17 @@ """Unit tests for SparkClient API.""" -from dataclasses import dataclass -from typing import Any, Optional from unittest.mock import Mock, patch import pytest from kubeflow.common.types import KubernetesBackendConfig from kubeflow.spark.api.spark_client import SparkClient +from kubeflow.spark.test.common import FAILED, SUCCESS, TestCase from kubeflow.spark.types.options import Name from kubeflow.spark.types.types import SparkConnectInfo, SparkConnectState -@dataclass -class TestCase: - """Test case structure for parametrized SparkClient tests.""" - - name: str - expected_status: str - config: dict[str, Any] - expected_output: Optional[Any] = None - expected_error: Optional[str] = None - # Prevent pytest from collecting this dataclass as a test - __test__ = False - - -SUCCESS = "SUCCESS" -EXCEPTION = "EXCEPTION" - - @pytest.fixture def mock_backend(): """Create mock backend for SparkClient tests.""" @@ -104,9 +86,9 @@ def spark_client(mock_backend): ), TestCase( name="invalid backend config", - expected_status=EXCEPTION, + expected_status=FAILED, config={"backend_config": "invalid"}, - expected_error="ValueError", + expected_error=ValueError, ), ], ) @@ -133,13 +115,11 @@ def test_spark_client_initialization(test_case: TestCase): ) except Exception as e: # If we got an exception but expected success, fail - assert test_case.expected_status == EXCEPTION, ( - f"Unexpected exception in {test_case.name}: {e}" - ) - # Validate the exception type/message if specified + assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" + # Validate the exception type if specified if test_case.expected_error: - assert test_case.expected_error in str(e), ( - f"Expected error '{test_case.expected_error}' but got '{str(e)}'" + assert isinstance(e, test_case.expected_error), ( + f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" ) @@ -154,14 +134,14 @@ def test_spark_client_initialization(test_case: TestCase): ), TestCase( name="connect with invalid URL validation", - expected_status=EXCEPTION, + expected_status=FAILED, config={"url": "http://localhost:15002"}, - expected_error="ValueError", + expected_error=ValueError, ), TestCase( name="connect create session verification", expected_status=SUCCESS, - config={}, + config={"test_connect": True}, ), ], ) @@ -174,6 +154,18 @@ def test_spark_client_connect(test_case: TestCase, spark_client): result = validate_spark_connect_url(test_case.config["url"]) assert result == test_case.expected_output + elif "test_connect" in test_case.config: + # Actually test the connect method for session creation + with patch("kubeflow.spark.api.spark_client.SparkSession") as mock_spark_session: + mock_session = Mock() + mock_spark_session.builder.remote.return_value.config.return_value.getOrCreate.return_value = mock_session + mock_spark_session.builder.getOrCreate.return_value = mock_session + + # Call connect without base_url to trigger create mode + result = spark_client.connect() + + # Verify the session was created + assert result is not None else: # Verify backend methods are not called initially spark_client.backend.create_session.assert_not_called() @@ -185,13 +177,11 @@ def test_spark_client_connect(test_case: TestCase, spark_client): ) except Exception as e: # If we got an exception but expected success, fail - assert test_case.expected_status == EXCEPTION, ( - f"Unexpected exception in {test_case.name}: {e}" - ) - # Validate the exception type/message if specified + assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" + # Validate the exception type if specified if test_case.expected_error: - assert test_case.expected_error in str(e), ( - f"Expected error '{test_case.expected_error}' but got '{str(e)}'" + assert isinstance(e, test_case.expected_error), ( + f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" ) @@ -209,9 +199,9 @@ def test_spark_client_connect(test_case: TestCase, spark_client): ), TestCase( name="get non-existent session", - expected_status=EXCEPTION, + expected_status=FAILED, config={"session_name": "nonexistent"}, - expected_error="Session not found", + expected_error=ValueError, ), TestCase( name="delete session", @@ -257,13 +247,11 @@ def test_spark_client_session_management(test_case: TestCase, spark_client, mock ) except Exception as e: # If we got an exception but expected success, fail - assert test_case.expected_status == EXCEPTION, ( - f"Unexpected exception in {test_case.name}: {e}" - ) - # Validate the exception type/message if specified + assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" + # Validate the exception type if specified if test_case.expected_error: - assert test_case.expected_error in str(e), ( - f"Expected error '{test_case.expected_error}' but got '{str(e)}'" + assert isinstance(e, test_case.expected_error), ( + f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" ) @@ -307,11 +295,9 @@ def test_spark_client_connect_with_options(test_case: TestCase, spark_client, mo ) except Exception as e: # If we got an exception but expected success, fail - assert test_case.expected_status == EXCEPTION, ( - f"Unexpected exception in {test_case.name}: {e}" - ) - # Validate the exception type/message if specified + assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" + # Validate the exception type if specified if test_case.expected_error: - assert test_case.expected_error in str(e), ( - f"Expected error '{test_case.expected_error}' but got '{str(e)}'" + assert isinstance(e, test_case.expected_error), ( + f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" ) diff --git a/pyproject.toml b/pyproject.toml index 8f0aadcdd..5559a2246 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ spark = [ "pyspark[connect]==3.4.1", ] hub = [ - "model-registry>=0.3.6", + "model-registry>=0.3.0", ] [dependency-groups] diff --git a/uv.lock b/uv.lock index 632af8922..8a27e617e 100644 --- a/uv.lock +++ b/uv.lock @@ -934,7 +934,7 @@ requires-dist = [ { name = "kubeflow-katib-api", specifier = ">=0.19.0" }, { name = "kubeflow-trainer-api", specifier = ">=2.0.0" }, { name = "kubernetes", specifier = ">=27.2.0" }, - { name = "model-registry", marker = "extra == 'hub'", specifier = ">=0.3.6" }, + { name = "model-registry", marker = "extra == 'hub'", specifier = ">=0.3.0" }, { name = "podman", marker = "extra == 'podman'", specifier = ">=5.6.0" }, { name = "pydantic", specifier = ">=2.10.0" }, { name = "pyspark", extras = ["connect"], marker = "extra == 'spark'", specifier = "==3.4.1" },