diff --git a/fastembed/common/onnx_model.py b/fastembed/common/onnx_model.py index d465e870..d357f2c1 100644 --- a/fastembed/common/onnx_model.py +++ b/fastembed/common/onnx_model.py @@ -21,6 +21,7 @@ class OnnxOutputContext: model_output: NumpyArray attention_mask: NDArray[np.int64] | None = None input_ids: NDArray[np.int64] | None = None + metadata: dict[str, Any] | None = None class OnnxModel(Generic[T]): diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index efbc3e25..3b702f79 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -50,9 +50,10 @@ def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]: tokenizer = Tokenizer.from_file(str(tokenizer_path)) tokenizer.enable_truncation(max_length=max_context) - tokenizer.enable_padding( - pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"] - ) + if not tokenizer.padding: + tokenizer.enable_padding( + pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"] + ) for token in tokens_map.values(): if isinstance(token, str): diff --git a/fastembed/image/onnx_image_model.py b/fastembed/image/onnx_image_model.py index 86326da9..deddcf73 100644 --- a/fastembed/image/onnx_image_model.py +++ b/fastembed/image/onnx_image_model.py @@ -76,9 +76,11 @@ def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]: return {input_name: encoded} def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext: - with contextlib.ExitStack(): + with contextlib.ExitStack() as stack: image_files = [ - Image.open(image) if not isinstance(image, Image.Image) else image + stack.enter_context(Image.open(image)) + if not isinstance(image, Image.Image) + else image for image in images ] assert self.processor is not None, "Processor is not initialized" diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index b06ef46c..9d9e2197 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -145,3 +145,77 @@ def pad2square( new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image) return new_image + + +def resize_longest_edge( + image: Image.Image, + max_size: int, + resample: int | Image.Resampling = Image.Resampling.LANCZOS, +) -> Image.Image: + height, width = image.height, image.width + aspect_ratio = width / height + + if width >= height: + # Width is longer + new_width = max_size + new_height = int(new_width / aspect_ratio) + else: + # Height is longer + new_height = max_size + new_width = int(new_height * aspect_ratio) + + # Ensure even dimensions + if new_height % 2 != 0: + new_height += 1 + if new_width % 2 != 0: + new_width += 1 + + return image.resize((new_width, new_height), resample) + + +def crop_ndarray( + image: NumpyArray, + x1: int, + y1: int, + x2: int, + y2: int, + channel_first: bool = True, +) -> NumpyArray: + if channel_first: + # (C, H, W) format + return image[:, y1:y2, x1:x2] + else: + # (H, W, C) format + return image[y1:y2, x1:x2, :] + + +def resize_ndarray( + image: NumpyArray, + size: tuple[int, int], + resample: int | Image.Resampling = Image.Resampling.LANCZOS, + channel_first: bool = True, +) -> NumpyArray: + # Convert to PIL-friendly format (H, W, C) + if channel_first: + img_hwc = image.transpose((1, 2, 0)) + else: + img_hwc = image + + # Handle different dtypes + if img_hwc.dtype == np.float32 or img_hwc.dtype == np.float64: + # Assume normalized, scale to 0-255 for PIL + img_hwc_scaled = (img_hwc * 255).astype(np.uint8) + pil_img = Image.fromarray(img_hwc_scaled, mode="RGB") + resized = pil_img.resize(size, resample) + result = np.array(resized).astype(np.float32) / 255.0 + else: + # uint8 or similar + pil_img = Image.fromarray(img_hwc.astype(np.uint8), mode="RGB") + resized = pil_img.resize(size, resample) + result = np.array(resized) + + # Convert back to original format + if channel_first: + result = result.transpose((2, 0, 1)) + + return result diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 857b1999..e6ba4d95 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -1,4 +1,5 @@ from typing import Any +import math from PIL import Image @@ -6,10 +7,13 @@ from fastembed.image.transform.functional import ( center_crop, convert_to_rgb, + crop_ndarray, normalize, pil2ndarray, rescale, resize, + resize_longest_edge, + resize_ndarray, pad2square, ) @@ -37,8 +41,18 @@ def __init__(self, mean: float | list[float], std: float | list[float]): self.mean = mean self.std = std - def __call__(self, images: list[NumpyArray]) -> list[NumpyArray]: - return [normalize(image, mean=self.mean, std=self.std) for image in images] + def __call__( # type: ignore[override] + self, images: list[NumpyArray] | list[list[NumpyArray]] + ) -> list[NumpyArray] | list[list[NumpyArray]]: + if images and isinstance(images[0], list): + # Nested structure from ImageSplitter + return [ + [normalize(image, mean=self.mean, std=self.std) for image in img_patches] # type: ignore[arg-type] + for img_patches in images + ] + else: + # Flat structure (backward compatibility) + return [normalize(image, mean=self.mean, std=self.std) for image in images] # type: ignore[arg-type] class Resize(Transform): @@ -58,8 +72,18 @@ class Rescale(Transform): def __init__(self, scale: float = 1 / 255): self.scale = scale - def __call__(self, images: list[NumpyArray]) -> list[NumpyArray]: - return [rescale(image, scale=self.scale) for image in images] + def __call__( # type: ignore[override] + self, images: list[NumpyArray] | list[list[NumpyArray]] + ) -> list[NumpyArray] | list[list[NumpyArray]]: + if images and isinstance(images[0], list): + # Nested structure from ImageSplitter + return [ + [rescale(image, scale=self.scale) for image in img_patches] # type: ignore[arg-type] + for img_patches in images + ] + else: + # Flat structure (backward compatibility) + return [rescale(image, scale=self.scale) for image in images] # type: ignore[arg-type] class PILtoNDarray(Transform): @@ -82,6 +106,167 @@ def __call__(self, images: list[Image.Image]) -> list[Image.Image]: ] +class ResizeLongestEdge(Transform): + """Resize images so the longest edge equals target size, preserving aspect ratio.""" + + def __init__( + self, + size: int, + resample: Image.Resampling = Image.Resampling.LANCZOS, + ): + self.size = size + self.resample = resample + + def __call__(self, images: list[Image.Image]) -> list[Image.Image]: + return [resize_longest_edge(image, self.size, self.resample) for image in images] + + +class ResizeForVisionEncoder(Transform): + """ + Resize both dimensions to be multiples of vision_encoder_max_size. + Preserves aspect ratio approximately. + Works on numpy arrays in (C, H, W) format. + """ + + def __init__( + self, + max_size: int, + resample: Image.Resampling = Image.Resampling.LANCZOS, + ): + self.max_size = max_size + self.resample = resample + + def __call__(self, images: list[NumpyArray]) -> list[NumpyArray]: + result = [] + for image in images: + # Assume (C, H, W) format + _, height, width = image.shape + + aspect_ratio = width / height + + if width >= height: + # Calculate new width as multiple of max_size + new_width = math.ceil(width / self.max_size) * self.max_size + new_height = int(new_width / aspect_ratio) + new_height = math.ceil(new_height / self.max_size) * self.max_size + else: + # Calculate new height as multiple of max_size + new_height = math.ceil(height / self.max_size) * self.max_size + new_width = int(new_height * aspect_ratio) + new_width = math.ceil(new_width / self.max_size) * self.max_size + + # Resize using the ndarray resize function + resized = resize_ndarray( + image, + size=(new_width, new_height), # PIL expects (width, height) + resample=self.resample, + channel_first=True, + ) + result.append(resized) + + return result + + +class ImageSplitter(Transform): + """ + Split images into grid of patches plus a global view. + + If image dimensions exceed max_size: + - Divide into ceil(H/max_size) x ceil(W/max_size) patches + - Each patch is cropped from the image + - Add a global view (original resized to max_size x max_size) + + If image is smaller than max_size: + - Return single image unchanged + + Works on numpy arrays in (C, H, W) format. + """ + + def __init__( + self, + max_size: int, + resample: Image.Resampling = Image.Resampling.LANCZOS, + ): + self.max_size = max_size + self.resample = resample + + def __call__(self, images: list[NumpyArray]) -> list[list[NumpyArray]]: # type: ignore[override] + result = [] + + for image in images: + # Assume (C, H, W) format + _, height, width = image.shape + max_height = max_width = self.max_size + + frames = [] + + if height > max_height or width > max_width: + # Calculate the number of splits needed + num_splits_h = math.ceil(height / max_height) + num_splits_w = math.ceil(width / max_width) + + # Calculate optimal patch dimensions + optimal_height = math.ceil(height / num_splits_h) + optimal_width = math.ceil(width / num_splits_w) + + # Generate patches in grid order (row by row) + for r in range(num_splits_h): + for c in range(num_splits_w): + # Calculate crop coordinates + start_x = c * optimal_width + start_y = r * optimal_height + end_x = min(start_x + optimal_width, width) + end_y = min(start_y + optimal_height, height) + + # Crop the patch + cropped = crop_ndarray( + image, x1=start_x, y1=start_y, x2=end_x, y2=end_y, channel_first=True + ) + frames.append(cropped) + + # Add global view (resized to max_size x max_size) + global_view = resize_ndarray( + image, + size=(max_width, max_height), # PIL expects (width, height) + resample=self.resample, + channel_first=True, + ) + frames.append(global_view) + else: + # Image is small enough, no splitting needed + frames.append(image) + + # Append (not extend) to preserve per-image grouping + result.append(frames) + + return result + + +class SquareResize(Transform): + """ + Resize images to square dimensions (max_size x max_size). + Works on numpy arrays in (C, H, W) format. + """ + + def __init__( + self, + size: int, + resample: Image.Resampling = Image.Resampling.LANCZOS, + ): + self.size = size + self.resample = resample + + def __call__(self, images: list[NumpyArray]) -> list[list[NumpyArray]]: # type: ignore[override] + return [ + [ + resize_ndarray( + image, size=(self.size, self.size), resample=self.resample, channel_first=True + ) + ] + for image in images + ] + + class Compose: def __init__(self, transforms: list[Transform]): self.transforms = transforms @@ -118,6 +303,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": Valid size keys (nested): - {"height", "width"} - {"shortest_edge"} + - {"longest_edge"} Returns: Compose: Image processor. @@ -128,6 +314,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": cls._get_pad2square(transforms, config) cls._get_center_crop(transforms, config) cls._get_pil2ndarray(transforms, config) + cls._get_image_splitting(transforms, config) cls._get_rescale(transforms, config) cls._get_normalize(transforms, config) return cls(transforms=transforms) @@ -196,6 +383,25 @@ def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> Non resample=resample, ) ) + elif mode == "Idefics3ImageProcessor": + if config.get("do_resize", False): + size = config.get("size", {}) + if "longest_edge" not in size: + raise ValueError( + "Size dictionary must contain 'longest_edge' key for Idefics3ImageProcessor" + ) + + # Handle resample parameter - can be int enum or PIL.Image.Resampling + resample = config.get("resample", Image.Resampling.LANCZOS) + if isinstance(resample, int): + resample = Image.Resampling(resample) + + transforms.append( + ResizeLongestEdge( + size=size["longest_edge"], + resample=resample, + ) + ) else: raise ValueError(f"Preprocessor {mode} is not supported") @@ -217,6 +423,8 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> Non pass elif mode == "JinaCLIPImageProcessor": pass + elif mode == "Idefics3ImageProcessor": + pass else: raise ValueError(f"Preprocessor {mode} is not supported") @@ -224,6 +432,28 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> Non def _get_pil2ndarray(transforms: list[Transform], config: dict[str, Any]) -> None: transforms.append(PILtoNDarray()) + @classmethod + def _get_image_splitting(cls, transforms: list[Transform], config: dict[str, Any]) -> None: + """ + Add image splitting transforms for Idefics3. + Handles conditional logic: splitting vs square resize. + Must be called AFTER PILtoNDarray. + """ + mode = config.get("image_processor_type", "CLIPImageProcessor") + + if mode == "Idefics3ImageProcessor": + do_splitting = config.get("do_image_splitting", False) + max_size = config.get("max_image_size", {}).get("longest_edge", 512) + resample = config.get("resample", Image.Resampling.LANCZOS) + if isinstance(resample, int): + resample = Image.Resampling(resample) + + if do_splitting: + transforms.append(ResizeForVisionEncoder(max_size, resample)) + transforms.append(ImageSplitter(max_size, resample)) + else: + transforms.append(SquareResize(max_size, resample)) + @staticmethod def _get_rescale(transforms: list[Transform], config: dict[str, Any]) -> None: if config.get("do_rescale", True): diff --git a/fastembed/late_interaction_multimodal/colmodernvbert.py b/fastembed/late_interaction_multimodal/colmodernvbert.py new file mode 100644 index 00000000..20b8e4f7 --- /dev/null +++ b/fastembed/late_interaction_multimodal/colmodernvbert.py @@ -0,0 +1,532 @@ +import contextlib +from typing import Any, Iterable, Type, Optional, Sequence +import json + +import numpy as np +from tokenizers import Encoding +from PIL import Image + +from fastembed.common import ImageInput +from fastembed.common.model_description import DenseModelDescription, ModelSource +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray, OnnxProvider +from fastembed.common.utils import define_cache_dir, iter_batch +from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( + LateInteractionMultimodalEmbeddingBase, +) +from fastembed.late_interaction_multimodal.onnx_multimodal_model import ( + OnnxMultimodalModel, + TextEmbeddingWorker, + ImageEmbeddingWorker, +) + +supported_colmodernvbert_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="Qdrant/colmodernvbert", + dim=128, + description="The late-interaction version of ModernVBERT, CPU friendly, English, 2025.", + license="mit", + size_in_GB=1.0, + sources=ModelSource(hf="Qdrant/colmodernvbert"), + additional_files=["processor_config.json"], + model_file="model.onnx", + ), +] + + +class ColModernVBERT(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyArray]): + """ + The ModernVBERT/colmodernvbert model implementation. This model uses + bidirectional attention, which proves to work better for retrieval. + + See: https://huggingface.co/ModernVBERT/colmodernvbert + """ + + VISUAL_PROMPT_PREFIX = ( + "<|begin_of_text|>User:Describe the image.\nAssistant:" + ) + QUERY_AUGMENTATION_TOKEN = "" + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + specific_model_path: Optional[str] = None, + **kwargs: Any, + ): + """ + Args: + model_name (str): The name of the model to use. + cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. + providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use. + Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None. + cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers` + Defaults to False. + device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in + workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None. + lazy_load (bool, optional): Whether to load the model during class initialization or on demand. + Should be set to True when using multiple-gpu and parallel encoding. Defaults to False. + device_id (Optional[int], optional): The device id to use for loading the model in the worker process. + + Raises: + ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. + """ + super().__init__(model_name, cache_dir, threads, **kwargs) + self.providers = providers + self.lazy_load = lazy_load + self._extra_session_options = self._select_exposed_session_options(kwargs) + + # List of device ids, that can be used for data parallel processing in workers + self.device_ids = device_ids + self.cuda = cuda + + # This device_id will be used if we need to load model in current process + self.device_id: Optional[int] = None + if device_id is not None: + self.device_id = device_id + elif self.device_ids is not None: + self.device_id = self.device_ids[0] + + self.model_description = self._get_model_description(model_name) + self.cache_dir = str(define_cache_dir(cache_dir)) + + self._specific_model_path = specific_model_path + self._model_dir = self.download_model( + self.model_description, + self.cache_dir, + local_files_only=self._local_files_only, + specific_model_path=self._specific_model_path, + ) + self.mask_token_id = None + self.pad_token_id = None + self.image_seq_len: Optional[int] = None + self.max_image_size: Optional[int] = None + self.image_size: Optional[int] = None + + if not self.lazy_load: + self.load_onnx_model() + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + """Lists the supported models. + + Returns: + list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. + """ + return supported_colmodernvbert_models + + def load_onnx_model(self) -> None: + self._load_onnx_model( + model_dir=self._model_dir, + model_file=self.model_description.model_file, + threads=self.threads, + providers=self.providers, + cuda=self.cuda, + device_id=self.device_id, + extra_session_options=self._extra_session_options, + ) + + # Load image processing configuration + processor_config_path = self._model_dir / "processor_config.json" + with open(processor_config_path) as f: + processor_config = json.load(f) + self.image_seq_len = processor_config.get("image_seq_len", 64) + + preprocessor_config_path = self._model_dir / "preprocessor_config.json" + with open(preprocessor_config_path) as f: + preprocessor_config = json.load(f) + self.max_image_size = preprocessor_config.get("max_image_size", {}).get( + "longest_edge", 512 + ) + + # Load model configuration + config_path = self._model_dir / "config.json" + with open(config_path) as f: + model_config = json.load(f) + vision_config = model_config.get("vision_config", {}) + self.image_size = vision_config.get("image_size", 512) + + def _preprocess_onnx_text_input( + self, onnx_input: dict[str, NumpyArray], **kwargs: Any + ) -> dict[str, NumpyArray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[NumpyArray]: Post-processed output as NumPy arrays. + """ + batch_size, seq_length = onnx_input["input_ids"].shape + empty_image_placeholder: NumpyArray = np.zeros( + (batch_size, seq_length, 3, self.image_size, self.image_size), + dtype=np.float32, # type: ignore[type-var,arg-type,assignment] + ) + onnx_input["pixel_values"] = empty_image_placeholder + return onnx_input + + def _post_process_onnx_text_output( + self, + output: OnnxOutputContext, + ) -> Iterable[NumpyArray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[NumpyArray]: Post-processed output as NumPy arrays. + """ + return output.model_output + + def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: + # Add query augmentation tokens (matching process_queries logic from colpali-engine) + augmented_queries = [doc + self.QUERY_AUGMENTATION_TOKEN * 10 for doc in documents] + encoded = self.tokenizer.encode_batch(augmented_queries) # type: ignore[union-attr] + return encoded + + def token_count( + self, + texts: str | Iterable[str], + batch_size: int = 1024, + include_extension: bool = False, + **kwargs: Any, + ) -> int: + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() # loads the tokenizer as well + token_num = 0 + texts = [texts] if isinstance(texts, str) else texts + assert self.tokenizer is not None + tokenize_func = self.tokenize if include_extension else self.tokenizer.encode_batch + for batch in iter_batch(texts, batch_size): + token_num += sum([sum(encoding.attention_mask) for encoding in tokenize_func(batch)]) + return token_num + + def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext: + with contextlib.ExitStack() as stack: + image_files = [ + stack.enter_context(Image.open(image)) + if not isinstance(image, Image.Image) + else image + for image in images + ] + assert self.processor is not None, "Processor is not initialized" + processed = self.processor(image_files) + encoded, attention_mask, metadata = self._process_nested_patches(processed) # type: ignore[arg-type] + + onnx_input = {"pixel_values": encoded, "attention_mask": attention_mask} + onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs) + model_output = self.model.run(None, onnx_input) # type: ignore[union-attr] + + return OnnxOutputContext( + model_output=model_output[0], + attention_mask=attention_mask, # type: ignore[arg-type] + metadata=metadata, + ) + + @staticmethod + def _process_nested_patches( + processed: list[list[NumpyArray]], + ) -> tuple[NumpyArray, NumpyArray, dict[str, Any]]: + """ + Process nested image patches (from ImageSplitter). + + Args: + processed: List of patch lists, one per image [[img1_patches], [img2_patches], ...] + + Returns: + tuple: (encoded array, attention_mask, metadata) + - encoded: (batch_size, max_patches, C, H, W) + - attention_mask: (batch_size, max_patches) with 1 for real patches, 0 for padding + - metadata: Dict with 'patch_counts' key + """ + patch_counts = [len(patches) for patches in processed] + max_patches = max(patch_counts) + + # Get dimensions from first patch + channels, height, width = processed[0][0].shape + batch_size = len(processed) + + # Create padded array + encoded = np.zeros( + (batch_size, max_patches, channels, height, width), dtype=processed[0][0].dtype + ) + + # Create attention mask (1 for real patches, 0 for padding) + attention_mask = np.zeros((batch_size, max_patches), dtype=np.int64) + + # Fill in patches and attention mask + for i, patches in enumerate(processed): + for j, patch in enumerate(patches): + encoded[i, j] = patch + attention_mask[i, j] = 1 + + metadata = {"patch_counts": patch_counts} + return encoded, attention_mask, metadata # type: ignore[return-value] + + def _preprocess_onnx_image_input( + self, onnx_input: dict[str, np.ndarray], **kwargs: Any + ) -> dict[str, NumpyArray]: + """ + Add text input placeholders for image data, following Idefics3 processing logic. + + Constructs input_ids dynamically based on the actual number of image patches, + using the same token expansion logic as Idefics3Processor. + + Args: + onnx_input: Dict with 'pixel_values' (batch, num_patches, C, H, W) + and 'attention_mask' (batch, num_patches) indicating real patches + **kwargs: Additional arguments + + Returns: + Updated onnx_input with 'input_ids' and updated 'attention_mask' for token sequence + """ + # The attention_mask in onnx_input has a shape of (batch_size, num_patches), + # and should be used to create an attention mask matching the input_ids shape. + patch_attention_mask = onnx_input["attention_mask"] + pixel_values = onnx_input["pixel_values"] + + batch_size = pixel_values.shape[0] + batch_input_ids = [] + + # Build input_ids for each image based on its actual patch count + for i in range(batch_size): + # Count real patches (non-padded) from attention mask + patch_count = int(np.sum(patch_attention_mask[i])) + + # Compute rows/cols from patch count + rows, cols = self._compute_rows_cols_from_patches(patch_count) + + # Build input_ids for this image + input_ids = self._build_input_ids_for_image(rows, cols) + batch_input_ids.append(input_ids) + + # Pad sequences to max length in batch + max_len = max(len(ids) for ids in batch_input_ids) + + # Get padding config from tokenizer + padding_direction = self.tokenizer.padding["direction"] # type: ignore[index,union-attr] + pad_token_id = self.tokenizer.padding["pad_id"] # type: ignore[index,union-attr] + + # Initialize with pad token + padded_input_ids = np.full((batch_size, max_len), pad_token_id, dtype=np.int64) + attention_mask = np.zeros((batch_size, max_len), dtype=np.int64) + + for i, input_ids in enumerate(batch_input_ids): + seq_len = len(input_ids) + if padding_direction == "left": + # Left padding: place tokens at the END of the array + start_idx = max_len - seq_len + padded_input_ids[i, start_idx:] = input_ids + attention_mask[i, start_idx:] = 1 + else: + # Right padding: place tokens at the START of the array + padded_input_ids[i, :seq_len] = input_ids + attention_mask[i, :seq_len] = 1 + + onnx_input["input_ids"] = padded_input_ids + # Update attention_mask with token-level data + onnx_input["attention_mask"] = attention_mask + return onnx_input + + @staticmethod + def _compute_rows_cols_from_patches(patch_count: int) -> tuple[int, int]: + if patch_count <= 1: + return 0, 0 + + # Subtract 1 for the global image + grid_patches = patch_count - 1 + + # Find rows and cols (assume square or near-square grid) + rows = int(grid_patches**0.5) + cols = grid_patches // rows + + # Verify the calculation + if rows * cols + 1 != patch_count: + # Handle non-square grids + for r in range(1, grid_patches + 1): + if grid_patches % r == 0: + c = grid_patches // r + if r * c + 1 == patch_count: + return r, c + # Fallback: treat as unsplit + return 0, 0 + + return rows, cols + + def _create_single_image_prompt_string(self) -> str: + return ( + "" + + "" + + "" * self.image_seq_len # type: ignore[operator] + + "" + ) + + def _create_split_image_prompt_string(self, rows: int, cols: int) -> str: + text_split_images = "" + + # Add tokens for each patch in the grid + for n_h in range(rows): + for n_w in range(cols): + text_split_images += ( + "" + + f"" + + "" * self.image_seq_len # type: ignore[operator] + ) + text_split_images += "\n" + + # Add global image at the end + text_split_images += ( + "\n" + + "" + + "" * self.image_seq_len # type: ignore[operator] + + "" + ) + + return text_split_images + + def _build_input_ids_for_image(self, rows: int, cols: int) -> np.ndarray: + # Create the appropriate image prompt string + if rows == 0 and cols == 0: + image_prompt_tokens = self._create_single_image_prompt_string() + else: + image_prompt_tokens = self._create_split_image_prompt_string(rows, cols) + + # Replace in visual prompt with expanded tokens + # The visual prompt is: "<|begin_of_text|>User:Describe the image.\nAssistant:" + expanded_prompt = self.VISUAL_PROMPT_PREFIX.replace("", image_prompt_tokens) + + # Tokenize the complete prompt + encoded = self.tokenizer.encode(expanded_prompt) # type: ignore[union-attr] + + # Convert to numpy array + return np.array(encoded.ids, dtype=np.int64) + + def _post_process_onnx_image_output( + self, + output: OnnxOutputContext, + ) -> Iterable[NumpyArray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[NumpyArray]: Post-processed output as NumPy arrays. + """ + assert self.model_description.dim is not None, "Model dim is not defined" + return output.model_output.reshape( + output.model_output.shape[0], -1, self.model_description.dim + ) + + def embed_text( + self, + documents: str | Iterable[str], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs: Any, + ) -> Iterable[NumpyArray]: + """ + Encode a list of documents into list of embeddings. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self._embed_documents( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + documents=documents, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + local_files_only=self._local_files_only, + specific_model_path=self._specific_model_path, + extra_session_options=self._extra_session_options, + **kwargs, + ) + + def embed_image( + self, + images: ImageInput | Iterable[ImageInput], + batch_size: int = 16, + parallel: Optional[int] = None, + **kwargs: Any, + ) -> Iterable[NumpyArray]: + """ + Encode a list of images into list of embeddings. + + Args: + images: Iterator of image paths or single image path to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self._embed_images( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + images=images, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + local_files_only=self._local_files_only, + specific_model_path=self._specific_model_path, + extra_session_options=self._extra_session_options, + **kwargs, + ) + + @classmethod + def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]: + return ColModernVBERTTextEmbeddingWorker + + @classmethod + def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker[NumpyArray]]: + return ColModernVBERTImageEmbeddingWorker + + +class ColModernVBERTTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColModernVBERT: + return ColModernVBERT( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) + + +class ColModernVBERTImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColModernVBERT: + return ColModernVBERT( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py index afe839d4..10d426d0 100644 --- a/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +++ b/fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py @@ -4,6 +4,7 @@ from fastembed.common import OnnxProvider, ImageInput from fastembed.common.types import NumpyArray, Device from fastembed.late_interaction_multimodal.colpali import ColPali +from fastembed.late_interaction_multimodal.colmodernvbert import ColModernVBERT from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( LateInteractionMultimodalEmbeddingBase, @@ -12,7 +13,10 @@ class LateInteractionMultimodalEmbedding(LateInteractionMultimodalEmbeddingBase): - EMBEDDINGS_REGISTRY: list[Type[LateInteractionMultimodalEmbeddingBase]] = [ColPali] + EMBEDDINGS_REGISTRY: list[Type[LateInteractionMultimodalEmbeddingBase]] = [ + ColPali, + ColModernVBERT, + ] @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: diff --git a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py index 18b36338..93436895 100644 --- a/fastembed/late_interaction_multimodal/onnx_multimodal_model.py +++ b/fastembed/late_interaction_multimodal/onnx_multimodal_model.py @@ -170,9 +170,11 @@ def _embed_documents( yield from self._post_process_onnx_text_output(batch) # type: ignore def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext: - with contextlib.ExitStack(): + with contextlib.ExitStack() as stack: image_files = [ - Image.open(image) if not isinstance(image, Image.Image) else image + stack.enter_context(Image.open(image)) + if not isinstance(image, Image.Image) + else image for image in images ] assert self.processor is not None, "Processor is not initialized" diff --git a/tests/test_late_interaction_multimodal.py b/tests/test_late_interaction_multimodal.py index 8a102ace..94ae47e7 100644 --- a/tests/test_late_interaction_multimodal.py +++ b/tests/test_late_interaction_multimodal.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager import pytest from PIL import Image @@ -6,7 +7,7 @@ from fastembed import LateInteractionMultimodalEmbedding from tests.config import TEST_MISC_DIR - +from tests.utils import delete_model_cache # vectors are abridged and rounded for brevity CANONICAL_IMAGE_VALUES = { @@ -21,6 +22,17 @@ [-0.1299, -0.0691, 0.1097, 0.0728, 0.0123, 0.0519, 0.0122], ] ), + "Qdrant/colmodernvbert": np.array( + [ + [0.11614, -0.15793, -0.11194, 0.0688, 0.08001, 0.10575, -0.07871], + [0.10094, -0.13301, -0.12069, 0.10932, 0.04645, 0.09884, 0.04048], + [0.13106, -0.18613, -0.13469, 0.10566, 0.03659, 0.07712, -0.03916], + [0.09754, -0.09596, -0.04839, 0.14991, 0.05692, 0.10569, -0.08349], + [0.02576, -0.15651, -0.09977, 0.09707, 0.13412, 0.09994, -0.09931], + [-0.06741, -0.1787, -0.19677, -0.07618, 0.13102, -0.02131, -0.02437], + [-0.02776, -0.10187, -0.13793, 0.03835, 0.04766, 0.04701, -0.15635], + ] + ), } CANONICAL_QUERY_VALUES = { @@ -35,6 +47,17 @@ [-0.0165, -0.0106, 0.1672, -0.0768, 0.0389, -0.0038, 0.1137], ] ), + "Qdrant/colmodernvbert": np.array( + [ + [0.05, 0.06557, 0.04026, 0.14981, 0.1842, 0.0263, -0.18706], + [-0.05664, -0.14028, 0.00649, -0.02849, 0.09034, -0.01494, 0.10693], + [-0.10147, -0.00716, 0.09084, -0.08236, -0.01849, -0.00972, -0.00461], + [-0.1233, -0.10814, -0.02337, -0.00329, 0.05984, 0.09934, 0.09846], + [-0.07053, -0.13119, -0.06487, 0.01508, 0.07459, 0.07655, 0.14821], + [0.00526, -0.13842, -0.05837, -0.02721, 0.13009, 0.05076, 0.17962], + [0.00924, -0.14383, -0.03057, -0.03691, 0.11718, 0.037, 0.13344], + ] + ), } queries = ["hello world", "flag embedding"] @@ -44,43 +67,69 @@ Image.open((TEST_MISC_DIR / "image.jpeg")), ] +_MODELS_TO_CACHE = ("Qdrant/colmodernvbert",) +MODELS_TO_CACHE = tuple(model_name.lower() for model_name in _MODELS_TO_CACHE) -def test_batch_embedding(): - if os.getenv("CI"): - pytest.skip("Colpali is too large to test in CI") - for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): - print("evaluating", model_name) - model = LateInteractionMultimodalEmbedding(model_name=model_name) - result = list(model.embed_image(images, batch_size=2)) +@pytest.fixture(scope="module") +def model_cache(): + is_ci = os.getenv("CI") + cache = {} - for value in result: - token_num, abridged_dim = expected_result.shape - assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=2e-3) + @contextmanager + def get_model(model_name: str): + lowercase_model_name = model_name.lower() + if lowercase_model_name not in cache: + cache[lowercase_model_name] = LateInteractionMultimodalEmbedding(lowercase_model_name) + yield cache[lowercase_model_name] + if lowercase_model_name not in MODELS_TO_CACHE: + model_inst = cache.pop(lowercase_model_name) + if is_ci: + delete_model_cache(model_inst.model._model_dir) + del model_inst + + yield get_model + if is_ci: + for _, model in cache.items(): + delete_model_cache(model.model._model_dir) + cache.clear() -def test_single_embedding(): - if os.getenv("CI"): - pytest.skip("Colpali is too large to test in CI") +def test_batch_embedding(model_cache): for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): + if model_name.lower() == "Qdrant/colpali-v1.3-fp16".lower() and os.getenv("CI"): + continue # colpali is too large for ci + print("evaluating", model_name) - model = LateInteractionMultimodalEmbedding(model_name=model_name) - result = next(iter(model.embed_image(images, batch_size=6))) - token_num, abridged_dim = expected_result.shape - assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) + with model_cache(model_name) as model: + result = list(model.embed_image(images, batch_size=2)) + + for value in result: + token_num, abridged_dim = expected_result.shape + assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=2e-3) + +def test_single_embedding(model_cache): + for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): + if model_name.lower() == "Qdrant/colpali-v1.3-fp16".lower() and os.getenv("CI"): + continue # colpali is too large for ci + print("evaluating", model_name) + with model_cache(model_name) as model: + result = next(iter(model.embed_image(images, batch_size=6))) + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) -def test_single_embedding_query(): - if os.getenv("CI"): - pytest.skip("Colpali is too large to test in CI") +def test_single_embedding_query(model_cache): for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): + if model_name.lower() == "Qdrant/colpali-v1.3-fp16".lower() and os.getenv("CI"): + continue # colpali is too large for ci print("evaluating", model_name) - model = LateInteractionMultimodalEmbedding(model_name=model_name) - result = next(iter(model.embed_text(queries))) - token_num, abridged_dim = expected_result.shape - assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) + with model_cache(model_name) as model: + result = next(iter(model.embed_text(queries))) + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) def test_get_embedding_size(): @@ -90,33 +139,27 @@ def test_get_embedding_size(): model_name = "Qdrant/ColPali-v1.3-fp16" assert LateInteractionMultimodalEmbedding.get_embedding_size(model_name) == 128 + model_name = "Qdrant/colmodernvbert" + assert LateInteractionMultimodalEmbedding.get_embedding_size(model_name) == 128 -def test_embedding_size(): - if os.getenv("CI"): - pytest.skip("Colpali is too large to test in CI") - model_name = "Qdrant/colpali-v1.3-fp16" - model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) - assert model.embedding_size == 128 - model_name = "Qdrant/ColPali-v1.3-fp16" +def test_embedding_size(): + model_name = "Qdrant/colmodernvbert" model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) assert model.embedding_size == 128 -def test_token_count() -> None: - if os.getenv("CI"): - pytest.skip("Colpali is too large to test in CI") - model_name = "Qdrant/colpali-v1.3-fp16" - model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) - - documents = ["short doc", "it is a long document to check attention mask for paddings"] - short_doc_token_count = model.token_count(documents[0]) - long_doc_token_count = model.token_count(documents[1]) - documents_token_count = model.token_count(documents) - assert short_doc_token_count + long_doc_token_count == documents_token_count - assert short_doc_token_count + long_doc_token_count == model.token_count( - documents, batch_size=1 - ) - assert short_doc_token_count + long_doc_token_count < model.token_count( - documents, include_extension=True - ) +def test_token_count(model_cache) -> None: + model_name = "Qdrant/colmodernvbert" + with model_cache(model_name) as model: + documents = ["short doc", "it is a long document to check attention mask for paddings"] + short_doc_token_count = model.token_count(documents[0]) + long_doc_token_count = model.token_count(documents[1]) + documents_token_count = model.token_count(documents) + assert short_doc_token_count + long_doc_token_count == documents_token_count + assert short_doc_token_count + long_doc_token_count == model.token_count( + documents, batch_size=1 + ) + assert short_doc_token_count + long_doc_token_count < model.token_count( + documents, include_extension=True + )