diff --git a/fastembed/text/builtin_sentence_embedding.py b/fastembed/text/builtin_sentence_embedding.py new file mode 100644 index 00000000..9a4358df --- /dev/null +++ b/fastembed/text/builtin_sentence_embedding.py @@ -0,0 +1,69 @@ +from typing import Any, Iterable, Type + + +from fastembed.common.types import NumpyArray +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker +from fastembed.common.model_description import DenseModelDescription, ModelSource + + +supported_builtin_sentence_embedding_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="google/embeddinggemma-300m", + dim=768, + description=( + "Text embeddings, Unimodal (text), multilingual, 2048 input tokens truncation, " + "Prefixes for queries/documents: `task: search result | query: {content}` for query, " + "`title: {title | 'none'} | text: {content}` for documents, 2025 year." + ), + license="apache-2.0", + size_in_GB=1.24, + sources=ModelSource( + hf="onnx-community/embeddinggemma-300m-ONNX", + ), + model_file="onnx/model.onnx", + additional_files=["onnx/model.onnx_data"], + ), +] + + +class BuiltinSentenceEmbedding(OnnxTextEmbedding): + """Builtin Sentence Embedding uses built-in pooling and normalization of underlying onnx models""" + + @classmethod + def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: + return BuiltinSentenceEmbeddingWorker + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + """Lists the supported models. + + Returns: + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. + """ + return supported_builtin_sentence_embedding_models + + def _post_process_onnx_output( + self, output: OnnxOutputContext, **kwargs: Any + ) -> Iterable[NumpyArray]: + return output.model_output + + def _run_model( + self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None + ) -> NumpyArray: + return self.model.run(onnx_output_names, onnx_input)[1] # type: ignore[union-attr] + + +class BuiltinSentenceEmbeddingWorker(OnnxTextEmbeddingWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxTextEmbedding: + return BuiltinSentenceEmbedding( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index c8001a91..10a4aa17 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -92,14 +92,21 @@ def onnx_embed( [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 ) onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) + model_output = self._run_model( + onnx_input=onnx_input, onnx_output_names=self.ONNX_OUTPUT_NAMES + ) - model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] return OnnxOutputContext( - model_output=model_output[0], + model_output=model_output, attention_mask=onnx_input.get("attention_mask", attention_mask), input_ids=onnx_input.get("input_ids", input_ids), ) + def _run_model( + self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None + ) -> NumpyArray: + return self.model.run(onnx_output_names, onnx_input)[0] # type: ignore[union-attr] + def _embed_documents( self, model_name: str, diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index a4ae48cc..5a37e378 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -8,6 +8,7 @@ from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.pooled_embedding import PooledEmbedding from fastembed.text.multitask_embedding import JinaEmbeddingV3 +from fastembed.text.builtin_sentence_embedding import BuiltinSentenceEmbedding from fastembed.text.onnx_embedding import OnnxTextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType @@ -20,6 +21,7 @@ class TextEmbedding(TextEmbeddingBase): PooledNormalizedEmbedding, PooledEmbedding, JinaEmbeddingV3, + BuiltinSentenceEmbedding, CustomTextEmbedding, ] diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e919faf9..cdd22d79 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -68,6 +68,22 @@ "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), "thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]), "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), + "google/embeddinggemma-300m": np.array( + [-0.08181356, 0.0214127, 0.05120273, -0.03690156, -0.0254504] + ), +} + + +DOC_PREFIXES = { + "google/embeddinggemma-300m": "title: none | text: ", +} +QUERY_PREFIXES = { + "google/embeddinggemma-300m": "task: search result | query: ", +} +CANONICAL_QUERY_VECTOR_VALUES = { + "google/embeddinggemma-300m": np.array( + [-0.22990295, 0.03311195, 0.04290345, -0.03558498, -0.01399477] + ) } MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] @@ -119,6 +135,9 @@ def test_embedding(model_cache, model_name: str) -> None: with model_cache(model_desc.model) as model: docs = ["hello world", "flag embedding"] + if model_desc.model in DOC_PREFIXES: + docs = [DOC_PREFIXES[model_desc.model] + doc for doc in docs] + embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) assert embeddings.shape == (2, dim) @@ -129,6 +148,39 @@ def test_embedding(model_cache, model_name: str) -> None: ), model_desc.model +def test_query_embedding(model_cache) -> None: + is_ci = os.getenv("CI") + is_mac = platform.system() == "Darwin" + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + + for model_desc in TextEmbedding._list_supported_models(): + if model_desc.model in MULTI_TASK_MODELS or ( + is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q" + ): + continue + + if model_desc.model not in CANONICAL_QUERY_VECTOR_VALUES: + continue + + if not should_test_model(model_desc, "", is_ci, is_manual): + continue + + dim = model_desc.dim + with model_cache(model_desc.model) as model: + queries = ["hello world", "flag embedding"] + if model_desc.model in QUERY_PREFIXES: + queries = [QUERY_PREFIXES[model_desc.model] + query for query in queries] + + embeddings = list(model.query_embed(queries)) + embeddings = np.stack(embeddings, axis=0) + assert embeddings.shape == (2, dim) + + canonical_vector = CANONICAL_QUERY_VECTOR_VALUES[model_desc.model] + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc.model + + @pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")]) def test_batch_embedding(model_cache, n_dims: int, model_name: str) -> None: with model_cache(model_name) as model: