Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 242 additions & 110 deletions kubeflow/spark/api/spark_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from kubeflow.common.types import KubernetesBackendConfig
from kubeflow.spark.api.spark_client import SparkClient
from kubeflow.spark.test.common import FAILED, SUCCESS, TestCase
from kubeflow.spark.types.options import Name
from kubeflow.spark.types.types import SparkConnectInfo, SparkConnectState

Expand All @@ -37,16 +38,24 @@ def mock_backend():
backend.list_sessions.return_value = [
SparkConnectInfo(name="s1", namespace="default", state=SparkConnectState.READY),
]
backend.get_session.return_value = SparkConnectInfo(
name="test", namespace="default", state=SparkConnectState.READY
)

# Configure mock to handle both existing and non-existent sessions
def mock_get_session(session_name):
if session_name == "nonexistent":
raise ValueError("Session not found")
return SparkConnectInfo(
name=session_name, namespace="default", state=SparkConnectState.READY
)

backend.get_session.side_effect = mock_get_session
backend.create_session.return_value = SparkConnectInfo(
name="new-session", namespace="default", state=SparkConnectState.PROVISIONING
)
backend.wait_for_session_ready.return_value = ready_info
backend._create_session.return_value = ready_info
backend._wait_for_session_ready.return_value = ready_info
backend.get_connect_url.return_value = ("sc://localhost:15002", None)
backend.get_session_logs.return_value = iter(["log1", "log2"])
return backend


Expand All @@ -62,110 +71,233 @@ def spark_client(mock_backend):
yield client


class TestSparkClientInit:
"""Tests for SparkClient initialization."""

def test_default_backend(self):
"""C01: Init with default creates KubernetesBackendConfig."""
with patch("kubeflow.spark.api.spark_client.KubernetesBackend"):
client = SparkClient()
assert client.backend is not None

def test_custom_namespace(self):
"""C02: Init with custom namespace."""
with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock:
SparkClient(backend_config=KubernetesBackendConfig(namespace="spark"))
mock.assert_called_once()

def test_invalid_backend(self):
"""C03: Init with invalid backend raises ValueError."""
with pytest.raises(ValueError):
SparkClient(backend_config="invalid")


class TestSparkClientConnect:
"""Tests for connect method."""

def test_connect_with_url(self, spark_client):
"""C04: Connect with URL returns SparkSession."""
mock_session = Mock()
mock_builder = Mock()
mock_builder.remote.return_value = mock_builder
mock_builder.getOrCreate.return_value = mock_session

mock_spark = Mock()
mock_spark.builder = mock_builder

with (
patch.dict("sys.modules", {"pyspark": Mock(), "pyspark.sql": mock_spark}),
patch("kubeflow.spark.api.spark_client.SparkSession", mock_spark),
):
pass

# Test URL validation works
from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url

assert validate_spark_connect_url("sc://localhost:15002") is True

def test_connect_with_url_invalid(self, spark_client):
"""C04b: Connect with invalid URL raises ValueError."""
from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url

with pytest.raises(ValueError):
validate_spark_connect_url("http://localhost:15002")

def test_connect_create_session(self, spark_client, mock_backend):
"""C06: Connect without URL creates new session - verifies backend calls."""
# Since pyspark is not installed, we verify the backend is called correctly
mock_backend.create_session.assert_not_called()
mock_backend.wait_for_session_ready.assert_not_called()


class TestSparkClientSessionManagement:
"""Tests for session management methods."""

def test_list_sessions(self, spark_client, mock_backend):
"""C14: list_sessions delegates to backend."""
result = spark_client.list_sessions()
mock_backend.list_sessions.assert_called_once()
assert len(result) == 1

def test_get_session(self, spark_client, mock_backend):
"""C15: get_session delegates to backend."""
result = spark_client.get_session("test")
mock_backend.get_session.assert_called_once_with("test")
assert result.name == "test"

def test_delete_session(self, spark_client, mock_backend):
"""C16: delete_session delegates to backend."""
spark_client.delete_session("test")
mock_backend.delete_session.assert_called_once_with("test")

def test_get_session_logs(self, spark_client, mock_backend):
"""C17: get_session_logs delegates to backend."""
mock_backend.get_session_logs.return_value = iter(["log1", "log2"])
list(spark_client.get_session_logs("test"))
mock_backend.get_session_logs.assert_called_once_with("test", follow=False)


