diff --git a/kubeflow/spark/api/spark_client_test.py b/kubeflow/spark/api/spark_client_test.py index 21ff32919..d18c04595 100644 --- a/kubeflow/spark/api/spark_client_test.py +++ b/kubeflow/spark/api/spark_client_test.py @@ -20,6 +20,7 @@ from kubeflow.common.types import KubernetesBackendConfig from kubeflow.spark.api.spark_client import SparkClient +from kubeflow.spark.test.common import FAILED, SUCCESS, TestCase from kubeflow.spark.types.options import Name from kubeflow.spark.types.types import SparkConnectInfo, SparkConnectState @@ -37,9 +38,16 @@ def mock_backend(): backend.list_sessions.return_value = [ SparkConnectInfo(name="s1", namespace="default", state=SparkConnectState.READY), ] - backend.get_session.return_value = SparkConnectInfo( - name="test", namespace="default", state=SparkConnectState.READY - ) + + # Configure mock to handle both existing and non-existent sessions + def mock_get_session(session_name): + if session_name == "nonexistent": + raise ValueError("Session not found") + return SparkConnectInfo( + name=session_name, namespace="default", state=SparkConnectState.READY + ) + + backend.get_session.side_effect = mock_get_session backend.create_session.return_value = SparkConnectInfo( name="new-session", namespace="default", state=SparkConnectState.PROVISIONING ) @@ -47,6 +55,7 @@ def mock_backend(): backend._create_session.return_value = ready_info backend._wait_for_session_ready.return_value = ready_info backend.get_connect_url.return_value = ("sc://localhost:15002", None) + backend.get_session_logs.return_value = iter(["log1", "log2"]) return backend @@ -62,110 +71,233 @@ def spark_client(mock_backend): yield client -class TestSparkClientInit: - """Tests for SparkClient initialization.""" - - def test_default_backend(self): - """C01: Init with default creates KubernetesBackendConfig.""" - with patch("kubeflow.spark.api.spark_client.KubernetesBackend"): - client = SparkClient() - assert client.backend is not None - - def test_custom_namespace(self): - """C02: Init with custom namespace.""" - with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock: - SparkClient(backend_config=KubernetesBackendConfig(namespace="spark")) - mock.assert_called_once() - - def test_invalid_backend(self): - """C03: Init with invalid backend raises ValueError.""" - with pytest.raises(ValueError): - SparkClient(backend_config="invalid") - - -class TestSparkClientConnect: - """Tests for connect method.""" - - def test_connect_with_url(self, spark_client): - """C04: Connect with URL returns SparkSession.""" - mock_session = Mock() - mock_builder = Mock() - mock_builder.remote.return_value = mock_builder - mock_builder.getOrCreate.return_value = mock_session - - mock_spark = Mock() - mock_spark.builder = mock_builder - - with ( - patch.dict("sys.modules", {"pyspark": Mock(), "pyspark.sql": mock_spark}), - patch("kubeflow.spark.api.spark_client.SparkSession", mock_spark), - ): - pass - - # Test URL validation works - from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url - - assert validate_spark_connect_url("sc://localhost:15002") is True - - def test_connect_with_url_invalid(self, spark_client): - """C04b: Connect with invalid URL raises ValueError.""" - from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url - - with pytest.raises(ValueError): - validate_spark_connect_url("http://localhost:15002") - - def test_connect_create_session(self, spark_client, mock_backend): - """C06: Connect without URL creates new session - verifies backend calls.""" - # Since pyspark is not installed, we verify the backend is called correctly - mock_backend.create_session.assert_not_called() - mock_backend.wait_for_session_ready.assert_not_called() - - -class TestSparkClientSessionManagement: - """Tests for session management methods.""" - - def test_list_sessions(self, spark_client, mock_backend): - """C14: list_sessions delegates to backend.""" - result = spark_client.list_sessions() - mock_backend.list_sessions.assert_called_once() - assert len(result) == 1 - - def test_get_session(self, spark_client, mock_backend): - """C15: get_session delegates to backend.""" - result = spark_client.get_session("test") - mock_backend.get_session.assert_called_once_with("test") - assert result.name == "test" - - def test_delete_session(self, spark_client, mock_backend): - """C16: delete_session delegates to backend.""" - spark_client.delete_session("test") - mock_backend.delete_session.assert_called_once_with("test") - - def test_get_session_logs(self, spark_client, mock_backend): - """C17: get_session_logs delegates to backend.""" - mock_backend.get_session_logs.return_value = iter(["log1", "log2"]) - list(spark_client.get_session_logs("test")) - mock_backend.get_session_logs.assert_called_once_with("test", follow=False) - - -class TestSparkClientConnectWithNameOption: - """Tests for connect method with Name option.""" - - def test_connect_with_name_option(self, spark_client, mock_backend): - """C18: Connect passes options to backend including Name option.""" - mock_session = Mock() - mock_backend.create_and_connect.return_value = mock_session - options = [Name("custom-session")] - spark_client.connect(options=options) - mock_backend.create_and_connect.assert_called_once() - call_args = mock_backend.create_and_connect.call_args - assert call_args.kwargs["options"] == options - - def test_connect_without_options_auto_generates(self, spark_client, mock_backend): - """C19: Connect without options auto-generates name via backend.""" - mock_session = Mock() - mock_backend.create_and_connect.return_value = mock_session - spark_client.connect() - mock_backend.create_and_connect.assert_called_once() - call_args = mock_backend.create_and_connect.call_args - assert call_args.kwargs["options"] is None +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="default backend initialization", + expected_status=SUCCESS, + config={}, + ), + TestCase( + name="custom namespace initialization", + expected_status=SUCCESS, + config={"namespace": "spark"}, + ), + TestCase( + name="invalid backend config", + expected_status=FAILED, + config={"backend_config": "invalid"}, + expected_error=ValueError, + ), + ], +) +def test_spark_client_initialization(test_case: TestCase): + """Test SparkClient initialization scenarios.""" + + try: + if "namespace" in test_case.config: + with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock: + SparkClient( + backend_config=KubernetesBackendConfig(namespace=test_case.config["namespace"]) + ) + mock.assert_called_once() + elif "backend_config" in test_case.config: + SparkClient(backend_config=test_case.config["backend_config"]) + else: + with patch("kubeflow.spark.api.spark_client.KubernetesBackend"): + client = SparkClient() + assert client.backend is not None + + # If we reach here but expected an exception, fail + assert test_case.expected_status == SUCCESS, ( + f"Expected exception but none was raised for {test_case.name}" + ) + except Exception as e: + # If we got an exception but expected success, fail + assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" + # Validate the exception type if specified + if test_case.expected_error: + assert isinstance(e, test_case.expected_error), ( + f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" + ) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="connect with valid URL validation", + expected_status=SUCCESS, + config={"url": "sc://localhost:15002"}, + expected_output=True, + ), + TestCase( + name="connect with invalid URL validation", + expected_status=FAILED, + config={"url": "http://localhost:15002"}, + expected_error=ValueError, + ), + TestCase( + name="connect create session verification", + expected_status=SUCCESS, + config={"test_connect": True}, + ), + ], +) +def test_spark_client_connect(test_case: TestCase, spark_client): + """Test SparkClient connect method scenarios.""" + + try: + if "url" in test_case.config: + from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url + + result = validate_spark_connect_url(test_case.config["url"]) + assert result == test_case.expected_output + elif "test_connect" in test_case.config: + # Actually test the connect method for session creation + with patch("kubeflow.spark.api.spark_client.SparkSession") as mock_spark_session: + mock_session = Mock() + mock_spark_session.builder.remote.return_value.config.return_value.getOrCreate.return_value = mock_session + mock_spark_session.builder.getOrCreate.return_value = mock_session + + # Call connect without base_url to trigger create mode + result = spark_client.connect() + + # Verify the session was created + assert result is not None + else: + # Verify backend methods are not called initially + spark_client.backend.create_session.assert_not_called() + spark_client.backend.wait_for_session_ready.assert_not_called() + + # If we reach here but expected an exception, fail + assert test_case.expected_status == SUCCESS, ( + f"Expected exception but none was raised for {test_case.name}" + ) + except Exception as e: + # If we got an exception but expected success, fail + assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" + # Validate the exception type if specified + if test_case.expected_error: + assert isinstance(e, test_case.expected_error), ( + f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" + ) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="list sessions successfully", expected_status=SUCCESS, config={}, expected_output=1 + ), + TestCase( + name="get existing session", + expected_status=SUCCESS, + config={"session_name": "test"}, + expected_output="test", + ), + TestCase( + name="get non-existent session", + expected_status=FAILED, + config={"session_name": "nonexistent"}, + expected_error=ValueError, + ), + TestCase( + name="delete session", + expected_status=SUCCESS, + config={"session_name": "test", "operation": "delete"}, + ), + TestCase( + name="get session logs", + expected_status=SUCCESS, + config={"session_name": "test", "operation": "logs"}, + expected_output=2, # Expected number of log entries + ), + ], +) +def test_spark_client_session_management(test_case: TestCase, spark_client, mock_backend): + """Test SparkClient session management operations.""" + + try: + if "operation" in test_case.config: + if test_case.config["operation"] == "delete": + spark_client.delete_session(test_case.config["session_name"]) + mock_backend.delete_session.assert_called_once_with( + test_case.config["session_name"] + ) + elif test_case.config["operation"] == "logs": + result = list(spark_client.get_session_logs(test_case.config["session_name"])) + assert len(result) == test_case.expected_output + mock_backend.get_session_logs.assert_called_once_with( + test_case.config["session_name"], follow=False + ) + elif "session_name" in test_case.config: + result = spark_client.get_session(test_case.config["session_name"]) + assert result.name == test_case.expected_output + mock_backend.get_session.assert_called_with(test_case.config["session_name"]) + else: + result = spark_client.list_sessions() + assert len(result) == test_case.expected_output + mock_backend.list_sessions.assert_called_once() + + # If we reach here but expected an exception, fail + assert test_case.expected_status == SUCCESS, ( + f"Expected exception but none was raised for {test_case.name}" + ) + except Exception as e: + # If we got an exception but expected success, fail + assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" + # Validate the exception type if specified + if test_case.expected_error: + assert isinstance(e, test_case.expected_error), ( + f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" + ) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="connect with name option", + expected_status=SUCCESS, + config={"options": [Name("custom-session")]}, + ), + TestCase( + name="connect without options auto-generates", + expected_status=SUCCESS, + config={}, + ), + ], +) +def test_spark_client_connect_with_options(test_case: TestCase, spark_client, mock_backend): + """Test SparkClient connect method with Name option scenarios.""" + + mock_session = Mock() + mock_backend.create_and_connect.return_value = mock_session + + try: + if "options" in test_case.config: + options = test_case.config["options"] + spark_client.connect(options=options) + mock_backend.create_and_connect.assert_called_once() + call_args = mock_backend.create_and_connect.call_args + assert call_args.kwargs["options"] == options + else: + spark_client.connect() + mock_backend.create_and_connect.assert_called_once() + call_args = mock_backend.create_and_connect.call_args + assert call_args.kwargs["options"] is None + + # If we reach here but expected an exception, fail + assert test_case.expected_status == SUCCESS, ( + f"Expected exception but none was raised for {test_case.name}" + ) + except Exception as e: + # If we got an exception but expected success, fail + assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" + # Validate the exception type if specified + if test_case.expected_error: + assert isinstance(e, test_case.expected_error), ( + f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" + )