diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 4dfc2a05..6545a5ab 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -8,7 +8,7 @@ from fastembed.common.types import NumpyArray from fastembed.common import OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext -from fastembed.common.utils import define_cache_dir +from fastembed.common.utils import define_cache_dir, iter_batch from fastembed.late_interaction.late_interaction_embedding_base import ( LateInteractionTextEmbeddingBase, ) @@ -96,6 +96,38 @@ def _tokenize_documents(self, documents: list[str]) -> list[Encoding]: encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr] return encoded + def token_count( + self, + texts: Union[str, Iterable[str]], + batch_size: int = 1024, + is_doc: bool = True, + include_extension: bool = False, + **kwargs: Any, + ) -> int: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() # loads the tokenizer as well + token_num = 0 + texts = [texts] if isinstance(texts, str) else texts + tokenizer = self.tokenizer if is_doc else self.query_tokenizer + assert tokenizer is not None + for batch in iter_batch(texts, batch_size): + for tokens in tokenizer.encode_batch(batch): + if is_doc: + token_num += sum(tokens.attention_mask) + else: + attend_count = sum(tokens.attention_mask) + if include_extension: + token_num += max(attend_count, self.MIN_QUERY_LENGTH) + + else: + token_num += attend_count + if include_extension: + token_num += len( + batch + ) # add 1 for each cls.DOC_MARKER_TOKEN_ID or cls.QUERY_MARKER_TOKEN_ID + + return token_num + @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: """Lists the supported models. diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_embedding_base.py index f677ba98..ec37c79b 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_embedding_base.py @@ -69,3 +69,12 @@ def get_embedding_size(cls, model_name: str) -> int: def embedding_size(self) -> int: """Returns embedding size for the current model""" raise NotImplementedError("Subclasses must implement this method") + + def token_count( + self, + texts: Union[str, Iterable[str]], + batch_size: int = 1024, + **kwargs: Any, + ) -> int: + """Returns the number of tokens in the texts.""" + raise NotImplementedError("Subclasses must implement this method") diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 22833618..482a4331 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -151,3 +151,30 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterab # This is model-specific, so that different models can have specialized implementations yield from self.model.query_embed(query, **kwargs) + + def token_count( + self, + texts: Union[str, Iterable[str]], + batch_size: int = 1024, + is_doc: bool = True, + include_extension: bool = False, + **kwargs: Any, + ) -> int: + """Returns the number of tokens in the texts. + + Args: + texts (str | Iterable[str]): The list of texts to embed. + batch_size (int): Batch size for encoding + is_doc (bool): Whether the texts are documents (disable embedding a query with include_mask=True). + include_extension (bool): Turn on to count DOC / QUERY marker tokens, and [MASK] token in query mode. + + Returns: + int: Sum of number of tokens in the texts. + """ + return self.model.token_count( + texts, + batch_size=batch_size, + is_doc=is_doc, + include_extension=include_extension, + **kwargs, + ) diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 059f7971..7d0218fe 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -6,7 +6,7 @@ from fastembed.common import OnnxProvider, ImageInput from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.types import NumpyArray -from fastembed.common.utils import define_cache_dir +from fastembed.common.utils import define_cache_dir, iter_batch from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( LateInteractionMultimodalEmbeddingBase, ) @@ -172,6 +172,23 @@ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: encoded = self.tokenizer.encode_batch(texts_query) # type: ignore[union-attr] return encoded + def token_count( + self, + texts: Union[str, Iterable[str]], + batch_size: int = 1024, + include_extension: bool = False, + **kwargs: Any, + ) -> int: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() # loads the tokenizer as well + token_num = 0 + texts = [texts] if isinstance(texts, str) else texts + assert self.tokenizer is not None + tokenize_func = self.tokenize if include_extension else self.tokenizer.encode_batch + for batch in iter_batch(texts, batch_size): + token_num += sum([sum(encoding.attention_mask) for encoding in tokenize_func(batch)]) + return token_num + def _preprocess_onnx_text_input( self, onnx_input: dict[str, NumpyArray], **kwargs: Any ) -> dict[str, NumpyArray]: diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index 39c1763e..01a57294 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -162,3 +162,24 @@ def embed_image( List of embeddings, one per image """ yield from self.model.embed_image(images, batch_size, parallel, **kwargs) + + def token_count( + self, + texts: Union[str, Iterable[str]], + batch_size: int = 1024, + include_extension: bool = False, + **kwargs: Any, + ) -> int: + """Returns the number of tokens in the texts. + + Args: + texts (str | Iterable[str]): The list of texts to embed. + batch_size (int): Batch size for encoding + include_extension (bool): Whether to include tokens added by preprocessing + + Returns: + int: Sum of number of tokens in the texts. + """ + return self.model.token_count( + texts, batch_size=batch_size, include_extension=include_extension, **kwargs + ) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 12e3553c..0d148ce4 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -76,3 +76,11 @@ def get_embedding_size(cls, model_name: str) -> int: def embedding_size(self) -> int: """Returns embedding size for the current model""" raise NotImplementedError("Subclasses must implement this method") + + def token_count( + self, + texts: Union[str, Iterable[str]], + **kwargs: Any, + ) -> int: + """Returns the number of tokens in the texts.""" + raise NotImplementedError("Subclasses must implement this method") diff --git a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py index 56f2b86c..4a91a010 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py @@ -207,6 +207,20 @@ def _post_process_onnx_output( ) -> Iterable[float]: return (float(elem) for elem in output.model_output) + def token_count( + self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any + ) -> int: + """Returns the number of tokens in the pairs. + + Args: + pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized + batch_size: Batch size for tokenizing + + Returns: + token count: overall number of tokens in the pairs + """ + return self._token_count(pairs, batch_size=batch_size, **kwargs) + class TextCrossEncoderWorker(TextRerankerWorker): def init_embedding( diff --git a/fastembed/rerank/cross_encoder/onnx_text_model.py b/fastembed/rerank/cross_encoder/onnx_text_model.py index 5c85d27e..801c60dc 100644 --- a/fastembed/rerank/cross_encoder/onnx_text_model.py +++ b/fastembed/rerank/cross_encoder/onnx_text_model.py @@ -165,6 +165,20 @@ def _preprocess_onnx_input( """ return onnx_input + def _token_count( + self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **_: Any + ) -> int: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() # loads the tokenizer as well + + token_num = 0 + assert self.tokenizer is not None + for batch in iter_batch(pairs, batch_size): + for tokens in self.tokenizer.encode_batch(batch): + token_num += sum(tokens.attention_mask) + + return token_num + class TextRerankerWorker(EmbeddingWorker[float]): def __init__( diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index e269570d..c6182084 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -161,3 +161,17 @@ def add_custom_model( additional_files=additional_files or [], ) ) + + def token_count( + self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any + ) -> int: + """Returns the number of tokens in the pairs. + + Args: + pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized + batch_size: Batch size for tokenizing + + Returns: + token count: overall number of tokens in the pairs + """ + return self.model.token_count(pairs, batch_size=batch_size, **kwargs) diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder_base.py b/fastembed/rerank/cross_encoder/text_cross_encoder_base.py index 84b44e41..7baffd0e 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder_base.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder_base.py @@ -57,3 +57,7 @@ def rerank_pairs( Iterable[float]: Scores for each individual pair """ raise NotImplementedError("This method should be overridden by subclasses") + + def token_count(self, pairs: Iterable[tuple[str, str]], **kwargs: Any) -> int: + """Returns the number of tokens in the pairs.""" + raise NotImplementedError("This method should be overridden by subclasses") diff --git a/fastembed/sparse/bm25.py b/fastembed/sparse/bm25.py index b6ac59fd..8265a621 100644 --- a/fastembed/sparse/bm25.py +++ b/fastembed/sparse/bm25.py @@ -268,6 +268,15 @@ def raw_embed( embeddings.append(SparseEmbedding.from_dict(token_id2value)) return embeddings + def token_count(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> int: + token_num = 0 + texts = [texts] if isinstance(texts, str) else texts + for text in texts: + document = remove_non_alphanumeric(text) + tokens = self.tokenizer.tokenize(document) + token_num += len(tokens) + return token_num + def _term_frequency(self, tokens: list[str]) -> dict[int, float]: """Calculate the term frequency part of the BM25 formula. diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index 848b1753..536ba61e 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -352,6 +352,13 @@ def query_embed( def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]: return Bm42TextEmbeddingWorker + def token_count( + self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any + ) -> int: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() # loads the tokenizer as well + return self._token_count(texts, batch_size=batch_size, **kwargs) + class Bm42TextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]): def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Bm42: diff --git a/fastembed/sparse/minicoil.py b/fastembed/sparse/minicoil.py index dde52d90..04d74793 100644 --- a/fastembed/sparse/minicoil.py +++ b/fastembed/sparse/minicoil.py @@ -187,6 +187,11 @@ def load_onnx_model(self) -> None: avg_len=self.avg_len, ) + def token_count( + self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any + ) -> int: + return self._token_count(texts, batch_size=batch_size, **kwargs) + def embed( self, documents: Union[str, Iterable[str]], diff --git a/fastembed/sparse/sparse_embedding_base.py b/fastembed/sparse/sparse_embedding_base.py index b153c814..47026f65 100644 --- a/fastembed/sparse/sparse_embedding_base.py +++ b/fastembed/sparse/sparse_embedding_base.py @@ -86,3 +86,7 @@ def query_embed( yield from self.embed([query], **kwargs) else: yield from self.embed(query, **kwargs) + + def token_count(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> int: + """Returns the number of tokens in the texts.""" + raise NotImplementedError("Subclasses must implement this method") diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index 3cb14c3e..6f51f69e 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -128,3 +128,17 @@ def query_embed( Iterable[SparseEmbedding]: The sparse embeddings. """ yield from self.model.query_embed(query, **kwargs) + + def token_count( + self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any + ) -> int: + """Returns the number of tokens in the texts. + + Args: + texts (str | Iterable[str]): The list of texts to embed. + batch_size (int): Batch size for encoding + + Returns: + int: Sum of number of tokens in the texts. + """ + return self.model.token_count(texts, batch_size=batch_size, **kwargs) diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index 95e43bb2..8480cb10 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -53,6 +53,11 @@ def _post_process_onnx_output( scores = row_scores[indices] yield SparseEmbedding(values=scores, indices=indices) + def token_count( + self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any + ) -> int: + return self._token_count(texts, batch_size=batch_size, **kwargs) + @classmethod def _list_supported_models(cls) -> list[SparseModelDescription]: """Lists the supported models. diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index d76db8bf..2e3fc7d2 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -331,6 +331,11 @@ def load_onnx_model(self) -> None: extra_session_options=self._extra_session_options, ) + def token_count( + self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any + ) -> int: + return self._token_count(texts, batch_size=batch_size, **kwargs) + class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]): def init_embedding( diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 6cb49178..16dd6946 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -159,6 +159,21 @@ def _embed_documents( for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore + def _token_count( + self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **_: Any + ) -> int: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() # loads the tokenizer as well + + token_num = 0 + assert self.tokenizer is not None + texts = [texts] if isinstance(texts, str) else texts + for batch in iter_batch(texts, batch_size): + for tokens in self.tokenizer.encode_batch(batch): + token_num += sum(tokens.attention_mask) + + return token_num + class TextEmbeddingWorker(EmbeddingWorker[T]): def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]: diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 117f5af7..0c58a7f3 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -212,3 +212,17 @@ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyAr """ # This is model-specific, so that different models can have specialized implementations yield from self.model.passage_embed(texts, **kwargs) + + def token_count( + self, texts: Union[str, Iterable[str]], batch_size: int = 1024, **kwargs: Any + ) -> int: + """Returns the number of tokens in the texts. + + Args: + texts (str | Iterable[str]): The list of texts to embed. + batch_size (int): Batch size for encoding + + Returns: + int: Sum of number of tokens in the texts. + """ + return self.model.token_count(texts, batch_size=batch_size, **kwargs) diff --git a/fastembed/text/text_embedding_base.py b/fastembed/text/text_embedding_base.py index 75df9ac5..a11ecae0 100644 --- a/fastembed/text/text_embedding_base.py +++ b/fastembed/text/text_embedding_base.py @@ -69,3 +69,7 @@ def get_embedding_size(cls, model_name: str) -> int: def embedding_size(self) -> int: """Returns embedding size for the current model""" raise NotImplementedError("Subclasses must implement this method") + + def token_count(self, texts: Union[str, Iterable[str]], **kwargs: Any) -> int: + """Returns the number of tokens in the texts.""" + raise NotImplementedError("Subclasses must implement this method") diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index f2499db8..ea83e76a 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -318,3 +318,27 @@ def test_session_options(model_cache, model_name) -> None: model = LateInteractionTextEmbedding(model_name=model_name, enable_cpu_mem_arena=False) session_options = model.model.model.get_session_options() assert session_options.enable_cpu_mem_arena is False + + +@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) +def test_token_count(model_cache, model_name) -> None: + with model_cache(model_name) as model: + documents = ["short doc", "it is a long document to check attention mask for paddings"] + short_doc_token_count = model.token_count(documents[0]) + long_doc_token_count = model.token_count(documents[1]) + documents_token_count = model.token_count(documents) + assert short_doc_token_count + long_doc_token_count == documents_token_count + # 2 is 2*DOC_MARKER_TOKEN_ID for each document + assert short_doc_token_count + long_doc_token_count + 2 == model.token_count( + documents, include_extension=True + ) + assert short_doc_token_count + long_doc_token_count == model.token_count( + documents, batch_size=1 + ) + assert short_doc_token_count + long_doc_token_count == model.token_count( + documents, is_doc=False + ) + # query min length is 32 + assert model.token_count(documents, is_doc=False, include_extension=True) == 64 + very_long_query = "It's a very long query which definitely contains more than 32 tokens and we're using it to check whether the method can handle large query properly without cutting it to 32 tokens" + assert model.token_count(very_long_query, is_doc=False, include_extension=True) > 32 diff --git a/tests/test_late_interaction_multimodal.py b/tests/test_late_interaction_multimodal.py index 80135f3b..8a102ace 100644 --- a/tests/test_late_interaction_multimodal.py +++ b/tests/test_late_interaction_multimodal.py @@ -101,3 +101,22 @@ def test_embedding_size(): model_name = "Qdrant/ColPali-v1.3-fp16" model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) assert model.embedding_size == 128 + + +def test_token_count() -> None: + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + model_name = "Qdrant/colpali-v1.3-fp16" + model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) + + documents = ["short doc", "it is a long document to check attention mask for paddings"] + short_doc_token_count = model.token_count(documents[0]) + long_doc_token_count = model.token_count(documents[1]) + documents_token_count = model.token_count(documents) + assert short_doc_token_count + long_doc_token_count == documents_token_count + assert short_doc_token_count + long_doc_token_count == model.token_count( + documents, batch_size=1 + ) + assert short_doc_token_count + long_doc_token_count < model.token_count( + documents, include_extension=True + ) diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 4c02a683..c2a7e2ff 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -298,3 +298,25 @@ def test_session_options(model_cache, model_name) -> None: model = SparseTextEmbedding(model_name=model_name, enable_cpu_mem_arena=False) session_options = model.model.model.get_session_options() assert session_options.enable_cpu_mem_arena is False + + +@pytest.mark.parametrize( + "model_name", + [ + "prithivida/Splade_PP_en_v1", + "Qdrant/minicoil-v1", + "Qdrant/bm42-all-minilm-l6-v2-attentions", + "Qdrant/bm25", + ], +) +def test_token_count(model_cache, model_name) -> None: + with model_cache(model_name) as model: + documents = [ + "Name me a couple of cities were the capitals of Germany?", + "Berlin is the current capital of Germany, Bonn is a former capital of Germany.", + ] + first_doc_token_count = model.token_count(documents[0]) + second_doc_token_count = model.token_count(documents[1]) + doc_token_count = model.token_count(documents) + assert first_doc_token_count + second_doc_token_count == doc_token_count + assert doc_token_count == model.token_count(documents, batch_size=1) diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index d23ee8ef..4d0d5b7d 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -124,6 +124,23 @@ def test_rerank_pairs_parallel(model_cache, model_name: str) -> None: ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Expected: {canonical_scores}" +@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) +def test_token_count(model_cache, model_name: str) -> None: + with model_cache(model_name) as model: + pairs = [ + ("What is the capital of France?", "Paris is the capital of France."), + ( + "Name me a couple of cities were the capitals of Germany?", + "Berlin is the current capital of Germany, Bonn is a former capital of Germany.", + ), + ] + first_pair_token_count = model.token_count([pairs[0]]) + second_pair_token_count = model.token_count([pairs[1]]) + pairs_token_count = model.token_count(pairs) + assert first_pair_token_count + second_pair_token_count == pairs_token_count + assert pairs_token_count == model.token_count(pairs, batch_size=1) + + @pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) def test_session_options(model_cache, model_name) -> None: with model_cache(model_name) as default_model: diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 43e88ca8..e919faf9 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -203,3 +203,17 @@ def test_session_options(model_cache, model_name) -> None: model = TextEmbedding(model_name=model_name, enable_cpu_mem_arena=False) session_options = model.model.model.get_session_options() assert session_options.enable_cpu_mem_arena is False + + +@pytest.mark.parametrize("model_name", ["sentence-transformers/all-MiniLM-L6-v2"]) +def test_token_count(model_cache, model_name) -> None: + with model_cache(model_name) as model: + documents = [ + "Name me a couple of cities were the capitals of Germany?", + "Berlin is the current capital of Germany, Bonn is a former capital of Germany.", + ] + first_doc_token_count = model.token_count(documents[0]) + second_doc_token_count = model.token_count(documents[1]) + doc_token_count = model.token_count(documents) + assert first_doc_token_count + second_doc_token_count == doc_token_count + assert doc_token_count == model.token_count(documents, batch_size=1)