class TestSparkClientConnectWithNameOption:
"""Tests for connect method with Name option."""

def test_connect_with_name_option(self, spark_client, mock_backend):
"""C18: Connect passes options to backend including Name option."""
mock_session = Mock()
mock_backend.create_and_connect.return_value = mock_session
options = [Name("custom-session")]
spark_client.connect(options=options)
mock_backend.create_and_connect.assert_called_once()
call_args = mock_backend.create_and_connect.call_args
assert call_args.kwargs["options"] == options

def test_connect_without_options_auto_generates(self, spark_client, mock_backend):
"""C19: Connect without options auto-generates name via backend."""
mock_session = Mock()
mock_backend.create_and_connect.return_value = mock_session
spark_client.connect()
mock_backend.create_and_connect.assert_called_once()
call_args = mock_backend.create_and_connect.call_args
assert call_args.kwargs["options"] is None
@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="default backend initialization",
expected_status=SUCCESS,
config={},
),
TestCase(
name="custom namespace initialization",
expected_status=SUCCESS,
config={"namespace": "spark"},
),
TestCase(
name="invalid backend config",
expected_status=FAILED,
config={"backend_config": "invalid"},
expected_error=ValueError,
),
],
)
def test_spark_client_initialization(test_case: TestCase):
"""Test SparkClient initialization scenarios."""

try:
if "namespace" in test_case.config:
with patch("kubeflow.spark.api.spark_client.KubernetesBackend") as mock:
SparkClient(
backend_config=KubernetesBackendConfig(namespace=test_case.config["namespace"])
)
mock.assert_called_once()
elif "backend_config" in test_case.config:
SparkClient(backend_config=test_case.config["backend_config"])
else:
with patch("kubeflow.spark.api.spark_client.KubernetesBackend"):
client = SparkClient()
assert client.backend is not None

# If we reach here but expected an exception, fail
assert test_case.expected_status == SUCCESS, (
f"Expected exception but none was raised for {test_case.name}"
)
except Exception as e:
# If we got an exception but expected success, fail
assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}"
# Validate the exception type if specified
if test_case.expected_error:
assert isinstance(e, test_case.expected_error), (
f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'"
)
Comment on lines 116 to 123
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expected_error is set to the string "ValueError", but the assertion checks "ValueError" in str(e); str(ValueError(...)) is just the message and typically won’t contain the class name, so this test will fail even when the correct exception is raised (prefer asserting the exception type, e.g. type(e) is ValueError, or using pytest.raises(ValueError)).

Copilot uses AI. Check for mistakes.


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="connect with valid URL validation",
expected_status=SUCCESS,
config={"url": "sc://localhost:15002"},
expected_output=True,
),
TestCase(
name="connect with invalid URL validation",
expected_status=FAILED,
config={"url": "http://localhost:15002"},
expected_error=ValueError,
),
TestCase(
name="connect create session verification",
expected_status=SUCCESS,
config={"test_connect": True},
),
],
)
def test_spark_client_connect(test_case: TestCase, spark_client):
"""Test SparkClient connect method scenarios."""

try:
if "url" in test_case.config:
from kubeflow.spark.backends.kubernetes.utils import validate_spark_connect_url

result = validate_spark_connect_url(test_case.config["url"])
assert result == test_case.expected_output
elif "test_connect" in test_case.config:
# Actually test the connect method for session creation
with patch("kubeflow.spark.api.spark_client.SparkSession") as mock_spark_session:
mock_session = Mock()
mock_spark_session.builder.remote.return_value.config.return_value.getOrCreate.return_value = mock_session
mock_spark_session.builder.getOrCreate.return_value = mock_session

# Call connect without base_url to trigger create mode
result = spark_client.connect()

# Verify the session was created
assert result is not None
else:
# Verify backend methods are not called initially
spark_client.backend.create_session.assert_not_called()
spark_client.backend.wait_for_session_ready.assert_not_called()

Comment on lines 141 to 173
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "connect create session verification" case doesn’t call SparkClient.connect() and only asserts some backend methods haven’t been called, so it doesn’t validate any behavior and can’t fail if connect() regresses; either remove this case or make it call connect() and assert the expected backend interaction (e.g., create_and_connect is invoked).

