diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 4cc892f5..2e37308f 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -107,6 +107,28 @@ size_in_GB=0.64, sources=ModelSource(hf="mixedbread-ai/mxbai-embed-large-v1"), model_file="onnx/model.onnx", + # Prefixes from https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1#usage + tasks={ + "query_prefix": "Represent this sentence for searching relevant passages: ", + "passage_prefix": "", + }, + ), + DenseModelDescription( + model="mixedbread-ai/deepset-mxbai-embed-de-large-v1", + dim=1024, + description=( + "Text embeddings, Unimodal (text), German/English, 512 input tokens truncation, " + "Prefixes for queries/documents: necessary, 2024 year." + ), + license="apache-2.0", + size_in_GB=1.94, + sources=ModelSource(hf="mixedbread-ai/deepset-mxbai-embed-de-large-v1"), + model_file="onnx/model.onnx", + # Prefixes from https://huggingface.co/mixedbread-ai/deepset-mxbai-embed-de-large-v1#usage + tasks={ + "query_prefix": "query: ", + "passage_prefix": "passage: ", + }, ), DenseModelDescription( model="snowflake/snowflake-arctic-embed-xs", @@ -294,6 +316,50 @@ def embed( **kwargs, ) + def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: + """ + Embeds queries with optional query prefix. + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + + Returns: + Iterable[NumpyArray]: The embeddings. + """ + # Check if model has query prefix + query_prefix = self.model_description.tasks.get("query_prefix", "") if self.model_description.tasks else "" + + # Apply prefix if specified + if query_prefix: + if isinstance(query, str): + query = [query_prefix + query] + else: + query = [query_prefix + q for q in query] + elif isinstance(query, str): + query = [query] + + yield from self.embed(query, **kwargs) + + def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: + """ + Embeds passages with optional passage prefix. + + Args: + texts (Iterable[str]): The list of texts to embed. + **kwargs: Additional keyword arguments to pass to the embed method. + + Yields: + Iterable[NumpyArray]: The embeddings. + """ + # Check if model has passage prefix + passage_prefix = self.model_description.tasks.get("passage_prefix", "") if self.model_description.tasks else "" + + # Apply prefix if specified + if passage_prefix: + texts = [passage_prefix + text for text in texts] + + yield from self.embed(texts, **kwargs) + @classmethod def _get_worker_class(cls) -> Type["TextEmbeddingWorker[NumpyArray]"]: return OnnxTextEmbeddingWorker diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 6b25d900..43decffd 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -57,6 +57,9 @@ "mixedbread-ai/mxbai-embed-large-v1": np.array( [0.02295546, 0.03196154, 0.016512, -0.04031524, -0.0219634] ), + "mixedbread-ai/deepset-mxbai-embed-de-large-v1": np.array( + [0.00574683, 0.00185086, 0.00910093, -0.03800965, 0.00805963] + ), "snowflake/snowflake-arctic-embed-xs": np.array([0.0092, 0.0619, 0.0196, 0.009, -0.0114]), "snowflake/snowflake-arctic-embed-s": np.array([-0.0416, -0.0867, 0.0209, 0.0554, -0.0272]), "snowflake/snowflake-arctic-embed-m": np.array([-0.0329, 0.0364, 0.0481, 0.0016, 0.0328]),