From c850878da937349924ffcb94968f025455db4015 Mon Sep 17 00:00:00 2001 From: Aaron Spring Date: Mon, 20 Oct 2025 20:44:59 +0200 Subject: [PATCH 1/3] feat: Add Jina Embeddings v3 with task-specific LoRA support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for jinaai/jina-embeddings-v3, a multilingual embedding model with 1024 dimensions supporting 89+ languages and task-specific LoRA adapters. Features: - Task-specific embeddings via LoRA adapters (retrieval.query, retrieval.passage, classification, text-matching, separation) - Automatic task_id handling for ONNX inference - Default to text-matching task for general purpose use - query_embed() and passage_embed() methods for retrieval tasks - Matryoshka dimensions support (32-1024) - 8,192 token context window Model specs: - 570M parameters - 2.29 GB ONNX model - Apache 2.0 license Implementation: - Added model configuration with additional_files for model.onnx_data - Load lora_adaptations from config.json - Preprocess ONNX input to add task_id parameter - Override query_embed/passage_embed for automatic task selection - Added comprehensive multi-task test with canonical vectors Following the pattern from PR #561 but using task_id instead of text prefixes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- fastembed/text/onnx_embedding.py | 83 ++++++++++++++++++++++++++++++ tests/test_text_onnx_embeddings.py | 62 ++++++++++++++++++++++ 2 files changed, 145 insertions(+) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 4cc892f5..dddc3a3a 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,5 +1,9 @@ +import json +from pathlib import Path from typing import Any, Iterable, Optional, Sequence, Type, Union +import numpy as np + from fastembed.common.types import NumpyArray, OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir, normalize @@ -180,6 +184,24 @@ sources=ModelSource(hf="jinaai/jina-clip-v1"), model_file="onnx/text_model.onnx", ), + DenseModelDescription( + model="jinaai/jina-embeddings-v3", + dim=1024, + description=( + "Text embeddings, Unimodal (text), Multilingual (89+ languages), 8192 input tokens truncation, " + "Task-specific LoRA adapters (retrieval, classification, text-matching, clustering), " + "Matryoshka dimensions: 32-1024, 2024 year." + ), + license="apache-2.0", + size_in_GB=2.29, + sources=ModelSource(hf="jinaai/jina-embeddings-v3"), + model_file="onnx/model.onnx", + additional_files=["onnx/model.onnx_data"], + tasks={ + "query_task": "retrieval.query", + "passage_task": "retrieval.passage", + }, + ), ] @@ -255,6 +277,14 @@ def __init__( specific_model_path=self._specific_model_path, ) + # Load LoRA adaptations for models that support task-specific embeddings (e.g., Jina v3) + self.lora_adaptations: Optional[list[str]] = None + config_path = Path(self._model_dir) / "config.json" + if config_path.exists(): + with open(config_path, "r") as f: + config = json.load(f) + self.lora_adaptations = config.get("lora_adaptations") + if not self.lazy_load: self.load_onnx_model() @@ -303,7 +333,20 @@ def _preprocess_onnx_input( ) -> dict[str, NumpyArray]: """ Preprocess the onnx input. + Adds task_id for models with LoRA adapters (e.g., Jina v3). """ + # Handle task-specific embeddings for models with LoRA adapters + if self.lora_adaptations: + task_type = kwargs.get("task_type") + + # If no task specified, use default (text-matching for general purpose) + if not task_type: + # Default to text-matching if available, otherwise first task + task_type = "text-matching" if "text-matching" in self.lora_adaptations else self.lora_adaptations[0] + + if task_type in self.lora_adaptations: + task_id = np.array(self.lora_adaptations.index(task_type), dtype=np.int64) + onnx_input["task_id"] = task_id return onnx_input def _post_process_onnx_output( @@ -329,6 +372,46 @@ def load_onnx_model(self) -> None: device_id=self.device_id, ) + def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: + """ + Embeds queries with task-specific handling for models that support it. + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + **kwargs: Additional keyword arguments. + + Returns: + Iterable[NumpyArray]: The embeddings. + """ + # Use task-specific embedding for models with LoRA adapters + if self.model_description.tasks and "query_task" in self.model_description.tasks: + kwargs["task_type"] = self.model_description.tasks["query_task"] + + if isinstance(query, str): + yield from self.embed([query], **kwargs) + else: + yield from self.embed(query, **kwargs) + + def passage_embed(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: + """ + Embeds passages with task-specific handling for models that support it. + + Args: + texts (Union[str, Iterable[str]]): The text(s) to embed. + **kwargs: Additional keyword arguments. + + Returns: + Iterable[NumpyArray]: The embeddings. + """ + # Use task-specific embedding for models with LoRA adapters + if self.model_description.tasks and "passage_task" in self.model_description.tasks: + kwargs["task_type"] = self.model_description.tasks["passage_task"] + + if isinstance(texts, str): + yield from self.embed([texts], **kwargs) + else: + yield from self.embed(texts, **kwargs) + class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]): def init_embedding( diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 6b25d900..e007621a 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -67,6 +67,7 @@ "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), "thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]), "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), + "jinaai/jina-embeddings-v3": np.array([0.07257809, -0.08073004, 0.09241360, -0.01755937, 0.06534681]), } MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] @@ -175,3 +176,64 @@ def test_embedding_size() -> None: if is_ci: delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize("model_name", MULTI_TASK_MODELS) +def test_multi_task_embedding(model_name: str) -> None: + """Test models that support task-specific embeddings (query vs passage).""" + is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + + # Skip in CI unless manual + if is_ci and not is_manual: + pytest.skip("Skipping multi-task model tests in CI (large models)") + + model_desc = None + for desc in TextEmbedding._list_supported_models(): + if desc.model == model_name: + model_desc = desc + break + + assert model_desc is not None, f"Model {model_name} not found in supported models" + + dim = model_desc.dim + model = TextEmbedding(model_name=model_name) + + # Test query embedding + queries = ["What is the capital of France?", "How does photosynthesis work?"] + query_embeddings = list(model.query_embed(queries)) + query_embeddings = np.stack(query_embeddings, axis=0) + assert query_embeddings.shape == (2, dim), f"Query embeddings shape mismatch for {model_name}" + + # Test passage embedding + passages = ["Paris is the capital of France.", "Photosynthesis is a process used by plants."] + passage_embeddings = list(model.passage_embed(passages)) + passage_embeddings = np.stack(passage_embeddings, axis=0) + assert passage_embeddings.shape == (2, dim), f"Passage embeddings shape mismatch for {model_name}" + + # Test regular embed (should work without task specification) + docs = ["hello world", "flag embedding"] + embeddings = list(model.embed(docs)) + embeddings = np.stack(embeddings, axis=0) + assert embeddings.shape == (2, dim), f"Regular embeddings shape mismatch for {model_name}" + + # Verify that query and passage embeddings are different (due to different LoRA adapters) + # Using the same text should produce different embeddings for query vs passage + test_text = "This is a test sentence" + query_emb = np.array(list(model.query_embed([test_text]))) + passage_emb = np.array(list(model.passage_embed([test_text]))) + + # They should not be identical (different task adapters) + assert not np.allclose(query_emb, passage_emb, atol=1e-6), \ + f"Query and passage embeddings should differ for {model_name}" + + # Optional: Check canonical vectors if available + if model_name in CANONICAL_VECTOR_VALUES: + canonical_vector = CANONICAL_VECTOR_VALUES[model_name] + # Check against regular embeddings[0] which is "hello world" + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), f"Canonical vector mismatch for {model_name}" + + if is_ci: + delete_model_cache(model.model._model_dir) From 56863e656ece9f4d0cfd073b91fe331c0ee746ba Mon Sep 17 00:00:00 2001 From: Aaron Spring Date: Mon, 20 Oct 2025 20:49:58 +0200 Subject: [PATCH 2/3] feat: Add comprehensive task metadata to Jina v3 model description Enhance the Jina v3 model configuration to expose all available LoRA tasks: - Add 'available_tasks' list with all 5 LoRA adapters - Add 'default_task' for explicit default behavior - Update _preprocess_onnx_input to use default_task from model description - Maintain backward compatibility with existing task selection logic This makes the model's capabilities more discoverable and allows users to see all available task types via list_supported_models(). Available tasks: - retrieval.query (for search queries) - retrieval.passage (for documents/passages) - separation (for clustering) - classification (for text classification) - text-matching (for semantic similarity, default) Co-Authored-By: Claude --- fastembed/text/onnx_embedding.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index dddc3a3a..19225035 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -200,6 +200,14 @@ tasks={ "query_task": "retrieval.query", "passage_task": "retrieval.passage", + "default_task": "text-matching", + "available_tasks": [ + "retrieval.query", + "retrieval.passage", + "separation", + "classification", + "text-matching", + ], }, ), ] @@ -339,10 +347,14 @@ def _preprocess_onnx_input( if self.lora_adaptations: task_type = kwargs.get("task_type") - # If no task specified, use default (text-matching for general purpose) + # If no task specified, use default from model description or text-matching if not task_type: - # Default to text-matching if available, otherwise first task - task_type = "text-matching" if "text-matching" in self.lora_adaptations else self.lora_adaptations[0] + if self.model_description.tasks and "default_task" in self.model_description.tasks: + task_type = self.model_description.tasks["default_task"] + elif "text-matching" in self.lora_adaptations: + task_type = "text-matching" + else: + task_type = self.lora_adaptations[0] if task_type in self.lora_adaptations: task_id = np.array(self.lora_adaptations.index(task_type), dtype=np.int64) From 4887b36f40cc0ce2dfb67361d64a0cef6a49ab35 Mon Sep 17 00:00:00 2001 From: Aaron Spring Date: Mon, 20 Oct 2025 20:54:42 +0200 Subject: [PATCH 3/3] fix: Add comprehensive validation for lora_adaptations config Add robust validation for lora_adaptations loaded from config.json to fail fast with clear error messages: Validation checks: - Verify lora_adaptations is a list (not string, dict, etc.) - Ensure list is non-empty - Validate each item is a string - Raise ValueError if model requires LoRA but config is missing/invalid Benefits: - Fail fast with descriptive errors instead of cryptic failures later - Clear error messages guide users to fix config issues - Protects against malformed config files - Validates contract between model description and config.json Error examples: - "'lora_adaptations' must be a list, got str" - "'lora_adaptations' must be a non-empty list" - "'lora_adaptations[1]' must be a string, got int" - "Model requires task-specific LoRA adapters, but 'lora_adaptations' is missing" Addresses CodeRabbit review feedback on PR #563. Co-Authored-By: Claude --- fastembed/text/onnx_embedding.py | 39 +++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 19225035..18adb052 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -291,7 +291,44 @@ def __init__( if config_path.exists(): with open(config_path, "r") as f: config = json.load(f) - self.lora_adaptations = config.get("lora_adaptations") + lora_adaptations = config.get("lora_adaptations") + + # Validate lora_adaptations if present or required + if lora_adaptations is not None: + # Validate it's a list + if not isinstance(lora_adaptations, list): + raise ValueError( + f"Invalid config for model '{model_name}': " + f"'lora_adaptations' must be a list, got {type(lora_adaptations).__name__}" + ) + + # Validate it's non-empty + if len(lora_adaptations) == 0: + raise ValueError( + f"Invalid config for model '{model_name}': " + f"'lora_adaptations' must be a non-empty list" + ) + + # Validate each item is a string + for idx, item in enumerate(lora_adaptations): + if not isinstance(item, str): + raise ValueError( + f"Invalid config for model '{model_name}': " + f"'lora_adaptations[{idx}]' must be a string, got {type(item).__name__}" + ) + + self.lora_adaptations = lora_adaptations + + # Check if model requires LoRA but config is missing or invalid + elif self.model_description.tasks and any( + key in self.model_description.tasks + for key in ["query_task", "passage_task", "available_tasks"] + ): + raise ValueError( + f"Model '{model_name}' requires task-specific LoRA adapters, " + f"but 'lora_adaptations' is missing from config.json. " + f"Expected a non-empty list of task names (e.g., ['retrieval.query', 'text-matching'])." + ) if not self.lazy_load: self.load_onnx_model()