diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 4cc892f5..18adb052 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,5 +1,9 @@ +import json +from pathlib import Path from typing import Any, Iterable, Optional, Sequence, Type, Union +import numpy as np + from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir, normalize @@ -180,6 +184,32 @@ sources=ModelSource(hf="jinaai/jina-clip-v1"), model_file="onnx/text_model.onnx", ), + DenseModelDescription( + model="jinaai/jina-embeddings-v3", + dim=1024, + description=( + "Text embeddings, Unimodal (text), Multilingual (89+ languages), 8192 input tokens truncation, " + "Task-specific LoRA adapters (retrieval, classification, text-matching, clustering), " + "Matryoshka dimensions: 32-1024, 2024 year." + ), + license="apache-2.0", + size_in_GB=2.29, + sources=ModelSource(hf="jinaai/jina-embeddings-v3"), + model_file="onnx/model.onnx", + additional_files=["onnx/model.onnx_data"], + tasks={ + "query_task": "retrieval.query", + "passage_task": "retrieval.passage", + "default_task": "text-matching", + "available_tasks": [ + "retrieval.query", + "retrieval.passage", + "separation", + "classification", + "text-matching", + ], + }, + ), ] @@ -255,6 +285,51 @@ def __init__( specific_model_path=self._specific_model_path, ) + # Load LoRA adaptations for models that support task-specific embeddings (e.g., Jina v3) + self.lora_adaptations: Optional[list[str]] = None + config_path = Path(self._model_dir) / "config.json" + if config_path.exists(): + with open(config_path, "r") as f: + config = json.load(f) + lora_adaptations = config.get("lora_adaptations") + + # Validate lora_adaptations if present or required + if lora_adaptations is not None: + # Validate it's a list + if not isinstance(lora_adaptations, list): + raise ValueError( + f"Invalid config for model '{model_name}': " + f"'lora_adaptations' must be a list, got {type(lora_adaptations).__name__}" + ) + + # Validate it's non-empty + if len(lora_adaptations) == 0: + raise ValueError( + f"Invalid config for model '{model_name}': " + f"'lora_adaptations' must be a non-empty list" + ) + + # Validate each item is a string + for idx, item in enumerate(lora_adaptations): + if not isinstance(item, str): + raise ValueError( + f"Invalid config for model '{model_name}': " + f"'lora_adaptations[{idx}]' must be a string, got {type(item).__name__}" + ) + + self.lora_adaptations = lora_adaptations + + # Check if model requires LoRA but config is missing or invalid + elif self.model_description.tasks and any( + key in self.model_description.tasks + for key in ["query_task", "passage_task", "available_tasks"] + ): + raise ValueError( + f"Model '{model_name}' requires task-specific LoRA adapters, " + f"but 'lora_adaptations' is missing from config.json. " + f"Expected a non-empty list of task names (e.g., ['retrieval.query', 'text-matching'])." + ) + if not self.lazy_load: self.load_onnx_model() @@ -303,7 +378,24 @@ def _preprocess_onnx_input( ) -> dict[str, NumpyArray]: """ Preprocess the onnx input. + Adds task_id for models with LoRA adapters (e.g., Jina v3). """ + # Handle task-specific embeddings for models with LoRA adapters + if self.lora_adaptations: + task_type = kwargs.get("task_type") + + # If no task specified, use default from model description or text-matching + if not task_type: + if self.model_description.tasks and "default_task" in self.model_description.tasks: + task_type = self.model_description.tasks["default_task"] + elif "text-matching" in self.lora_adaptations: + task_type = "text-matching" + else: + task_type = self.lora_adaptations[0] + + if task_type in self.lora_adaptations: + task_id = np.array(self.lora_adaptations.index(task_type), dtype=np.int64) + onnx_input["task_id"] = task_id return onnx_input def _post_process_onnx_output( @@ -329,6 +421,46 @@ def load_onnx_model(self) -> None: device_id=self.device_id, ) + def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: + """ + Embeds queries with task-specific handling for models that support it. + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + **kwargs: Additional keyword arguments. + + Returns: + Iterable[NumpyArray]: The embeddings. + """ + # Use task-specific embedding for models with LoRA adapters + if self.model_description.tasks and "query_task" in self.model_description.tasks: + kwargs["task_type"] = self.model_description.tasks["query_task"] + + if isinstance(query, str): + yield from self.embed([query], **kwargs) + else: + yield from self.embed(query, **kwargs) + + def passage_embed(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: + """ + Embeds passages with task-specific handling for models that support it. + + Args: + texts (Union[str, Iterable[str]]): The text(s) to embed. + **kwargs: Additional keyword arguments. + + Returns: + Iterable[NumpyArray]: The embeddings. + """ + # Use task-specific embedding for models with LoRA adapters + if self.model_description.tasks and "passage_task" in self.model_description.tasks: + kwargs["task_type"] = self.model_description.tasks["passage_task"] + + if isinstance(texts, str): + yield from self.embed([texts], **kwargs) + else: + yield from self.embed(texts, **kwargs) + class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]): def init_embedding( diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 6b25d900..e007621a 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -67,6 +67,7 @@ "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), "thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]), "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), + "jinaai/jina-embeddings-v3": np.array([0.07257809, -0.08073004, 0.09241360, -0.01755937, 0.06534681]), } MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] @@ -175,3 +176,64 @@ def test_embedding_size() -> None: if is_ci: delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize("model_name", MULTI_TASK_MODELS) +def test_multi_task_embedding(model_name: str) -> None: + """Test models that support task-specific embeddings (query vs passage).""" + is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + + # Skip in CI unless manual + if is_ci and not is_manual: + pytest.skip("Skipping multi-task model tests in CI (large models)") + + model_desc = None + for desc in TextEmbedding._list_supported_models(): + if desc.model == model_name: + model_desc = desc + break + + assert model_desc is not None, f"Model {model_name} not found in supported models" + + dim = model_desc.dim + model = TextEmbedding(model_name=model_name) + + # Test query embedding + queries = ["What is the capital of France?", "How does photosynthesis work?"] + query_embeddings = list(model.query_embed(queries)) + query_embeddings = np.stack(query_embeddings, axis=0) + assert query_embeddings.shape == (2, dim), f"Query embeddings shape mismatch for {model_name}" + + # Test passage embedding + passages = ["Paris is the capital of France.", "Photosynthesis is a process used by plants."] + passage_embeddings = list(model.passage_embed(passages)) + passage_embeddings = np.stack(passage_embeddings, axis=0) + assert passage_embeddings.shape == (2, dim), f"Passage embeddings shape mismatch for {model_name}" + + # Test regular embed (should work without task specification) + docs = ["hello world", "flag embedding"] + embeddings = list(model.embed(docs)) + embeddings = np.stack(embeddings, axis=0) + assert embeddings.shape == (2, dim), f"Regular embeddings shape mismatch for {model_name}" + + # Verify that query and passage embeddings are different (due to different LoRA adapters) + # Using the same text should produce different embeddings for query vs passage + test_text = "This is a test sentence" + query_emb = np.array(list(model.query_embed([test_text]))) + passage_emb = np.array(list(model.passage_embed([test_text]))) + + # They should not be identical (different task adapters) + assert not np.allclose(query_emb, passage_emb, atol=1e-6), \ + f"Query and passage embeddings should differ for {model_name}" + + # Optional: Check canonical vectors if available + if model_name in CANONICAL_VECTOR_VALUES: + canonical_vector = CANONICAL_VECTOR_VALUES[model_name] + # Check against regular embeddings[0] which is "hello world" + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), f"Canonical vector mismatch for {model_name}" + + if is_ci: + delete_model_cache(model.model._model_dir)