Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import json
from pathlib import Path
from typing import Any, Iterable, Optional, Sequence, Type, Union

import numpy as np

from fastembed.common.types import NumpyArray, OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir, normalize
Expand Down Expand Up @@ -180,6 +184,32 @@
sources=ModelSource(hf="jinaai/jina-clip-v1"),
model_file="onnx/text_model.onnx",
),
DenseModelDescription(
model="jinaai/jina-embeddings-v3",
dim=1024,
description=(
"Text embeddings, Unimodal (text), Multilingual (89+ languages), 8192 input tokens truncation, "
"Task-specific LoRA adapters (retrieval, classification, text-matching, clustering), "
"Matryoshka dimensions: 32-1024, 2024 year."
),
license="apache-2.0",
size_in_GB=2.29,
sources=ModelSource(hf="jinaai/jina-embeddings-v3"),
model_file="onnx/model.onnx",
additional_files=["onnx/model.onnx_data"],
tasks={
"query_task": "retrieval.query",
"passage_task": "retrieval.passage",
"default_task": "text-matching",
"available_tasks": [
"retrieval.query",
"retrieval.passage",
"separation",
"classification",
"text-matching",
],
},
),
]


Expand Down Expand Up @@ -255,6 +285,51 @@ def __init__(
specific_model_path=self._specific_model_path,
)

# Load LoRA adaptations for models that support task-specific embeddings (e.g., Jina v3)
self.lora_adaptations: Optional[list[str]] = None
config_path = Path(self._model_dir) / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
lora_adaptations = config.get("lora_adaptations")

# Validate lora_adaptations if present or required
if lora_adaptations is not None:
# Validate it's a list
if not isinstance(lora_adaptations, list):
raise ValueError(
f"Invalid config for model '{model_name}': "
f"'lora_adaptations' must be a list, got {type(lora_adaptations).__name__}"
)

# Validate it's non-empty
if len(lora_adaptations) == 0:
raise ValueError(
f"Invalid config for model '{model_name}': "
f"'lora_adaptations' must be a non-empty list"
)

# Validate each item is a string
for idx, item in enumerate(lora_adaptations):
if not isinstance(item, str):
raise ValueError(
f"Invalid config for model '{model_name}': "
f"'lora_adaptations[{idx}]' must be a string, got {type(item).__name__}"
)

self.lora_adaptations = lora_adaptations

# Check if model requires LoRA but config is missing or invalid
elif self.model_description.tasks and any(
key in self.model_description.tasks
for key in ["query_task", "passage_task", "available_tasks"]
):
raise ValueError(
f"Model '{model_name}' requires task-specific LoRA adapters, "
f"but 'lora_adaptations' is missing from config.json. "
f"Expected a non-empty list of task names (e.g., ['retrieval.query', 'text-matching'])."
)

if not self.lazy_load:
self.load_onnx_model()

Expand Down Expand Up @@ -303,7 +378,24 @@ def _preprocess_onnx_input(
) -> dict[str, NumpyArray]:
"""
Preprocess the onnx input.
Adds task_id for models with LoRA adapters (e.g., Jina v3).
"""
# Handle task-specific embeddings for models with LoRA adapters
if self.lora_adaptations:
task_type = kwargs.get("task_type")

# If no task specified, use default from model description or text-matching
if not task_type:
if self.model_description.tasks and "default_task" in self.model_description.tasks:
task_type = self.model_description.tasks["default_task"]
elif "text-matching" in self.lora_adaptations:
task_type = "text-matching"
else:
task_type = self.lora_adaptations[0]

if task_type in self.lora_adaptations:
task_id = np.array(self.lora_adaptations.index(task_type), dtype=np.int64)
onnx_input["task_id"] = task_id
return onnx_input

def _post_process_onnx_output(
Expand All @@ -329,6 +421,46 @@ def load_onnx_model(self) -> None:
device_id=self.device_id,
)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
"""
Embeds queries with task-specific handling for models that support it.

Args:
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
**kwargs: Additional keyword arguments.

Returns:
Iterable[NumpyArray]: The embeddings.
"""
# Use task-specific embedding for models with LoRA adapters
if self.model_description.tasks and "query_task" in self.model_description.tasks:
kwargs["task_type"] = self.model_description.tasks["query_task"]

if isinstance(query, str):
yield from self.embed([query], **kwargs)
else:
yield from self.embed(query, **kwargs)

def passage_embed(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
"""
Embeds passages with task-specific handling for models that support it.

Args:
texts (Union[str, Iterable[str]]): The text(s) to embed.
**kwargs: Additional keyword arguments.

Returns:
Iterable[NumpyArray]: The embeddings.
"""
# Use task-specific embedding for models with LoRA adapters
if self.model_description.tasks and "passage_task" in self.model_description.tasks:
kwargs["task_type"] = self.model_description.tasks["passage_task"]

if isinstance(texts, str):
yield from self.embed([texts], **kwargs)
else:
yield from self.embed(texts, **kwargs)


class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
def init_embedding(
Expand Down
62 changes: 62 additions & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
"jinaai/jina-embeddings-v3": np.array([0.07257809, -0.08073004, 0.09241360, -0.01755937, 0.06534681]),
}

MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"]
Expand Down Expand Up @@ -175,3 +176,64 @@ def test_embedding_size() -> None:

if is_ci:
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize("model_name", MULTI_TASK_MODELS)
def test_multi_task_embedding(model_name: str) -> None:
"""Test models that support task-specific embeddings (query vs passage)."""
is_ci = os.getenv("CI")
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"

# Skip in CI unless manual
if is_ci and not is_manual:
pytest.skip("Skipping multi-task model tests in CI (large models)")

model_desc = None
for desc in TextEmbedding._list_supported_models():
if desc.model == model_name:
model_desc = desc
break

assert model_desc is not None, f"Model {model_name} not found in supported models"

dim = model_desc.dim
model = TextEmbedding(model_name=model_name)

# Test query embedding
queries = ["What is the capital of France?", "How does photosynthesis work?"]
query_embeddings = list(model.query_embed(queries))
query_embeddings = np.stack(query_embeddings, axis=0)
assert query_embeddings.shape == (2, dim), f"Query embeddings shape mismatch for {model_name}"

# Test passage embedding
passages = ["Paris is the capital of France.", "Photosynthesis is a process used by plants."]
passage_embeddings = list(model.passage_embed(passages))
passage_embeddings = np.stack(passage_embeddings, axis=0)
assert passage_embeddings.shape == (2, dim), f"Passage embeddings shape mismatch for {model_name}"

# Test regular embed (should work without task specification)
docs = ["hello world", "flag embedding"]
embeddings = list(model.embed(docs))
embeddings = np.stack(embeddings, axis=0)
assert embeddings.shape == (2, dim), f"Regular embeddings shape mismatch for {model_name}"

# Verify that query and passage embeddings are different (due to different LoRA adapters)
# Using the same text should produce different embeddings for query vs passage
test_text = "This is a test sentence"
query_emb = np.array(list(model.query_embed([test_text])))
passage_emb = np.array(list(model.passage_embed([test_text])))

# They should not be identical (different task adapters)
assert not np.allclose(query_emb, passage_emb, atol=1e-6), \
f"Query and passage embeddings should differ for {model_name}"

# Optional: Check canonical vectors if available
if model_name in CANONICAL_VECTOR_VALUES:
canonical_vector = CANONICAL_VECTOR_VALUES[model_name]
# Check against regular embeddings[0] which is "hello world"
assert np.allclose(
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), f"Canonical vector mismatch for {model_name}"

if is_ci:
delete_model_cache(model.model._model_dir)