From 5abd25a9128fad5ff607a4455694e84690f33209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Mon, 15 Dec 2025 11:53:26 +0100 Subject: [PATCH 1/5] feat: implement get_image_mask for LateInteractionMultimodalEmbeddingBase --- .../late_interaction_multimodal/colpali.py | 48 ++++++++++++++- .../late_interaction_multimodal_embedding.py | 35 +++++++++++ ...e_interaction_multimodal_embedding_base.py | 38 ++++++++++++ tests/test_late_interaction_multimodal.py | 61 +++++++++++++++++++ 4 files changed, 181 insertions(+), 1 deletion(-) diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 85fbcd06..418cba22 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -36,9 +36,10 @@ class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyA BOS_TOKEN = "" PAD_TOKEN = "" QUERY_MARKER_TOKEN_ID = [2, 5098] + IMAGE_TOKEN_ID = 257152 # The '' special token IMAGE_PLACEHOLDER_SIZE = (3, 448, 448) EMPTY_TEXT_PLACEHOLDER = np.array( - [257152] * 1024 + [2, 50721, 573, 2416, 235265, 108] + [IMAGE_TOKEN_ID] * 1024 + [2, 50721, 573, 2416, 235265, 108] ) # This is a tokenization of '' * 1024 + 'Describe the image.\n' line which is used as placeholder # while processing an image EVEN_ATTENTION_MASK = np.array([1] * 1030) @@ -298,6 +299,51 @@ def embed_image( **kwargs, ) + def get_image_mask( + self, + images: Union[ImageInput, Iterable[ImageInput]], + batch_size: int = 16, + **kwargs: Any, + ) -> list[NumpyArray]: + """ + Generate image token masks for ColPali embeddings. + + For ColPali, image embeddings use 1030 tokens: + - Tokens 0-1023: Image tokens (token ID 257152) + - Tokens 1024-1029: Text tokens from prompt "Describe the image.\\n" + + Args: + images: Single image or iterable of images + batch_size: Batch size for processing + **kwargs: Additional processing arguments + + Returns: + List of binary masks (dtype=bool) where True = image token (ID 257152), False = other tokens. + """ + from pathlib import Path + + # Ensure images is iterable + is_single = isinstance(images, (str, bytes, Path)) or hasattr(images, "read") + images_to_process: Iterable[ImageInput] = [images] if is_single else images # type: ignore[list-item] + + # Process images in batches to get input_ids + masks: list[NumpyArray] = [] + images_list = list(images_to_process) + for batch_start in range(0, len(images_list), batch_size): + batch = images_list[batch_start : batch_start + batch_size] + + # Load the model if not already loaded + if self.model is None: + self.load_onnx_model() + + # For ColPali images, input_ids follow EMPTY_TEXT_PLACEHOLDER pattern + # Generate mask: True for image tokens (ID 257152), False for others + for _ in batch: + mask: NumpyArray = self.EMPTY_TEXT_PLACEHOLDER == self.IMAGE_TOKEN_ID + masks.append(mask) + + return masks + @classmethod def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]: return ColPaliTextEmbeddingWorker diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index afe839d4..9dab0b71 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -183,3 +183,38 @@ def token_count( return self.model.token_count( texts, batch_size=batch_size, include_extension=include_extension, **kwargs ) + + def get_image_mask( + self, + images: Union[ImageInput, Iterable[ImageInput]], + batch_size: int = 16, + **kwargs: Any, + ) -> list[NumpyArray]: + """ + Generate binary masks identifying image tokens in processed image sequences. + + This method processes images and returns masks indicating which tokens in the + resulting sequence correspond to image content (value=1) vs text/special tokens (value=0). + + Args: + images: Single image or iterable of images (file paths, bytes, or PIL Image objects) + batch_size: Number of images to process in each batch. Defaults to 16. + **kwargs: Additional keyword arguments for image processing. + + Returns: + List of binary masks (numpy arrays with dtype=bool), one per image. Each mask has shape (sequence_length,) + where sequence_length is the number of tokens in the processed image representation. + Values are True for image tokens, False for non-image tokens (text, special tokens, etc.). + + Raises: + NotImplementedError: If the underlying model doesn't support image mask generation. + + Example: + ```python + model = LateInteractionMultimodalEmbedding("Qdrant/colpali-v1.3-fp16") + masks = model.get_image_mask(["image1.jpg", "image2.jpg"]) + # masks[0] is a numpy array of shape (1030,) with dtype=bool for ColPali + # First 1024 values are True (image tokens), last 6 are False (text tokens) + ``` + """ + return self.model.get_image_mask(images, batch_size, **kwargs) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 72a87fe5..1de6e8b4 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -84,3 +84,41 @@ def token_count( ) -> int: """Returns the number of tokens in the texts.""" raise NotImplementedError("Subclasses must implement this method") + + def get_image_mask( + self, + images: Union[ImageInput, Iterable[ImageInput]], + batch_size: int = 16, + **kwargs: Any, + ) -> list[NumpyArray]: + """ + Generate binary masks identifying image tokens in processed image sequences. + + This method processes images and returns masks indicating which tokens in the + resulting sequence correspond to image content (value=1) vs text/special tokens (value=0). + + Args: + images: Single image or iterable of images (file paths, bytes, or PIL Image objects) + batch_size: Number of images to process in each batch. Defaults to 16. + **kwargs: Additional keyword arguments for image processing. + + Returns: + List of binary masks (numpy arrays with dtype=bool), one per image. Each mask has shape (sequence_length,) + where sequence_length is the number of tokens in the processed image representation. + Values are True for image tokens, False for non-image tokens (text, special tokens, etc.). + + Raises: + NotImplementedError: If the model doesn't support image mask generation. + + Example: + ```python + model = ColPali.load("Qdrant/colpali-v1.3-fp16") + masks = model.get_image_mask(["image1.jpg", "image2.jpg"]) + # masks[0] is a numpy array of shape (1030,) with dtype=bool for ColPali + # First 1024 values are True (image tokens), last 6 are False (text tokens) + ``` + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support image mask generation. " + "Override this method in subclasses to provide model-specific implementation." + ) diff --git a/tests/test_late_interaction_multimodal.py b/tests/test_late_interaction_multimodal.py index 8a102ace..fea4afd3 100644 --- a/tests/test_late_interaction_multimodal.py +++ b/tests/test_late_interaction_multimodal.py @@ -120,3 +120,64 @@ def test_token_count() -> None: assert short_doc_token_count + long_doc_token_count < model.token_count( documents, include_extension=True ) + + +def test_colpali_image_mask(): + """Test that get_image_mask returns correct masks for image tokens.""" + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + + model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16") + + # Get mask for single image + masks = model.get_image_mask([images[0]]) + + assert len(masks) == 1, "Should return one mask per image" + mask = masks[0] + + # ColPali uses 1030 tokens total: 1024 image + 6 text + assert mask.shape == (1030,), f"Expected shape (1030,), got {mask.shape}" + assert mask.dtype == np.bool_, f"Expected bool dtype, got {mask.dtype}" + + # First 1024 tokens should be image tokens (value=True) + assert np.all(mask[:1024]), "First 1024 tokens should be image tokens (True)" + + # Last 6 tokens should be text tokens (value=False) + assert np.all(~mask[1024:]), "Last 6 tokens should be text tokens (False)" + + # Test with multiple images + masks = model.get_image_mask([images[0], images[1]]) + assert len(masks) == 2, "Should return two masks for two images" + assert all(m.shape == (1030,) for m in masks), "All masks should have same shape" + + +def test_colpali_image_mask_single_image(): + """Test get_image_mask with a single image (not in a list).""" + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + + model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16") + + # Pass single image without list + masks = model.get_image_mask(images[0]) + + assert len(masks) == 1, "Should return one mask for single image" + assert masks[0].shape == (1030,), "Mask should have correct shape" + + +def test_base_class_raises_not_implemented(): + """Test that base class raises NotImplementedError.""" + from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( + LateInteractionMultimodalEmbeddingBase, + ) + + # Create a minimal subclass that doesn't implement get_image_mask + class MinimalModel(LateInteractionMultimodalEmbeddingBase): + pass + + model = MinimalModel(model_name="test", cache_dir="/tmp") + + with pytest.raises(NotImplementedError) as exc_info: + model.get_image_mask(["dummy.jpg"]) + + assert "does not support image mask generation" in str(exc_info.value) From 5e0af4700ae77ac6ca889463e70f5937900322fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Mon, 15 Dec 2025 12:02:17 +0100 Subject: [PATCH 2/5] fix: do not load the model for mask generation in colpali.py --- fastembed/late_interaction_multimodal/colpali.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 418cba22..b05a1e94 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -332,10 +332,6 @@ def get_image_mask( for batch_start in range(0, len(images_list), batch_size): batch = images_list[batch_start : batch_start + batch_size] - # Load the model if not already loaded - if self.model is None: - self.load_onnx_model() - # For ColPali images, input_ids follow EMPTY_TEXT_PLACEHOLDER pattern # Generate mask: True for image tokens (ID 257152), False for others for _ in batch: From 986c6ba639b1f3af0ca823d9e40e0f4913991482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Mon, 15 Dec 2025 12:22:48 +0100 Subject: [PATCH 3/5] remove Union usage --- fastembed/late_interaction_multimodal/colpali.py | 2 +- .../late_interaction_multimodal_embedding.py | 2 +- .../late_interaction_multimodal_embedding_base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index b05a1e94..bd75b9fe 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -301,7 +301,7 @@ def embed_image( def get_image_mask( self, - images: Union[ImageInput, Iterable[ImageInput]], + images: ImageInput | Iterable[ImageInput], batch_size: int = 16, **kwargs: Any, ) -> list[NumpyArray]: diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index 9dab0b71..8332bc42 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -186,7 +186,7 @@ def token_count( def get_image_mask( self, - images: Union[ImageInput, Iterable[ImageInput]], + images: ImageInput | Iterable[ImageInput], batch_size: int = 16, **kwargs: Any, ) -> list[NumpyArray]: diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 1de6e8b4..bc542139 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -87,7 +87,7 @@ def token_count( def get_image_mask( self, - images: Union[ImageInput, Iterable[ImageInput]], + images: ImageInput | Iterable[ImageInput], batch_size: int = 16, **kwargs: Any, ) -> list[NumpyArray]: From 06f1829f3bc818913a7ece82f7014f2af0f6dfe8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Mon, 15 Dec 2025 12:30:20 +0100 Subject: [PATCH 4/5] surpass mypy warnings --- fastembed/late_interaction_multimodal/colpali.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index bd75b9fe..675f46d3 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -324,7 +324,7 @@ def get_image_mask( # Ensure images is iterable is_single = isinstance(images, (str, bytes, Path)) or hasattr(images, "read") - images_to_process: Iterable[ImageInput] = [images] if is_single else images # type: ignore[list-item] + images_to_process: Iterable[ImageInput] = [images] if is_single else images # type: ignore[assignment, list-item] # Process images in batches to get input_ids masks: list[NumpyArray] = [] From c13576a51a765f67c6ce1170eda3bc15c7b3a4fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Mon, 15 Dec 2025 13:06:24 +0100 Subject: [PATCH 5/5] chore: remove batch_size --- .../late_interaction_multimodal/colpali.py | 18 +++++------------- .../late_interaction_multimodal_embedding.py | 6 ++---- ...te_interaction_multimodal_embedding_base.py | 6 ++---- 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 675f46d3..3a5c5faf 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -302,7 +302,6 @@ def embed_image( def get_image_mask( self, images: ImageInput | Iterable[ImageInput], - batch_size: int = 16, **kwargs: Any, ) -> list[NumpyArray]: """ @@ -314,8 +313,7 @@ def get_image_mask( Args: images: Single image or iterable of images - batch_size: Batch size for processing - **kwargs: Additional processing arguments + **kwargs: Additional processing arguments (reserved for future use) Returns: List of binary masks (dtype=bool) where True = image token (ID 257152), False = other tokens. @@ -326,17 +324,11 @@ def get_image_mask( is_single = isinstance(images, (str, bytes, Path)) or hasattr(images, "read") images_to_process: Iterable[ImageInput] = [images] if is_single else images # type: ignore[assignment, list-item] - # Process images in batches to get input_ids + # Generate masks - all images get the same mask based on fixed tokenization pattern masks: list[NumpyArray] = [] - images_list = list(images_to_process) - for batch_start in range(0, len(images_list), batch_size): - batch = images_list[batch_start : batch_start + batch_size] - - # For ColPali images, input_ids follow EMPTY_TEXT_PLACEHOLDER pattern - # Generate mask: True for image tokens (ID 257152), False for others - for _ in batch: - mask: NumpyArray = self.EMPTY_TEXT_PLACEHOLDER == self.IMAGE_TOKEN_ID - masks.append(mask) + for _ in images_to_process: + mask: NumpyArray = self.EMPTY_TEXT_PLACEHOLDER == self.IMAGE_TOKEN_ID + masks.append(mask) return masks diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index 8332bc42..5106737e 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -187,7 +187,6 @@ def token_count( def get_image_mask( self, images: ImageInput | Iterable[ImageInput], - batch_size: int = 16, **kwargs: Any, ) -> list[NumpyArray]: """ @@ -198,8 +197,7 @@ def get_image_mask( Args: images: Single image or iterable of images (file paths, bytes, or PIL Image objects) - batch_size: Number of images to process in each batch. Defaults to 16. - **kwargs: Additional keyword arguments for image processing. + **kwargs: Additional keyword arguments (reserved for future use) Returns: List of binary masks (numpy arrays with dtype=bool), one per image. Each mask has shape (sequence_length,) @@ -217,4 +215,4 @@ def get_image_mask( # First 1024 values are True (image tokens), last 6 are False (text tokens) ``` """ - return self.model.get_image_mask(images, batch_size, **kwargs) + return self.model.get_image_mask(images, **kwargs) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index bc542139..f883ceef 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -88,7 +88,6 @@ def token_count( def get_image_mask( self, images: ImageInput | Iterable[ImageInput], - batch_size: int = 16, **kwargs: Any, ) -> list[NumpyArray]: """ @@ -99,8 +98,7 @@ def get_image_mask( Args: images: Single image or iterable of images (file paths, bytes, or PIL Image objects) - batch_size: Number of images to process in each batch. Defaults to 16. - **kwargs: Additional keyword arguments for image processing. + **kwargs: Additional keyword arguments (reserved for future use) Returns: List of binary masks (numpy arrays with dtype=bool), one per image. Each mask has shape (sequence_length,) @@ -112,7 +110,7 @@ def get_image_mask( Example: ```python - model = ColPali.load("Qdrant/colpali-v1.3-fp16") + model = ColPali(model_name="Qdrant/colpali-v1.3-fp16") masks = model.get_image_mask(["image1.jpg", "image2.jpg"]) # masks[0] is a numpy array of shape (1030,) with dtype=bool for ColPali # First 1024 values are True (image tokens), last 6 are False (text tokens)