From f1e228d34ad7b29f8d075346ddc7f38fe24671ff Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Thu, 29 May 2025 09:57:10 +0000 Subject: [PATCH 01/10] index factory refactor --- aixplain/factories/index_factory/__init__.py | 2 +- aixplain/factories/model_factory/utils.py | 12 +- aixplain/modules/__init__.py | 2 +- .../modules/model/index_models/__init__.py | 5 + .../model/{ => index_models}/index_model.py | 128 +++++------------- .../knowledge_graph_index_model.py | 103 ++++++++++++++ .../model/index_models/vector_index_model.py | 90 ++++++++++++ tests/functional/model/run_model_test.py | 14 +- tests/unit/index_model_test.py | 4 +- 9 files changed, 259 insertions(+), 101 deletions(-) create mode 100644 aixplain/modules/model/index_models/__init__.py rename aixplain/modules/model/{ => index_models}/index_model.py (61%) create mode 100644 aixplain/modules/model/index_models/knowledge_graph_index_model.py create mode 100644 aixplain/modules/model/index_models/vector_index_model.py diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 189f9417..21e45e00 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -21,7 +21,7 @@ Index Factory Class """ -from aixplain.modules.model.index_model import IndexModel +from aixplain.modules.model.index_models.index_model import IndexModel from aixplain.factories import ModelFactory from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier, IndexStores, EmbeddingModel from typing import Text, Union, List, Tuple, Optional, TypeVar, Generic diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index ca3e1eec..b78b894a 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -2,7 +2,7 @@ import logging from aixplain.modules.model import Model from aixplain.modules.model.llm_model import LLM -from aixplain.modules.model.index_model import IndexModel +from aixplain.modules.model.index_models import IndexModel, VectorIndexModel, KnowledgeGraphIndexModel from aixplain.modules.model.integration import Integration from aixplain.modules.model.connection import ConnectionTool from aixplain.modules.model.utility_model import UtilityModel @@ -67,7 +67,15 @@ def create_model_from_response(response: Dict) -> Model: if len(f) > 0 and len(f[0].get("defaultValues", [])) > 0: temperature = float(f[0]["defaultValues"][0]["value"]) elif function == Function.SEARCH: - ModelClass = IndexModel + version = response.get("version", None) + if version and version.get("id", None) is not None and "-" in version["id"]: + collection_type = version["id"].split("-", 1)[0] + ModelClass = IndexModel + if collection_type in VectorIndexModel.supported_indices: + ModelClass = VectorIndexModel + elif collection_type in KnowledgeGraphIndexModel.supported_indices: + ModelClass = KnowledgeGraphIndexModel + additional_kwargs["llm"] = next((item["code"] for item in attributes if item["name"] == "llm"), None) elif function_type == FunctionType.INTEGRATION: ModelClass = Integration elif function_type == FunctionType.CONNECTION: diff --git a/aixplain/modules/__init__.py b/aixplain/modules/__init__.py index f8a64650..5e8fe8cb 100644 --- a/aixplain/modules/__init__.py +++ b/aixplain/modules/__init__.py @@ -37,4 +37,4 @@ from .agent.tool import Tool from .team_agent import TeamAgent from .api_key import APIKey, APIKeyLimits, APIKeyUsageLimit -from .model.index_model import IndexModel +from .model.index_models.index_model import IndexModel diff --git a/aixplain/modules/model/index_models/__init__.py b/aixplain/modules/model/index_models/__init__.py new file mode 100644 index 00000000..ff8d6921 --- /dev/null +++ b/aixplain/modules/model/index_models/__init__.py @@ -0,0 +1,5 @@ +from .knowledge_graph_index_model import KnowledgeGraphIndexModel +from .vector_index_model import VectorIndexModel, IndexFilter, IndexFilterOperator +from .index_model import IndexModel + +__all__ = ["KnowledgeGraphIndexModel", "VectorIndexModel", "IndexModel", "IndexFilter", "IndexFilterOperator"] diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_models/index_model.py similarity index 61% rename from aixplain/modules/model/index_model.py rename to aixplain/modules/model/index_models/index_model.py index 04ed3df1..876af611 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_models/index_model.py @@ -1,55 +1,9 @@ -from aixplain.enums import EmbeddingModel, Function, Supplier, ResponseStatus, StorageType, FunctionType +from typing import Text, Optional, Union, Dict, List +from aixplain.enums import EmbeddingModel, Function, Supplier, ResponseStatus from aixplain.modules.model import Model from aixplain.utils import config from aixplain.modules.model.response import ModelResponse -from typing import Text, Optional, Union, Dict from aixplain.modules.model.record import Record -from enum import Enum -from typing import List -from aixplain.enums.splitting_options import SplittingOptions - - -class IndexFilterOperator(Enum): - EQUALS = "==" - NOT_EQUALS = "!=" - CONTAINS = "in" - NOT_CONTAINS = "not in" - GREATER_THAN = ">" - LESS_THAN = "<" - GREATER_THAN_OR_EQUALS = ">=" - LESS_THAN_OR_EQUALS = "<=" - - -class IndexFilter: - field: str - value: str - operator: Union[IndexFilterOperator, str] - - def __init__(self, field: str, value: str, operator: Union[IndexFilterOperator, str]): - self.field = field - self.value = value - self.operator = operator - - def to_dict(self): - return { - "field": self.field, - "value": self.value, - "operator": self.operator.value if isinstance(self.operator, IndexFilterOperator) else self.operator, - } - - -class Splitter: - def __init__( - self, - split: bool = False, - split_by: SplittingOptions = SplittingOptions.WORD, - split_length: int = 1, - split_overlap: int = 0, - ): - self.split = split - self.split_by = split_by - self.split_length = split_length - self.split_overlap = split_overlap class IndexModel(Model): @@ -64,8 +18,7 @@ def __init__( function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, - embedding_model: Union[EmbeddingModel, str] = None, - function_type: Optional[FunctionType] = FunctionType.SEARCH, + embedding_model: Optional[EmbeddingModel] = None, **additional_info, ) -> None: """Index Init @@ -80,7 +33,7 @@ def __init__( function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. - embedding_model (Union[EmbeddingModel, str], optional): embedding model. Defaults to None. + embedding_model (EmbeddingModel, optional): embedding model. Defaults to None. **additional_info: Any additional Model info to be saved """ assert function == Function.SEARCH, "Index only supports search function" @@ -94,12 +47,12 @@ def __init__( function=function, is_subscribed=is_subscribed, api_key=api_key, - function_type=function_type, **additional_info, ) self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL self.embedding_model = embedding_model + self.embedding_size = None if embedding_model: try: from aixplain.factories import ModelFactory @@ -110,7 +63,6 @@ def __init__( import warnings warnings.warn(f"Failed to get embedding size for embedding model {embedding_model}: {e}") - self.embedding_size = None def to_dict(self) -> Dict: data = super().to_dict() @@ -119,53 +71,40 @@ def to_dict(self) -> Dict: data["collection_type"] = self.version.split("-", 1)[0] return data - def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: + def search(self, query: str, top_k: int = 10) -> ModelResponse: """Search for documents in the index Args: query (str): Query to be searched top_k (int, optional): Number of results to be returned. Defaults to 10. - filters (List[IndexFilter], optional): Filters to be applied. Defaults to []. Returns: ModelResponse: Response from the indexing service Example: - index_model.search("Hello") - - index_model.search("", filters=[IndexFilter(field="category", value="animate", operator=IndexFilterOperator.EQUALS)]) """ - from aixplain.factories import FileFactory - - uri, value_type = "", "text" - storage_type = FileFactory.check_storage_type(query) - if storage_type in [StorageType.FILE, StorageType.URL]: - uri = FileFactory.to_link(query) - query = "" - value_type = "image" data = { "action": "search", - "data": query or uri, - "dataType": value_type, - "filters": [filter.to_dict() for filter in filters], - "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, + "data": query, + "dataType": "text", + "filters": [], + "payload": {"top_k": top_k}, } return self.run(data=data) - def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: + def upsert(self, documents: List[Record]) -> ModelResponse: """Upsert documents into the index Args: documents (List[Record]): List of documents to be upserted - splitter (Splitter, optional): Splitter to be applied. Defaults to None. Returns: ModelResponse: Response from the indexing service - Examples: + Example: index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) - index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) - Splitter in the above example is optional and can be used to split the documents into smaller chunks. """ # Validate documents for doc in documents: @@ -173,19 +112,7 @@ def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) - # Convert documents to payloads payloads = [doc.to_dict() for doc in documents] # Build payload - data = { - "action": "ingest", - "data": payloads, - } - if splitter and splitter.split: - data["additional_params"] = { - "splitter": { - "split": splitter.split, - "split_by": splitter.split_by, - "split_length": splitter.split_length, - "split_overlap": splitter.split_overlap, - } - } + data = {"action": "ingest", "data": payloads} # Run the indexing service response = self.run(data=data) if response.status == ResponseStatus.SUCCESS: @@ -193,13 +120,6 @@ def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) - return response raise Exception(f"Failed to upsert documents: {response.error_message}") - def count(self) -> float: - data = {"action": "count", "data": ""} - response = self.run(data=data) - if response.status == "SUCCESS": - return int(response.data) - raise Exception(f"Failed to count documents: {response.error_message}") - def get_record(self, record_id: Text) -> ModelResponse: """ Get a document from the index. @@ -243,3 +163,25 @@ def delete_record(self, record_id: Text) -> ModelResponse: if response.status == "SUCCESS": return response raise Exception(f"Failed to delete record: {response.error_message}") + + def count(self) -> float: + data = {"action": "count", "data": ""} + response = self.run(data=data) + if response.status == "SUCCESS": + return int(response.data) + raise Exception(f"Failed to count documents: {response.error_message}") + + @staticmethod + def parse_file(file_path: Text) -> ModelResponse: + """ + Parse a file using the Docling model. + """ + try: + from aixplain.factories import ModelFactory + + docling_model_id = "677bee6c6eb56331f9192a91" + model = ModelFactory.get(docling_model_id) + response = model.run(file_path) + return response + except Exception as e: + raise Exception(f"Failed to parse file: {e}") diff --git a/aixplain/modules/model/index_models/knowledge_graph_index_model.py b/aixplain/modules/model/index_models/knowledge_graph_index_model.py new file mode 100644 index 00000000..534c8381 --- /dev/null +++ b/aixplain/modules/model/index_models/knowledge_graph_index_model.py @@ -0,0 +1,103 @@ +from aixplain.modules.model.index_models.index_model import IndexModel +from aixplain.enums import ResponseStatus +from typing import Text, Optional, Union, Dict, List, Any +from aixplain.modules.model.record import Record +from aixplain.modules.model.response import ModelResponse +from aixplain.enums import Function, Supplier, EmbeddingModel + + +class KnowledgeGraphIndexModel(IndexModel): + supported_indices = ["graphrag"] + + def __init__( + self, + id: Text, + name: Text, + description: Text = "", + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + version: Optional[Text] = None, + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + embedding_model: Optional[EmbeddingModel] = None, + llm: Optional[Text] = None, + **additional_info, + ): + super().__init__( + id, name, description, api_key, supplier, version, function, is_subscribed, cost, embedding_model, **additional_info + ) + self.llm = llm + + def to_dict(self) -> Dict[str, Any]: + data = super().to_dict() + data["llm"] = self.llm + return data + + def get_prompts(self) -> Dict[str, str]: + data = {"action": "get_prompts", "data": ""} + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = response.data + return response + raise Exception(f"Failed to get prompts: {response.error_message}") + + # Add documents to the storage. Later can run indexing process incrementally based on the new documents added. + def add_documents(self, documents: List[Record]): + # Validate documents + for doc in documents: + doc.validate() + # Convert documents to payloads + payloads = [doc.to_dict() for doc in documents] + # Build payload + data = {"action": "upload_documents", "data": payloads} + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = documents + return response + raise Exception(f"Failed to add documents: {response.error_message}") + + # Run prompt auto tuning, while also adding new documents to the storage if provided. + def auto_prompt_tune(self, documents: List[Record]) -> Dict[str, str]: + # TODO: Check if any files are already uploaded. If none and documents is also None, then raise an error. + self.add_documents(documents) + # Build payload + data = {"action": "auto_prompt_tune", "data": ""} + # Run the indexing service + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + return response.data + raise Exception(f"Failed to upsert documents: {response.error_message}") + + def manual_prompt_tune(self, prompts: Dict[str, str]) -> Dict[str, str]: + data = {"action": "manual_prompt_tune", "data": prompts} + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = prompts + return response + raise Exception(f"Failed to prompt tune: {response.error_message}") + + # Start the indexing process based on files already uploaded. + def graph_indexing(self): + data = {"action": "ingest", "data": []} + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = response.data + return response + raise Exception(f"Failed to run indexing: {response.error_message}") + + # Add current documents to the storage and start the indexing process. + def upsert(self, documents: List[Record]) -> ModelResponse: + """Upsert documents into the index + + Args: + documents (List[Record]): List of documents to be upserted + + Returns: + ModelResponse: Response from the indexing service + + Example: + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) + """ + self.add_documents(documents) + return self.graph_indexing() diff --git a/aixplain/modules/model/index_models/vector_index_model.py b/aixplain/modules/model/index_models/vector_index_model.py new file mode 100644 index 00000000..04ac829d --- /dev/null +++ b/aixplain/modules/model/index_models/vector_index_model.py @@ -0,0 +1,90 @@ +from typing import Union, List, Text, Optional, Dict +from enum import Enum +from aixplain.enums import StorageType +from aixplain.modules.model.index_models.index_model import IndexModel +from aixplain.modules.model.response import ModelResponse +from aixplain.enums import Function, Supplier, EmbeddingModel + + +class IndexFilterOperator(Enum): + EQUALS = "==" + NOT_EQUALS = "!=" + CONTAINS = "in" + NOT_CONTAINS = "not in" + GREATER_THAN = ">" + LESS_THAN = "<" + GREATER_THAN_OR_EQUALS = ">=" + LESS_THAN_OR_EQUALS = "<=" + + +class IndexFilter: + field: str + value: str + operator: Union[IndexFilterOperator, str] + + def __init__(self, field: str, value: str, operator: Union[IndexFilterOperator, str]): + self.field = field + self.value = value + self.operator = operator + + def to_dict(self): + return { + "field": self.field, + "value": self.value, + "operator": self.operator.value if isinstance(self.operator, IndexFilterOperator) else self.operator, + } + + +class VectorIndexModel(IndexModel): + supported_indices = ["airv2", "vectara", "zeroentropy"] + + def __init__( + self, + id: Text, + name: Text, + description: Text = "", + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + version: Optional[Text] = None, + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + embedding_model: Optional[EmbeddingModel] = None, + **additional_info + ): + super().__init__( + id, name, description, api_key, supplier, version, function, is_subscribed, cost, embedding_model, **additional_info + ) + + def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: + """Search for documents in the index + + Args: + query (str): Query to be searched + top_k (int, optional): Number of results to be returned. Defaults to 10. + filters (List[IndexFilter], optional): Filters to be applied. Defaults to []. + + Returns: + ModelResponse: Response from the indexing service + + Example: + - index_model.search("Hello") + - index_model.search("", filters=[IndexFilter(field="category", value="animate", operator=IndexFilterOperator.EQUALS)]) + """ + from aixplain.factories import FileFactory + + uri, value_type = "", "text" + storage_type = FileFactory.check_storage_type(query) + if storage_type in [StorageType.FILE, StorageType.URL]: + uri = FileFactory.to_link(query) + query = "" + value_type = "image" + + data = { + "action": "search", + "data": query or uri, + "dataType": value_type, + "filters": [filter.to_dict() for filter in filters], + "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, + } + return self.run(data=data) diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 71838034..61e59c0c 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -11,6 +11,7 @@ from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams import time import os +import json def pytest_generate_tests(metafunc): @@ -91,6 +92,7 @@ def run_index_model(index_model, retries): except Exception as e: time.sleep(180) + response = index_model.search("Berlin") assert str(response.status) == "SUCCESS" assert "germany" in response.data.lower() @@ -128,6 +130,7 @@ def test_index_model(embedding_model, supplier_params): run_index_model(index_model, retries) + @pytest.mark.parametrize( "embedding_model,supplier_params", [ @@ -143,7 +146,7 @@ def test_index_model_with_filter(embedding_model, supplier_params): from uuid import uuid4 from aixplain.modules.model.record import Record from aixplain.factories import IndexFactory - from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + from aiXplain.aixplain.modules.model.index_models.index_model import IndexFilter, IndexFilterOperator for index in IndexFactory.list()["results"]: index.delete() @@ -159,6 +162,9 @@ def test_index_model_with_filter(embedding_model, supplier_params): retries = 1 for _ in range(retries): try: + index_model.upsert( + [Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})] + ) index_model.upsert( [Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})] ) @@ -167,6 +173,9 @@ def test_index_model_with_filter(embedding_model, supplier_params): time.sleep(180) for _ in range(retries): try: + index_model.upsert( + [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] + ) index_model.upsert( [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] ) @@ -174,6 +183,7 @@ def test_index_model_with_filter(embedding_model, supplier_params): except Exception: time.sleep(180) + assert index_model.count() == 2 response = index_model.search( "", filters=[IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS)] @@ -304,7 +314,7 @@ def test_index_model_air_with_splitter(embedding_model, supplier_params): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record from uuid import uuid4 - from aixplain.modules.model.index_model import Splitter + from aixplain.modules.model.index_models.index_model import Splitter from aixplain.enums.splitting_options import SplittingOptions for index in IndexFactory.list()["results"]: diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 8d5c3a74..f226e763 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -3,7 +3,7 @@ from aixplain.factories.index_factory import IndexFactory from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse -from aixplain.modules.model.index_model import IndexModel +from aiXplain.aixplain.modules.model.index_models.index_model import IndexModel from aixplain.utils import config import logging import pytest @@ -208,7 +208,7 @@ def test_record_to_dict(): def test_index_filter(): - from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + from aiXplain.aixplain.modules.model.index_models.index_model import IndexFilter, IndexFilterOperator filter = IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS) assert filter.field == "category" From f4431b67aaefa59f3828bcbdb8c6470be9067cc7 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Mon, 2 Jun 2025 23:20:51 +0000 Subject: [PATCH 02/10] add unit tests - index_model --- tests/unit/index_model_test.py | 296 ++++++++++++++++----------------- 1 file changed, 146 insertions(+), 150 deletions(-) diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index f226e763..f003f013 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -1,224 +1,220 @@ +""" +Covers +────── +• VectorIndexModel +• KnowledgeGraphIndexModel +• Record utilities +• IndexFilter dataclass +• IndexFactory negative-path rules +""" import requests_mock from aixplain.enums import DataType, Function, ResponseStatus, StorageType, EmbeddingModel from aixplain.factories.index_factory import IndexFactory +from aixplain.factories.index_factory.utils import AirParams from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse -from aiXplain.aixplain.modules.model.index_models.index_model import IndexModel +from aixplain.modules.model.index_models.vector_index_model import ( + VectorIndexModel, + IndexFilter, + IndexFilterOperator, +) +from aixplain.modules.model.index_models.knowledge_graph_index_model import ( + KnowledgeGraphIndexModel, +) from aixplain.utils import config import logging import pytest -data = {"data": "Model Index", "description": "This is a dummy collection for testing."} -index_id = "id" -execute_url = f"{config.MODELS_RUN_URL}/{index_id}".replace("/api/v1/execute", "/api/v2/execute") +logging.basicConfig( + format="%(levelname)s • %(name)s • %(message)s", + level=logging.DEBUG, +) +# ──────────────────────────────────────────────────────────────────────────────── +# VECTOR-BASED INDEX TESTS +# ──────────────────────────────────────────────────────────────────────────────── -def test_text_search_success(mocker): +VEC_ID = "vec-id" +VEC_EXEC_URL = f"{config.MODELS_RUN_URL}/{VEC_ID}".replace("/api/v1/execute", "/api/v2/execute") +logger = logging.getLogger("VectorIndexModelTests") + + +def _make_vec(): + return VectorIndexModel( + id=VEC_ID, + name="vec-name", + description="", + function=Function.SEARCH, + embedding_model=EmbeddingModel.OPENAI_ADA002, + ) + + +def test_vector_text_search(mocker): mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.TEXT) - mock_response = {"status": "SUCCESS"} - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - response = index_model.search("test query") + with requests_mock.Mocker() as m: + m.post(VEC_EXEC_URL, json={"status": "SUCCESS"}, status_code=200) + logger.debug("POST %s • payload={status:SUCCESS}", VEC_EXEC_URL) + resp = _make_vec().search("hello world") - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS + assert isinstance(resp, ModelResponse) + assert resp.status == ResponseStatus.SUCCESS -def test_image_search_success(mocker): +def test_vector_image_search(mocker): mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) - mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") - - mock_response = {"status": "SUCCESS"} - - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel( - id=index_id, - data=data, - name="name", - function=Function.SEARCH, - embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, - ) - response = index_model.search("test.jpg") + mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://ex.com/img.jpg") + + with requests_mock.Mocker() as m: + m.post(VEC_EXEC_URL, json={"status": "SUCCESS"}, status_code=200) + logger.debug("POST %s • payload={status:SUCCESS}", VEC_EXEC_URL) + resp = _make_vec().search("img.jpg") - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS + assert resp.status == ResponseStatus.SUCCESS -def test_text_add_success(mocker): +def test_vector_upsert_text(mocker): mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) - mock_response = {"status": "SUCCESS"} - mock_documents = [ - Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={}), - Record(value="Sample document content 2", value_type="text", id=1, uri="", attributes={}), + docs = [ + Record(value="doc1", value_type="text", id=1, uri="", attributes={}), + Record(value="doc2", value_type="text", id=2, uri="", attributes={}), ] - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) + with requests_mock.Mocker() as m: + m.post(VEC_EXEC_URL, json={"status": "SUCCESS"}, status_code=200) + resp = _make_vec().upsert(docs) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) + assert resp.status == ResponseStatus.SUCCESS + assert resp.data == [d.to_dict() for d in docs] - response = index_model.upsert(mock_documents) - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS +def test_vector_index_filter(): + flt = IndexFilter("category", "news", IndexFilterOperator.EQUALS) + assert flt.to_dict() == {"field": "category", "value": "news", "operator": "=="} -def test_image_add_success(mocker): - mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.FILE] * 4) - mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) - mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") - mock_response = {"status": "SUCCESS"} +# ──────────────────────────────────────────────────────────────────────────────── +# KNOWLEDGE-GRAPH INDEX TESTS +# ──────────────────────────────────────────────────────────────────────────────── - mock_documents = [ - Record(uri="https://example.com/test.jpg", value_type="image", id=0, attributes={}), - ] +KG_ID = "kg-id" +KG_EXEC_URL = f"{config.MODELS_RUN_URL}/{KG_ID}".replace("/api/v1/execute", "/api/v2/execute") - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - response = index_model.upsert(mock_documents) +kg_logger = logging.getLogger("KGIndexModelTests") - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS +def _make_kg(): + return KnowledgeGraphIndexModel( + id=KG_ID, + name="kg-name", + description="", + function=Function.SEARCH, + llm="gpt-4o", + ) -def test_text_update_success(mocker): - mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) - mock_response = {"status": "SUCCESS"} - mock_documents = [ - Record(value="Updated document content 1", value_type="text", id=0, uri="", attributes={}), - Record(value="Updated document content 2", value_type="text", id=1, uri="", attributes={}), - ] +def _kg_mock(m, payload): + m.post(KG_EXEC_URL, json=payload, status_code=200) + kg_logger.debug("POST %s • payload=%s", KG_EXEC_URL, payload) - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - logging.debug(f"Requesting URL: {execute_url}") - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) +def test_kg_get_prompts(): + with requests_mock.Mocker() as m: + _kg_mock(m, {"status": "SUCCESS", "data": {"sys": "hi"}}) + resp = _make_kg().get_prompts() - response = index_model.upsert(mock_documents) + assert resp.status == ResponseStatus.SUCCESS + assert resp.data == {"sys": "hi"} - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS +def test_kg_add_documents(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) -def test_count_success(): - mock_response = {"status": "SUCCESS", "data": 4} + docs = [Record(value="kg-doc", value_type="text", id=0, uri="", attributes={})] - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - logging.debug(f"Requesting URL: {execute_url}") + with requests_mock.Mocker() as m: + _kg_mock(m, {"status": "SUCCESS"}) + resp = _make_kg().add_documents(docs) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) + assert resp.status == ResponseStatus.SUCCESS + assert resp.data == docs - response = index_model.count() - assert isinstance(response, int) - assert response == 4 +def test_kg_manual_prompt_tune(): + prompts = {"sys": "You are helpful"} + with requests_mock.Mocker() as m: + _kg_mock(m, {"status": "SUCCESS"}) + resp = _make_kg().manual_prompt_tune(prompts) -def test_get_document_success(): - mock_response = { - "status": "SUCCESS", - "data": {"value": "Sample document content 1", "value_type": "text", "id": 0, "uri": "", "attributes": {}}, - } - mock_documents = [Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={})] - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - index_model.upsert(mock_documents) - response = index_model.get_record(0) + assert resp.status == ResponseStatus.SUCCESS + assert resp.data == prompts - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS +def test_kg_auto_prompt_tune(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) -def test_delete_document_success(): - mock_response = {"status": "SUCCESS"} - mock_documents = [Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={})] + docs = [Record(value="auto", value_type="text", id=1, uri="", attributes={})] - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - index_model.upsert(mock_documents) - response = index_model.delete_record("0") + with requests_mock.Mocker() as m: + # upload_documents + _kg_mock(m, {"status": "SUCCESS"}) + # auto_prompt_tune + _kg_mock(m, {"status": "SUCCESS", "data": {"sys": "ok"}}) - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS + resp = _make_kg().auto_prompt_tune(docs) + assert resp == {"sys": "ok"} -def test_validate_record_success(mocker): - mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) - mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) - mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") - record = Record(uri="test.jpg", value_type="image", id=0, attributes={}) - record.validate() - assert record.value_type == DataType.IMAGE - assert record.uri == "https://example.com/test.jpg" - assert record.value == "" +def test_kg_graph_indexing(): + with requests_mock.Mocker() as m: + _kg_mock(m, {"status": "SUCCESS", "data": "started"}) + resp = _make_kg().graph_indexing() + assert resp.status == ResponseStatus.SUCCESS -def test_validate_record_failure(mocker): - mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=False) - mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) - mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") - record = Record(uri="test.mov", value_type="video", id=0, attributes={}) - with pytest.raises(Exception) as e: - record.validate() - assert str(e.value) == "Index Upsert Error: Invalid value type" +def test_kg_upsert_full_roundtrip(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) + + docs = [Record(value="round", value_type="text", id=7, uri="", attributes={})] -def test_validate_record_failure_no_uri(mocker): - record = Record(value="test.jpg", value_type="image", id=0, uri="", attributes={}) - with pytest.raises(Exception) as e: - record.validate() - assert str(e.value) == "Index Upsert Error: URI is required for image records" + with requests_mock.Mocker() as m: + # upload_documents + _kg_mock(m, {"status": "SUCCESS"}) + # ingest + _kg_mock(m, {"status": "SUCCESS", "data": "done"}) + resp = _make_kg().upsert(docs) -def test_validate_record_failure_no_value(mocker): - record = Record(uri="test.jpg", value_type="text", id=0, attributes={}) - with pytest.raises(Exception) as e: - record.validate() - assert str(e.value) == "Index Upsert Error: Value is required for text records" + assert resp.status == ResponseStatus.SUCCESS -def test_record_to_dict(): - record = Record(value="test", value_type=DataType.TEXT, id=0, uri="", attributes={}) - record_dict = record.to_dict() - assert record_dict["dataType"] == "text" - assert record_dict["uri"] == "" - assert record_dict["data"] == "test" - assert record_dict["document_id"] == 0 - assert record_dict["attributes"] == {} +# ──────────────────────────────────────────────────────────────────────────────── +# RECORD UTILITY TESTS +# ──────────────────────────────────────────────────────────────────────────────── +def test_record_validate_and_dict(mocker): + mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) + mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) + mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://ex.com/img.jpg") - record = Record(value="test", value_type=DataType.IMAGE, id=0, uri="https://example.com/test.jpg", attributes={}) - record_dict = record.to_dict() - assert record_dict["dataType"] == "image" - assert record_dict["uri"] == "https://example.com/test.jpg" - assert record_dict["data"] == "test" - assert record_dict["document_id"] == 0 - assert record_dict["attributes"] == {} + rec = Record(uri="img.jpg", value_type="image", id=0, attributes={}) + rec.validate() + d = rec.to_dict() + assert d["dataType"] == DataType.IMAGE + assert d["uri"] == "https://ex.com/img.jpg" -def test_index_filter(): - from aiXplain.aixplain.modules.model.index_models.index_model import IndexFilter, IndexFilterOperator - filter = IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS) - assert filter.field == "category" - assert filter.value == "world" - assert filter.operator == IndexFilterOperator.EQUALS +# ──────────────────────────────────────────────────────────────────────────────── +# INDEX FACTORY TESTS +# ──────────────────────────────────────────────────────────────────────────────── def test_index_factory_create_failure(): - from aixplain.factories.index_factory.utils import AirParams - with pytest.raises(Exception) as e: IndexFactory.create( name="test", From 56ee0d79261f55970b03c531b367761022411183 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Tue, 3 Jun 2025 00:01:53 +0000 Subject: [PATCH 03/10] index rebase --- .../modules/model/index_models/index_model.py | 34 +++++++++++++++++-- .../knowledge_graph_index_model.py | 7 ++-- tests/functional/model/run_model_test.py | 5 +-- tests/unit/index_model_test.py | 2 +- 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/aixplain/modules/model/index_models/index_model.py b/aixplain/modules/model/index_models/index_model.py index 876af611..093de3c1 100644 --- a/aixplain/modules/model/index_models/index_model.py +++ b/aixplain/modules/model/index_models/index_model.py @@ -4,6 +4,21 @@ from aixplain.utils import config from aixplain.modules.model.response import ModelResponse from aixplain.modules.model.record import Record +from aixplain.enums.splitting_options import SplittingOptions + + +class Splitter: + def __init__( + self, + split: bool = False, + split_by: SplittingOptions = SplittingOptions.WORD, + split_length: int = 1, + split_overlap: int = 0, + ): + self.split = split + self.split_by = split_by + self.split_length = split_length + self.split_overlap = split_overlap class IndexModel(Model): @@ -94,17 +109,20 @@ def search(self, query: str, top_k: int = 10) -> ModelResponse: } return self.run(data=data) - def upsert(self, documents: List[Record]) -> ModelResponse: + def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: """Upsert documents into the index Args: documents (List[Record]): List of documents to be upserted + splitter (Splitter, optional): Splitter to be applied. Defaults to None. Returns: ModelResponse: Response from the indexing service Example: index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) + Splitter in the above example is optional and can be used to split the documents into smaller chunks. """ # Validate documents for doc in documents: @@ -112,7 +130,19 @@ def upsert(self, documents: List[Record]) -> ModelResponse: # Convert documents to payloads payloads = [doc.to_dict() for doc in documents] # Build payload - data = {"action": "ingest", "data": payloads} + data = { + "action": "ingest", + "data": payloads, + } + if splitter and splitter.split: + data["additional_params"] = { + "splitter": { + "split": splitter.split, + "split_by": splitter.split_by, + "split_length": splitter.split_length, + "split_overlap": splitter.split_overlap, + } + } # Run the indexing service response = self.run(data=data) if response.status == ResponseStatus.SUCCESS: diff --git a/aixplain/modules/model/index_models/knowledge_graph_index_model.py b/aixplain/modules/model/index_models/knowledge_graph_index_model.py index 534c8381..2b99ab74 100644 --- a/aixplain/modules/model/index_models/knowledge_graph_index_model.py +++ b/aixplain/modules/model/index_models/knowledge_graph_index_model.py @@ -1,4 +1,4 @@ -from aixplain.modules.model.index_models.index_model import IndexModel +from aixplain.modules.model.index_models.index_model import IndexModel, Splitter from aixplain.enums import ResponseStatus from typing import Text, Optional, Union, Dict, List, Any from aixplain.modules.model.record import Record @@ -87,17 +87,20 @@ def graph_indexing(self): raise Exception(f"Failed to run indexing: {response.error_message}") # Add current documents to the storage and start the indexing process. - def upsert(self, documents: List[Record]) -> ModelResponse: + def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: """Upsert documents into the index Args: documents (List[Record]): List of documents to be upserted + splitter (Splitter, optional): Splitter to be applied. Defaults to None. Returns: ModelResponse: Response from the indexing service Example: index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) + Splitter in the above example is optional and can be used to split the documents into smaller chunks. """ self.add_documents(documents) return self.graph_indexing() diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 61e59c0c..4851e1c1 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -92,7 +92,6 @@ def run_index_model(index_model, retries): except Exception as e: time.sleep(180) - response = index_model.search("Berlin") assert str(response.status) == "SUCCESS" assert "germany" in response.data.lower() @@ -130,7 +129,6 @@ def test_index_model(embedding_model, supplier_params): run_index_model(index_model, retries) - @pytest.mark.parametrize( "embedding_model,supplier_params", [ @@ -146,7 +144,7 @@ def test_index_model_with_filter(embedding_model, supplier_params): from uuid import uuid4 from aixplain.modules.model.record import Record from aixplain.factories import IndexFactory - from aiXplain.aixplain.modules.model.index_models.index_model import IndexFilter, IndexFilterOperator + from aixplain.modules.model.index_models.vector_index_model import IndexFilter, IndexFilterOperator for index in IndexFactory.list()["results"]: index.delete() @@ -183,7 +181,6 @@ def test_index_model_with_filter(embedding_model, supplier_params): except Exception: time.sleep(180) - assert index_model.count() == 2 response = index_model.search( "", filters=[IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS)] diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index f003f013..7b942afb 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -241,7 +241,7 @@ def test_index_factory_create_failure(): def test_index_model_splitter(): - from aixplain.modules.model.index_model import Splitter + from aixplain.modules.model.index_models.index_model import Splitter splitter = Splitter(split=True, split_by="sentence", split_length=100, split_overlap=0) assert splitter.split == True From 5f1e45f89c9d9d2ce36d2a8d40d9570bedc3126f Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Tue, 3 Jun 2025 00:10:16 +0000 Subject: [PATCH 04/10] index rebase --- .../modules/model/index_models/index_model.py | 46 +----------------- .../knowledge_graph_index_model.py | 7 +-- .../model/index_models/vector_index_model.py | 48 +++++++++++++++++-- 3 files changed, 49 insertions(+), 52 deletions(-) diff --git a/aixplain/modules/model/index_models/index_model.py b/aixplain/modules/model/index_models/index_model.py index 093de3c1..8af7204e 100644 --- a/aixplain/modules/model/index_models/index_model.py +++ b/aixplain/modules/model/index_models/index_model.py @@ -1,9 +1,8 @@ -from typing import Text, Optional, Union, Dict, List -from aixplain.enums import EmbeddingModel, Function, Supplier, ResponseStatus +from typing import Text, Optional, Union, Dict +from aixplain.enums import EmbeddingModel, Function, Supplier from aixplain.modules.model import Model from aixplain.utils import config from aixplain.modules.model.response import ModelResponse -from aixplain.modules.model.record import Record from aixplain.enums.splitting_options import SplittingOptions @@ -109,47 +108,6 @@ def search(self, query: str, top_k: int = 10) -> ModelResponse: } return self.run(data=data) - def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: - """Upsert documents into the index - - Args: - documents (List[Record]): List of documents to be upserted - splitter (Splitter, optional): Splitter to be applied. Defaults to None. - - Returns: - ModelResponse: Response from the indexing service - - Example: - index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) - index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) - Splitter in the above example is optional and can be used to split the documents into smaller chunks. - """ - # Validate documents - for doc in documents: - doc.validate() - # Convert documents to payloads - payloads = [doc.to_dict() for doc in documents] - # Build payload - data = { - "action": "ingest", - "data": payloads, - } - if splitter and splitter.split: - data["additional_params"] = { - "splitter": { - "split": splitter.split, - "split_by": splitter.split_by, - "split_length": splitter.split_length, - "split_overlap": splitter.split_overlap, - } - } - # Run the indexing service - response = self.run(data=data) - if response.status == ResponseStatus.SUCCESS: - response.data = payloads - return response - raise Exception(f"Failed to upsert documents: {response.error_message}") - def get_record(self, record_id: Text) -> ModelResponse: """ Get a document from the index. diff --git a/aixplain/modules/model/index_models/knowledge_graph_index_model.py b/aixplain/modules/model/index_models/knowledge_graph_index_model.py index 2b99ab74..534c8381 100644 --- a/aixplain/modules/model/index_models/knowledge_graph_index_model.py +++ b/aixplain/modules/model/index_models/knowledge_graph_index_model.py @@ -1,4 +1,4 @@ -from aixplain.modules.model.index_models.index_model import IndexModel, Splitter +from aixplain.modules.model.index_models.index_model import IndexModel from aixplain.enums import ResponseStatus from typing import Text, Optional, Union, Dict, List, Any from aixplain.modules.model.record import Record @@ -87,20 +87,17 @@ def graph_indexing(self): raise Exception(f"Failed to run indexing: {response.error_message}") # Add current documents to the storage and start the indexing process. - def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: + def upsert(self, documents: List[Record]) -> ModelResponse: """Upsert documents into the index Args: documents (List[Record]): List of documents to be upserted - splitter (Splitter, optional): Splitter to be applied. Defaults to None. Returns: ModelResponse: Response from the indexing service Example: index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) - index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) - Splitter in the above example is optional and can be used to split the documents into smaller chunks. """ self.add_documents(documents) return self.graph_indexing() diff --git a/aixplain/modules/model/index_models/vector_index_model.py b/aixplain/modules/model/index_models/vector_index_model.py index 04ac829d..f7970c9c 100644 --- a/aixplain/modules/model/index_models/vector_index_model.py +++ b/aixplain/modules/model/index_models/vector_index_model.py @@ -1,9 +1,10 @@ from typing import Union, List, Text, Optional, Dict from enum import Enum from aixplain.enums import StorageType -from aixplain.modules.model.index_models.index_model import IndexModel +from aixplain.modules.model.index_models.index_model import IndexModel, Splitter +from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse -from aixplain.enums import Function, Supplier, EmbeddingModel +from aixplain.enums import Function, Supplier, EmbeddingModel, ResponseStatus class IndexFilterOperator(Enum): @@ -50,7 +51,7 @@ def __init__( is_subscribed: bool = False, cost: Optional[Dict] = None, embedding_model: Optional[EmbeddingModel] = None, - **additional_info + **additional_info, ): super().__init__( id, name, description, api_key, supplier, version, function, is_subscribed, cost, embedding_model, **additional_info @@ -88,3 +89,44 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) - "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, } return self.run(data=data) + + def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: + """Upsert documents into the index + + Args: + documents (List[Record]): List of documents to be upserted + splitter (Splitter, optional): Splitter to be applied. Defaults to None. + + Returns: + ModelResponse: Response from the indexing service + + Example: + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) + Splitter in the above example is optional and can be used to split the documents into smaller chunks. + """ + # Validate documents + for doc in documents: + doc.validate() + # Convert documents to payloads + payloads = [doc.to_dict() for doc in documents] + # Build payload + data = { + "action": "ingest", + "data": payloads, + } + if splitter and splitter.split: + data["additional_params"] = { + "splitter": { + "split": splitter.split, + "split_by": splitter.split_by, + "split_length": splitter.split_length, + "split_overlap": splitter.split_overlap, + } + } + # Run the indexing service + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = payloads + return response + raise Exception(f"Failed to upsert documents: {response.error_message}") From 35444fe991dc5c84fc99d3555864353b0208804c Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Tue, 3 Jun 2025 00:13:59 +0000 Subject: [PATCH 05/10] index rebase --- tests/functional/model/run_model_test.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 4851e1c1..e5ab758e 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -160,9 +160,6 @@ def test_index_model_with_filter(embedding_model, supplier_params): retries = 1 for _ in range(retries): try: - index_model.upsert( - [Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})] - ) index_model.upsert( [Record(value="Hello, aiXplain!", value_type="text", uri="", id="1", attributes={"category": "hello"})] ) @@ -171,9 +168,6 @@ def test_index_model_with_filter(embedding_model, supplier_params): time.sleep(180) for _ in range(retries): try: - index_model.upsert( - [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] - ) index_model.upsert( [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] ) From fdbec4d919f2d512ec8b77b5a452c39528d80ab8 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Tue, 10 Jun 2025 11:04:55 +0000 Subject: [PATCH 06/10] add enums, index type, bug fix --- aixplain/enums/__init__.py | 2 ++ aixplain/enums/index_type.py | 6 ++++++ aixplain/factories/index_factory/__init__.py | 15 +++------------ aixplain/modules/model/__init__.py | 2 +- .../modules/model/index_models/index_model.py | 10 ++++++---- .../index_models/knowledge_graph_index_model.py | 17 ++++++++++++++--- .../model/index_models/vector_index_model.py | 17 ++++++++++++++--- 7 files changed, 46 insertions(+), 23 deletions(-) create mode 100644 aixplain/enums/index_type.py diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 725fdb90..4b93f7c4 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -20,3 +20,5 @@ from .asset_status import AssetStatus from .index_stores import IndexStores from .function_type import FunctionType +from .splitting_options import SplittingOptions +from .index_type import IndexType diff --git a/aixplain/enums/index_type.py b/aixplain/enums/index_type.py new file mode 100644 index 00000000..125578c2 --- /dev/null +++ b/aixplain/enums/index_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class IndexType(str, Enum): + VECTOR = "vector" + KNOWLEDGE_GRAPH = "knowledge_graph" diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 21e45e00..62955dee 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -25,16 +25,11 @@ from aixplain.factories import ModelFactory from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier, IndexStores, EmbeddingModel from typing import Text, Union, List, Tuple, Optional, TypeVar, Generic -from aixplain.factories.index_factory.utils import BaseIndexParams +from aixplain.factories.index_factory.utils import BaseIndexParams, AirParams T = TypeVar("T", bound=BaseIndexParams) -def validate_embedding_model(model_id) -> bool: - model = ModelFactory.get(model_id) - return model.function == Function.TEXT_EMBEDDING - - class IndexFactory(ModelFactory, Generic[T]): @classmethod def create( @@ -64,12 +59,8 @@ def create( assert ( name is not None and description is not None and embedding_model is not None ), "Index Factory Exception: name, description, and embedding_model must be provided when params is not" - if validate_embedding_model(embedding_model): - data = { - "data": name, - "description": description, - "model": embedding_model, - } + params = AirParams(name=name, description=description, embedding_model=embedding_model) + data = params.to_dict() model = cls.get(model_id) response = model.run(data=data) if response.status == ResponseStatus.SUCCESS: diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index f20d4e20..45b65c78 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -136,7 +136,7 @@ def to_dict(self) -> Dict: "additional_info": clean_additional_info, "input_params": self.input_params, "output_params": self.output_params, - "model_params": self.model_params.to_dict(), + "model_params": self.model_params.to_dict() if self.model_params else None, "function": self.function, "status": self.status, } diff --git a/aixplain/modules/model/index_models/index_model.py b/aixplain/modules/model/index_models/index_model.py index 8af7204e..4f4d6403 100644 --- a/aixplain/modules/model/index_models/index_model.py +++ b/aixplain/modules/model/index_models/index_model.py @@ -3,7 +3,7 @@ from aixplain.modules.model import Model from aixplain.utils import config from aixplain.modules.model.response import ModelResponse -from aixplain.enums.splitting_options import SplittingOptions +from aixplain.enums import IndexType, SplittingOptions class Splitter: @@ -25,14 +25,15 @@ def __init__( self, id: Text, name: Text, + version: Text, description: Text = "", api_key: Optional[Text] = None, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", - version: Optional[Text] = None, function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, - embedding_model: Optional[EmbeddingModel] = None, + embedding_model: Union[EmbeddingModel, Text] = None, + index_type: Optional[IndexType] = None, **additional_info, ) -> None: """Index Init @@ -47,7 +48,7 @@ def __init__( function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. - embedding_model (EmbeddingModel, optional): embedding model. Defaults to None. + embedding_model (Union[EmbeddingModel, Text], optional): embedding model. Defaults to None. **additional_info: Any additional Model info to be saved """ assert function == Function.SEARCH, "Index only supports search function" @@ -66,6 +67,7 @@ def __init__( self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL self.embedding_model = embedding_model + self.index_type = index_type self.embedding_size = None if embedding_model: try: diff --git a/aixplain/modules/model/index_models/knowledge_graph_index_model.py b/aixplain/modules/model/index_models/knowledge_graph_index_model.py index 534c8381..fbe86893 100644 --- a/aixplain/modules/model/index_models/knowledge_graph_index_model.py +++ b/aixplain/modules/model/index_models/knowledge_graph_index_model.py @@ -3,7 +3,7 @@ from typing import Text, Optional, Union, Dict, List, Any from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse -from aixplain.enums import Function, Supplier, EmbeddingModel +from aixplain.enums import Function, Supplier, EmbeddingModel, IndexType class KnowledgeGraphIndexModel(IndexModel): @@ -13,10 +13,10 @@ def __init__( self, id: Text, name: Text, + version: Text, description: Text = "", api_key: Optional[Text] = None, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", - version: Optional[Text] = None, function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, @@ -25,7 +25,18 @@ def __init__( **additional_info, ): super().__init__( - id, name, description, api_key, supplier, version, function, is_subscribed, cost, embedding_model, **additional_info + id, + name, + version, + description, + api_key, + supplier, + function, + is_subscribed, + cost, + embedding_model, + IndexType.KNOWLEDGE_GRAPH, + **additional_info, ) self.llm = llm diff --git a/aixplain/modules/model/index_models/vector_index_model.py b/aixplain/modules/model/index_models/vector_index_model.py index f7970c9c..4ae311e3 100644 --- a/aixplain/modules/model/index_models/vector_index_model.py +++ b/aixplain/modules/model/index_models/vector_index_model.py @@ -4,7 +4,7 @@ from aixplain.modules.model.index_models.index_model import IndexModel, Splitter from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse -from aixplain.enums import Function, Supplier, EmbeddingModel, ResponseStatus +from aixplain.enums import Function, Supplier, EmbeddingModel, ResponseStatus, IndexType class IndexFilterOperator(Enum): @@ -43,10 +43,10 @@ def __init__( self, id: Text, name: Text, + version: Text, description: Text = "", api_key: Optional[Text] = None, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", - version: Optional[Text] = None, function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, @@ -54,7 +54,18 @@ def __init__( **additional_info, ): super().__init__( - id, name, description, api_key, supplier, version, function, is_subscribed, cost, embedding_model, **additional_info + id, + name, + version, + description, + api_key, + supplier, + function, + is_subscribed, + cost, + embedding_model, + IndexType.VECTOR, + **additional_info, ) def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: From 3fb4916ce35d86a6cd7551dc9a598feab2f1870d Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Wed, 11 Jun 2025 12:00:07 +0000 Subject: [PATCH 07/10] add base index model --- aixplain/factories/index_factory/__init__.py | 2 +- aixplain/factories/model_factory/utils.py | 4 ++-- aixplain/modules/__init__.py | 2 +- aixplain/modules/model/index_models/__init__.py | 4 ++-- .../index_models/{index_model.py => base_index_model.py} | 2 +- .../modules/model/index_models/knowledge_graph_index_model.py | 4 ++-- aixplain/modules/model/index_models/vector_index_model.py | 4 ++-- 7 files changed, 11 insertions(+), 11 deletions(-) rename aixplain/modules/model/index_models/{index_model.py => base_index_model.py} (99%) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 62955dee..3b1c4654 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -21,7 +21,7 @@ Index Factory Class """ -from aixplain.modules.model.index_models.index_model import IndexModel +from aixplain.modules.model.index_models import BaseIndexModel from aixplain.factories import ModelFactory from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier, IndexStores, EmbeddingModel from typing import Text, Union, List, Tuple, Optional, TypeVar, Generic diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index b78b894a..ba2480ca 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -2,7 +2,7 @@ import logging from aixplain.modules.model import Model from aixplain.modules.model.llm_model import LLM -from aixplain.modules.model.index_models import IndexModel, VectorIndexModel, KnowledgeGraphIndexModel +from aixplain.modules.model.index_models import BaseIndexModel, VectorIndexModel, KnowledgeGraphIndexModel from aixplain.modules.model.integration import Integration from aixplain.modules.model.connection import ConnectionTool from aixplain.modules.model.utility_model import UtilityModel @@ -70,7 +70,7 @@ def create_model_from_response(response: Dict) -> Model: version = response.get("version", None) if version and version.get("id", None) is not None and "-" in version["id"]: collection_type = version["id"].split("-", 1)[0] - ModelClass = IndexModel + ModelClass = BaseIndexModel if collection_type in VectorIndexModel.supported_indices: ModelClass = VectorIndexModel elif collection_type in KnowledgeGraphIndexModel.supported_indices: diff --git a/aixplain/modules/__init__.py b/aixplain/modules/__init__.py index 5e8fe8cb..3fc8e0b3 100644 --- a/aixplain/modules/__init__.py +++ b/aixplain/modules/__init__.py @@ -37,4 +37,4 @@ from .agent.tool import Tool from .team_agent import TeamAgent from .api_key import APIKey, APIKeyLimits, APIKeyUsageLimit -from .model.index_models.index_model import IndexModel +from .model.index_models import BaseIndexModel diff --git a/aixplain/modules/model/index_models/__init__.py b/aixplain/modules/model/index_models/__init__.py index ff8d6921..46969de2 100644 --- a/aixplain/modules/model/index_models/__init__.py +++ b/aixplain/modules/model/index_models/__init__.py @@ -1,5 +1,5 @@ from .knowledge_graph_index_model import KnowledgeGraphIndexModel from .vector_index_model import VectorIndexModel, IndexFilter, IndexFilterOperator -from .index_model import IndexModel +from .base_index_model import BaseIndexModel -__all__ = ["KnowledgeGraphIndexModel", "VectorIndexModel", "IndexModel", "IndexFilter", "IndexFilterOperator"] +__all__ = ["KnowledgeGraphIndexModel", "VectorIndexModel", "BaseIndexModel", "IndexFilter", "IndexFilterOperator"] diff --git a/aixplain/modules/model/index_models/index_model.py b/aixplain/modules/model/index_models/base_index_model.py similarity index 99% rename from aixplain/modules/model/index_models/index_model.py rename to aixplain/modules/model/index_models/base_index_model.py index 4f4d6403..d864d612 100644 --- a/aixplain/modules/model/index_models/index_model.py +++ b/aixplain/modules/model/index_models/base_index_model.py @@ -20,7 +20,7 @@ def __init__( self.split_overlap = split_overlap -class IndexModel(Model): +class BaseIndexModel(Model): def __init__( self, id: Text, diff --git a/aixplain/modules/model/index_models/knowledge_graph_index_model.py b/aixplain/modules/model/index_models/knowledge_graph_index_model.py index fbe86893..87178d9d 100644 --- a/aixplain/modules/model/index_models/knowledge_graph_index_model.py +++ b/aixplain/modules/model/index_models/knowledge_graph_index_model.py @@ -1,4 +1,4 @@ -from aixplain.modules.model.index_models.index_model import IndexModel +from aiXplain.aixplain.modules.model.index_models.base_index_model import BaseIndexModel from aixplain.enums import ResponseStatus from typing import Text, Optional, Union, Dict, List, Any from aixplain.modules.model.record import Record @@ -6,7 +6,7 @@ from aixplain.enums import Function, Supplier, EmbeddingModel, IndexType -class KnowledgeGraphIndexModel(IndexModel): +class KnowledgeGraphIndexModel(BaseIndexModel): supported_indices = ["graphrag"] def __init__( diff --git a/aixplain/modules/model/index_models/vector_index_model.py b/aixplain/modules/model/index_models/vector_index_model.py index 4ae311e3..c6759832 100644 --- a/aixplain/modules/model/index_models/vector_index_model.py +++ b/aixplain/modules/model/index_models/vector_index_model.py @@ -1,7 +1,7 @@ from typing import Union, List, Text, Optional, Dict from enum import Enum from aixplain.enums import StorageType -from aixplain.modules.model.index_models.index_model import IndexModel, Splitter +from aiXplain.aixplain.modules.model.index_models.base_index_model import BaseIndexModel, Splitter from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse from aixplain.enums import Function, Supplier, EmbeddingModel, ResponseStatus, IndexType @@ -36,7 +36,7 @@ def to_dict(self): } -class VectorIndexModel(IndexModel): +class VectorIndexModel(BaseIndexModel): supported_indices = ["airv2", "vectara", "zeroentropy"] def __init__( From 8502bfed93d9fe66fe512b5cbb5cd2fb36d1204d Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Wed, 11 Jun 2025 13:45:31 +0000 Subject: [PATCH 08/10] Update index tests --- aixplain/factories/index_factory/__init__.py | 4 +- aixplain/factories/index_factory/utils.py | 13 ++++-- .../model/index_models/base_index_model.py | 3 +- .../knowledge_graph_index_model.py | 10 ++--- .../model/index_models/vector_index_model.py | 2 +- tests/functional/model/run_model_test.py | 42 +++++++++++++++++-- tests/unit/index_model_test.py | 16 +++---- 7 files changed, 64 insertions(+), 26 deletions(-) diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 3b1c4654..8da5d7f0 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -39,7 +39,7 @@ def create( embedding_model: Union[EmbeddingModel, str] = EmbeddingModel.OPENAI_ADA002, params: Optional[T] = None, **kwargs, - ) -> IndexModel: + ) -> BaseIndexModel: """Create a new index collection""" import warnings @@ -83,7 +83,7 @@ def list( sort_order: SortOrder = SortOrder.ASCENDING, page_number: int = 0, page_size: int = 20, - ) -> List[IndexModel]: + ) -> List[BaseIndexModel]: """List all indexes""" return super().list( function=Function.SEARCH, diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py index b09df3c0..a8169f17 100644 --- a/aixplain/factories/index_factory/utils.py +++ b/aixplain/factories/index_factory/utils.py @@ -80,7 +80,12 @@ class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): def id(self) -> str: return self._id - - - - + @field_validator('llm') + def validate_llm(cls, llm) -> Optional[Text]: + if llm is None: + return None + model = ModelFactory.get(llm) + if model.function == Function.TEXT_GENERATION: + return llm + else: + raise ValueError("This is not an LLM model") diff --git a/aixplain/modules/model/index_models/base_index_model.py b/aixplain/modules/model/index_models/base_index_model.py index d864d612..53e57094 100644 --- a/aixplain/modules/model/index_models/base_index_model.py +++ b/aixplain/modules/model/index_models/base_index_model.py @@ -158,7 +158,8 @@ def count(self) -> float: data = {"action": "count", "data": ""} response = self.run(data=data) if response.status == "SUCCESS": - return int(response.data) + n_indexed = response.data.get("data", response.data) + return int(n_indexed) raise Exception(f"Failed to count documents: {response.error_message}") @staticmethod diff --git a/aixplain/modules/model/index_models/knowledge_graph_index_model.py b/aixplain/modules/model/index_models/knowledge_graph_index_model.py index 87178d9d..d9d1d4bf 100644 --- a/aixplain/modules/model/index_models/knowledge_graph_index_model.py +++ b/aixplain/modules/model/index_models/knowledge_graph_index_model.py @@ -1,4 +1,4 @@ -from aiXplain.aixplain.modules.model.index_models.base_index_model import BaseIndexModel +from aixplain.modules.model.index_models.base_index_model import BaseIndexModel from aixplain.enums import ResponseStatus from typing import Text, Optional, Union, Dict, List, Any from aixplain.modules.model.record import Record @@ -69,11 +69,12 @@ def add_documents(self, documents: List[Record]): raise Exception(f"Failed to add documents: {response.error_message}") # Run prompt auto tuning, while also adding new documents to the storage if provided. - def auto_prompt_tune(self, documents: List[Record]) -> Dict[str, str]: + def auto_prompt_tune(self, documents: List[Record] = None) -> Dict[str, str]: # TODO: Check if any files are already uploaded. If none and documents is also None, then raise an error. - self.add_documents(documents) + if documents: + self.add_documents(documents) # Build payload - data = {"action": "auto_prompt_tune", "data": ""} + data = {"action": "auto_prompt_tune", "data": []} # Run the indexing service response = self.run(data=data) if response.status == ResponseStatus.SUCCESS: @@ -93,7 +94,6 @@ def graph_indexing(self): data = {"action": "ingest", "data": []} response = self.run(data=data) if response.status == ResponseStatus.SUCCESS: - response.data = response.data return response raise Exception(f"Failed to run indexing: {response.error_message}") diff --git a/aixplain/modules/model/index_models/vector_index_model.py b/aixplain/modules/model/index_models/vector_index_model.py index c6759832..839acba0 100644 --- a/aixplain/modules/model/index_models/vector_index_model.py +++ b/aixplain/modules/model/index_models/vector_index_model.py @@ -1,7 +1,7 @@ from typing import Union, List, Text, Optional, Dict from enum import Enum from aixplain.enums import StorageType -from aiXplain.aixplain.modules.model.index_models.base_index_model import BaseIndexModel, Splitter +from aixplain.modules.model.index_models.base_index_model import BaseIndexModel, Splitter from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse from aixplain.enums import Function, Supplier, EmbeddingModel, ResponseStatus, IndexType diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index e5ab758e..57e7bf33 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -105,7 +105,6 @@ def run_index_model(index_model, retries): [ pytest.param(None, VectaraParams, id="VECTARA"), pytest.param(None, ZeroEntropyParams, id="ZERO_ENTROPY"), - pytest.param(EmbeddingModel.OPENAI_ADA002, GraphRAGParams, id="GRAPHRAG"), pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002"), pytest.param("6658d40729985c2cf72f42ec", AirParams, id="AIR - Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), @@ -144,7 +143,7 @@ def test_index_model_with_filter(embedding_model, supplier_params): from uuid import uuid4 from aixplain.modules.model.record import Record from aixplain.factories import IndexFactory - from aixplain.modules.model.index_models.vector_index_model import IndexFilter, IndexFilterOperator + from aixplain.modules.model.index_models import IndexFilter, IndexFilterOperator for index in IndexFactory.list()["results"]: index.delete() @@ -185,6 +184,43 @@ def test_index_model_with_filter(embedding_model, supplier_params): index_model.delete() +def test_knowledge_graph_index_mode_prompt_tune(): + from uuid import uuid4 + from aixplain.modules.model.record import Record + from aixplain.factories import IndexFactory + + for index in IndexFactory.list()["results"]: + index.delete() + + params = GraphRAGParams(name=str(uuid4()), description=str(uuid4()), llm="6646261c6eb563165658bbb1") + index_model = IndexFactory.create(params=params) + add_response = index_model.add_documents( + [Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})] + ) + response = index_model.auto_prompt_tune() + assert index_model.count() == 0 + assert str(response.status) == "SUCCESS" + index_model.delete() + + +def test_knowledge_graph_index_model(): + from aixplain.factories import IndexFactory + from aixplain.modules.model.record import Record + from uuid import uuid4 + + for index in IndexFactory.list()["results"]: + index.delete() + + params = GraphRAGParams(name=str(uuid4()), description=str(uuid4()), llm="6646261c6eb563165658bbb1") + index_model = IndexFactory.create(params=params) + index_model.upsert([Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})]) + response = index_model.search("Berlin") + assert str(response.status) == "SUCCESS" + assert "germany" in response.data.lower() + assert index_model.count() == 1 + index_model.delete() + + def test_llm_run_with_file(): """Testing LLM with local file input containing emoji""" @@ -305,7 +341,7 @@ def test_index_model_air_with_splitter(embedding_model, supplier_params): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record from uuid import uuid4 - from aixplain.modules.model.index_models.index_model import Splitter + from aixplain.modules.model.index_models import Splitter from aixplain.enums.splitting_options import SplittingOptions for index in IndexFactory.list()["results"]: diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 7b942afb..034c29f5 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -13,17 +13,11 @@ from aixplain.factories.index_factory.utils import AirParams from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse -from aixplain.modules.model.index_models.vector_index_model import ( - VectorIndexModel, - IndexFilter, - IndexFilterOperator, -) -from aixplain.modules.model.index_models.knowledge_graph_index_model import ( - KnowledgeGraphIndexModel, -) +from aixplain.modules.model.index_models import VectorIndexModel, IndexFilter, IndexFilterOperator, KnowledgeGraphIndexModel from aixplain.utils import config import logging import pytest +from aixplain.modules.model.index_models.base_index_model import BaseIndexModel logging.basicConfig( format="%(levelname)s • %(name)s • %(message)s", @@ -43,6 +37,7 @@ def _make_vec(): return VectorIndexModel( id=VEC_ID, name="vec-name", + version="airv2-dev-1-test", description="", function=Function.SEARCH, embedding_model=EmbeddingModel.OPENAI_ADA002, @@ -109,6 +104,7 @@ def _make_kg(): return KnowledgeGraphIndexModel( id=KG_ID, name="kg-name", + version="graphrag-dev-1-test", description="", function=Function.SEARCH, llm="gpt-4o", @@ -177,7 +173,7 @@ def test_kg_graph_indexing(): assert resp.status == ResponseStatus.SUCCESS -def test_kg_upsert_full_roundtrip(mocker): +def test_kg_upsert(mocker): mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) docs = [Record(value="round", value_type="text", id=7, uri="", attributes={})] @@ -241,7 +237,7 @@ def test_index_factory_create_failure(): def test_index_model_splitter(): - from aixplain.modules.model.index_models.index_model import Splitter + from aixplain.modules.model.index_models.base_index_model import Splitter splitter = Splitter(split=True, split_by="sentence", split_length=100, split_overlap=0) assert splitter.split == True From 4520cd6be82ecfa7c8ec3dca9226b857e79131f3 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Wed, 11 Jun 2025 13:46:58 +0000 Subject: [PATCH 09/10] add llm validation --- aixplain/factories/index_factory/utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py index a8169f17..bce5ba2b 100644 --- a/aixplain/factories/index_factory/utils.py +++ b/aixplain/factories/index_factory/utils.py @@ -27,7 +27,7 @@ class BaseIndexParamsWithEmbeddingModel(BaseIndexParams, ABC): embedding_model: Optional[Union[EmbeddingModel, str]] = EmbeddingModel.OPENAI_ADA002 embedding_size: Optional[int] = None - @field_validator('embedding_model') + @field_validator("embedding_model") def validate_embedding_model(cls, model_id) -> bool: model = ModelFactory.get(model_id) if model.function == Function.TEXT_EMBEDDING: @@ -35,7 +35,6 @@ def validate_embedding_model(cls, model_id) -> bool: else: raise ValueError("This is not an embedding model") - def to_dict(self): data = super().to_dict() data["model"] = data.pop("embedding_model") @@ -44,9 +43,6 @@ def to_dict(self): data["additional_params"] = {"embedding_size": data.pop("embedding_size")} return data - - - class VectaraParams(BaseIndexParams): _id: ClassVar[str] = IndexStores.VECTARA.get_model_id() @@ -80,7 +76,7 @@ class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): def id(self) -> str: return self._id - @field_validator('llm') + @field_validator("llm") def validate_llm(cls, llm) -> Optional[Text]: if llm is None: return None From 52e50d9a66f640db882b48ebaf71118b04efca17 Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Wed, 11 Jun 2025 13:54:59 +0000 Subject: [PATCH 10/10] add upsert abstract method --- aixplain/modules/model/index_models/base_index_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/aixplain/modules/model/index_models/base_index_model.py b/aixplain/modules/model/index_models/base_index_model.py index 53e57094..c817d32b 100644 --- a/aixplain/modules/model/index_models/base_index_model.py +++ b/aixplain/modules/model/index_models/base_index_model.py @@ -1,4 +1,5 @@ -from typing import Text, Optional, Union, Dict +from typing import Text, Optional, Union, Dict, List +from aixplain.modules.model.record import Record from aixplain.enums import EmbeddingModel, Function, Supplier from aixplain.modules.model import Model from aixplain.utils import config @@ -162,6 +163,9 @@ def count(self) -> float: return int(n_indexed) raise Exception(f"Failed to count documents: {response.error_message}") + def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: + raise NotImplementedError("Upsert is not implemented for this index model") + @staticmethod def parse_file(file_path: Text) -> ModelResponse: """