Skip to content

Commit 4478b85

Browse files
committed
Add integration tests for upsert
1 parent 0c9ecb7 commit 4478b85

File tree

7 files changed

+468
-5
lines changed

7 files changed

+468
-5
lines changed

pinecone/exceptions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
)
1313
from .exceptions import PineconeConfigurationError, PineconeProtocolError, ListConversionException
1414

15+
PineconeNotFoundException = NotFoundException
16+
1517
__all__ = [
1618
"PineconeConfigurationError",
1719
"PineconeProtocolError",
@@ -22,6 +24,7 @@
2224
"PineconeApiKeyError",
2325
"PineconeApiException",
2426
"NotFoundException",
27+
"PineconeNotFoundException",
2528
"UnauthorizedException",
2629
"ForbiddenException",
2730
"ServiceException",

pinecone/grpc/grpc_runner.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,19 @@
88
from .utils import _generate_request_id
99
from .config import GRPCClientConfig
1010
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
11-
from pinecone.exceptions.exceptions import PineconeException
12-
from grpc import CallCredentials, Compression
11+
from grpc import CallCredentials, Compression, StatusCode
12+
from grpc.aio import AioRpcError
1313
from google.protobuf.message import Message
1414

15+
from pinecone.exceptions import (
16+
PineconeException,
17+
PineconeApiValueError,
18+
PineconeApiException,
19+
UnauthorizedException,
20+
PineconeNotFoundException,
21+
ServiceException,
22+
)
23+
1524

1625
class GrpcRunner:
1726
def __init__(self, index_name: str, config: Config, grpc_config: GRPCClientConfig):
@@ -50,7 +59,7 @@ def wrapped():
5059
compression=compression,
5160
)
5261
except _InactiveRpcError as e:
53-
raise PineconeException(e._state.debug_error_string) from e
62+
self._map_exception(e, e._state.code, e._state.details)
5463

5564
return wrapped()
5665

@@ -89,8 +98,8 @@ async def wrapped():
8998
wait_for_ready=wait_for_ready,
9099
compression=compression,
91100
)
92-
except _InactiveRpcError as e:
93-
raise PineconeException(e._state.debug_error_string) from e
101+
except AioRpcError as e:
102+
self._map_exception(e, e.code(), e.details())
94103

95104
return await wrapped()
96105

