From 8f04f57b502b8b295e33a86f987762e30fcc6807 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Mon, 15 Dec 2025 18:59:06 +0700 Subject: [PATCH 1/2] new: add gemma embed --- .../builtin_pooling_normalized_embedding.py | 68 +++++++++++++++++++ fastembed/text/onnx_text_model.py | 11 ++- fastembed/text/text_embedding.py | 2 + tests/test_text_onnx_embeddings.py | 52 ++++++++++++++ 4 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 fastembed/text/builtin_pooling_normalized_embedding.py diff --git a/fastembed/text/builtin_pooling_normalized_embedding.py b/fastembed/text/builtin_pooling_normalized_embedding.py new file mode 100644 index 00000000..876f853a --- /dev/null +++ b/fastembed/text/builtin_pooling_normalized_embedding.py @@ -0,0 +1,68 @@ +from typing import Any, Iterable, Type + + +from fastembed.common.types import NumpyArray +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.utils import normalize +from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker +from fastembed.common.model_description import DenseModelDescription, ModelSource + + +supported_builtin_pooling_normalized_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="google/embeddinggemma-300m", + dim=768, + description=( + "Text embeddings, Unimodal (text), multilingual, 2048 input tokens truncation, " + "Prefixes for queries/documents: `task: search result | query: {content}` for query, " + "`title: {title | 'none'} | text: {content}` for documents, 2025 year." + ), + license="apache-2.0", + size_in_GB=1.24, + sources=ModelSource( + hf="onnx-community/embeddinggemma-300m-ONNX", + ), + model_file="onnx/model.onnx", + additional_files=["onnx/model.onnx_data"], + ), +] + + +class BuiltinPoolingNormalizedEmbedding(OnnxTextEmbedding): + @classmethod + def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: + return BuiltinPoolingNormalizedEmbeddingWorker + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + """Lists the supported models. + + Returns: + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. + """ + return supported_builtin_pooling_normalized_models + + def _post_process_onnx_output( + self, output: OnnxOutputContext, **kwargs: Any + ) -> Iterable[NumpyArray]: + return normalize(output.model_output) + + def _run_model( + self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None + ) -> NumpyArray: + return self.model.run(onnx_output_names, onnx_input)[1] # type: ignore[union-attr] + + +class BuiltinPoolingNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxTextEmbedding: + return BuiltinPoolingNormalizedEmbedding( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index c8001a91..10a4aa17 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -92,14 +92,21 @@ def onnx_embed( [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 ) onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) + model_output = self._run_model( + onnx_input=onnx_input, onnx_output_names=self.ONNX_OUTPUT_NAMES + ) - model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] return OnnxOutputContext( - model_output=model_output[0], + model_output=model_output, attention_mask=onnx_input.get("attention_mask", attention_mask), input_ids=onnx_input.get("input_ids", input_ids), ) + def _run_model( + self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None + ) -> NumpyArray: + return self.model.run(onnx_output_names, onnx_input)[0] # type: ignore[union-attr] + def _embed_documents( self, model_name: str, diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index a4ae48cc..0bda6fcd 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -8,6 +8,7 @@ from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.pooled_embedding import PooledEmbedding from fastembed.text.multitask_embedding import JinaEmbeddingV3 +from fastembed.text.builtin_pooling_normalized_embedding import BuiltinPoolingNormalizedEmbedding from fastembed.text.onnx_embedding import OnnxTextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType @@ -20,6 +21,7 @@ class TextEmbedding(TextEmbeddingBase): PooledNormalizedEmbedding, PooledEmbedding, JinaEmbeddingV3, + BuiltinPoolingNormalizedEmbedding, CustomTextEmbedding, ] diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e919faf9..cdd22d79 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -68,6 +68,22 @@ "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), "thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]), "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), + "google/embeddinggemma-300m": np.array( + [-0.08181356, 0.0214127, 0.05120273, -0.03690156, -0.0254504] + ), +} + + +DOC_PREFIXES = { + "google/embeddinggemma-300m": "title: none | text: ", +} +QUERY_PREFIXES = { + "google/embeddinggemma-300m": "task: search result | query: ", +} +CANONICAL_QUERY_VECTOR_VALUES = { + "google/embeddinggemma-300m": np.array( + [-0.22990295, 0.03311195, 0.04290345, -0.03558498, -0.01399477] + ) } MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] @@ -119,6 +135,9 @@ def test_embedding(model_cache, model_name: str) -> None: with model_cache(model_desc.model) as model: docs = ["hello world", "flag embedding"] + if model_desc.model in DOC_PREFIXES: + docs = [DOC_PREFIXES[model_desc.model] + doc for doc in docs] + embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) assert embeddings.shape == (2, dim) @@ -129,6 +148,39 @@ def test_embedding(model_cache, model_name: str) -> None: ), model_desc.model +def test_query_embedding(model_cache) -> None: + is_ci = os.getenv("CI") + is_mac = platform.system() == "Darwin" + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + + for model_desc in TextEmbedding._list_supported_models(): + if model_desc.model in MULTI_TASK_MODELS or ( + is_mac and model_desc.model == "nomic-ai/nomic-embed-text-v1.5-Q" + ): + continue + + if model_desc.model not in CANONICAL_QUERY_VECTOR_VALUES: + continue + + if not should_test_model(model_desc, "", is_ci, is_manual): + continue + + dim = model_desc.dim + with model_cache(model_desc.model) as model: + queries = ["hello world", "flag embedding"] + if model_desc.model in QUERY_PREFIXES: + queries = [QUERY_PREFIXES[model_desc.model] + query for query in queries] + + embeddings = list(model.query_embed(queries)) + embeddings = np.stack(embeddings, axis=0) + assert embeddings.shape == (2, dim) + + canonical_vector = CANONICAL_QUERY_VECTOR_VALUES[model_desc.model] + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc.model + + @pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")]) def test_batch_embedding(model_cache, n_dims: int, model_name: str) -> None: with model_cache(model_name) as model: From e14914aafc3d67f9e26f304826fc9cb00f59627c Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Mon, 15 Dec 2025 19:29:22 +0700 Subject: [PATCH 2/2] refactor: rename builtin pooling normalized embedding to builtin sentence embedding --- ...bedding.py => builtin_sentence_embedding.py} | 17 +++++++++-------- fastembed/text/text_embedding.py | 4 ++-- 2 files changed, 11 insertions(+), 10 deletions(-) rename fastembed/text/{builtin_pooling_normalized_embedding.py => builtin_sentence_embedding.py} (80%) diff --git a/fastembed/text/builtin_pooling_normalized_embedding.py b/fastembed/text/builtin_sentence_embedding.py similarity index 80% rename from fastembed/text/builtin_pooling_normalized_embedding.py rename to fastembed/text/builtin_sentence_embedding.py index 876f853a..9a4358df 100644 --- a/fastembed/text/builtin_pooling_normalized_embedding.py +++ b/fastembed/text/builtin_sentence_embedding.py @@ -3,12 +3,11 @@ from fastembed.common.types import NumpyArray from fastembed.common.onnx_model import OnnxOutputContext -from fastembed.common.utils import normalize from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker from fastembed.common.model_description import DenseModelDescription, ModelSource -supported_builtin_pooling_normalized_models: list[DenseModelDescription] = [ +supported_builtin_sentence_embedding_models: list[DenseModelDescription] = [ DenseModelDescription( model="google/embeddinggemma-300m", dim=768, @@ -28,10 +27,12 @@ ] -class BuiltinPoolingNormalizedEmbedding(OnnxTextEmbedding): +class BuiltinSentenceEmbedding(OnnxTextEmbedding): + """Builtin Sentence Embedding uses built-in pooling and normalization of underlying onnx models""" + @classmethod def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: - return BuiltinPoolingNormalizedEmbeddingWorker + return BuiltinSentenceEmbeddingWorker @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: @@ -40,12 +41,12 @@ def _list_supported_models(cls) -> list[DenseModelDescription]: Returns: list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. """ - return supported_builtin_pooling_normalized_models + return supported_builtin_sentence_embedding_models def _post_process_onnx_output( self, output: OnnxOutputContext, **kwargs: Any ) -> Iterable[NumpyArray]: - return normalize(output.model_output) + return output.model_output def _run_model( self, onnx_input: dict[str, Any], onnx_output_names: list[str] | None = None @@ -53,14 +54,14 @@ def _run_model( return self.model.run(onnx_output_names, onnx_input)[1] # type: ignore[union-attr] -class BuiltinPoolingNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker): +class BuiltinSentenceEmbeddingWorker(OnnxTextEmbeddingWorker): def init_embedding( self, model_name: str, cache_dir: str, **kwargs: Any, ) -> OnnxTextEmbedding: - return BuiltinPoolingNormalizedEmbedding( + return BuiltinSentenceEmbedding( model_name=model_name, cache_dir=cache_dir, threads=1, diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 0bda6fcd..5a37e378 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -8,7 +8,7 @@ from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding from fastembed.text.pooled_embedding import PooledEmbedding from fastembed.text.multitask_embedding import JinaEmbeddingV3 -from fastembed.text.builtin_pooling_normalized_embedding import BuiltinPoolingNormalizedEmbedding +from fastembed.text.builtin_sentence_embedding import BuiltinSentenceEmbedding from fastembed.text.onnx_embedding import OnnxTextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType @@ -21,7 +21,7 @@ class TextEmbedding(TextEmbeddingBase): PooledNormalizedEmbedding, PooledEmbedding, JinaEmbeddingV3, - BuiltinPoolingNormalizedEmbedding, + BuiltinSentenceEmbedding, CustomTextEmbedding, ]