diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index 4ad3fef3..e7562934 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "0.18.0"
+ ".": "0.19.0"
}
\ No newline at end of file
diff --git a/.stats.yml b/.stats.yml
index 9e4dbe62..125a84b7 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
-configured_endpoints: 65
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/kernel%2Fkernel-015c11efc34c81d4d82a937c017f5eb789ea3ca21a05b70e2ed31c069b839293.yml
-openapi_spec_hash: 3dcab2044da305f376cecf4eca38caee
-config_hash: 0fbdda3a736cc2748ca33371871e61b3
+configured_endpoints: 66
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/kernel%2Fkernel-86854c41729a6b26f71e26c906f665f69939f23e2d7adcc43380aee64cf6d056.yml
+openapi_spec_hash: 270a40c8af29e83cbda77d3700fd456a
+config_hash: 9421eb86b7f3f4b274f123279da3858e
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 81dc4886..d834d1b9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,27 @@
# Changelog
+## 0.19.0 (2025-11-12)
+
+Full Changelog: [v0.18.0...v0.19.0](https://github.com/onkernel/kernel-python-sdk/compare/v0.18.0...v0.19.0)
+
+### Features
+
+* feat hide cursor v2 ([5e694ae](https://github.com/onkernel/kernel-python-sdk/commit/5e694aedf67d142bcf2e4e4a018846236c45af69))
+* Remove price gating on computer endpoints ([cbbdaee](https://github.com/onkernel/kernel-python-sdk/commit/cbbdaee2225d593a7abbe04e3bd768476d06340c))
+
+
+### Bug Fixes
+
+* compat with Python 3.14 ([297d7b3](https://github.com/onkernel/kernel-python-sdk/commit/297d7b3069073d8ea98e67abf123505b1b20b140))
+* **compat:** update signatures of `model_dump` and `model_dump_json` for Pydantic v1 ([0e796f3](https://github.com/onkernel/kernel-python-sdk/commit/0e796f368e0eeb3260b3558e89fd26ec32b9567f))
+
+
+### Chores
+
+* **internal/tests:** avoid race condition with implicit client cleanup ([4d011fd](https://github.com/onkernel/kernel-python-sdk/commit/4d011fd41e10e5aaa49bb2447e27b71c8240e205))
+* **internal:** grammar fix (it's -> its) ([fcdf068](https://github.com/onkernel/kernel-python-sdk/commit/fcdf06877b370bf8c2abe047c5dcff7597dfadd9))
+* **package:** drop Python 3.8 support ([5dcea22](https://github.com/onkernel/kernel-python-sdk/commit/5dcea220253de47542381b91401965198c3d3dd2))
+
## 0.18.0 (2025-10-30)
Full Changelog: [v0.17.0...v0.18.0](https://github.com/onkernel/kernel-python-sdk/compare/v0.17.0...v0.18.0)
diff --git a/README.md b/README.md
index ae0066ec..699f3c26 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
[)](https://pypi.org/project/kernel/)
-The Kernel Python library provides convenient access to the Kernel REST API from any Python 3.8+
+The Kernel Python library provides convenient access to the Kernel REST API from any Python 3.9+
application. The library includes type definitions for all request params and response fields,
and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx).
@@ -487,7 +487,7 @@ print(kernel.__version__)
## Requirements
-Python 3.8 or higher.
+Python 3.9 or higher.
## Contributing
diff --git a/api.md b/api.md
index aa2ef0e8..8a015c3c 100644
--- a/api.md
+++ b/api.md
@@ -168,6 +168,12 @@ Methods:
## Computer
+Types:
+
+```python
+from kernel.types.browsers import ComputerSetCursorVisibilityResponse
+```
+
Methods:
- client.browsers.computer.capture_screenshot(id, \*\*params) -> BinaryAPIResponse
@@ -176,6 +182,7 @@ Methods:
- client.browsers.computer.move_mouse(id, \*\*params) -> None
- client.browsers.computer.press_key(id, \*\*params) -> None
- client.browsers.computer.scroll(id, \*\*params) -> None
+- client.browsers.computer.set_cursor_visibility(id, \*\*params) -> ComputerSetCursorVisibilityResponse
- client.browsers.computer.type_text(id, \*\*params) -> None
## Playwright
diff --git a/pyproject.toml b/pyproject.toml
index 737af0ac..518a22eb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "kernel"
-version = "0.18.0"
+version = "0.19.0"
description = "The official Python library for the kernel API"
dynamic = ["readme"]
license = "Apache-2.0"
@@ -15,11 +15,10 @@ dependencies = [
"distro>=1.7.0, <2",
"sniffio",
]
-requires-python = ">= 3.8"
+requires-python = ">= 3.9"
classifiers = [
"Typing :: Typed",
"Intended Audience :: Developers",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
@@ -141,7 +140,7 @@ filterwarnings = [
# there are a couple of flags that are still disabled by
# default in strict mode as they are experimental and niche.
typeCheckingMode = "strict"
-pythonVersion = "3.8"
+pythonVersion = "3.9"
exclude = [
"_dev",
diff --git a/src/kernel/_models.py b/src/kernel/_models.py
index 6a3cd1d2..ca9500b2 100644
--- a/src/kernel/_models.py
+++ b/src/kernel/_models.py
@@ -2,6 +2,7 @@
import os
import inspect
+import weakref
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
from datetime import date, datetime
from typing_extensions import (
@@ -256,15 +257,16 @@ def model_dump(
mode: Literal["json", "python"] | str = "python",
include: IncEx | None = None,
exclude: IncEx | None = None,
+ context: Any | None = None,
by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
+ exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
- context: dict[str, Any] | None = None,
- serialize_as_any: bool = False,
fallback: Callable[[Any], Any] | None = None,
+ serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
@@ -272,16 +274,24 @@ def model_dump(
Args:
mode: The mode in which `to_python` should run.
- If mode is 'json', the dictionary will only contain JSON serializable types.
- If mode is 'python', the dictionary may contain any Python objects.
- include: A list of fields to include in the output.
- exclude: A list of fields to exclude from the output.
+ If mode is 'json', the output will only contain JSON serializable types.
+ If mode is 'python', the output may contain non-JSON-serializable Python objects.
+ include: A set of fields to include in the output.
+ exclude: A set of fields to exclude from the output.
+ context: Additional context to pass to the serializer.
by_alias: Whether to use the field's alias in the dictionary key if defined.
- exclude_unset: Whether to exclude fields that are unset or None from the output.
- exclude_defaults: Whether to exclude fields that are set to their default value from the output.
- exclude_none: Whether to exclude fields that have a value of `None` from the output.
- round_trip: Whether to enable serialization and deserialization round-trip support.
- warnings: Whether to log warnings when invalid fields are encountered.
+ exclude_unset: Whether to exclude fields that have not been explicitly set.
+ exclude_defaults: Whether to exclude fields that are set to their default value.
+ exclude_none: Whether to exclude fields that have a value of `None`.
+ exclude_computed_fields: Whether to exclude computed fields.
+ While this can be useful for round-tripping, it is usually recommended to use the dedicated
+ `round_trip` parameter instead.
+ round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T].
+ warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
+ "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
+ fallback: A function to call when an unknown value is encountered. If not provided,
+ a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
+ serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
Returns:
A dictionary representation of the model.
@@ -298,6 +308,8 @@ def model_dump(
raise ValueError("serialize_as_any is only supported in Pydantic v2")
if fallback is not None:
raise ValueError("fallback is only supported in Pydantic v2")
+ if exclude_computed_fields != False:
+ raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
dumped = super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
@@ -314,15 +326,17 @@ def model_dump_json(
self,
*,
indent: int | None = None,
+ ensure_ascii: bool = False,
include: IncEx | None = None,
exclude: IncEx | None = None,
+ context: Any | None = None,
by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
+ exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
- context: dict[str, Any] | None = None,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> str:
@@ -354,6 +368,10 @@ def model_dump_json(
raise ValueError("serialize_as_any is only supported in Pydantic v2")
if fallback is not None:
raise ValueError("fallback is only supported in Pydantic v2")
+ if ensure_ascii != False:
+ raise ValueError("ensure_ascii is only supported in Pydantic v2")
+ if exclude_computed_fields != False:
+ raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
return super().json( # type: ignore[reportDeprecated]
indent=indent,
include=include,
@@ -573,6 +591,9 @@ class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails
+DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
+
+
class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.
@@ -615,8 +636,9 @@ def __init__(
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
- if isinstance(union, CachedDiscriminatorType):
- return union.__discriminator__
+ cached = DISCRIMINATOR_CACHE.get(union)
+ if cached is not None:
+ return cached
discriminator_field_name: str | None = None
@@ -669,7 +691,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
- cast(CachedDiscriminatorType, union).__discriminator__ = details
+ DISCRIMINATOR_CACHE.setdefault(union, details)
return details
diff --git a/src/kernel/_utils/_sync.py b/src/kernel/_utils/_sync.py
index ad7ec71b..f6027c18 100644
--- a/src/kernel/_utils/_sync.py
+++ b/src/kernel/_utils/_sync.py
@@ -1,10 +1,8 @@
from __future__ import annotations
-import sys
import asyncio
import functools
-import contextvars
-from typing import Any, TypeVar, Callable, Awaitable
+from typing import TypeVar, Callable, Awaitable
from typing_extensions import ParamSpec
import anyio
@@ -15,34 +13,11 @@
T_ParamSpec = ParamSpec("T_ParamSpec")
-if sys.version_info >= (3, 9):
- _asyncio_to_thread = asyncio.to_thread
-else:
- # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread
- # for Python 3.8 support
- async def _asyncio_to_thread(
- func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
- ) -> Any:
- """Asynchronously run function *func* in a separate thread.
-
- Any *args and **kwargs supplied for this function are directly passed
- to *func*. Also, the current :class:`contextvars.Context` is propagated,
- allowing context variables from the main thread to be accessed in the
- separate thread.
-
- Returns a coroutine that can be awaited to get the eventual result of *func*.
- """
- loop = asyncio.events.get_running_loop()
- ctx = contextvars.copy_context()
- func_call = functools.partial(ctx.run, func, *args, **kwargs)
- return await loop.run_in_executor(None, func_call)
-
-
async def to_thread(
func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
) -> T_Retval:
if sniffio.current_async_library() == "asyncio":
- return await _asyncio_to_thread(func, *args, **kwargs)
+ return await asyncio.to_thread(func, *args, **kwargs)
return await anyio.to_thread.run_sync(
functools.partial(func, *args, **kwargs),
@@ -53,10 +28,7 @@ async def to_thread(
def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
- positional and keyword arguments. For python version 3.9 and above, it uses
- asyncio.to_thread to run the function in a separate thread. For python version
- 3.8, it uses locally defined copy of the asyncio.to_thread function which was
- introduced in python 3.9.
+ positional and keyword arguments.
Usage:
diff --git a/src/kernel/_utils/_utils.py b/src/kernel/_utils/_utils.py
index 50d59269..eec7f4a1 100644
--- a/src/kernel/_utils/_utils.py
+++ b/src/kernel/_utils/_utils.py
@@ -133,7 +133,7 @@ def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
-# care about the contained types we can safely use `object` in it's place.
+# care about the contained types we can safely use `object` in its place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
diff --git a/src/kernel/_version.py b/src/kernel/_version.py
index abb53bdb..dea2694c 100644
--- a/src/kernel/_version.py
+++ b/src/kernel/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "kernel"
-__version__ = "0.18.0" # x-release-please-version
+__version__ = "0.19.0" # x-release-please-version
diff --git a/src/kernel/resources/apps.py b/src/kernel/resources/apps.py
index 34aa1299..b803299d 100644
--- a/src/kernel/resources/apps.py
+++ b/src/kernel/resources/apps.py
@@ -63,9 +63,9 @@ def list(
Args:
app_name: Filter results by application name.
- limit: Limit the number of app to return.
+ limit: Limit the number of apps to return.
- offset: Offset the number of app to return.
+ offset: Offset the number of apps to return.
version: Filter results by version label.
@@ -140,9 +140,9 @@ def list(
Args:
app_name: Filter results by application name.
- limit: Limit the number of app to return.
+ limit: Limit the number of apps to return.
- offset: Offset the number of app to return.
+ offset: Offset the number of apps to return.
version: Filter results by version label.
diff --git a/src/kernel/resources/browsers/computer.py b/src/kernel/resources/browsers/computer.py
index 68cee420..87d377fd 100644
--- a/src/kernel/resources/browsers/computer.py
+++ b/src/kernel/resources/browsers/computer.py
@@ -34,7 +34,9 @@
computer_move_mouse_params,
computer_click_mouse_params,
computer_capture_screenshot_params,
+ computer_set_cursor_visibility_params,
)
+from ...types.browsers.computer_set_cursor_visibility_response import ComputerSetCursorVisibilityResponse
__all__ = ["ComputerResource", "AsyncComputerResource"]
@@ -390,6 +392,45 @@ def scroll(
cast_to=NoneType,
)
+ def set_cursor_visibility(
+ self,
+ id: str,
+ *,
+ hidden: bool,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> ComputerSetCursorVisibilityResponse:
+ """
+ Set cursor visibility
+
+ Args:
+ hidden: Whether the cursor should be hidden or visible
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not id:
+ raise ValueError(f"Expected a non-empty value for `id` but received {id!r}")
+ return self._post(
+ f"/browsers/{id}/computer/cursor",
+ body=maybe_transform(
+ {"hidden": hidden}, computer_set_cursor_visibility_params.ComputerSetCursorVisibilityParams
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ComputerSetCursorVisibilityResponse,
+ )
+
def type_text(
self,
id: str,
@@ -789,6 +830,45 @@ async def scroll(
cast_to=NoneType,
)
+ async def set_cursor_visibility(
+ self,
+ id: str,
+ *,
+ hidden: bool,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
+ ) -> ComputerSetCursorVisibilityResponse:
+ """
+ Set cursor visibility
+
+ Args:
+ hidden: Whether the cursor should be hidden or visible
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not id:
+ raise ValueError(f"Expected a non-empty value for `id` but received {id!r}")
+ return await self._post(
+ f"/browsers/{id}/computer/cursor",
+ body=await async_maybe_transform(
+ {"hidden": hidden}, computer_set_cursor_visibility_params.ComputerSetCursorVisibilityParams
+ ),
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ComputerSetCursorVisibilityResponse,
+ )
+
async def type_text(
self,
id: str,
@@ -860,6 +940,9 @@ def __init__(self, computer: ComputerResource) -> None:
self.scroll = to_raw_response_wrapper(
computer.scroll,
)
+ self.set_cursor_visibility = to_raw_response_wrapper(
+ computer.set_cursor_visibility,
+ )
self.type_text = to_raw_response_wrapper(
computer.type_text,
)
@@ -888,6 +971,9 @@ def __init__(self, computer: AsyncComputerResource) -> None:
self.scroll = async_to_raw_response_wrapper(
computer.scroll,
)
+ self.set_cursor_visibility = async_to_raw_response_wrapper(
+ computer.set_cursor_visibility,
+ )
self.type_text = async_to_raw_response_wrapper(
computer.type_text,
)
@@ -916,6 +1002,9 @@ def __init__(self, computer: ComputerResource) -> None:
self.scroll = to_streamed_response_wrapper(
computer.scroll,
)
+ self.set_cursor_visibility = to_streamed_response_wrapper(
+ computer.set_cursor_visibility,
+ )
self.type_text = to_streamed_response_wrapper(
computer.type_text,
)
@@ -944,6 +1033,9 @@ def __init__(self, computer: AsyncComputerResource) -> None:
self.scroll = async_to_streamed_response_wrapper(
computer.scroll,
)
+ self.set_cursor_visibility = async_to_streamed_response_wrapper(
+ computer.set_cursor_visibility,
+ )
self.type_text = async_to_streamed_response_wrapper(
computer.type_text,
)
diff --git a/src/kernel/types/app_list_params.py b/src/kernel/types/app_list_params.py
index 1e2f5278..296ded55 100644
--- a/src/kernel/types/app_list_params.py
+++ b/src/kernel/types/app_list_params.py
@@ -12,10 +12,10 @@ class AppListParams(TypedDict, total=False):
"""Filter results by application name."""
limit: int
- """Limit the number of app to return."""
+ """Limit the number of apps to return."""
offset: int
- """Offset the number of app to return."""
+ """Offset the number of apps to return."""
version: str
"""Filter results by version label."""
diff --git a/src/kernel/types/browsers/__init__.py b/src/kernel/types/browsers/__init__.py
index f8a263c1..546fdc64 100644
--- a/src/kernel/types/browsers/__init__.py
+++ b/src/kernel/types/browsers/__init__.py
@@ -39,3 +39,9 @@
from .f_set_file_permissions_params import FSetFilePermissionsParams as FSetFilePermissionsParams
from .process_stdout_stream_response import ProcessStdoutStreamResponse as ProcessStdoutStreamResponse
from .computer_capture_screenshot_params import ComputerCaptureScreenshotParams as ComputerCaptureScreenshotParams
+from .computer_set_cursor_visibility_params import (
+ ComputerSetCursorVisibilityParams as ComputerSetCursorVisibilityParams,
+)
+from .computer_set_cursor_visibility_response import (
+ ComputerSetCursorVisibilityResponse as ComputerSetCursorVisibilityResponse,
+)
diff --git a/src/kernel/types/browsers/computer_set_cursor_visibility_params.py b/src/kernel/types/browsers/computer_set_cursor_visibility_params.py
new file mode 100644
index 00000000..f003ee91
--- /dev/null
+++ b/src/kernel/types/browsers/computer_set_cursor_visibility_params.py
@@ -0,0 +1,12 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Required, TypedDict
+
+__all__ = ["ComputerSetCursorVisibilityParams"]
+
+
+class ComputerSetCursorVisibilityParams(TypedDict, total=False):
+ hidden: Required[bool]
+ """Whether the cursor should be hidden or visible"""
diff --git a/src/kernel/types/browsers/computer_set_cursor_visibility_response.py b/src/kernel/types/browsers/computer_set_cursor_visibility_response.py
new file mode 100644
index 00000000..c82302ef
--- /dev/null
+++ b/src/kernel/types/browsers/computer_set_cursor_visibility_response.py
@@ -0,0 +1,10 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from ..._models import BaseModel
+
+__all__ = ["ComputerSetCursorVisibilityResponse"]
+
+
+class ComputerSetCursorVisibilityResponse(BaseModel):
+ ok: bool
+ """Indicates success."""
diff --git a/tests/api_resources/browsers/test_computer.py b/tests/api_resources/browsers/test_computer.py
index 9e245481..7634b89b 100644
--- a/tests/api_resources/browsers/test_computer.py
+++ b/tests/api_resources/browsers/test_computer.py
@@ -10,12 +10,16 @@
from respx import MockRouter
from kernel import Kernel, AsyncKernel
+from tests.utils import assert_matches_type
from kernel._response import (
BinaryAPIResponse,
AsyncBinaryAPIResponse,
StreamedBinaryAPIResponse,
AsyncStreamedBinaryAPIResponse,
)
+from kernel.types.browsers import (
+ ComputerSetCursorVisibilityResponse,
+)
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -396,6 +400,52 @@ def test_path_params_scroll(self, client: Kernel) -> None:
y=0,
)
+ @pytest.mark.skip(reason="Prism tests are disabled")
+ @parametrize
+ def test_method_set_cursor_visibility(self, client: Kernel) -> None:
+ computer = client.browsers.computer.set_cursor_visibility(
+ id="id",
+ hidden=True,
+ )
+ assert_matches_type(ComputerSetCursorVisibilityResponse, computer, path=["response"])
+
+ @pytest.mark.skip(reason="Prism tests are disabled")
+ @parametrize
+ def test_raw_response_set_cursor_visibility(self, client: Kernel) -> None:
+ response = client.browsers.computer.with_raw_response.set_cursor_visibility(
+ id="id",
+ hidden=True,
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ computer = response.parse()
+ assert_matches_type(ComputerSetCursorVisibilityResponse, computer, path=["response"])
+
+ @pytest.mark.skip(reason="Prism tests are disabled")
+ @parametrize
+ def test_streaming_response_set_cursor_visibility(self, client: Kernel) -> None:
+ with client.browsers.computer.with_streaming_response.set_cursor_visibility(
+ id="id",
+ hidden=True,
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ computer = response.parse()
+ assert_matches_type(ComputerSetCursorVisibilityResponse, computer, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip(reason="Prism tests are disabled")
+ @parametrize
+ def test_path_params_set_cursor_visibility(self, client: Kernel) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"):
+ client.browsers.computer.with_raw_response.set_cursor_visibility(
+ id="",
+ hidden=True,
+ )
+
@pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_type_text(self, client: Kernel) -> None:
@@ -835,6 +885,52 @@ async def test_path_params_scroll(self, async_client: AsyncKernel) -> None:
y=0,
)
+ @pytest.mark.skip(reason="Prism tests are disabled")
+ @parametrize
+ async def test_method_set_cursor_visibility(self, async_client: AsyncKernel) -> None:
+ computer = await async_client.browsers.computer.set_cursor_visibility(
+ id="id",
+ hidden=True,
+ )
+ assert_matches_type(ComputerSetCursorVisibilityResponse, computer, path=["response"])
+
+ @pytest.mark.skip(reason="Prism tests are disabled")
+ @parametrize
+ async def test_raw_response_set_cursor_visibility(self, async_client: AsyncKernel) -> None:
+ response = await async_client.browsers.computer.with_raw_response.set_cursor_visibility(
+ id="id",
+ hidden=True,
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ computer = await response.parse()
+ assert_matches_type(ComputerSetCursorVisibilityResponse, computer, path=["response"])
+
+ @pytest.mark.skip(reason="Prism tests are disabled")
+ @parametrize
+ async def test_streaming_response_set_cursor_visibility(self, async_client: AsyncKernel) -> None:
+ async with async_client.browsers.computer.with_streaming_response.set_cursor_visibility(
+ id="id",
+ hidden=True,
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ computer = await response.parse()
+ assert_matches_type(ComputerSetCursorVisibilityResponse, computer, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip(reason="Prism tests are disabled")
+ @parametrize
+ async def test_path_params_set_cursor_visibility(self, async_client: AsyncKernel) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"):
+ await async_client.browsers.computer.with_raw_response.set_cursor_visibility(
+ id="",
+ hidden=True,
+ )
+
@pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_type_text(self, async_client: AsyncKernel) -> None:
diff --git a/tests/test_client.py b/tests/test_client.py
index 74ffb05e..6425811f 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -59,51 +59,49 @@ def _get_open_connections(client: Kernel | AsyncKernel) -> int:
class TestKernel:
- client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
@pytest.mark.respx(base_url=base_url)
- def test_raw_response(self, respx_mock: MockRouter) -> None:
+ def test_raw_response(self, respx_mock: MockRouter, client: Kernel) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.post("/foo", cast_to=httpx.Response)
+ response = client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
@pytest.mark.respx(base_url=base_url)
- def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ def test_raw_response_for_binary(self, respx_mock: MockRouter, client: Kernel) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)
- response = self.client.post("/foo", cast_to=httpx.Response)
+ response = client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
- def test_copy(self) -> None:
- copied = self.client.copy()
- assert id(copied) != id(self.client)
+ def test_copy(self, client: Kernel) -> None:
+ copied = client.copy()
+ assert id(copied) != id(client)
- copied = self.client.copy(api_key="another My API Key")
+ copied = client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
- assert self.client.api_key == "My API Key"
+ assert client.api_key == "My API Key"
- def test_copy_default_options(self) -> None:
+ def test_copy_default_options(self, client: Kernel) -> None:
# options that have a default are overridden correctly
- copied = self.client.copy(max_retries=7)
+ copied = client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 2
+ assert client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7
# timeout
- assert isinstance(self.client.timeout, httpx.Timeout)
- copied = self.client.copy(timeout=None)
+ assert isinstance(client.timeout, httpx.Timeout)
+ copied = client.copy(timeout=None)
assert copied.timeout is None
- assert isinstance(self.client.timeout, httpx.Timeout)
+ assert isinstance(client.timeout, httpx.Timeout)
def test_copy_default_headers(self) -> None:
client = Kernel(
@@ -138,6 +136,7 @@ def test_copy_default_headers(self) -> None:
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+ client.close()
def test_copy_default_query(self) -> None:
client = Kernel(
@@ -175,13 +174,15 @@ def test_copy_default_query(self) -> None:
):
client.copy(set_default_query={}, default_query={"foo": "Bar"})
- def test_copy_signature(self) -> None:
+ client.close()
+
+ def test_copy_signature(self, client: Kernel) -> None:
# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
init_signature = inspect.signature(
# mypy doesn't like that we access the `__init__` property.
- self.client.__init__, # type: ignore[misc]
+ client.__init__, # type: ignore[misc]
)
- copy_signature = inspect.signature(self.client.copy)
+ copy_signature = inspect.signature(client.copy)
exclude_params = {"transport", "proxies", "_strict_response_validation"}
for name in init_signature.parameters.keys():
@@ -192,12 +193,12 @@ def test_copy_signature(self) -> None:
assert copy_param is not None, f"copy() signature is missing the {name} param"
@pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
- def test_copy_build_request(self) -> None:
+ def test_copy_build_request(self, client: Kernel) -> None:
options = FinalRequestOptions(method="get", url="/foo")
def build_request(options: FinalRequestOptions) -> None:
- client = self.client.copy()
- client._build_request(options)
+ client_copy = client.copy()
+ client_copy._build_request(options)
# ensure that the machinery is warmed up before tracing starts.
build_request(options)
@@ -254,14 +255,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()
- def test_request_timeout(self) -> None:
- request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ def test_request_timeout(self, client: Kernel) -> None:
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
- request = self.client._build_request(
- FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
- )
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(100.0)
@@ -272,6 +271,8 @@ def test_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(0)
+ client.close()
+
def test_http_client_timeout_option(self) -> None:
# custom timeout given to the httpx client should be used
with httpx.Client(timeout=None) as http_client:
@@ -283,6 +284,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(None)
+ client.close()
+
# no timeout given to the httpx client should not use the httpx default
with httpx.Client() as http_client:
client = Kernel(
@@ -293,6 +296,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
+ client.close()
+
# explicitly passing the default timeout currently results in it being ignored
with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
client = Kernel(
@@ -303,6 +308,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ client.close()
+
async def test_invalid_http_client(self) -> None:
with pytest.raises(TypeError, match="Invalid `http_client` arg"):
async with httpx.AsyncClient() as http_client:
@@ -314,14 +321,14 @@ async def test_invalid_http_client(self) -> None:
)
def test_default_headers_option(self) -> None:
- client = Kernel(
+ test_client = Kernel(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "bar"
assert request.headers.get("x-stainless-lang") == "python"
- client2 = Kernel(
+ test_client2 = Kernel(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
@@ -330,10 +337,13 @@ def test_default_headers_option(self) -> None:
"X-Stainless-Lang": "my-overriding-header",
},
)
- request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+ test_client.close()
+ test_client2.close()
+
def test_validate_headers(self) -> None:
client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
@@ -362,8 +372,10 @@ def test_default_query_option(self) -> None:
url = httpx.URL(request.url)
assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
- def test_request_extra_json(self) -> None:
- request = self.client._build_request(
+ client.close()
+
+ def test_request_extra_json(self, client: Kernel) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -374,7 +386,7 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": False}
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -385,7 +397,7 @@ def test_request_extra_json(self) -> None:
assert data == {"baz": False}
# `extra_json` takes priority over `json_data` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -396,8 +408,8 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": None}
- def test_request_extra_headers(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_headers(self, client: Kernel) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -407,7 +419,7 @@ def test_request_extra_headers(self) -> None:
assert request.headers.get("X-Foo") == "Foo"
# `extra_headers` takes priority over `default_headers` when keys clash
- request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -418,8 +430,8 @@ def test_request_extra_headers(self) -> None:
)
assert request.headers.get("X-Bar") == "false"
- def test_request_extra_query(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_query(self, client: Kernel) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -432,7 +444,7 @@ def test_request_extra_query(self) -> None:
assert params == {"my_query_param": "Foo"}
# if both `query` and `extra_query` are given, they are merged
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -446,7 +458,7 @@ def test_request_extra_query(self) -> None:
assert params == {"bar": "1", "foo": "2"}
# `extra_query` takes priority over `query` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -489,7 +501,7 @@ def test_multipart_repeating_array(self, client: Kernel) -> None:
]
@pytest.mark.respx(base_url=base_url)
- def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ def test_basic_union_response(self, respx_mock: MockRouter, client: Kernel) -> None:
class Model1(BaseModel):
name: str
@@ -498,12 +510,12 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
@pytest.mark.respx(base_url=base_url)
- def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ def test_union_response_different_types(self, respx_mock: MockRouter, client: Kernel) -> None:
"""Union of objects with the same field name using a different type"""
class Model1(BaseModel):
@@ -514,18 +526,18 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model1)
assert response.foo == 1
@pytest.mark.respx(base_url=base_url)
- def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
+ def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: Kernel) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""
@@ -541,7 +553,7 @@ class Model(BaseModel):
)
)
- response = self.client.get("/foo", cast_to=Model)
+ response = client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2
@@ -553,6 +565,8 @@ def test_base_url_setter(self) -> None:
assert client.base_url == "https://example.com/from_setter/"
+ client.close()
+
def test_base_url_env(self) -> None:
with update_env(KERNEL_BASE_URL="http://localhost:5000/from/env"):
client = Kernel(api_key=api_key, _strict_response_validation=True)
@@ -566,6 +580,8 @@ def test_base_url_env(self) -> None:
client = Kernel(base_url=None, api_key=api_key, _strict_response_validation=True, environment="production")
assert str(client.base_url).startswith("https://api.onkernel.com/")
+ client.close()
+
@pytest.mark.parametrize(
"client",
[
@@ -588,6 +604,7 @@ def test_base_url_trailing_slash(self, client: Kernel) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ client.close()
@pytest.mark.parametrize(
"client",
@@ -611,6 +628,7 @@ def test_base_url_no_trailing_slash(self, client: Kernel) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ client.close()
@pytest.mark.parametrize(
"client",
@@ -634,35 +652,36 @@ def test_absolute_request_url(self, client: Kernel) -> None:
),
)
assert request.url == "https://myapi.com/foo"
+ client.close()
def test_copied_client_does_not_close_http(self) -> None:
- client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- assert not client.is_closed()
+ test_client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not test_client.is_closed()
- copied = client.copy()
- assert copied is not client
+ copied = test_client.copy()
+ assert copied is not test_client
del copied
- assert not client.is_closed()
+ assert not test_client.is_closed()
def test_client_context_manager(self) -> None:
- client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- with client as c2:
- assert c2 is client
+ test_client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ with test_client as c2:
+ assert c2 is test_client
assert not c2.is_closed()
- assert not client.is_closed()
- assert client.is_closed()
+ assert not test_client.is_closed()
+ assert test_client.is_closed()
@pytest.mark.respx(base_url=base_url)
- def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ def test_client_response_validation_error(self, respx_mock: MockRouter, client: Kernel) -> None:
class Model(BaseModel):
foo: str
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
with pytest.raises(APIResponseValidationError) as exc:
- self.client.get("/foo", cast_to=Model)
+ client.get("/foo", cast_to=Model)
assert isinstance(exc.value.__cause__, ValidationError)
@@ -682,11 +701,14 @@ class Model(BaseModel):
with pytest.raises(APIResponseValidationError):
strict_client.get("/foo", cast_to=Model)
- client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ non_strict_client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=False)
- response = client.get("/foo", cast_to=Model)
+ response = non_strict_client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]
+ strict_client.close()
+ non_strict_client.close()
+
@pytest.mark.parametrize(
"remaining_retries,retry_after,timeout",
[
@@ -709,9 +731,9 @@ class Model(BaseModel):
],
)
@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
- def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
- client = Kernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
+ def test_parse_retry_after_header(
+ self, remaining_retries: int, retry_after: str, timeout: float, client: Kernel
+ ) -> None:
headers = httpx.Headers({"retry-after": retry_after})
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
@@ -726,7 +748,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, clien
with pytest.raises(APITimeoutError):
client.browsers.with_streaming_response.create().__enter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(client) == 0
@pytest.mark.skip() # SDK-2615
@mock.patch("kernel._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -736,7 +758,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client
with pytest.raises(APIStatusError):
client.browsers.with_streaming_response.create().__enter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(client) == 0
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("kernel._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -838,83 +860,77 @@ def test_default_client_creation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- def test_follow_redirects(self, respx_mock: MockRouter) -> None:
+ def test_follow_redirects(self, respx_mock: MockRouter, client: Kernel) -> None:
# Test that the default follow_redirects=True allows following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
- response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.respx(base_url=base_url)
- def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
+ def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: Kernel) -> None:
# Test that follow_redirects=False prevents following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
with pytest.raises(APIStatusError) as exc_info:
- self.client.post(
- "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
- )
+ client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
assert exc_info.value.response.status_code == 302
assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
class TestAsyncKernel:
- client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_raw_response(self, respx_mock: MockRouter) -> None:
+ async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncKernel) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.post("/foo", cast_to=httpx.Response)
+ response = await async_client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncKernel) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)
- response = await self.client.post("/foo", cast_to=httpx.Response)
+ response = await async_client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
- def test_copy(self) -> None:
- copied = self.client.copy()
- assert id(copied) != id(self.client)
+ def test_copy(self, async_client: AsyncKernel) -> None:
+ copied = async_client.copy()
+ assert id(copied) != id(async_client)
- copied = self.client.copy(api_key="another My API Key")
+ copied = async_client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
- assert self.client.api_key == "My API Key"
+ assert async_client.api_key == "My API Key"
- def test_copy_default_options(self) -> None:
+ def test_copy_default_options(self, async_client: AsyncKernel) -> None:
# options that have a default are overridden correctly
- copied = self.client.copy(max_retries=7)
+ copied = async_client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 2
+ assert async_client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7
# timeout
- assert isinstance(self.client.timeout, httpx.Timeout)
- copied = self.client.copy(timeout=None)
+ assert isinstance(async_client.timeout, httpx.Timeout)
+ copied = async_client.copy(timeout=None)
assert copied.timeout is None
- assert isinstance(self.client.timeout, httpx.Timeout)
+ assert isinstance(async_client.timeout, httpx.Timeout)
- def test_copy_default_headers(self) -> None:
+ async def test_copy_default_headers(self) -> None:
client = AsyncKernel(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
@@ -947,8 +963,9 @@ def test_copy_default_headers(self) -> None:
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+ await client.close()
- def test_copy_default_query(self) -> None:
+ async def test_copy_default_query(self) -> None:
client = AsyncKernel(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
)
@@ -984,13 +1001,15 @@ def test_copy_default_query(self) -> None:
):
client.copy(set_default_query={}, default_query={"foo": "Bar"})
- def test_copy_signature(self) -> None:
+ await client.close()
+
+ def test_copy_signature(self, async_client: AsyncKernel) -> None:
# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
init_signature = inspect.signature(
# mypy doesn't like that we access the `__init__` property.
- self.client.__init__, # type: ignore[misc]
+ async_client.__init__, # type: ignore[misc]
)
- copy_signature = inspect.signature(self.client.copy)
+ copy_signature = inspect.signature(async_client.copy)
exclude_params = {"transport", "proxies", "_strict_response_validation"}
for name in init_signature.parameters.keys():
@@ -1001,12 +1020,12 @@ def test_copy_signature(self) -> None:
assert copy_param is not None, f"copy() signature is missing the {name} param"
@pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
- def test_copy_build_request(self) -> None:
+ def test_copy_build_request(self, async_client: AsyncKernel) -> None:
options = FinalRequestOptions(method="get", url="/foo")
def build_request(options: FinalRequestOptions) -> None:
- client = self.client.copy()
- client._build_request(options)
+ client_copy = async_client.copy()
+ client_copy._build_request(options)
# ensure that the machinery is warmed up before tracing starts.
build_request(options)
@@ -1063,12 +1082,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()
- async def test_request_timeout(self) -> None:
- request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ async def test_request_timeout(self, async_client: AsyncKernel) -> None:
+ request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
- request = self.client._build_request(
+ request = async_client._build_request(
FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
)
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
@@ -1083,6 +1102,8 @@ async def test_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(0)
+ await client.close()
+
async def test_http_client_timeout_option(self) -> None:
# custom timeout given to the httpx client should be used
async with httpx.AsyncClient(timeout=None) as http_client:
@@ -1094,6 +1115,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(None)
+ await client.close()
+
# no timeout given to the httpx client should not use the httpx default
async with httpx.AsyncClient() as http_client:
client = AsyncKernel(
@@ -1104,6 +1127,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
+ await client.close()
+
# explicitly passing the default timeout currently results in it being ignored
async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
client = AsyncKernel(
@@ -1114,6 +1139,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ await client.close()
+
def test_invalid_http_client(self) -> None:
with pytest.raises(TypeError, match="Invalid `http_client` arg"):
with httpx.Client() as http_client:
@@ -1124,15 +1151,15 @@ def test_invalid_http_client(self) -> None:
http_client=cast(Any, http_client),
)
- def test_default_headers_option(self) -> None:
- client = AsyncKernel(
+ async def test_default_headers_option(self) -> None:
+ test_client = AsyncKernel(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "bar"
assert request.headers.get("x-stainless-lang") == "python"
- client2 = AsyncKernel(
+ test_client2 = AsyncKernel(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
@@ -1141,10 +1168,13 @@ def test_default_headers_option(self) -> None:
"X-Stainless-Lang": "my-overriding-header",
},
)
- request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+ await test_client.close()
+ await test_client2.close()
+
def test_validate_headers(self) -> None:
client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
@@ -1155,7 +1185,7 @@ def test_validate_headers(self) -> None:
client2 = AsyncKernel(base_url=base_url, api_key=None, _strict_response_validation=True)
_ = client2
- def test_default_query_option(self) -> None:
+ async def test_default_query_option(self) -> None:
client = AsyncKernel(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
)
@@ -1173,8 +1203,10 @@ def test_default_query_option(self) -> None:
url = httpx.URL(request.url)
assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
- def test_request_extra_json(self) -> None:
- request = self.client._build_request(
+ await client.close()
+
+ def test_request_extra_json(self, client: Kernel) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1185,7 +1217,7 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": False}
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1196,7 +1228,7 @@ def test_request_extra_json(self) -> None:
assert data == {"baz": False}
# `extra_json` takes priority over `json_data` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1207,8 +1239,8 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": None}
- def test_request_extra_headers(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_headers(self, client: Kernel) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1218,7 +1250,7 @@ def test_request_extra_headers(self) -> None:
assert request.headers.get("X-Foo") == "Foo"
# `extra_headers` takes priority over `default_headers` when keys clash
- request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1229,8 +1261,8 @@ def test_request_extra_headers(self) -> None:
)
assert request.headers.get("X-Bar") == "false"
- def test_request_extra_query(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_query(self, client: Kernel) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1243,7 +1275,7 @@ def test_request_extra_query(self) -> None:
assert params == {"my_query_param": "Foo"}
# if both `query` and `extra_query` are given, they are merged
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1257,7 +1289,7 @@ def test_request_extra_query(self) -> None:
assert params == {"bar": "1", "foo": "2"}
# `extra_query` takes priority over `query` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1300,7 +1332,7 @@ def test_multipart_repeating_array(self, async_client: AsyncKernel) -> None:
]
@pytest.mark.respx(base_url=base_url)
- async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncKernel) -> None:
class Model1(BaseModel):
name: str
@@ -1309,12 +1341,12 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
@pytest.mark.respx(base_url=base_url)
- async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncKernel) -> None:
"""Union of objects with the same field name using a different type"""
class Model1(BaseModel):
@@ -1325,18 +1357,20 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model1)
assert response.foo == 1
@pytest.mark.respx(base_url=base_url)
- async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
+ async def test_non_application_json_content_type_for_json_data(
+ self, respx_mock: MockRouter, async_client: AsyncKernel
+ ) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""
@@ -1352,11 +1386,11 @@ class Model(BaseModel):
)
)
- response = await self.client.get("/foo", cast_to=Model)
+ response = await async_client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2
- def test_base_url_setter(self) -> None:
+ async def test_base_url_setter(self) -> None:
client = AsyncKernel(
base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
)
@@ -1366,7 +1400,9 @@ def test_base_url_setter(self) -> None:
assert client.base_url == "https://example.com/from_setter/"
- def test_base_url_env(self) -> None:
+ await client.close()
+
+ async def test_base_url_env(self) -> None:
with update_env(KERNEL_BASE_URL="http://localhost:5000/from/env"):
client = AsyncKernel(api_key=api_key, _strict_response_validation=True)
assert client.base_url == "http://localhost:5000/from/env/"
@@ -1381,6 +1417,8 @@ def test_base_url_env(self) -> None:
)
assert str(client.base_url).startswith("https://api.onkernel.com/")
+ await client.close()
+
@pytest.mark.parametrize(
"client",
[
@@ -1396,7 +1434,7 @@ def test_base_url_env(self) -> None:
],
ids=["standard", "custom http client"],
)
- def test_base_url_trailing_slash(self, client: AsyncKernel) -> None:
+ async def test_base_url_trailing_slash(self, client: AsyncKernel) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1405,6 +1443,7 @@ def test_base_url_trailing_slash(self, client: AsyncKernel) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ await client.close()
@pytest.mark.parametrize(
"client",
@@ -1421,7 +1460,7 @@ def test_base_url_trailing_slash(self, client: AsyncKernel) -> None:
],
ids=["standard", "custom http client"],
)
- def test_base_url_no_trailing_slash(self, client: AsyncKernel) -> None:
+ async def test_base_url_no_trailing_slash(self, client: AsyncKernel) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1430,6 +1469,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncKernel) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ await client.close()
@pytest.mark.parametrize(
"client",
@@ -1446,7 +1486,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncKernel) -> None:
],
ids=["standard", "custom http client"],
)
- def test_absolute_request_url(self, client: AsyncKernel) -> None:
+ async def test_absolute_request_url(self, client: AsyncKernel) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1455,37 +1495,37 @@ def test_absolute_request_url(self, client: AsyncKernel) -> None:
),
)
assert request.url == "https://myapi.com/foo"
+ await client.close()
async def test_copied_client_does_not_close_http(self) -> None:
- client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- assert not client.is_closed()
+ test_client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not test_client.is_closed()
- copied = client.copy()
- assert copied is not client
+ copied = test_client.copy()
+ assert copied is not test_client
del copied
await asyncio.sleep(0.2)
- assert not client.is_closed()
+ assert not test_client.is_closed()
async def test_client_context_manager(self) -> None:
- client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- async with client as c2:
- assert c2 is client
+ test_client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ async with test_client as c2:
+ assert c2 is test_client
assert not c2.is_closed()
- assert not client.is_closed()
- assert client.is_closed()
+ assert not test_client.is_closed()
+ assert test_client.is_closed()
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncKernel) -> None:
class Model(BaseModel):
foo: str
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
with pytest.raises(APIResponseValidationError) as exc:
- await self.client.get("/foo", cast_to=Model)
+ await async_client.get("/foo", cast_to=Model)
assert isinstance(exc.value.__cause__, ValidationError)
@@ -1496,7 +1536,6 @@ async def test_client_max_retries_validation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
class Model(BaseModel):
name: str
@@ -1508,11 +1547,14 @@ class Model(BaseModel):
with pytest.raises(APIResponseValidationError):
await strict_client.get("/foo", cast_to=Model)
- client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ non_strict_client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=False)
- response = await client.get("/foo", cast_to=Model)
+ response = await non_strict_client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]
+ await strict_client.close()
+ await non_strict_client.close()
+
@pytest.mark.parametrize(
"remaining_retries,retry_after,timeout",
[
@@ -1535,13 +1577,12 @@ class Model(BaseModel):
],
)
@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
- @pytest.mark.asyncio
- async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
- client = AsyncKernel(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
+ async def test_parse_retry_after_header(
+ self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncKernel
+ ) -> None:
headers = httpx.Headers({"retry-after": retry_after})
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
- calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
+ calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
@pytest.mark.skip() # SDK-2615
@@ -1553,7 +1594,7 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter,
with pytest.raises(APITimeoutError):
await async_client.browsers.with_streaming_response.create().__aenter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(async_client) == 0
@pytest.mark.skip() # SDK-2615
@mock.patch("kernel._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -1563,12 +1604,11 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter,
with pytest.raises(APIStatusError):
await async_client.browsers.with_streaming_response.create().__aenter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(async_client) == 0
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("kernel._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
@pytest.mark.parametrize("failure_mode", ["status", "exception"])
async def test_retries_taken(
self,
@@ -1600,7 +1640,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("kernel._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_omit_retry_count_header(
self, async_client: AsyncKernel, failures_before_success: int, respx_mock: MockRouter
) -> None:
@@ -1624,7 +1663,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("kernel._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_overwrite_retry_count_header(
self, async_client: AsyncKernel, failures_before_success: int, respx_mock: MockRouter
) -> None:
@@ -1672,26 +1710,26 @@ async def test_default_client_creation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
+ async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncKernel) -> None:
# Test that the default follow_redirects=True allows following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
- response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.respx(base_url=base_url)
- async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
+ async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncKernel) -> None:
# Test that follow_redirects=False prevents following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
with pytest.raises(APIStatusError) as exc_info:
- await self.client.post(
+ await async_client.post(
"/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
)
diff --git a/tests/test_models.py b/tests/test_models.py
index ff5955d4..78f0fd32 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -9,7 +9,7 @@
from kernel._utils import PropertyInfo
from kernel._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
-from kernel._models import BaseModel, construct_type
+from kernel._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
class BasicModel(BaseModel):
@@ -809,7 +809,7 @@ class B(BaseModel):
UnionType = cast(Any, Union[A, B])
- assert not hasattr(UnionType, "__discriminator__")
+ assert not DISCRIMINATOR_CACHE.get(UnionType)
m = construct_type(
value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
@@ -818,7 +818,7 @@ class B(BaseModel):
assert m.type == "b"
assert m.data == "foo" # type: ignore[comparison-overlap]
- discriminator = UnionType.__discriminator__
+ discriminator = DISCRIMINATOR_CACHE.get(UnionType)
assert discriminator is not None
m = construct_type(
@@ -830,7 +830,7 @@ class B(BaseModel):
# if the discriminator details object stays the same between invocations then
# we hit the cache
- assert UnionType.__discriminator__ is discriminator
+ assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")