Copilot generated this review using guidance from repository custom instructions.
# If we reach here but expected an exception, fail
assert test_case.expected_status == SUCCESS, (
f"Expected exception but none was raised for {test_case.name}"
)
except Exception as e:
# If we got an exception but expected success, fail
assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}"
# Validate the exception type if specified
if test_case.expected_error:
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here: the invalid-URL case sets expected_error="ValueError" but validates it via substring match against str(e), which won’t include the exception type (use pytest.raises(ValueError, match=...) or compare type(e) to the expected exception).

Suggested change
if test_case.expected_error:
if test_case.expected_error:
# First, allow matching by exception type name (e.g., "ValueError")
if type(e).__name__ == test_case.expected_error:
return
# Otherwise, fall back to matching against the exception message

Copilot uses AI. Check for mistakes.
assert isinstance(e, test_case.expected_error), (
f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'"
)


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="list sessions successfully", expected_status=SUCCESS, config={}, expected_output=1
),
TestCase(
name="get existing session",
expected_status=SUCCESS,
config={"session_name": "test"},
expected_output="test",
),
TestCase(
name="get non-existent session",
expected_status=FAILED,
config={"session_name": "nonexistent"},
expected_error=ValueError,
),
TestCase(
name="delete session",
expected_status=SUCCESS,
config={"session_name": "test", "operation": "delete"},
),
TestCase(
name="get session logs",
expected_status=SUCCESS,
config={"session_name": "test", "operation": "logs"},
expected_output=2, # Expected number of log entries
),
],
)
def test_spark_client_session_management(test_case: TestCase, spark_client, mock_backend):
"""Test SparkClient session management operations."""

try:
if "operation" in test_case.config:
if test_case.config["operation"] == "delete":
spark_client.delete_session(test_case.config["session_name"])
mock_backend.delete_session.assert_called_once_with(
test_case.config["session_name"]
)
elif test_case.config["operation"] == "logs":
result = list(spark_client.get_session_logs(test_case.config["session_name"]))
assert len(result) == test_case.expected_output
mock_backend.get_session_logs.assert_called_once_with(
test_case.config["session_name"], follow=False
)
elif "session_name" in test_case.config:
result = spark_client.get_session(test_case.config["session_name"])
assert result.name == test_case.expected_output
mock_backend.get_session.assert_called_with(test_case.config["session_name"])
else:
result = spark_client.list_sessions()
assert len(result) == test_case.expected_output
mock_backend.list_sessions.assert_called_once()

# If we reach here but expected an exception, fail
assert test_case.expected_status == SUCCESS, (
f"Expected exception but none was raised for {test_case.name}"
)
except Exception as e:
# If we got an exception but expected success, fail
assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}"
# Validate the exception type if specified
if test_case.expected_error:
assert isinstance(e, test_case.expected_error), (
f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'"
)


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="connect with name option",
expected_status=SUCCESS,
config={"options": [Name("custom-session")]},
),
TestCase(
name="connect without options auto-generates",
expected_status=SUCCESS,
config={},
),
],
)
def test_spark_client_connect_with_options(test_case: TestCase, spark_client, mock_backend):
"""Test SparkClient connect method with Name option scenarios."""

mock_session = Mock()
mock_backend.create_and_connect.return_value = mock_session

try:
if "options" in test_case.config:
options = test_case.config["options"]
spark_client.connect(options=options)
mock_backend.create_and_connect.assert_called_once()
call_args = mock_backend.create_and_connect.call_args
assert call_args.kwargs["options"] == options
else:
spark_client.connect()
mock_backend.create_and_connect.assert_called_once()
call_args = mock_backend.create_and_connect.call_args
assert call_args.kwargs["options"] is None

# If we reach here but expected an exception, fail
assert test_case.expected_status == SUCCESS, (
f"Expected exception but none was raised for {test_case.name}"
)
except Exception as e:
# If we got an exception but expected success, fail
assert test_case.expected_status == FAILED, f"Unexpected exception in {test_case.name}: {e}"
# Validate the exception type if specified
if test_case.expected_error:
assert isinstance(e, test_case.expected_error), (
f"Expected exception type '{test_case.expected_error.__name__}' but got '{type(e).__name__}: {str(e)}'"
)
Loading