-
Notifications
You must be signed in to change notification settings - Fork 105
feat(spark): Refactor unit tests to sdk coding standards #293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
70ff95a
78cff45
5d8280e
efdb709
a0f3194
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |||||||||||||
|
|
||||||||||||||
| from kubeflow.common.types import KubernetesBackendConfig | ||||||||||||||
| from kubeflow.spark.api.spark_client import SparkClient | ||||||||||||||
| from kubeflow.spark.test.common import FAILED, SUCCESS, TestCase | ||||||||||||||
| from kubeflow.spark.types.options import Name | ||||||||||||||
| from kubeflow.spark.types.types import SparkConnectInfo, SparkConnectState | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -37,16 +38,24 @@ def mock_backend(): | |||||||||||||
| backend.list_sessions.return_value = [ | ||||||||||||||
| SparkConnectInfo(name="s1", namespace="default", state=SparkConnectState.READY), | ||||||||||||||
| ] | ||||||||||||||
| backend.get_session.return_value = SparkConnectInfo( | ||||||||||||||
| name="test", namespace="default", state=SparkConnectState.READY | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| # Configure mock to handle both existing and non-existent sessions | ||||||||||||||
| def mock_get_session(session_name): | ||||||||||||||
| if session_name == "nonexistent": | ||||||||||||||
| raise ValueError("Session not found") | ||||||||||||||
| return SparkConnectInfo( | ||||||||||||||
| name=session_name, namespace="default", state=SparkConnectState.READY | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| backend.get_session.side_effect = mock_get_session | ||||||||||||||
| backend.create_session.return_value = SparkConnectInfo( | ||||||||||||||
| name="new-session", namespace="default", state=SparkConnectState.PROVISIONING | ||||||||||||||
| ) | ||||||||||||||
| backend.wait_for_session_ready.return_value = ready_info | ||||||||||||||
| backend._create_session.return_value = ready_info | ||||||||||||||
| backend._wait_for_session_ready.return_value = ready_info | ||||||||||||||
| backend.get_connect_url.return_value = ("sc://localhost:15002", None) | ||||||||||||||
| backend.get_session_logs.return_value = iter(["log1", "log2"]) | ||||||||||||||
| return backend | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
|
|
@@ -62,110 +71,233 @@ def spark_client(mock_backend): | |||||||||||||
| yield client | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class TestSparkClientInit: | ||||||||||||||
| """Tests for SparkClient initialization.""" | ||||||||||||||
|
|
||||||||||||||
| def test_default_backend(self): | ||||||||||||||
| """C01: Init with default creates KubernetesBackendConfig.""" | ||||||||||||||
| with patch("kubeflow.spark.api.spark_client.KubernetesBackend"): | ||||||||||||||
| client = SparkClient() | ||||||||||||||
| assert client.backend is not None | ||||||||||||||
|
|
||||||||||||||
| def test_custom_namespace(self): | ||||||||||||||
| """C02: Init with custom namespace.""" | ||||||||||||||
| with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock: | ||||||||||||||
| SparkClient(backend_config=KubernetesBackendConfig(namespace="spark")) | ||||||||||||||
| mock.assert_called_once() | ||||||||||||||
|
|
||||||||||||||
| def test_invalid_backend(self): | ||||||||||||||
| """C03: Init with invalid backend raises ValueError.""" | ||||||||||||||
| with pytest.raises(ValueError): | ||||||||||||||
| SparkClient(backend_config="invalid") | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class TestSparkClientConnect: | ||||||||||||||
| """Tests for connect method.""" | ||||||||||||||
|
|
||||||||||||||
| def test_connect_with_url(self, spark_client): | ||||||||||||||
| """C04: Connect with URL returns SparkSession.""" | ||||||||||||||
| mock_session = Mock() | ||||||||||||||
| mock_builder = Mock() | ||||||||||||||
| mock_builder.remote.return_value = mock_builder | ||||||||||||||
| mock_builder.getOrCreate.return_value = mock_session | ||||||||||||||
|
|
||||||||||||||
| mock_spark = Mock() | ||||||||||||||
| mock_spark.builder = mock_builder | ||||||||||||||
|
|
||||||||||||||
| with ( | ||||||||||||||
| patch.dict("sys.modules", {"pyspark": Mock(), "pyspark.sql": mock_spark}), | ||||||||||||||
| patch("kubeflow.spark.api.spark_client.SparkSession", mock_spark), | ||||||||||||||
| ): | ||||||||||||||
| pass | ||||||||||||||
|
|
||||||||||||||
| # Test URL validation works | ||||||||||||||
| from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url | ||||||||||||||
|
|
||||||||||||||
| assert validate_spark_connect_url("sc://localhost:15002") is True | ||||||||||||||
|
|
||||||||||||||
| def test_connect_with_url_invalid(self, spark_client): | ||||||||||||||
| """C04b: Connect with invalid URL raises ValueError.""" | ||||||||||||||
| from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url | ||||||||||||||
|
|
||||||||||||||
| with pytest.raises(ValueError): | ||||||||||||||
| validate_spark_connect_url("http://localhost:15002") | ||||||||||||||
|
|
||||||||||||||
| def test_connect_create_session(self, spark_client, mock_backend): | ||||||||||||||
| """C06: Connect without URL creates new session - verifies backend calls.""" | ||||||||||||||
| # Since pyspark is not installed, we verify the backend is called correctly | ||||||||||||||
| mock_backend.create_session.assert_not_called() | ||||||||||||||
| mock_backend.wait_for_session_ready.assert_not_called() | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class TestSparkClientSessionManagement: | ||||||||||||||
| """Tests for session management methods.""" | ||||||||||||||
|
|
||||||||||||||
| def test_list_sessions(self, spark_client, mock_backend): | ||||||||||||||
| """C14: list_sessions delegates to backend.""" | ||||||||||||||
| result = spark_client.list_sessions() | ||||||||||||||
| mock_backend.list_sessions.assert_called_once() | ||||||||||||||
| assert len(result) == 1 | ||||||||||||||
|
|
||||||||||||||
| def test_get_session(self, spark_client, mock_backend): | ||||||||||||||
| """C15: get_session delegates to backend.""" | ||||||||||||||
| result = spark_client.get_session("test") | ||||||||||||||
| mock_backend.get_session.assert_called_once_with("test") | ||||||||||||||
| assert result.name == "test" | ||||||||||||||
|
|
||||||||||||||
| def test_delete_session(self, spark_client, mock_backend): | ||||||||||||||
| """C16: delete_session delegates to backend.""" | ||||||||||||||
| spark_client.delete_session("test") | ||||||||||||||
| mock_backend.delete_session.assert_called_once_with("test") | ||||||||||||||
|
|
||||||||||||||
| def test_get_session_logs(self, spark_client, mock_backend): | ||||||||||||||
| """C17: get_session_logs delegates to backend.""" | ||||||||||||||
| mock_backend.get_session_logs.return_value = iter(["log1", "log2"]) | ||||||||||||||
| list(spark_client.get_session_logs("test")) | ||||||||||||||
| mock_backend.get_session_logs.assert_called_once_with("test", follow=False) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class TestSparkClientConnectWithNameOption: | ||||||||||||||
| """Tests for connect method with Name option.""" | ||||||||||||||
|
|
||||||||||||||
| def test_connect_with_name_option(self, spark_client, mock_backend): | ||||||||||||||
| """C18: Connect passes options to backend including Name option.""" | ||||||||||||||
| mock_session = Mock() | ||||||||||||||
| mock_backend.create_and_connect.return_value = mock_session | ||||||||||||||
| options = [Name("custom-session")] | ||||||||||||||
| spark_client.connect(options=options) | ||||||||||||||
| mock_backend.create_and_connect.assert_called_once() | ||||||||||||||
| call_args = mock_backend.create_and_connect.call_args | ||||||||||||||
| assert call_args.kwargs["options"] == options | ||||||||||||||
|
|
||||||||||||||
| def test_connect_without_options_auto_generates(self, spark_client, mock_backend): | ||||||||||||||
| """C19: Connect without options auto-generates name via backend.""" | ||||||||||||||
| mock_session = Mock() | ||||||||||||||
| mock_backend.create_and_connect.return_value = mock_session | ||||||||||||||
| spark_client.connect() | ||||||||||||||
| mock_backend.create_and_connect.assert_called_once() | ||||||||||||||
| call_args = mock_backend.create_and_connect.call_args | ||||||||||||||
| assert call_args.kwargs["options"] is None | ||||||||||||||
| @pytest.mark.parametrize( | ||||||||||||||
| "test_case", | ||||||||||||||
| [ | ||||||||||||||
| TestCase( | ||||||||||||||
| name="default backend initialization", | ||||||||||||||
| expected_status=SUCCESS, | ||||||||||||||
| config={}, | ||||||||||||||
| ), | ||||||||||||||
| TestCase( | ||||||||||||||
| name="custom namespace initialization", | ||||||||||||||
| expected_status=SUCCESS, | ||||||||||||||
| config={"namespace": "spark"}, | ||||||||||||||
| ), | ||||||||||||||
| TestCase( | ||||||||||||||
| name="invalid backend config", | ||||||||||||||
| expected_status=FAILED, | ||||||||||||||
| config={"backend_config": "invalid"}, | ||||||||||||||
| expected_error=ValueError, | ||||||||||||||
| ), | ||||||||||||||
| ], | ||||||||||||||
| ) | ||||||||||||||
| def test_spark_client_initialization(test_case: TestCase): | ||||||||||||||
| """Test SparkClient initialization scenarios.""" | ||||||||||||||
|
|
||||||||||||||
| try: | ||||||||||||||
| if "namespace" in test_case.config: | ||||||||||||||
| with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock: | ||||||||||||||
| SparkClient( | ||||||||||||||
| backend_config=KubernetesBackendConfig(namespace=test_case.config["namespace"]) | ||||||||||||||
| ) | ||||||||||||||
| mock.assert_called_once() | ||||||||||||||
| elif "backend_config" in test_case.config: | ||||||||||||||
| SparkClient(backend_config=test_case.config["backend_config"]) | ||||||||||||||
| else: | ||||||||||||||
| with patch("kubeflow.spark.api.spark_client.KubernetesBackend"): | ||||||||||||||
| client = SparkClient() | ||||||||||||||
| assert client.backend is not None | ||||||||||||||
|
|
||||||||||||||
| # If we reach here but expected an exception, fail | ||||||||||||||
| assert test_case.expected_status == SUCCESS, ( | ||||||||||||||
| f"Expected exception but none was raised for {test_case.name}" | ||||||||||||||
| ) | ||||||||||||||
| except Exception as e: | ||||||||||||||
| # If we got an exception but expected success, fail | ||||||||||||||
| assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" | ||||||||||||||
| # Validate the exception type if specified | ||||||||||||||
| if test_case.expected_error: | ||||||||||||||
| assert isinstance(e, test_case.expected_error), ( | ||||||||||||||
| f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'" | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @pytest.mark.parametrize( | ||||||||||||||
| "test_case", | ||||||||||||||
| [ | ||||||||||||||
| TestCase( | ||||||||||||||
| name="connect with valid URL validation", | ||||||||||||||
| expected_status=SUCCESS, | ||||||||||||||
| config={"url": "sc://localhost:15002"}, | ||||||||||||||
| expected_output=True, | ||||||||||||||
| ), | ||||||||||||||
| TestCase( | ||||||||||||||
| name="connect with invalid URL validation", | ||||||||||||||
| expected_status=FAILED, | ||||||||||||||
| config={"url": "http://localhost:15002"}, | ||||||||||||||
| expected_error=ValueError, | ||||||||||||||
| ), | ||||||||||||||
| TestCase( | ||||||||||||||
| name="connect create session verification", | ||||||||||||||
| expected_status=SUCCESS, | ||||||||||||||
| config={"test_connect": True}, | ||||||||||||||
| ), | ||||||||||||||
| ], | ||||||||||||||
| ) | ||||||||||||||
| def test_spark_client_connect(test_case: TestCase, spark_client): | ||||||||||||||
| """Test SparkClient connect method scenarios.""" | ||||||||||||||
|
|
||||||||||||||
| try: | ||||||||||||||
| if "url" in test_case.config: | ||||||||||||||
| from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url | ||||||||||||||
|
|
||||||||||||||
| result = validate_spark_connect_url(test_case.config["url"]) | ||||||||||||||
| assert result == test_case.expected_output | ||||||||||||||
| elif "test_connect" in test_case.config: | ||||||||||||||
| # Actually test the connect method for session creation | ||||||||||||||
| with patch("kubeflow.spark.api.spark_client.SparkSession") as mock_spark_session: | ||||||||||||||
| mock_session = Mock() | ||||||||||||||
| mock_spark_session.builder.remote.return_value.config.return_value.getOrCreate.return_value = mock_session | ||||||||||||||
| mock_spark_session.builder.getOrCreate.return_value = mock_session | ||||||||||||||
|
|
||||||||||||||
| # Call connect without base_url to trigger create mode | ||||||||||||||
| result = spark_client.connect() | ||||||||||||||
|
|
||||||||||||||
| # Verify the session was created | ||||||||||||||
| assert result is not None | ||||||||||||||
| else: | ||||||||||||||
| # Verify backend methods are not called initially | ||||||||||||||
| spark_client.backend.create_session.assert_not_called() | ||||||||||||||
| spark_client.backend.wait_for_session_ready.assert_not_called() | ||||||||||||||
|
|
||||||||||||||
|
Comment on lines
141
to
173
|
||||||||||||||
| # If we reach here but expected an exception, fail | ||||||||||||||
| assert test_case.expected_status == SUCCESS, ( | ||||||||||||||
| f"Expected exception but none was raised for {test_case.name}" | ||||||||||||||
| ) | ||||||||||||||
| except Exception as e: | ||||||||||||||
| # If we got an exception but expected success, fail | ||||||||||||||
| assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}" | ||||||||||||||
| # Validate the exception type if specified | ||||||||||||||
| if test_case.expected_error: | ||||||||||||||
|
||||||||||||||
| if test_case.expected_error: | |
| if test_case.expected_error: | |
| # First, allow matching by exception type name (e.g., "ValueError") | |
| if type(e).__name__ == test_case.expected_error: | |
| return | |
| # Otherwise, fall back to matching against the exception message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expected_erroris set to the string "ValueError", but the assertion checks"ValueError" in str(e);str(ValueError(...))is just the message and typically won’t contain the class name, so this test will fail even when the correct exception is raised (prefer asserting the exception type, e.g.type(e) is ValueError, or usingpytest.raises(ValueError)).