Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastembed.common.types import NumpyArray
from fastembed.common import OnnxProvider
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.utils import define_cache_dir
from fastembed.common.utils import define_cache_dir, iter_batch
from fastembed.late_interaction.late_interaction_embedding_base import (
LateInteractionTextEmbeddingBase,
)
Expand Down Expand Up @@ -96,6 +96,38 @@ def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
return encoded

def token_count(
self,
texts: Union[str, Iterable[str]],
batch_size: int = 1024,
is_doc: bool = True,
include_extension: bool = False,
**kwargs: Any,
) -> int:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model() # loads the tokenizer as well
token_num = 0
texts = [texts] if isinstance(texts, str) else texts
tokenizer = self.tokenizer if is_doc else self.query_tokenizer
assert tokenizer is not None
for batch in iter_batch(texts, batch_size):
for tokens in tokenizer.encode_batch(batch):
if is_doc:
token_num += sum(tokens.attention_mask)
else:
attend_count = sum(tokens.attention_mask)
if include_extension:
token_num += max(attend_count, self.MIN_QUERY_LENGTH)

else:
token_num += attend_count
if include_extension:
token_num += len(
batch
) # add 1 for each cls.DOC_MARKER_TOKEN_ID or cls.QUERY_MARKER_TOKEN_ID

return token_num

@classmethod
def _list_supported_models(cls) -> list[DenseModelDescription]:
"""Lists the supported models.
Expand Down
9 changes: 9 additions & 0 deletions fastembed/late_interaction/late_interaction_embedding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,12 @@ def get_embedding_size(cls, model_name: str) -> int:
def embedding_size(self) -> int:
"""Returns embedding size for the current model"""
raise NotImplementedError("Subclasses must implement this method")

def token_count(
self,
texts: Union[str, Iterable[str]],
batch_size: int = 1024,
**kwargs: Any,
) -> int:
"""Returns the number of tokens in the texts."""
raise NotImplementedError("Subclasses must implement this method")
27 changes: 27 additions & 0 deletions fastembed/late_interaction/late_interaction_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,30 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterab

# This is model-specific, so that different models can have specialized implementations
yield from self.model.query_embed(query, **kwargs)

def token_count(
self,
texts: Union[str, Iterable[str]],
batch_size: int = 1024,
is_doc: bool = True,
include_extension: bool = False,
**kwargs: Any,
) -> int:
"""Returns the number of tokens in the texts.

Args:
texts (str | Iterable[str]): The list of texts to embed.
batch_size (int): Batch size for encoding
is_doc (bool): Whether the texts are documents (disable embedding a query with include_mask=True).
include_extension (bool): Turn on to count DOC / QUERY marker tokens, and [MASK] token in query mode.

Returns:
int: Sum of number of tokens in the texts.
"""
return self.model.token_count(
texts,
batch_size=batch_size,
is_doc=is_doc,
include_extension=include_extension,
**kwargs,
)
19 changes: 18 additions & 1 deletion fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastembed.common import OnnxProvider, ImageInput
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.common.types import NumpyArray
from fastembed.common.utils import define_cache_dir
from fastembed.common.utils import define_cache_dir, iter_batch
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
LateInteractionMultimodalEmbeddingBase,
)
Expand Down Expand Up @@ -172,6 +172,23 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
encoded = self.tokenizer.encode_batch(texts_query) # type: ignore[union-attr]
return encoded

def token_count(
self,
texts: Union[str, Iterable[str]],
batch_size: int = 1024,
include_extension: bool = False,
**kwargs: Any,
) -> int:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model() # loads the tokenizer as well
token_num = 0
texts = [texts] if isinstance(texts, str) else texts
assert self.tokenizer is not None
tokenize_func = self.tokenize if include_extension else self.tokenizer.encode_batch
for batch in iter_batch(texts, batch_size):
token_num += sum([sum(encoding.attention_mask) for encoding in tokenize_func(batch)])
return token_num

def _preprocess_onnx_text_input(
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
) -> dict[str, NumpyArray]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,24 @@ def embed_image(
List of embeddings, one per image
"""
yield from self.model.embed_image(images, batch_size, parallel, **kwargs)

