From 9fb521e4b64faa57d638cc71e0affbd09cb44e92 Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Thu, 8 Jan 2026 14:20:59 -0800 Subject: [PATCH 1/9] Add context utility for request-time API keys --- datacommons_client/utils/context.py | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 datacommons_client/utils/context.py diff --git a/datacommons_client/utils/context.py b/datacommons_client/utils/context.py new file mode 100644 index 00000000..fcf1ec6a --- /dev/null +++ b/datacommons_client/utils/context.py @@ -0,0 +1,32 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextvars import ContextVar +from contextlib import contextmanager +from typing import Optional, Generator + +_API_KEY_CONTEXT_VAR: ContextVar[Optional[str]] = ContextVar("api_key", default=None) + +@contextmanager +def use_api_key(api_key: str) -> Generator[None, None, None]: + """Context manager to set the API key for the current execution context. + + Args: + api_key: The API key to use. + """ + token = _API_KEY_CONTEXT_VAR.set(api_key) + try: + yield + finally: + _API_KEY_CONTEXT_VAR.reset(token) From 89b78b964d47c30790d939fe4265a4525adc2b0e Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Thu, 8 Jan 2026 14:26:20 -0800 Subject: [PATCH 2/9] Add tests for context API key utility --- datacommons_client/tests/test_context.py | 36 ++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 datacommons_client/tests/test_context.py diff --git a/datacommons_client/tests/test_context.py b/datacommons_client/tests/test_context.py new file mode 100644 index 00000000..8bf8baa0 --- /dev/null +++ b/datacommons_client/tests/test_context.py @@ -0,0 +1,36 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datacommons_client.utils.context import use_api_key +from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR + + +class TestContext(unittest.TestCase): + + def test_use_api_key_sets_var(self): + """Test that use_api_key sets the context variable.""" + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + with use_api_key("test-key"): + self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "test-key") + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + + def test_use_api_key_nested(self): + """Test nested usage of use_api_key.""" + with use_api_key("outer"): + self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") + with use_api_key("inner"): + self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "inner") + self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) From a7ff930db37f76396a056439c61737a98ab0ed49 Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Fri, 9 Jan 2026 11:01:43 -0800 Subject: [PATCH 3/9] Use context API key in base API endpoint --- datacommons_client/endpoints/base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/datacommons_client/endpoints/base.py b/datacommons_client/endpoints/base.py index 1452ea87..f5076937 100644 --- a/datacommons_client/endpoints/base.py +++ b/datacommons_client/endpoints/base.py @@ -4,6 +4,7 @@ from datacommons_client.utils.request_handling import check_instance_is_valid from datacommons_client.utils.request_handling import post_request from datacommons_client.utils.request_handling import resolve_instance_url +from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR class API: @@ -94,9 +95,15 @@ def post(self, url = (self.base_url if endpoint is None else f"{self.base_url}/{endpoint}") + headers = self.headers + ctx_api_key = _API_KEY_CONTEXT_VAR.get() + if ctx_api_key: + headers = self.headers.copy() + headers["X-API-Key"] = ctx_api_key + return post_request(url=url, payload=payload, - headers=self.headers, + headers=headers, all_pages=all_pages, next_token=next_token) From 7fc300ceb63cd9d4d41bc29a2f363cc5650f25d4 Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Fri, 9 Jan 2026 11:28:51 -0800 Subject: [PATCH 4/9] Support optional API keys --- datacommons_client/tests/test_context.py | 9 +++++++ datacommons_client/utils/context.py | 30 ++++++++++++++++++++---- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/datacommons_client/tests/test_context.py b/datacommons_client/tests/test_context.py index 8bf8baa0..4a4528ce 100644 --- a/datacommons_client/tests/test_context.py +++ b/datacommons_client/tests/test_context.py @@ -34,3 +34,12 @@ def test_use_api_key_nested(self): self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "inner") self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + + def test_use_api_key_none(self): + """Test that use_api_key with None/empty does not set the variable.""" + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + with use_api_key(None): + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + with use_api_key(""): + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + diff --git a/datacommons_client/utils/context.py b/datacommons_client/utils/context.py index fcf1ec6a..587029d8 100644 --- a/datacommons_client/utils/context.py +++ b/datacommons_client/utils/context.py @@ -14,17 +14,39 @@ from contextvars import ContextVar from contextlib import contextmanager -from typing import Optional, Generator +from contextvars import ContextVar +from contextlib import contextmanager +from typing import Generator -_API_KEY_CONTEXT_VAR: ContextVar[Optional[str]] = ContextVar("api_key", default=None) +_API_KEY_CONTEXT_VAR: ContextVar[str | None] = ContextVar("api_key", default=None) @contextmanager -def use_api_key(api_key: str) -> Generator[None, None, None]: +def use_api_key(api_key: str | None) -> Generator[None, None, None]: """Context manager to set the API key for the current execution context. + If api_key is None or empty, this context manager does nothing, allowing + the underlying client to use its default API key. + Args: - api_key: The API key to use. + api_key: The API key to use. If None or empty, no change is made. + + Example: + client = DataCommonsClient(api_key="default-key") + + # Uses "default-key" + client.observation.fetch(...) + + with use_api_key("temp-key"): + # Uses "temp-key" + client.observation.fetch(...) + + # Back to "default-key" + client.observation.fetch(...) """ + if not api_key: + yield + return + token = _API_KEY_CONTEXT_VAR.set(api_key) try: yield From 7074bd8a5174de6caac28b2b4a8e3cb43e924488 Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Fri, 9 Jan 2026 11:35:25 -0800 Subject: [PATCH 5/9] Propagate API key context to graph worker threads --- datacommons_client/utils/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datacommons_client/utils/graph.py b/datacommons_client/utils/graph.py index b9d14495..a0336a0a 100644 --- a/datacommons_client/utils/graph.py +++ b/datacommons_client/utils/graph.py @@ -5,6 +5,7 @@ from concurrent.futures import wait from functools import lru_cache from typing import Callable, Literal, Optional, TypeAlias +import contextvars from datacommons_client.models.node import Node @@ -108,6 +109,7 @@ def build_graph_map( original_root = root + ctx = contextvars.copy_context() with ThreadPoolExecutor(max_workers=max_workers) as executor: queue = deque([root]) @@ -119,7 +121,7 @@ def build_graph_map( # Check if the node has already been visited or is in progress if dcid not in visited and dcid not in in_progress: # Submit the fetch task - in_progress[dcid] = executor.submit(fetch_fn, dcid=dcid) + in_progress[dcid] = executor.submit(ctx.run, fetch_fn, dcid=dcid) # Check if any futures are still in progress if not in_progress: From 102cc52c77337e9f58677febe55738aafd7d7287 Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Fri, 9 Jan 2026 12:19:18 -0800 Subject: [PATCH 6/9] Fix context propagation in NodeEndpoint and add integration tests --- datacommons_client/__init__.py | 2 + datacommons_client/endpoints/node.py | 4 +- datacommons_client/tests/test_client.py | 55 +++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/datacommons_client/__init__.py b/datacommons_client/__init__.py index 95b203a1..f7d02654 100644 --- a/datacommons_client/__init__.py +++ b/datacommons_client/__init__.py @@ -10,6 +10,7 @@ from datacommons_client.endpoints.node import NodeEndpoint from datacommons_client.endpoints.observation import ObservationEndpoint from datacommons_client.endpoints.resolve import ResolveEndpoint +from datacommons_client.utils.context import use_api_key __all__ = [ "DataCommonsClient", @@ -17,4 +18,5 @@ "NodeEndpoint", "ObservationEndpoint", "ResolveEndpoint", + "use_api_key", ] diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index 69bf3e9d..794ee16e 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -447,9 +447,11 @@ def _fetch_place_relationships( ) # Use a thread pool to fetch ancestry graphs in parallel for each input entity + import contextvars + ctx = contextvars.copy_context() with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: futures = [ - executor.submit(build_graph_map, root=dcid, fetch_fn=fetch_fn) + executor.submit(ctx.run, build_graph_map, root=dcid, fetch_fn=fetch_fn) for dcid in place_dcids ] # Gather ancestry maps and postprocess into flat or nested form diff --git a/datacommons_client/tests/test_client.py b/datacommons_client/tests/test_client.py index a17a2d9c..22071c42 100644 --- a/datacommons_client/tests/test_client.py +++ b/datacommons_client/tests/test_client.py @@ -5,6 +5,7 @@ import pytest from datacommons_client.client import DataCommonsClient +from datacommons_client import use_api_key from datacommons_client.endpoints.base import API from datacommons_client.endpoints.node import NodeEndpoint from datacommons_client.endpoints.observation import ObservationEndpoint @@ -419,3 +420,57 @@ def test_client_end_to_end_surface_header_propagation_observation( assert headers is not None assert headers.get("x-surface") == "datagemma" assert headers.get("X-API-Key") == "test_key" + + + +@patch("datacommons_client.endpoints.base.post_request") +def test_use_api_key_with_observation_fetch(mock_post_request): + """Test use_api_key override for observation fetches (non-threaded).""" + + # Setup client with default key + client = DataCommonsClient(api_key="default-key") + + # Configure mock to return valid response structure + mock_post_request.return_value = {"byVariable": {}, "facets": {}} + + # Default usage + client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) + mock_post_request.assert_called() + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "default-key" + + # Context override + with use_api_key("context-key"): + client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "context-key" + + # Back to default + client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "default-key" + + +@patch("datacommons_client.endpoints.base.post_request") +def test_use_api_key_with_node_fetch_place_ancestors(mock_post_request): + """Test use_api_key propagation for node graph methods (threaded).""" + + client = DataCommonsClient(api_key="default-key") + + # Configure mock. fetch_place_ancestors expects a dict response or list of nodes. + # NodeResponse.data is a dict. + mock_post_request.return_value = {"data": {}} + + # Default usage + client.node.fetch_place_ancestors(place_dcids=["geoId/06"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "default-key" + + # Context override + with use_api_key("context-key"): + # Use a different DCID to avoid hitting fetch_relationship_lru cache + client.node.fetch_place_ancestors(place_dcids=["geoId/07"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "context-key" + + From 7697c25a019436f4bca60c23d108287c58c747ea Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Fri, 9 Jan 2026 12:23:43 -0800 Subject: [PATCH 7/9] Format files --- datacommons_client/endpoints/base.py | 2 +- datacommons_client/endpoints/node.py | 6 ++-- datacommons_client/tests/test_client.py | 33 ++++++++--------- datacommons_client/tests/test_context.py | 46 ++++++++++++------------ datacommons_client/utils/context.py | 26 +++++++------- datacommons_client/utils/graph.py | 2 +- 6 files changed, 57 insertions(+), 58 deletions(-) diff --git a/datacommons_client/endpoints/base.py b/datacommons_client/endpoints/base.py index f5076937..a0a5403d 100644 --- a/datacommons_client/endpoints/base.py +++ b/datacommons_client/endpoints/base.py @@ -1,10 +1,10 @@ import re from typing import Any, Dict, Optional +from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR from datacommons_client.utils.request_handling import check_instance_is_valid from datacommons_client.utils.request_handling import post_request from datacommons_client.utils.request_handling import resolve_instance_url -from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR class API: diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index 794ee16e..d1e6f54b 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -451,8 +451,10 @@ def _fetch_place_relationships( ctx = contextvars.copy_context() with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: futures = [ - executor.submit(ctx.run, build_graph_map, root=dcid, fetch_fn=fetch_fn) - for dcid in place_dcids + executor.submit(ctx.run, + build_graph_map, + root=dcid, + fetch_fn=fetch_fn) for dcid in place_dcids ] # Gather ancestry maps and postprocess into flat or nested form for future in futures: diff --git a/datacommons_client/tests/test_client.py b/datacommons_client/tests/test_client.py index 22071c42..221befff 100644 --- a/datacommons_client/tests/test_client.py +++ b/datacommons_client/tests/test_client.py @@ -4,8 +4,8 @@ import pandas as pd import pytest -from datacommons_client.client import DataCommonsClient from datacommons_client import use_api_key +from datacommons_client.client import DataCommonsClient from datacommons_client.endpoints.base import API from datacommons_client.endpoints.node import NodeEndpoint from datacommons_client.endpoints.observation import ObservationEndpoint @@ -422,29 +422,28 @@ def test_client_end_to_end_surface_header_propagation_observation( assert headers.get("X-API-Key") == "test_key" - @patch("datacommons_client.endpoints.base.post_request") def test_use_api_key_with_observation_fetch(mock_post_request): """Test use_api_key override for observation fetches (non-threaded).""" - + # Setup client with default key client = DataCommonsClient(api_key="default-key") - + # Configure mock to return valid response structure mock_post_request.return_value = {"byVariable": {}, "facets": {}} - + # Default usage client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) mock_post_request.assert_called() _, kwargs = mock_post_request.call_args assert kwargs["headers"]["X-API-Key"] == "default-key" - + # Context override with use_api_key("context-key"): - client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) - _, kwargs = mock_post_request.call_args - assert kwargs["headers"]["X-API-Key"] == "context-key" - + client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "context-key" + # Back to default client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"]) _, kwargs = mock_post_request.call_args @@ -454,9 +453,9 @@ def test_use_api_key_with_observation_fetch(mock_post_request): @patch("datacommons_client.endpoints.base.post_request") def test_use_api_key_with_node_fetch_place_ancestors(mock_post_request): """Test use_api_key propagation for node graph methods (threaded).""" - + client = DataCommonsClient(api_key="default-key") - + # Configure mock. fetch_place_ancestors expects a dict response or list of nodes. # NodeResponse.data is a dict. mock_post_request.return_value = {"data": {}} @@ -468,9 +467,7 @@ def test_use_api_key_with_node_fetch_place_ancestors(mock_post_request): # Context override with use_api_key("context-key"): - # Use a different DCID to avoid hitting fetch_relationship_lru cache - client.node.fetch_place_ancestors(place_dcids=["geoId/07"]) - _, kwargs = mock_post_request.call_args - assert kwargs["headers"]["X-API-Key"] == "context-key" - - + # Use a different DCID to avoid hitting fetch_relationship_lru cache + client.node.fetch_place_ancestors(place_dcids=["geoId/07"]) + _, kwargs = mock_post_request.call_args + assert kwargs["headers"]["X-API-Key"] == "context-key" diff --git a/datacommons_client/tests/test_context.py b/datacommons_client/tests/test_context.py index 4a4528ce..70f33df0 100644 --- a/datacommons_client/tests/test_context.py +++ b/datacommons_client/tests/test_context.py @@ -13,33 +13,33 @@ # limitations under the License. import unittest -from datacommons_client.utils.context import use_api_key + from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR +from datacommons_client.utils.context import use_api_key class TestContext(unittest.TestCase): - def test_use_api_key_sets_var(self): - """Test that use_api_key sets the context variable.""" - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - with use_api_key("test-key"): - self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "test-key") - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - - def test_use_api_key_nested(self): - """Test nested usage of use_api_key.""" - with use_api_key("outer"): - self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") - with use_api_key("inner"): - self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "inner") - self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + def test_use_api_key_sets_var(self): + """Test that use_api_key sets the context variable.""" + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + with use_api_key("test-key"): + self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "test-key") + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - def test_use_api_key_none(self): - """Test that use_api_key with None/empty does not set the variable.""" - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - with use_api_key(None): - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - with use_api_key(""): - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + def test_use_api_key_nested(self): + """Test nested usage of use_api_key.""" + with use_api_key("outer"): + self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") + with use_api_key("inner"): + self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "inner") + self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + def test_use_api_key_none(self): + """Test that use_api_key with None/empty does not set the variable.""" + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + with use_api_key(None): + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) + with use_api_key(""): + self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) diff --git a/datacommons_client/utils/context.py b/datacommons_client/utils/context.py index 587029d8..02415947 100644 --- a/datacommons_client/utils/context.py +++ b/datacommons_client/utils/context.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextvars import ContextVar from contextlib import contextmanager from contextvars import ContextVar -from contextlib import contextmanager from typing import Generator -_API_KEY_CONTEXT_VAR: ContextVar[str | None] = ContextVar("api_key", default=None) +_API_KEY_CONTEXT_VAR: ContextVar[str | None] = ContextVar("api_key", + default=None) + @contextmanager def use_api_key(api_key: str | None) -> Generator[None, None, None]: - """Context manager to set the API key for the current execution context. + """Context manager to set the API key for the current execution context. If api_key is None or empty, this context manager does nothing, allowing the underlying client to use its default API key. @@ -43,12 +43,12 @@ def use_api_key(api_key: str | None) -> Generator[None, None, None]: # Back to "default-key" client.observation.fetch(...) """ - if not api_key: - yield - return - - token = _API_KEY_CONTEXT_VAR.set(api_key) - try: - yield - finally: - _API_KEY_CONTEXT_VAR.reset(token) + if not api_key: + yield + return + + token = _API_KEY_CONTEXT_VAR.set(api_key) + try: + yield + finally: + _API_KEY_CONTEXT_VAR.reset(token) diff --git a/datacommons_client/utils/graph.py b/datacommons_client/utils/graph.py index a0336a0a..db636b5d 100644 --- a/datacommons_client/utils/graph.py +++ b/datacommons_client/utils/graph.py @@ -3,9 +3,9 @@ from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from concurrent.futures import wait +import contextvars from functools import lru_cache from typing import Callable, Literal, Optional, TypeAlias -import contextvars from datacommons_client.models.node import Node From 1c404fcd4d8a5bccdcd37a8845dd9e36f182c791 Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Fri, 9 Jan 2026 13:31:09 -0800 Subject: [PATCH 8/9] Address gemini comments --- datacommons_client/endpoints/node.py | 2 +- datacommons_client/tests/test_context.py | 46 ++++++++++++------------ 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/datacommons_client/endpoints/node.py b/datacommons_client/endpoints/node.py index d1e6f54b..808c3495 100644 --- a/datacommons_client/endpoints/node.py +++ b/datacommons_client/endpoints/node.py @@ -1,4 +1,5 @@ from concurrent.futures import ThreadPoolExecutor +import contextvars from functools import partial from functools import wraps from typing import Literal, Optional @@ -447,7 +448,6 @@ def _fetch_place_relationships( ) # Use a thread pool to fetch ancestry graphs in parallel for each input entity - import contextvars ctx = contextvars.copy_context() with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: futures = [ diff --git a/datacommons_client/tests/test_context.py b/datacommons_client/tests/test_context.py index 70f33df0..ef7117cb 100644 --- a/datacommons_client/tests/test_context.py +++ b/datacommons_client/tests/test_context.py @@ -12,34 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR from datacommons_client.utils.context import use_api_key -class TestContext(unittest.TestCase): +def test_use_api_key_sets_var(): + """Test that use_api_key sets the context variable.""" + assert _API_KEY_CONTEXT_VAR.get() is None + with use_api_key("test-key"): + assert _API_KEY_CONTEXT_VAR.get() == "test-key" + assert _API_KEY_CONTEXT_VAR.get() is None + - def test_use_api_key_sets_var(self): - """Test that use_api_key sets the context variable.""" - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - with use_api_key("test-key"): - self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "test-key") - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) +def test_use_api_key_nested(): + """Test nested usage of use_api_key.""" + with use_api_key("outer"): + assert _API_KEY_CONTEXT_VAR.get() == "outer" + with use_api_key("inner"): + assert _API_KEY_CONTEXT_VAR.get() == "inner" + assert _API_KEY_CONTEXT_VAR.get() == "outer" + assert _API_KEY_CONTEXT_VAR.get() is None - def test_use_api_key_nested(self): - """Test nested usage of use_api_key.""" - with use_api_key("outer"): - self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") - with use_api_key("inner"): - self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "inner") - self.assertEqual(_API_KEY_CONTEXT_VAR.get(), "outer") - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - def test_use_api_key_none(self): - """Test that use_api_key with None/empty does not set the variable.""" - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - with use_api_key(None): - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) - with use_api_key(""): - self.assertIsNone(_API_KEY_CONTEXT_VAR.get()) +def test_use_api_key_none(): + """Test that use_api_key with None/empty does not set the variable.""" + assert _API_KEY_CONTEXT_VAR.get() is None + with use_api_key(None): + assert _API_KEY_CONTEXT_VAR.get() is None + with use_api_key(""): + assert _API_KEY_CONTEXT_VAR.get() is None From 72e9b0c04bb8744284a7d524ed73ae91d97276cb Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Sat, 10 Jan 2026 18:05:32 -0800 Subject: [PATCH 9/9] Address review comments --- datacommons_client/endpoints/base.py | 1 + datacommons_client/utils/context.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/datacommons_client/endpoints/base.py b/datacommons_client/endpoints/base.py index a0a5403d..ff4adfdc 100644 --- a/datacommons_client/endpoints/base.py +++ b/datacommons_client/endpoints/base.py @@ -98,6 +98,7 @@ def post(self, headers = self.headers ctx_api_key = _API_KEY_CONTEXT_VAR.get() if ctx_api_key: + # Copy headers to avoid mutating the shared client state headers = self.headers.copy() headers["X-API-Key"] = ctx_api_key diff --git a/datacommons_client/utils/context.py b/datacommons_client/utils/context.py index 02415947..c76944c5 100644 --- a/datacommons_client/utils/context.py +++ b/datacommons_client/utils/context.py @@ -31,6 +31,8 @@ def use_api_key(api_key: str | None) -> Generator[None, None, None]: api_key: The API key to use. If None or empty, no change is made. Example: + from datacommons_client import use_api_key + # ... client = DataCommonsClient(api_key="default-key") # Uses "default-key"