diff --git a/fastembed/late_interaction_multimodal/colpali.py b/fastembed/late_interaction_multimodal/colpali.py index 85fbcd06..3a5c5faf 100644 --- a/fastembed/late_interaction_multimodal/colpali.py +++ b/fastembed/late_interaction_multimodal/colpali.py @@ -36,9 +36,10 @@ class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyA BOS_TOKEN = "" PAD_TOKEN = "" QUERY_MARKER_TOKEN_ID = [2, 5098] + IMAGE_TOKEN_ID = 257152 # The '' special token IMAGE_PLACEHOLDER_SIZE = (3, 448, 448) EMPTY_TEXT_PLACEHOLDER = np.array( - [257152] * 1024 + [2, 50721, 573, 2416, 235265, 108] + [IMAGE_TOKEN_ID] * 1024 + [2, 50721, 573, 2416, 235265, 108] ) # This is a tokenization of '' * 1024 + 'Describe the image.\n' line which is used as placeholder # while processing an image EVEN_ATTENTION_MASK = np.array([1] * 1030) @@ -298,6 +299,39 @@ def embed_image( **kwargs, ) + def get_image_mask( + self, + images: ImageInput | Iterable[ImageInput], + **kwargs: Any, + ) -> list[NumpyArray]: + """ + Generate image token masks for ColPali embeddings. + + For ColPali, image embeddings use 1030 tokens: + - Tokens 0-1023: Image tokens (token ID 257152) + - Tokens 1024-1029: Text tokens from prompt "Describe the image.\\n" + + Args: + images: Single image or iterable of images + **kwargs: Additional processing arguments (reserved for future use) + + Returns: + List of binary masks (dtype=bool) where True = image token (ID 257152), False = other tokens. + """ + from pathlib import Path + + # Ensure images is iterable + is_single = isinstance(images, (str, bytes, Path)) or hasattr(images, "read") + images_to_process: Iterable[ImageInput] = [images] if is_single else images # type: ignore[assignment, list-item] + + # Generate masks - all images get the same mask based on fixed tokenization pattern + masks: list[NumpyArray] = [] + for _ in images_to_process: + mask: NumpyArray = self.EMPTY_TEXT_PLACEHOLDER == self.IMAGE_TOKEN_ID + masks.append(mask) + + return masks + @classmethod def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]: return ColPaliTextEmbeddingWorker diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index afe839d4..5106737e 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -183,3 +183,36 @@ def token_count( return self.model.token_count( texts, batch_size=batch_size, include_extension=include_extension, **kwargs ) + + def get_image_mask( + self, + images: ImageInput | Iterable[ImageInput], + **kwargs: Any, + ) -> list[NumpyArray]: + """ + Generate binary masks identifying image tokens in processed image sequences. + + This method processes images and returns masks indicating which tokens in the + resulting sequence correspond to image content (value=1) vs text/special tokens (value=0). + + Args: + images: Single image or iterable of images (file paths, bytes, or PIL Image objects) + **kwargs: Additional keyword arguments (reserved for future use) + + Returns: + List of binary masks (numpy arrays with dtype=bool), one per image. Each mask has shape (sequence_length,) + where sequence_length is the number of tokens in the processed image representation. + Values are True for image tokens, False for non-image tokens (text, special tokens, etc.). + + Raises: + NotImplementedError: If the underlying model doesn't support image mask generation. + + Example: + ```python + model = LateInteractionMultimodalEmbedding("Qdrant/colpali-v1.3-fp16") + masks = model.get_image_mask(["image1.jpg", "image2.jpg"]) + # masks[0] is a numpy array of shape (1030,) with dtype=bool for ColPali + # First 1024 values are True (image tokens), last 6 are False (text tokens) + ``` + """ + return self.model.get_image_mask(images, **kwargs) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py index 72a87fe5..f883ceef 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py @@ -84,3 +84,39 @@ def token_count( ) -> int: """Returns the number of tokens in the texts.""" raise NotImplementedError("Subclasses must implement this method") + + def get_image_mask( + self, + images: ImageInput | Iterable[ImageInput], + **kwargs: Any, + ) -> list[NumpyArray]: + """ + Generate binary masks identifying image tokens in processed image sequences. + + This method processes images and returns masks indicating which tokens in the + resulting sequence correspond to image content (value=1) vs text/special tokens (value=0). + + Args: + images: Single image or iterable of images (file paths, bytes, or PIL Image objects) + **kwargs: Additional keyword arguments (reserved for future use) + + Returns: + List of binary masks (numpy arrays with dtype=bool), one per image. Each mask has shape (sequence_length,) + where sequence_length is the number of tokens in the processed image representation. + Values are True for image tokens, False for non-image tokens (text, special tokens, etc.). + + Raises: + NotImplementedError: If the model doesn't support image mask generation. + + Example: + ```python + model = ColPali(model_name="Qdrant/colpali-v1.3-fp16") + masks = model.get_image_mask(["image1.jpg", "image2.jpg"]) + # masks[0] is a numpy array of shape (1030,) with dtype=bool for ColPali + # First 1024 values are True (image tokens), last 6 are False (text tokens) + ``` + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support image mask generation. " + "Override this method in subclasses to provide model-specific implementation." + ) diff --git a/tests/test_late_interaction_multimodal.py b/tests/test_late_interaction_multimodal.py index 8a102ace..fea4afd3 100644 --- a/tests/test_late_interaction_multimodal.py +++ b/tests/test_late_interaction_multimodal.py @@ -120,3 +120,64 @@ def test_token_count() -> None: assert short_doc_token_count + long_doc_token_count < model.token_count( documents, include_extension=True ) + + +def test_colpali_image_mask(): + """Test that get_image_mask returns correct masks for image tokens.""" + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + + model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16") + + # Get mask for single image + masks = model.get_image_mask([images[0]]) + + assert len(masks) == 1, "Should return one mask per image" + mask = masks[0] + + # ColPali uses 1030 tokens total: 1024 image + 6 text + assert mask.shape == (1030,), f"Expected shape (1030,), got {mask.shape}" + assert mask.dtype == np.bool_, f"Expected bool dtype, got {mask.dtype}" + + # First 1024 tokens should be image tokens (value=True) + assert np.all(mask[:1024]), "First 1024 tokens should be image tokens (True)" + + # Last 6 tokens should be text tokens (value=False) + assert np.all(~mask[1024:]), "Last 6 tokens should be text tokens (False)" + + # Test with multiple images + masks = model.get_image_mask([images[0], images[1]]) + assert len(masks) == 2, "Should return two masks for two images" + assert all(m.shape == (1030,) for m in masks), "All masks should have same shape" + + +def test_colpali_image_mask_single_image(): + """Test get_image_mask with a single image (not in a list).""" + if os.getenv("CI"): + pytest.skip("Colpali is too large to test in CI") + + model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16") + + # Pass single image without list + masks = model.get_image_mask(images[0]) + + assert len(masks) == 1, "Should return one mask for single image" + assert masks[0].shape == (1030,), "Mask should have correct shape" + + +def test_base_class_raises_not_implemented(): + """Test that base class raises NotImplementedError.""" + from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( + LateInteractionMultimodalEmbeddingBase, + ) + + # Create a minimal subclass that doesn't implement get_image_mask + class MinimalModel(LateInteractionMultimodalEmbeddingBase): + pass + + model = MinimalModel(model_name="test", cache_dir="/tmp") + + with pytest.raises(NotImplementedError) as exc_info: + model.get_image_mask(["dummy.jpg"]) + + assert "does not support image mask generation" in str(exc_info.value)