@@ -108,3 +117,37 @@ def _prepare_metadata(
108117

109118
def _request_metadata(self) -> Dict[str, str]:
110119
return {REQUEST_ID: _generate_request_id()}
120+
121+
def _map_exception(self, e: Exception, code: Optional[StatusCode], details: Optional[str]):
122+
# Client / connection issues
123+
details = details or ""
124+
125+
if code in [StatusCode.DEADLINE_EXCEEDED]:
126+
raise TimeoutError(details) from e
127+
128+
# Permissions stuff
129+
if code in [StatusCode.PERMISSION_DENIED, StatusCode.UNAUTHENTICATED]:
130+
raise UnauthorizedException(status=code, reason=details) from e
131+
132+
# 400ish stuff
133+
if code in [StatusCode.NOT_FOUND]:
134+
raise PineconeNotFoundException(status=code, reason=details) from e
135+
if code in [StatusCode.INVALID_ARGUMENT, StatusCode.OUT_OF_RANGE]:
136+
raise PineconeApiValueError(details) from e
137+
if code in [
138+
StatusCode.ALREADY_EXISTS,
139+
StatusCode.FAILED_PRECONDITION,
140+
StatusCode.UNIMPLEMENTED,
141+
StatusCode.RESOURCE_EXHAUSTED,
142+
]:
143+
raise PineconeApiException(status=code, reason=details) from e
144+
145+
# 500ish stuff
146+
if code in [StatusCode.INTERNAL, StatusCode.UNAVAILABLE]:
147+
raise ServiceException(status=code, reason=details) from e
148+
if code in [StatusCode.UNKNOWN, StatusCode.DATA_LOSS, StatusCode.ABORTED]:
149+
# abandon hope, all ye who enter here
150+
raise PineconeException(code, details) from e
151+
152+
# If you get here, you're in a bad place
153+
raise PineconeException(code, details) from e

tests/integration/data_asyncio/__init__.py

Whitespace-only changes.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import os
3+
from ..helpers import get_environment_var, random_string
4+
5+
6+
@pytest.fixture(scope="session")
7+
def api_key():
8+
return get_environment_var("PINECONE_API_KEY")
9+
10+
11+
@pytest.fixture(scope="session")
12+
def host():
13+
return get_environment_var("INDEX_HOST")
14+
15+
16+
@pytest.fixture(scope="session")
17+
def dimension():
18+
return int(get_environment_var("DIMENSION"))
19+
20+
21+
def use_grpc():
22+
return os.environ.get("USE_GRPC", "false") == "true"
23+
24+
25+
def build_client(api_key):
26+
if use_grpc():
27+
from pinecone.grpc import PineconeGRPC
28+
29+
return PineconeGRPC(api_key=api_key)
30+
else:
31+
from pinecone import Pinecone
32+
33+
return Pinecone(
34+
api_key=api_key, additional_headers={"sdk-test-suite": "pinecone-python-client"}
35+
)
36+
37+
38+
@pytest.fixture(scope="session")
39+
async def pc(api_key):
40+
return build_client(api_key=api_key)
41+
42+
43+
@pytest.fixture(scope="session")
44+
async def asyncio_idx(pc, host):
45+
return pc.AsyncioIndex(host=host)
46+
47+
48+
@pytest.fixture(scope="session")
49+
async def namespace():
50+
return random_string(10)
51+
52+
53+
@pytest.fixture(scope="session")
54+
async def list_namespace():
55+
return random_string(10)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pytest
2+
from pinecone import Vector
3+
from .conftest import use_grpc
4+
from ..helpers import random_string
5+
from .utils import build_asyncio_idx, embedding_values, poll_for_freshness
6+
7+
8+
@pytest.mark.parametrize("target_namespace", ["", random_string(20)])
9+
@pytest.mark.skipif(use_grpc() == False, reason="Currently only GRPC supports asyncio")
10+
async def test_upsert_to_default_namespace(host, dimension, target_namespace):
11+
asyncio_idx = build_asyncio_idx(host)
12+
13+
def emb():
14+
return embedding_values(dimension)
15+
16+
# Upsert with tuples
17+
await asyncio_idx.upsert(
18+
vectors=[("1", emb()), ("2", emb()), ("3", emb())], namespace=target_namespace
19+
)
20+
21+
# Upsert with objects
22+
await asyncio_idx.upsert(
23+
vectors=[
24+
Vector(id="4", values=emb()),
25+
Vector(id="5", values=emb()),
26+
Vector(id="6", values=emb()),
27+
],
28+
namespace=target_namespace,
29+
)
30+
31+
# Upsert with dict
32+
await asyncio_idx.upsert(
33+
vectors=[
34+
{"id": "7", "values": emb()},
35+
{"id": "8", "values": emb()},
36+
{"id": "9", "values": emb()},
37+
],
38+
namespace=target_namespace,
39+
)
40+
41+
await poll_for_freshness(asyncio_idx, target_namespace, 9)
42+
43+
# # Check the vector count reflects some data has been upserted
44+
stats = await asyncio_idx.describe_index_stats()
45+
assert stats.total_vector_count >= 9
46+
# default namespace could have other stuff from other tests
47+
if target_namespace != "":
48+
assert stats.namespaces[target_namespace].vector_count == 9
49+
50+
51+
# @pytest.mark.parametrize("target_namespace", [
52+
# "",
53+
# random_string(20),
54+
# ])
55+
# @pytest.mark.skipif(
56+
# os.getenv("METRIC") != "dotproduct", reason="Only metric=dotprodouct indexes support hybrid"
57+
# )
58+
# async def test_upsert_to_namespace_with_sparse_embedding_values(pc, host, dimension, target_namespace):
59+
# asyncio_idx = pc.AsyncioIndex(host=host)
60+
61+
# # Upsert with sparse values object
62+
# await asyncio_idx.upsert(
63+
# vectors=[
64+
# Vector(
65+
# id="1",
66+
# values=embedding_values(dimension),
67+
# sparse_values=SparseValues(indices=[0, 1], values=embedding_values()),
68+
# )
69+
# ],
70+
# namespace=target_namespace,
71+
# )
72+
73+
# # Upsert with sparse values dict
74+
# await asyncio_idx.upsert(
75+
# vectors=[
76+
# {
77+
# "id": "2",
78+
# "values": embedding_values(dimension),
79+
# "sparse_values": {"indices": [0, 1], "values": embedding_values()},
80+
# },
81+
# {
82+
# "id": "3",
83+
# "values": embedding_values(dimension),
84+
# "sparse_values": {"indices": [0, 1], "values": embedding_values()},
85+
# },
86+
# ],
87+
# namespace=target_namespace,
88+
# )
89+
90+
# await poll_for_freshness(asyncio_idx, target_namespace, 9)
91+
92+
# # Check the vector count reflects some data has been upserted
93+
# stats = await asyncio_idx.describe_index_stats()
94+
# assert stats.total_vector_count >= 9
95+
96+
# if (target_namespace != ""):
97+
# assert stats.namespaces[target_namespace].vector_count == 9

0 commit comments

Comments
 (0)