diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 725fdb90..4b93f7c4 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -20,3 +20,5 @@ from .asset_status import AssetStatus from .index_stores import IndexStores from .function_type import FunctionType +from .splitting_options import SplittingOptions +from .index_type import IndexType diff --git a/aixplain/enums/index_type.py b/aixplain/enums/index_type.py new file mode 100644 index 00000000..125578c2 --- /dev/null +++ b/aixplain/enums/index_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class IndexType(str, Enum): + VECTOR = "vector" + KNOWLEDGE_GRAPH = "knowledge_graph" diff --git a/aixplain/factories/index_factory/__init__.py b/aixplain/factories/index_factory/__init__.py index 189f9417..8da5d7f0 100644 --- a/aixplain/factories/index_factory/__init__.py +++ b/aixplain/factories/index_factory/__init__.py @@ -21,20 +21,15 @@ Index Factory Class """ -from aixplain.modules.model.index_model import IndexModel +from aixplain.modules.model.index_models import BaseIndexModel from aixplain.factories import ModelFactory from aixplain.enums import Function, ResponseStatus, SortBy, SortOrder, OwnershipType, Supplier, IndexStores, EmbeddingModel from typing import Text, Union, List, Tuple, Optional, TypeVar, Generic -from aixplain.factories.index_factory.utils import BaseIndexParams +from aixplain.factories.index_factory.utils import BaseIndexParams, AirParams T = TypeVar("T", bound=BaseIndexParams) -def validate_embedding_model(model_id) -> bool: - model = ModelFactory.get(model_id) - return model.function == Function.TEXT_EMBEDDING - - class IndexFactory(ModelFactory, Generic[T]): @classmethod def create( @@ -44,7 +39,7 @@ def create( embedding_model: Union[EmbeddingModel, str] = EmbeddingModel.OPENAI_ADA002, params: Optional[T] = None, **kwargs, - ) -> IndexModel: + ) -> BaseIndexModel: """Create a new index collection""" import warnings @@ -64,12 +59,8 @@ def create( assert ( name is not None and description is not None and embedding_model is not None ), "Index Factory Exception: name, description, and embedding_model must be provided when params is not" - if validate_embedding_model(embedding_model): - data = { - "data": name, - "description": description, - "model": embedding_model, - } + params = AirParams(name=name, description=description, embedding_model=embedding_model) + data = params.to_dict() model = cls.get(model_id) response = model.run(data=data) if response.status == ResponseStatus.SUCCESS: @@ -92,7 +83,7 @@ def list( sort_order: SortOrder = SortOrder.ASCENDING, page_number: int = 0, page_size: int = 20, - ) -> List[IndexModel]: + ) -> List[BaseIndexModel]: """List all indexes""" return super().list( function=Function.SEARCH, diff --git a/aixplain/factories/index_factory/utils.py b/aixplain/factories/index_factory/utils.py index b09df3c0..bce5ba2b 100644 --- a/aixplain/factories/index_factory/utils.py +++ b/aixplain/factories/index_factory/utils.py @@ -27,7 +27,7 @@ class BaseIndexParamsWithEmbeddingModel(BaseIndexParams, ABC): embedding_model: Optional[Union[EmbeddingModel, str]] = EmbeddingModel.OPENAI_ADA002 embedding_size: Optional[int] = None - @field_validator('embedding_model') + @field_validator("embedding_model") def validate_embedding_model(cls, model_id) -> bool: model = ModelFactory.get(model_id) if model.function == Function.TEXT_EMBEDDING: @@ -35,7 +35,6 @@ def validate_embedding_model(cls, model_id) -> bool: else: raise ValueError("This is not an embedding model") - def to_dict(self): data = super().to_dict() data["model"] = data.pop("embedding_model") @@ -44,9 +43,6 @@ def to_dict(self): data["additional_params"] = {"embedding_size": data.pop("embedding_size")} return data - - - class VectaraParams(BaseIndexParams): _id: ClassVar[str] = IndexStores.VECTARA.get_model_id() @@ -80,7 +76,12 @@ class GraphRAGParams(BaseIndexParamsWithEmbeddingModel): def id(self) -> str: return self._id - - - - + @field_validator("llm") + def validate_llm(cls, llm) -> Optional[Text]: + if llm is None: + return None + model = ModelFactory.get(llm) + if model.function == Function.TEXT_GENERATION: + return llm + else: + raise ValueError("This is not an LLM model") diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index ca3e1eec..ba2480ca 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -2,7 +2,7 @@ import logging from aixplain.modules.model import Model from aixplain.modules.model.llm_model import LLM -from aixplain.modules.model.index_model import IndexModel +from aixplain.modules.model.index_models import BaseIndexModel, VectorIndexModel, KnowledgeGraphIndexModel from aixplain.modules.model.integration import Integration from aixplain.modules.model.connection import ConnectionTool from aixplain.modules.model.utility_model import UtilityModel @@ -67,7 +67,15 @@ def create_model_from_response(response: Dict) -> Model: if len(f) > 0 and len(f[0].get("defaultValues", [])) > 0: temperature = float(f[0]["defaultValues"][0]["value"]) elif function == Function.SEARCH: - ModelClass = IndexModel + version = response.get("version", None) + if version and version.get("id", None) is not None and "-" in version["id"]: + collection_type = version["id"].split("-", 1)[0] + ModelClass = BaseIndexModel + if collection_type in VectorIndexModel.supported_indices: + ModelClass = VectorIndexModel + elif collection_type in KnowledgeGraphIndexModel.supported_indices: + ModelClass = KnowledgeGraphIndexModel + additional_kwargs["llm"] = next((item["code"] for item in attributes if item["name"] == "llm"), None) elif function_type == FunctionType.INTEGRATION: ModelClass = Integration elif function_type == FunctionType.CONNECTION: diff --git a/aixplain/modules/__init__.py b/aixplain/modules/__init__.py index f8a64650..3fc8e0b3 100644 --- a/aixplain/modules/__init__.py +++ b/aixplain/modules/__init__.py @@ -37,4 +37,4 @@ from .agent.tool import Tool from .team_agent import TeamAgent from .api_key import APIKey, APIKeyLimits, APIKeyUsageLimit -from .model.index_model import IndexModel +from .model.index_models import BaseIndexModel diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index f20d4e20..45b65c78 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -136,7 +136,7 @@ def to_dict(self) -> Dict: "additional_info": clean_additional_info, "input_params": self.input_params, "output_params": self.output_params, - "model_params": self.model_params.to_dict(), + "model_params": self.model_params.to_dict() if self.model_params else None, "function": self.function, "status": self.status, } diff --git a/aixplain/modules/model/index_models/__init__.py b/aixplain/modules/model/index_models/__init__.py new file mode 100644 index 00000000..46969de2 --- /dev/null +++ b/aixplain/modules/model/index_models/__init__.py @@ -0,0 +1,5 @@ +from .knowledge_graph_index_model import KnowledgeGraphIndexModel +from .vector_index_model import VectorIndexModel, IndexFilter, IndexFilterOperator +from .base_index_model import BaseIndexModel + +__all__ = ["KnowledgeGraphIndexModel", "VectorIndexModel", "BaseIndexModel", "IndexFilter", "IndexFilterOperator"] diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_models/base_index_model.py similarity index 56% rename from aixplain/modules/model/index_model.py rename to aixplain/modules/model/index_models/base_index_model.py index 04ed3df1..c817d32b 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_models/base_index_model.py @@ -1,41 +1,10 @@ -from aixplain.enums import EmbeddingModel, Function, Supplier, ResponseStatus, StorageType, FunctionType +from typing import Text, Optional, Union, Dict, List +from aixplain.modules.model.record import Record +from aixplain.enums import EmbeddingModel, Function, Supplier from aixplain.modules.model import Model from aixplain.utils import config from aixplain.modules.model.response import ModelResponse -from typing import Text, Optional, Union, Dict -from aixplain.modules.model.record import Record -from enum import Enum -from typing import List -from aixplain.enums.splitting_options import SplittingOptions - - -class IndexFilterOperator(Enum): - EQUALS = "==" - NOT_EQUALS = "!=" - CONTAINS = "in" - NOT_CONTAINS = "not in" - GREATER_THAN = ">" - LESS_THAN = "<" - GREATER_THAN_OR_EQUALS = ">=" - LESS_THAN_OR_EQUALS = "<=" - - -class IndexFilter: - field: str - value: str - operator: Union[IndexFilterOperator, str] - - def __init__(self, field: str, value: str, operator: Union[IndexFilterOperator, str]): - self.field = field - self.value = value - self.operator = operator - - def to_dict(self): - return { - "field": self.field, - "value": self.value, - "operator": self.operator.value if isinstance(self.operator, IndexFilterOperator) else self.operator, - } +from aixplain.enums import IndexType, SplittingOptions class Splitter: @@ -52,20 +21,20 @@ def __init__( self.split_overlap = split_overlap -class IndexModel(Model): +class BaseIndexModel(Model): def __init__( self, id: Text, name: Text, + version: Text, description: Text = "", api_key: Optional[Text] = None, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", - version: Optional[Text] = None, function: Optional[Function] = None, is_subscribed: bool = False, cost: Optional[Dict] = None, - embedding_model: Union[EmbeddingModel, str] = None, - function_type: Optional[FunctionType] = FunctionType.SEARCH, + embedding_model: Union[EmbeddingModel, Text] = None, + index_type: Optional[IndexType] = None, **additional_info, ) -> None: """Index Init @@ -80,7 +49,7 @@ def __init__( function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. - embedding_model (Union[EmbeddingModel, str], optional): embedding model. Defaults to None. + embedding_model (Union[EmbeddingModel, Text], optional): embedding model. Defaults to None. **additional_info: Any additional Model info to be saved """ assert function == Function.SEARCH, "Index only supports search function" @@ -94,12 +63,13 @@ def __init__( function=function, is_subscribed=is_subscribed, api_key=api_key, - function_type=function_type, **additional_info, ) self.url = config.MODELS_RUN_URL self.backend_url = config.BACKEND_URL self.embedding_model = embedding_model + self.index_type = index_type + self.embedding_size = None if embedding_model: try: from aixplain.factories import ModelFactory @@ -110,7 +80,6 @@ def __init__( import warnings warnings.warn(f"Failed to get embedding size for embedding model {embedding_model}: {e}") - self.embedding_size = None def to_dict(self) -> Dict: data = super().to_dict() @@ -119,87 +88,29 @@ def to_dict(self) -> Dict: data["collection_type"] = self.version.split("-", 1)[0] return data - def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: + def search(self, query: str, top_k: int = 10) -> ModelResponse: """Search for documents in the index Args: query (str): Query to be searched top_k (int, optional): Number of results to be returned. Defaults to 10. - filters (List[IndexFilter], optional): Filters to be applied. Defaults to []. Returns: ModelResponse: Response from the indexing service Example: - index_model.search("Hello") - - index_model.search("", filters=[IndexFilter(field="category", value="animate", operator=IndexFilterOperator.EQUALS)]) """ - from aixplain.factories import FileFactory - - uri, value_type = "", "text" - storage_type = FileFactory.check_storage_type(query) - if storage_type in [StorageType.FILE, StorageType.URL]: - uri = FileFactory.to_link(query) - query = "" - value_type = "image" data = { "action": "search", - "data": query or uri, - "dataType": value_type, - "filters": [filter.to_dict() for filter in filters], - "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, + "data": query, + "dataType": "text", + "filters": [], + "payload": {"top_k": top_k}, } return self.run(data=data) - def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: - """Upsert documents into the index - - Args: - documents (List[Record]): List of documents to be upserted - splitter (Splitter, optional): Splitter to be applied. Defaults to None. - - Returns: - ModelResponse: Response from the indexing service - - Examples: - index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) - index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) - Splitter in the above example is optional and can be used to split the documents into smaller chunks. - """ - # Validate documents - for doc in documents: - doc.validate() - # Convert documents to payloads - payloads = [doc.to_dict() for doc in documents] - # Build payload - data = { - "action": "ingest", - "data": payloads, - } - if splitter and splitter.split: - data["additional_params"] = { - "splitter": { - "split": splitter.split, - "split_by": splitter.split_by, - "split_length": splitter.split_length, - "split_overlap": splitter.split_overlap, - } - } - # Run the indexing service - response = self.run(data=data) - if response.status == ResponseStatus.SUCCESS: - response.data = payloads - return response - raise Exception(f"Failed to upsert documents: {response.error_message}") - - def count(self) -> float: - data = {"action": "count", "data": ""} - response = self.run(data=data) - if response.status == "SUCCESS": - return int(response.data) - raise Exception(f"Failed to count documents: {response.error_message}") - def get_record(self, record_id: Text) -> ModelResponse: """ Get a document from the index. @@ -243,3 +154,29 @@ def delete_record(self, record_id: Text) -> ModelResponse: if response.status == "SUCCESS": return response raise Exception(f"Failed to delete record: {response.error_message}") + + def count(self) -> float: + data = {"action": "count", "data": ""} + response = self.run(data=data) + if response.status == "SUCCESS": + n_indexed = response.data.get("data", response.data) + return int(n_indexed) + raise Exception(f"Failed to count documents: {response.error_message}") + + def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: + raise NotImplementedError("Upsert is not implemented for this index model") + + @staticmethod + def parse_file(file_path: Text) -> ModelResponse: + """ + Parse a file using the Docling model. + """ + try: + from aixplain.factories import ModelFactory + + docling_model_id = "677bee6c6eb56331f9192a91" + model = ModelFactory.get(docling_model_id) + response = model.run(file_path) + return response + except Exception as e: + raise Exception(f"Failed to parse file: {e}") diff --git a/aixplain/modules/model/index_models/knowledge_graph_index_model.py b/aixplain/modules/model/index_models/knowledge_graph_index_model.py new file mode 100644 index 00000000..d9d1d4bf --- /dev/null +++ b/aixplain/modules/model/index_models/knowledge_graph_index_model.py @@ -0,0 +1,114 @@ +from aixplain.modules.model.index_models.base_index_model import BaseIndexModel +from aixplain.enums import ResponseStatus +from typing import Text, Optional, Union, Dict, List, Any +from aixplain.modules.model.record import Record +from aixplain.modules.model.response import ModelResponse +from aixplain.enums import Function, Supplier, EmbeddingModel, IndexType + + +class KnowledgeGraphIndexModel(BaseIndexModel): + supported_indices = ["graphrag"] + + def __init__( + self, + id: Text, + name: Text, + version: Text, + description: Text = "", + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + embedding_model: Optional[EmbeddingModel] = None, + llm: Optional[Text] = None, + **additional_info, + ): + super().__init__( + id, + name, + version, + description, + api_key, + supplier, + function, + is_subscribed, + cost, + embedding_model, + IndexType.KNOWLEDGE_GRAPH, + **additional_info, + ) + self.llm = llm + + def to_dict(self) -> Dict[str, Any]: + data = super().to_dict() + data["llm"] = self.llm + return data + + def get_prompts(self) -> Dict[str, str]: + data = {"action": "get_prompts", "data": ""} + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = response.data + return response + raise Exception(f"Failed to get prompts: {response.error_message}") + + # Add documents to the storage. Later can run indexing process incrementally based on the new documents added. + def add_documents(self, documents: List[Record]): + # Validate documents + for doc in documents: + doc.validate() + # Convert documents to payloads + payloads = [doc.to_dict() for doc in documents] + # Build payload + data = {"action": "upload_documents", "data": payloads} + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = documents + return response + raise Exception(f"Failed to add documents: {response.error_message}") + + # Run prompt auto tuning, while also adding new documents to the storage if provided. + def auto_prompt_tune(self, documents: List[Record] = None) -> Dict[str, str]: + # TODO: Check if any files are already uploaded. If none and documents is also None, then raise an error. + if documents: + self.add_documents(documents) + # Build payload + data = {"action": "auto_prompt_tune", "data": []} + # Run the indexing service + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + return response.data + raise Exception(f"Failed to upsert documents: {response.error_message}") + + def manual_prompt_tune(self, prompts: Dict[str, str]) -> Dict[str, str]: + data = {"action": "manual_prompt_tune", "data": prompts} + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = prompts + return response + raise Exception(f"Failed to prompt tune: {response.error_message}") + + # Start the indexing process based on files already uploaded. + def graph_indexing(self): + data = {"action": "ingest", "data": []} + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + return response + raise Exception(f"Failed to run indexing: {response.error_message}") + + # Add current documents to the storage and start the indexing process. + def upsert(self, documents: List[Record]) -> ModelResponse: + """Upsert documents into the index + + Args: + documents (List[Record]): List of documents to be upserted + + Returns: + ModelResponse: Response from the indexing service + + Example: + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) + """ + self.add_documents(documents) + return self.graph_indexing() diff --git a/aixplain/modules/model/index_models/vector_index_model.py b/aixplain/modules/model/index_models/vector_index_model.py new file mode 100644 index 00000000..839acba0 --- /dev/null +++ b/aixplain/modules/model/index_models/vector_index_model.py @@ -0,0 +1,143 @@ +from typing import Union, List, Text, Optional, Dict +from enum import Enum +from aixplain.enums import StorageType +from aixplain.modules.model.index_models.base_index_model import BaseIndexModel, Splitter +from aixplain.modules.model.record import Record +from aixplain.modules.model.response import ModelResponse +from aixplain.enums import Function, Supplier, EmbeddingModel, ResponseStatus, IndexType + + +class IndexFilterOperator(Enum): + EQUALS = "==" + NOT_EQUALS = "!=" + CONTAINS = "in" + NOT_CONTAINS = "not in" + GREATER_THAN = ">" + LESS_THAN = "<" + GREATER_THAN_OR_EQUALS = ">=" + LESS_THAN_OR_EQUALS = "<=" + + +class IndexFilter: + field: str + value: str + operator: Union[IndexFilterOperator, str] + + def __init__(self, field: str, value: str, operator: Union[IndexFilterOperator, str]): + self.field = field + self.value = value + self.operator = operator + + def to_dict(self): + return { + "field": self.field, + "value": self.value, + "operator": self.operator.value if isinstance(self.operator, IndexFilterOperator) else self.operator, + } + + +class VectorIndexModel(BaseIndexModel): + supported_indices = ["airv2", "vectara", "zeroentropy"] + + def __init__( + self, + id: Text, + name: Text, + version: Text, + description: Text = "", + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + embedding_model: Optional[EmbeddingModel] = None, + **additional_info, + ): + super().__init__( + id, + name, + version, + description, + api_key, + supplier, + function, + is_subscribed, + cost, + embedding_model, + IndexType.VECTOR, + **additional_info, + ) + + def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: + """Search for documents in the index + + Args: + query (str): Query to be searched + top_k (int, optional): Number of results to be returned. Defaults to 10. + filters (List[IndexFilter], optional): Filters to be applied. Defaults to []. + + Returns: + ModelResponse: Response from the indexing service + + Example: + - index_model.search("Hello") + - index_model.search("", filters=[IndexFilter(field="category", value="animate", operator=IndexFilterOperator.EQUALS)]) + """ + from aixplain.factories import FileFactory + + uri, value_type = "", "text" + storage_type = FileFactory.check_storage_type(query) + if storage_type in [StorageType.FILE, StorageType.URL]: + uri = FileFactory.to_link(query) + query = "" + value_type = "image" + + data = { + "action": "search", + "data": query or uri, + "dataType": value_type, + "filters": [filter.to_dict() for filter in filters], + "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, + } + return self.run(data=data) + + def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: + """Upsert documents into the index + + Args: + documents (List[Record]): List of documents to be upserted + splitter (Splitter, optional): Splitter to be applied. Defaults to None. + + Returns: + ModelResponse: Response from the indexing service + + Example: + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) + index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) + Splitter in the above example is optional and can be used to split the documents into smaller chunks. + """ + # Validate documents + for doc in documents: + doc.validate() + # Convert documents to payloads + payloads = [doc.to_dict() for doc in documents] + # Build payload + data = { + "action": "ingest", + "data": payloads, + } + if splitter and splitter.split: + data["additional_params"] = { + "splitter": { + "split": splitter.split, + "split_by": splitter.split_by, + "split_length": splitter.split_length, + "split_overlap": splitter.split_overlap, + } + } + # Run the indexing service + response = self.run(data=data) + if response.status == ResponseStatus.SUCCESS: + response.data = payloads + return response + raise Exception(f"Failed to upsert documents: {response.error_message}") diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 71838034..57e7bf33 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -11,6 +11,7 @@ from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams import time import os +import json def pytest_generate_tests(metafunc): @@ -104,7 +105,6 @@ def run_index_model(index_model, retries): [ pytest.param(None, VectaraParams, id="VECTARA"), pytest.param(None, ZeroEntropyParams, id="ZERO_ENTROPY"), - pytest.param(EmbeddingModel.OPENAI_ADA002, GraphRAGParams, id="GRAPHRAG"), pytest.param(EmbeddingModel.OPENAI_ADA002, AirParams, id="AIR - OpenAI Ada 002"), pytest.param("6658d40729985c2cf72f42ec", AirParams, id="AIR - Snowflake Arctic Embed M Long"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), @@ -143,7 +143,7 @@ def test_index_model_with_filter(embedding_model, supplier_params): from uuid import uuid4 from aixplain.modules.model.record import Record from aixplain.factories import IndexFactory - from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + from aixplain.modules.model.index_models import IndexFilter, IndexFilterOperator for index in IndexFactory.list()["results"]: index.delete() @@ -184,6 +184,43 @@ def test_index_model_with_filter(embedding_model, supplier_params): index_model.delete() +def test_knowledge_graph_index_mode_prompt_tune(): + from uuid import uuid4 + from aixplain.modules.model.record import Record + from aixplain.factories import IndexFactory + + for index in IndexFactory.list()["results"]: + index.delete() + + params = GraphRAGParams(name=str(uuid4()), description=str(uuid4()), llm="6646261c6eb563165658bbb1") + index_model = IndexFactory.create(params=params) + add_response = index_model.add_documents( + [Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})] + ) + response = index_model.auto_prompt_tune() + assert index_model.count() == 0 + assert str(response.status) == "SUCCESS" + index_model.delete() + + +def test_knowledge_graph_index_model(): + from aixplain.factories import IndexFactory + from aixplain.modules.model.record import Record + from uuid import uuid4 + + for index in IndexFactory.list()["results"]: + index.delete() + + params = GraphRAGParams(name=str(uuid4()), description=str(uuid4()), llm="6646261c6eb563165658bbb1") + index_model = IndexFactory.create(params=params) + index_model.upsert([Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})]) + response = index_model.search("Berlin") + assert str(response.status) == "SUCCESS" + assert "germany" in response.data.lower() + assert index_model.count() == 1 + index_model.delete() + + def test_llm_run_with_file(): """Testing LLM with local file input containing emoji""" @@ -304,7 +341,7 @@ def test_index_model_air_with_splitter(embedding_model, supplier_params): from aixplain.factories import IndexFactory from aixplain.modules.model.record import Record from uuid import uuid4 - from aixplain.modules.model.index_model import Splitter + from aixplain.modules.model.index_models import Splitter from aixplain.enums.splitting_options import SplittingOptions for index in IndexFactory.list()["results"]: diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 8d5c3a74..034c29f5 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -1,224 +1,216 @@ +""" +Covers +────── +• VectorIndexModel +• KnowledgeGraphIndexModel +• Record utilities +• IndexFilter dataclass +• IndexFactory negative-path rules +""" import requests_mock from aixplain.enums import DataType, Function, ResponseStatus, StorageType, EmbeddingModel from aixplain.factories.index_factory import IndexFactory +from aixplain.factories.index_factory.utils import AirParams from aixplain.modules.model.record import Record from aixplain.modules.model.response import ModelResponse -from aixplain.modules.model.index_model import IndexModel +from aixplain.modules.model.index_models import VectorIndexModel, IndexFilter, IndexFilterOperator, KnowledgeGraphIndexModel from aixplain.utils import config import logging import pytest - -data = {"data": "Model Index", "description": "This is a dummy collection for testing."} -index_id = "id" -execute_url = f"{config.MODELS_RUN_URL}/{index_id}".replace("/api/v1/execute", "/api/v2/execute") +from aixplain.modules.model.index_models.base_index_model import BaseIndexModel + +logging.basicConfig( + format="%(levelname)s • %(name)s • %(message)s", + level=logging.DEBUG, +) + +# ──────────────────────────────────────────────────────────────────────────────── +# VECTOR-BASED INDEX TESTS +# ──────────────────────────────────────────────────────────────────────────────── + +VEC_ID = "vec-id" +VEC_EXEC_URL = f"{config.MODELS_RUN_URL}/{VEC_ID}".replace("/api/v1/execute", "/api/v2/execute") +logger = logging.getLogger("VectorIndexModelTests") + + +def _make_vec(): + return VectorIndexModel( + id=VEC_ID, + name="vec-name", + version="airv2-dev-1-test", + description="", + function=Function.SEARCH, + embedding_model=EmbeddingModel.OPENAI_ADA002, + ) -def test_text_search_success(mocker): +def test_vector_text_search(mocker): mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.TEXT) - mock_response = {"status": "SUCCESS"} - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - response = index_model.search("test query") + with requests_mock.Mocker() as m: + m.post(VEC_EXEC_URL, json={"status": "SUCCESS"}, status_code=200) + logger.debug("POST %s • payload={status:SUCCESS}", VEC_EXEC_URL) + resp = _make_vec().search("hello world") - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS + assert isinstance(resp, ModelResponse) + assert resp.status == ResponseStatus.SUCCESS -def test_image_search_success(mocker): +def test_vector_image_search(mocker): mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) - mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") - - mock_response = {"status": "SUCCESS"} - - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel( - id=index_id, - data=data, - name="name", - function=Function.SEARCH, - embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, - ) - response = index_model.search("test.jpg") + mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://ex.com/img.jpg") - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS + with requests_mock.Mocker() as m: + m.post(VEC_EXEC_URL, json={"status": "SUCCESS"}, status_code=200) + logger.debug("POST %s • payload={status:SUCCESS}", VEC_EXEC_URL) + resp = _make_vec().search("img.jpg") + assert resp.status == ResponseStatus.SUCCESS -def test_text_add_success(mocker): + +def test_vector_upsert_text(mocker): mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) - mock_response = {"status": "SUCCESS"} - mock_documents = [ - Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={}), - Record(value="Sample document content 2", value_type="text", id=1, uri="", attributes={}), + docs = [ + Record(value="doc1", value_type="text", id=1, uri="", attributes={}), + Record(value="doc2", value_type="text", id=2, uri="", attributes={}), ] - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) + with requests_mock.Mocker() as m: + m.post(VEC_EXEC_URL, json={"status": "SUCCESS"}, status_code=200) + resp = _make_vec().upsert(docs) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) + assert resp.status == ResponseStatus.SUCCESS + assert resp.data == [d.to_dict() for d in docs] - response = index_model.upsert(mock_documents) - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS +def test_vector_index_filter(): + flt = IndexFilter("category", "news", IndexFilterOperator.EQUALS) + assert flt.to_dict() == {"field": "category", "value": "news", "operator": "=="} -def test_image_add_success(mocker): - mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.FILE] * 4) - mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) - mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") - mock_response = {"status": "SUCCESS"} +# ──────────────────────────────────────────────────────────────────────────────── +# KNOWLEDGE-GRAPH INDEX TESTS +# ──────────────────────────────────────────────────────────────────────────────── - mock_documents = [ - Record(uri="https://example.com/test.jpg", value_type="image", id=0, attributes={}), - ] +KG_ID = "kg-id" +KG_EXEC_URL = f"{config.MODELS_RUN_URL}/{KG_ID}".replace("/api/v1/execute", "/api/v2/execute") - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - response = index_model.upsert(mock_documents) +kg_logger = logging.getLogger("KGIndexModelTests") - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS +def _make_kg(): + return KnowledgeGraphIndexModel( + id=KG_ID, + name="kg-name", + version="graphrag-dev-1-test", + description="", + function=Function.SEARCH, + llm="gpt-4o", + ) -def test_text_update_success(mocker): - mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) - mock_response = {"status": "SUCCESS"} - mock_documents = [ - Record(value="Updated document content 1", value_type="text", id=0, uri="", attributes={}), - Record(value="Updated document content 2", value_type="text", id=1, uri="", attributes={}), - ] +def _kg_mock(m, payload): + m.post(KG_EXEC_URL, json=payload, status_code=200) + kg_logger.debug("POST %s • payload=%s", KG_EXEC_URL, payload) - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - logging.debug(f"Requesting URL: {execute_url}") - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) +def test_kg_get_prompts(): + with requests_mock.Mocker() as m: + _kg_mock(m, {"status": "SUCCESS", "data": {"sys": "hi"}}) + resp = _make_kg().get_prompts() - response = index_model.upsert(mock_documents) + assert resp.status == ResponseStatus.SUCCESS + assert resp.data == {"sys": "hi"} - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS +def test_kg_add_documents(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) + + docs = [Record(value="kg-doc", value_type="text", id=0, uri="", attributes={})] -def test_count_success(): - mock_response = {"status": "SUCCESS", "data": 4} + with requests_mock.Mocker() as m: + _kg_mock(m, {"status": "SUCCESS"}) + resp = _make_kg().add_documents(docs) - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - logging.debug(f"Requesting URL: {execute_url}") + assert resp.status == ResponseStatus.SUCCESS + assert resp.data == docs - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - response = index_model.count() +def test_kg_manual_prompt_tune(): + prompts = {"sys": "You are helpful"} - assert isinstance(response, int) - assert response == 4 + with requests_mock.Mocker() as m: + _kg_mock(m, {"status": "SUCCESS"}) + resp = _make_kg().manual_prompt_tune(prompts) + assert resp.status == ResponseStatus.SUCCESS + assert resp.data == prompts -def test_get_document_success(): - mock_response = { - "status": "SUCCESS", - "data": {"value": "Sample document content 1", "value_type": "text", "id": 0, "uri": "", "attributes": {}}, - } - mock_documents = [Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={})] - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - index_model.upsert(mock_documents) - response = index_model.get_record(0) - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS +def test_kg_auto_prompt_tune(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) + docs = [Record(value="auto", value_type="text", id=1, uri="", attributes={})] -def test_delete_document_success(): - mock_response = {"status": "SUCCESS"} - mock_documents = [Record(value="Sample document content 1", value_type="text", id=0, uri="", attributes={})] + with requests_mock.Mocker() as m: + # upload_documents + _kg_mock(m, {"status": "SUCCESS"}) + # auto_prompt_tune + _kg_mock(m, {"status": "SUCCESS", "data": {"sys": "ok"}}) - with requests_mock.Mocker() as mock: - mock.post(execute_url, json=mock_response, status_code=200) - index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) - index_model.upsert(mock_documents) - response = index_model.delete_record("0") + resp = _make_kg().auto_prompt_tune(docs) - assert isinstance(response, ModelResponse) - assert response.status == ResponseStatus.SUCCESS + assert resp == {"sys": "ok"} -def test_validate_record_success(mocker): - mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) - mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) - mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") +def test_kg_graph_indexing(): + with requests_mock.Mocker() as m: + _kg_mock(m, {"status": "SUCCESS", "data": "started"}) + resp = _make_kg().graph_indexing() - record = Record(uri="test.jpg", value_type="image", id=0, attributes={}) - record.validate() - assert record.value_type == DataType.IMAGE - assert record.uri == "https://example.com/test.jpg" - assert record.value == "" + assert resp.status == ResponseStatus.SUCCESS -def test_validate_record_failure(mocker): - mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=False) - mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) - mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://example.com/test.jpg") - record = Record(uri="test.mov", value_type="video", id=0, attributes={}) - with pytest.raises(Exception) as e: - record.validate() - assert str(e.value) == "Index Upsert Error: Invalid value type" +def test_kg_upsert(mocker): + mocker.patch("aixplain.factories.FileFactory.check_storage_type", side_effect=[StorageType.TEXT] * 4) + docs = [Record(value="round", value_type="text", id=7, uri="", attributes={})] -def test_validate_record_failure_no_uri(mocker): - record = Record(value="test.jpg", value_type="image", id=0, uri="", attributes={}) - with pytest.raises(Exception) as e: - record.validate() - assert str(e.value) == "Index Upsert Error: URI is required for image records" + with requests_mock.Mocker() as m: + # upload_documents + _kg_mock(m, {"status": "SUCCESS"}) + # ingest + _kg_mock(m, {"status": "SUCCESS", "data": "done"}) + resp = _make_kg().upsert(docs) -def test_validate_record_failure_no_value(mocker): - record = Record(uri="test.jpg", value_type="text", id=0, attributes={}) - with pytest.raises(Exception) as e: - record.validate() - assert str(e.value) == "Index Upsert Error: Value is required for text records" + assert resp.status == ResponseStatus.SUCCESS -def test_record_to_dict(): - record = Record(value="test", value_type=DataType.TEXT, id=0, uri="", attributes={}) - record_dict = record.to_dict() - assert record_dict["dataType"] == "text" - assert record_dict["uri"] == "" - assert record_dict["data"] == "test" - assert record_dict["document_id"] == 0 - assert record_dict["attributes"] == {} +# ──────────────────────────────────────────────────────────────────────────────── +# RECORD UTILITY TESTS +# ──────────────────────────────────────────────────────────────────────────────── +def test_record_validate_and_dict(mocker): + mocker.patch("aixplain.modules.model.utils.is_supported_image_type", return_value=True) + mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.FILE) + mocker.patch("aixplain.factories.FileFactory.to_link", return_value="https://ex.com/img.jpg") - record = Record(value="test", value_type=DataType.IMAGE, id=0, uri="https://example.com/test.jpg", attributes={}) - record_dict = record.to_dict() - assert record_dict["dataType"] == "image" - assert record_dict["uri"] == "https://example.com/test.jpg" - assert record_dict["data"] == "test" - assert record_dict["document_id"] == 0 - assert record_dict["attributes"] == {} + rec = Record(uri="img.jpg", value_type="image", id=0, attributes={}) + rec.validate() + d = rec.to_dict() + assert d["dataType"] == DataType.IMAGE + assert d["uri"] == "https://ex.com/img.jpg" -def test_index_filter(): - from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator - filter = IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS) - assert filter.field == "category" - assert filter.value == "world" - assert filter.operator == IndexFilterOperator.EQUALS +# ──────────────────────────────────────────────────────────────────────────────── +# INDEX FACTORY TESTS +# ──────────────────────────────────────────────────────────────────────────────── def test_index_factory_create_failure(): - from aixplain.factories.index_factory.utils import AirParams - with pytest.raises(Exception) as e: IndexFactory.create( name="test", @@ -245,7 +237,7 @@ def test_index_factory_create_failure(): def test_index_model_splitter(): - from aixplain.modules.model.index_model import Splitter + from aixplain.modules.model.index_models.base_index_model import Splitter splitter = Splitter(split=True, split_by="sentence", split_length=100, split_overlap=0) assert splitter.split == True