def token_count(
self,
texts: Union[str, Iterable[str]],
batch_size: int = 1024,
include_extension: bool = False,
**kwargs: Any,
) -> int:
"""Returns the number of tokens in the texts.

Args:
texts (str | Iterable[str]): The list of texts to embed.
batch_size (int): Batch size for encoding
include_extension (bool): Whether to include tokens added by preprocessing

Returns:
int: Sum of number of tokens in the texts.
"""
return self.model.token_count(
texts, batch_size=batch_size, include_extension=include_extension, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,11 @@ def get_embedding_size(cls, model_name: str) -> int:
def embedding_size(self) -> int:
"""Returns embedding size for the current model"""
raise NotImplementedError("Subclasses must implement this method")

def token_count(
self,
texts: Union[str, Iterable[str]],
**kwargs: Any,
) -> int:
"""Returns the number of tokens in the texts."""
raise NotImplementedError("Subclasses must implement this method")
14 changes: 14 additions & 0 deletions fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,20 @@ def _post_process_onnx_output(
) -> Iterable[float]:
return (float(elem) for elem in output.model_output)

def token_count(
self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any
) -> int:
"""Returns the number of tokens in the pairs.

Args:
pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized
batch_size: Batch size for tokenizing

Returns:
token count: overall number of tokens in the pairs
"""
return self._token_count(pairs, batch_size=batch_size, **kwargs)


class TextCrossEncoderWorker(TextRerankerWorker):
def init_embedding(
Expand Down
14 changes: 14 additions & 0 deletions fastembed/rerank/cross_encoder/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,20 @@ def _preprocess_onnx_input(
"""
return onnx_input

def _token_count(
self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **_: Any
) -> int:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model() # loads the tokenizer as well

token_num = 0
assert self.tokenizer is not None
for batch in iter_batch(pairs, batch_size):
for tokens in self.tokenizer.encode_batch(batch):
token_num += sum(tokens.attention_mask)

return token_num


class TextRerankerWorker(EmbeddingWorker[float]):
def __init__(
Expand Down
14 changes: 14 additions & 0 deletions fastembed/rerank/cross_encoder/text_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,17 @@ def add_custom_model(
additional_files=additional_files or [],
)
)

def token_count(
self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any
) -> int:
"""Returns the number of tokens in the pairs.

Args:
pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized
batch_size: Batch size for tokenizing

Returns:
token count: overall number of tokens in the pairs
"""
return self.model.token_count(pairs, batch_size=batch_size, **kwargs)
4 changes: 4 additions & 0 deletions fastembed/rerank/cross_encoder/text_cross_encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ def rerank_pairs(
Iterable[float]: Scores for each individual pair
"""
raise NotImplementedError("This method should be overridden by subclasses")

def token_count(self, pairs: Iterable[tuple[str, str]], **kwargs: Any) -> int:
"""Returns the number of tokens in the pairs."""
raise NotImplementedError("This method should be overridden by subclasses")
9 changes: 9 additions & 0 deletions fastembed/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def raw_embed(
embeddings.append(SparseEmbedding.from_dict(token_id2value))
return embeddings

def token_count(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> int:
token_num = 0
texts = [texts] if isinstance(texts, str) else texts
for text in texts:
document = remove_non_alphanumeric(text)
tokens = self.tokenizer.tokenize(document)
token_num += len(tokens)
return token_num

def _term_frequency(self, tokens: list[str]) -> dict[int, float]:
"""Calculate the term frequency part of the BM25 formula.

Expand Down
7 changes: 7 additions & 0 deletions fastembed/sparse/bm42.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ def query_embed(
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]:
return Bm42TextEmbeddingWorker

def token_count(
self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any
) -> int:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model() # loads the tokenizer as well
return self._token_count(texts, batch_size=batch_size, **kwargs)


class Bm42TextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Bm42:
Expand Down
5 changes: 5 additions & 0 deletions fastembed/sparse/minicoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def load_onnx_model(self) -> None:
avg_len=self.avg_len,
)

def token_count(
self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any
) -> int:
return self._token_count(texts, batch_size=batch_size, **kwargs)

def embed(
self,
documents: Union[str, Iterable[str]],
Expand Down
4 changes: 4 additions & 0 deletions fastembed/sparse/sparse_embedding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@ def query_embed(
yield from self.embed([query], **kwargs)
else:
yield from self.embed(query, **kwargs)

def token_count(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> int:
"""Returns the number of tokens in the texts."""
raise NotImplementedError("Subclasses must implement this method")
14 changes: 14 additions & 0 deletions fastembed/sparse/sparse_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,17 @@ def query_embed(
Iterable[SparseEmbedding]: The sparse embeddings.
"""
yield from self.model.query_embed(query, **kwargs)

def token_count(
self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any
) -> int:
"""Returns the number of tokens in the texts.

Args:
texts (str | Iterable[str]): The list of texts to embed.
batch_size (int): Batch size for encoding

Returns:
int: Sum of number of tokens in the texts.
"""
return self.model.token_count(texts, batch_size=batch_size, **kwargs)
5 changes: 5 additions & 0 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def _post_process_onnx_output(
scores = row_scores[indices]
yield SparseEmbedding(values=scores, indices=indices)

def token_count(
self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any
) -> int:
return self._token_count(texts, batch_size=batch_size, **kwargs)

@classmethod
def _list_supported_models(cls) -> list[SparseModelDescription]:
"""Lists the supported models.
Expand Down
5 changes: 5 additions & 0 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ def load_onnx_model(self) -> None:
extra_session_options=self._extra_session_options,
)

def token_count(
self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any
) -> int:
return self._token_count(texts, batch_size=batch_size, **kwargs)


class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
def init_embedding(
Expand Down
15 changes: 15 additions & 0 deletions fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ def _embed_documents(
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore

def _token_count(
self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **_: Any
) -> int:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model() # loads the tokenizer as well

token_num = 0
assert self.tokenizer is not None
texts = [texts] if isinstance(texts, str) else texts
for batch in iter_batch(texts, batch_size):
for tokens in self.tokenizer.encode_batch(batch):
token_num += sum(tokens.attention_mask)

return token_num


class TextEmbeddingWorker(EmbeddingWorker[T]):
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
Expand Down
14 changes: 14 additions & 0 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,17 @@ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyAr
"""
# This is model-specific, so that different models can have specialized implementations
yield from self.model.passage_embed(texts, **kwargs)

def token_count(
self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any
) -> int:
"""Returns the number of tokens in the texts.

Args:
texts (str | Iterable[str]): The list of texts to embed.
batch_size (int): Batch size for encoding

Returns:
int: Sum of number of tokens in the texts.
"""
return self.model.token_count(texts, batch_size=batch_size, **kwargs)
4 changes: 4 additions & 0 deletions fastembed/text/text_embedding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,7 @@ def get_embedding_size(cls, model_name: str) -> int:
def embedding_size(self) -> int:
"""Returns embedding size for the current model"""
raise NotImplementedError("Subclasses must implement this method")

def token_count(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> int:
"""Returns the number of tokens in the texts."""
raise NotImplementedError("Subclasses must implement this method")
24 changes: 24 additions & 0 deletions tests/test_late_interaction_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,27 @@ def test_session_options(model_cache, model_name) -> None:
model = LateInteractionTextEmbedding(model_name=model_name, enable_cpu_mem_arena=False)
session_options = model.model.model.get_session_options()
assert session_options.enable_cpu_mem_arena is False


@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
def test_token_count(model_cache, model_name) -> None:
with model_cache(model_name) as model:
documents = ["short doc", "it is a long document to check attention mask for paddings"]
short_doc_token_count = model.token_count(documents[0])
long_doc_token_count = model.token_count(documents[1])
documents_token_count = model.token_count(documents)
assert short_doc_token_count + long_doc_token_count == documents_token_count
# 2 is 2*DOC_MARKER_TOKEN_ID for each document
assert short_doc_token_count + long_doc_token_count + 2 == model.token_count(
documents, include_extension=True
)
assert short_doc_token_count + long_doc_token_count == model.token_count(
documents, batch_size=1
)
assert short_doc_token_count + long_doc_token_count == model.token_count(
documents, is_doc=False
)
# query min length is 32
assert model.token_count(documents, is_doc=False, include_extension=True) == 64
very_long_query = "It's a very long query which definitely contains more than 32 tokens and we're using it to check whether the method can handle large query properly without cutting it to 32 tokens"
assert model.token_count(very_long_query, is_doc=False, include_extension=True) > 32
Loading