From 70678dd6bd49549c50fde702f87cf1e6c8efcfac Mon Sep 17 00:00:00 2001 From: Devasena Inupakutika Date: Wed, 21 Jan 2026 07:05:38 -0800 Subject: [PATCH 1/2] Optimized vector generation for VDB Benchmark --- vdb_benchmark/README.md | 72 +++- vdb_benchmark/pyproject.toml | 3 +- vdb_benchmark/vdbbench/load_vdb.py | 658 ++++++++++++++++++++++++++--- 3 files changed, 680 insertions(+), 53 deletions(-) diff --git a/vdb_benchmark/README.md b/vdb_benchmark/README.md index e8ea20e4..9d874d05 100644 --- a/vdb_benchmark/README.md +++ b/vdb_benchmark/README.md @@ -56,12 +56,82 @@ The benchmark process consists of three main steps: ### Step 1: Load Vectors into the Database Use the load_vdb.py script to generate and load 10 million vectors into your vector database: (this process can take up to 8 hours) + +#### Default/ Standard Mode + +##### Basic execution with config file ```bash python vdbbench/load_vdb.py --config vdbbench/configs/10m_diskann.yaml ``` +##### With explicit parameters (no config) +```bash +python vdbbench/load_vdb.py --collection-name benchmark_test \ + --dimension 1536 \ + --num-vectors 1000000 \ + --batch-size 10000 +``` + +##### Override config values +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/10m_diskann.yaml \ + --collection-name custom_collection \ + --num-vectors 500000 \ + --force +``` + +##### With reproducible seed +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/10m_diskann.yaml \ + --seed 42 +``` + +#### Adaptive Mode (Memory-Aware Batch Sizing) + +##### Enable adaptive batching (auto-scales based on memory pressure) +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/100m_diskann.yaml \ + --adaptive +``` + +##### With explicit memory budget +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/100m_diskann.yaml \ + --adaptive \ + --memory-budget 4G +``` +##### Adaptive with smaller budget for constrained systems +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/100m_diskann.yaml \ + --adaptive \ + --memory-budget 2G \ + --batch-size 5000 +``` + +#### Disk-Backed Mode (Billion-Scale / Low Memory) + +##### Use memory-mapped temp file (default temp directory) +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/1b_diskann.yaml \ + --disk-backed +``` + +##### Specify fast NVMe for temp storage +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/1b_diskann.yaml \ + --disk-backed \ + --temp-dir /mnt/nvme/tmp +``` + +##### Disk-backed with seed for reproducibility +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/1b_diskann.yaml \ + --disk-backed \ + --temp-dir /mnt/nvme/tmp \ + --seed 12345 +``` -For testing, I recommend using a smaller data by passing the num_vectors option: +For testing, we recommend using a smaller data by passing the num_vectors option: ```bash python vdbbench/load_vdb.py --config vdbbench/configs/10m_diskann.yaml --collection-name mlps_500k_10shards_1536dim_uniform_diskann --num-vectors 500000 ``` diff --git a/vdb_benchmark/pyproject.toml b/vdb_benchmark/pyproject.toml index f4d56d8f..97b4c5d1 100644 --- a/vdb_benchmark/pyproject.toml +++ b/vdb_benchmark/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "pandas", "pymilvus", "pyyaml", - "tabulate" + "tabulate", + "psutil" ] [project.urls] diff --git a/vdb_benchmark/vdbbench/load_vdb.py b/vdb_benchmark/vdbbench/load_vdb.py index 0a7a9324..b524e182 100644 --- a/vdb_benchmark/vdbbench/load_vdb.py +++ b/vdb_benchmark/vdbbench/load_vdb.py @@ -1,8 +1,14 @@ import argparse +import gc import logging -import sys +import mmap import os +import sys +import tempfile import time +from pathlib import Path +from typing import Dict, Any, Optional, Generator, Tuple + import numpy as np from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility @@ -11,6 +17,13 @@ from vdbbench.config_loader import load_config, merge_config_with_args from vdbbench.compact_and_watch import monitor_progress +# Optional psutil for adaptive mode +try: + import psutil + PSUTIL_AVAILABLE = True +except ImportError: + PSUTIL_AVAILABLE = False + # Configure logging logging.basicConfig( level=logging.INFO, @@ -18,6 +31,168 @@ ) logger = logging.getLogger(__name__) + +# ============================================================================= +# Memory Management Utilities +# ============================================================================= + +def parse_memory_string(mem_str: str) -> int: + """Parse memory string like '4G', '512M' to bytes.""" + if isinstance(mem_str, (int, float)): + return int(mem_str) + if not mem_str: + return 0 + mem_str = str(mem_str).strip().upper() + multipliers = {'B': 1, 'K': 1024, 'M': 1024**2, 'G': 1024**3, 'T': 1024**4} + if mem_str[-1] in multipliers: + return int(float(mem_str[:-1]) * multipliers[mem_str[-1]]) + return int(mem_str) + + +def get_memory_percent() -> float: + """Get current memory usage percentage.""" + if PSUTIL_AVAILABLE: + return psutil.virtual_memory().percent + return 50.0 # Default if psutil not available + + +def get_available_memory() -> int: + """Get available memory in bytes.""" + if PSUTIL_AVAILABLE: + return psutil.virtual_memory().available + return 8 * 1024**3 # Assume 8GB if psutil not available + + +class AdaptiveBatchController: + """ + Adaptive batch size controller based on memory pressure. + Only active when --adaptive flag is used. + """ + + def __init__(self, initial_batch_size: int, + min_batch_size: int = 100, + max_batch_size: int = 100000, + memory_threshold: float = 80.0): + self.current_batch_size = initial_batch_size + self.initial_batch_size = initial_batch_size + self.min_batch_size = min_batch_size + self.max_batch_size = max_batch_size + self.memory_threshold = memory_threshold + self.adjustment_count = 0 + self.batches_since_adjustment = 0 + + def get_batch_size(self) -> int: + """Get current batch size, adjusting based on memory if needed.""" + if not PSUTIL_AVAILABLE: + return self.current_batch_size + + self.batches_since_adjustment += 1 + mem_percent = get_memory_percent() + + # Scale down if memory pressure + if mem_percent > self.memory_threshold: + new_size = max(self.min_batch_size, int(self.current_batch_size * 0.5)) + if new_size < self.current_batch_size: + logger.info(f"[Adaptive] Memory at {mem_percent:.1f}%, reducing batch: " + f"{self.current_batch_size} -> {new_size}") + self.current_batch_size = new_size + self.adjustment_count += 1 + self.batches_since_adjustment = 0 + gc.collect() + # Scale up if plenty of headroom + elif (self.batches_since_adjustment > 10 and + mem_percent < self.memory_threshold - 25 and + self.current_batch_size < self.initial_batch_size): + new_size = min(self.initial_batch_size, int(self.current_batch_size * 1.25)) + if new_size > self.current_batch_size: + logger.info(f"[Adaptive] Memory at {mem_percent:.1f}%, increasing batch: " + f"{self.current_batch_size} -> {new_size}") + self.current_batch_size = new_size + self.adjustment_count += 1 + self.batches_since_adjustment = 0 + + return self.current_batch_size + + def force_scale_down(self): + """Force scale down after an error.""" + new_size = max(self.min_batch_size, int(self.current_batch_size * 0.5)) + if new_size < self.current_batch_size: + logger.info(f"[Adaptive] Forcing batch reduction: {self.current_batch_size} -> {new_size}") + self.current_batch_size = new_size + self.adjustment_count += 1 + gc.collect() + + +class DiskBackedBuffer: + """ + Memory-mapped file buffer for disk-backed vector generation. + Only used when --disk-backed flag is specified. + """ + + def __init__(self, dimension: int, max_vectors: int, + temp_dir: Optional[str] = None): + self.dimension = dimension + self.max_vectors = max_vectors + self.dtype_size = 4 # float32 + self.vector_size = dimension * self.dtype_size + self.file_size = self.vector_size * max_vectors + + # Create temp file + temp_dir_path = Path(temp_dir) if temp_dir else Path(tempfile.gettempdir()) + temp_dir_path.mkdir(parents=True, exist_ok=True) + + self.temp_file = tempfile.NamedTemporaryFile( + dir=temp_dir_path, prefix='vdbbench_', suffix='.mmap', delete=False + ) + self.temp_path = Path(self.temp_file.name) + + # Pre-allocate + self.temp_file.seek(self.file_size - 1) + self.temp_file.write(b'\0') + self.temp_file.flush() + + # Memory map + self.mmap = mmap.mmap(self.temp_file.fileno(), self.file_size) + self.vectors_stored = 0 + + logger.info(f"Created disk buffer: {self.temp_path} ({self.file_size / (1024**3):.2f} GB)") + + def write_batch(self, vectors: np.ndarray, start_index: int): + """Write vectors to disk buffer.""" + start_offset = start_index * self.vector_size + end_offset = start_offset + len(vectors) * self.vector_size + self.mmap[start_offset:end_offset] = vectors.astype(np.float32).tobytes() + self.vectors_stored = max(self.vectors_stored, start_index + len(vectors)) + + def read_batch(self, start_index: int, count: int) -> np.ndarray: + """Read vectors from disk buffer.""" + start_offset = start_index * self.vector_size + end_offset = start_offset + count * self.vector_size + data = self.mmap[start_offset:end_offset] + return np.frombuffer(data, dtype=np.float32).reshape(count, self.dimension) + + def cleanup(self): + """Clean up resources.""" + try: + self.mmap.close() + self.temp_file.close() + if self.temp_path.exists(): + self.temp_path.unlink() + logger.info(f"Cleaned up disk buffer: {self.temp_path}") + except Exception as e: + logger.warning(f"Error cleaning up disk buffer: {e}") + + def __enter__(self): + return self + + def __exit__(self, *args): + self.cleanup() + + +# ============================================================================= +# Argument Parsing +# ============================================================================= + def parse_args(): parser = argparse.ArgumentParser(description="Load vectors into Milvus database") @@ -39,6 +214,7 @@ def parse_args(): choices=["uniform", "normal"], help="Distribution for vector generation") parser.add_argument("--batch-size", type=int, default=10000, help="Batch size for insertion") parser.add_argument("--chunk-size", type=int, default=1000000, help="Number of vectors to generate in each chunk (for memory management)") + parser.add_argument("--seed", type=int, help="Random seed for reproducible vector generation") # Index parameters parser.add_argument("--index-type", type=str, default="DISKANN", help="Index type") @@ -48,6 +224,16 @@ def parse_args(): parser.add_argument("--M", type=int, default=16, help="HNSW M parameter") parser.add_argument("--ef-construction", type=int, default=200, help="HNSW efConstruction parameter") + # Memory optimization parameters + parser.add_argument("--adaptive", action="store_true", + help="Enable adaptive batch sizing based on memory pressure") + parser.add_argument("--memory-budget", type=str, default="0", + help="Memory budget (e.g., 4G, 512M). Default: auto") + parser.add_argument("--disk-backed", action="store_true", + help="Use disk-backed buffer for memory-constrained systems") + parser.add_argument("--temp-dir", type=str, + help="Temp directory for disk-backed mode (fast NVMe recommended)") + # Monitoring parameters parser.add_argument("--monitor-interval", type=int, default=5, help="Interval in seconds for monitoring index building") parser.add_argument("--compact", action="store_true", help="Perform compaction after loading") @@ -82,7 +268,11 @@ def parse_args(): 'compact': not args.compact, # Default is False 'force': not args.force, # Default is False 'what_if': not args.what_if, # Default is False - 'debug': not args.debug # Default is False + 'debug': not args.debug, # Default is False + 'adaptive': not args.adaptive, # Default is False + 'memory_budget': args.memory_budget == "0", + 'disk_backed': not args.disk_backed, # Default is False + 'seed': args.seed is None, } # Set logging level to DEBUG if --debug is specified @@ -118,6 +308,10 @@ def parse_args(): return args +# ============================================================================= +# Milvus Connection and Collection Management +# ============================================================================= + def connect_to_milvus(host, port): """Connect to Milvus server""" try: @@ -169,30 +363,85 @@ def create_collection(collection_name, dim, num_shards, vector_dtype, force=Fals return None -def generate_vectors(num_vectors, dim, distribution='uniform'): - """Generate random vectors based on the specified distribution""" +# ============================================================================= +# Vector Generation (Enhanced with reproducibility and memory efficiency) +# ============================================================================= + +def generate_vectors(num_vectors: int, dim: int, distribution: str = 'uniform', + seed: Optional[int] = None, batch_index: int = 0) -> np.ndarray: + """ + Generate random vectors based on the specified distribution. + + Args: + num_vectors: Number of vectors to generate + dim: Vector dimension + distribution: 'uniform' or 'normal' + seed: Optional seed for reproducibility (combined with batch_index) + batch_index: Batch index for deterministic seeding across batches + + Returns: + numpy array of shape (num_vectors, dim), normalized + """ + # Use seeded RNG if seed provided (enables reproducibility) + if seed is not None: + rng = np.random.default_rng(seed + batch_index) + else: + rng = np.random.default_rng() + if distribution == 'uniform': - vectors = np.random.random((num_vectors, dim)).astype('float16') + vectors = rng.uniform(-1, 1, (num_vectors, dim)).astype(np.float32) elif distribution == 'normal': - vectors = np.random.normal(0, 1, (num_vectors, dim)).astype('float16') + vectors = rng.standard_normal((num_vectors, dim)).astype(np.float32) elif distribution == 'zipfian': # Simplified zipfian-like distribution - base = np.random.random((num_vectors, dim)).astype('float16') - skew = np.random.zipf(1.5, (num_vectors, 1)).astype('float16') + base = rng.uniform(0, 1, (num_vectors, dim)).astype(np.float32) + skew = rng.zipf(1.5, (num_vectors, 1)).astype(np.float32) vectors = base * (skew / 10) else: - vectors = np.random.random((num_vectors, dim)).astype('float16') + vectors = rng.uniform(-1, 1, (num_vectors, dim)).astype(np.float32) # Normalize vectors norms = np.linalg.norm(vectors, axis=1, keepdims=True) + # Avoid division by zero + norms = np.where(norms == 0, 1, norms) normalized_vectors = vectors / norms - return normalized_vectors.tolist() + return normalized_vectors -def insert_data(collection, vectors, batch_size=10000): +def generate_vectors_streaming(total_vectors: int, dimension: int, + batch_size: int, + distribution: str = 'uniform', + seed: Optional[int] = None) -> Generator[Tuple[int, np.ndarray], None, None]: + """ + Generator that yields vectors in batches without bulk allocation. + + This is the memory-efficient version that generates vectors on-demand. + + Yields: + Tuple of (start_id, vectors_array) + """ + num_batches = (total_vectors + batch_size - 1) // batch_size + vectors_remaining = total_vectors + current_id = 0 + + for batch_idx in range(num_batches): + current_batch_size = min(batch_size, vectors_remaining) + vectors = generate_vectors(current_batch_size, dimension, distribution, seed, batch_idx) + + yield current_id, vectors + + current_id += current_batch_size + vectors_remaining -= current_batch_size + + +# ============================================================================= +# Data Insertion Functions +# ============================================================================= + +def insert_data(collection, vectors, batch_size=10000, start_id=0): """Insert vectors into the collection in batches""" - total_vectors = len(vectors) + total_vectors = len(vectors) if isinstance(vectors, (list, np.ndarray)) else vectors.shape[0] num_batches = (total_vectors + batch_size - 1) // batch_size start_time = time.time() @@ -204,8 +453,11 @@ def insert_data(collection, vectors, batch_size=10000): batch_size_actual = batch_end - batch_start # Prepare batch data - ids = list(range(batch_start, batch_end)) - batch_vectors = vectors[batch_start:batch_end] + ids = list(range(start_id + batch_start, start_id + batch_end)) + if isinstance(vectors, np.ndarray): + batch_vectors = vectors[batch_start:batch_end].tolist() + else: + batch_vectors = vectors[batch_start:batch_end] # Insert batch try: @@ -216,9 +468,10 @@ def insert_data(collection, vectors, batch_size=10000): progress = total_inserted / total_vectors * 100 elapsed = time.time() - start_time rate = total_inserted / elapsed if elapsed > 0 else 0 + mem_info = f", Mem: {get_memory_percent():.1f}%" if PSUTIL_AVAILABLE else "" logger.info(f"Inserted batch {i+1}/{num_batches}: {progress:.2f}% complete, " - f"rate: {rate:.2f} vectors/sec") + f"rate: {rate:.2f} vectors/sec{mem_info}") except Exception as e: logger.error(f"Error inserting batch {i+1}: {str(e)}") @@ -226,8 +479,235 @@ def insert_data(collection, vectors, batch_size=10000): return total_inserted, time.time() - start_time +def insert_data_standard(collection, num_vectors: int, dimension: int, + distribution: str, batch_size: int, + seed: Optional[int] = None) -> Dict[str, Any]: + """ + Standard vector loading - generates and inserts vectors in batches. + """ + logger.info(f"Loading {num_vectors:,} vectors in batches of {batch_size:,} (standard mode)") + + start_time = time.time() + vectors_loaded = 0 + batch_idx = 0 + total_gen_time = 0.0 + total_insert_time = 0.0 + + for batch_start in range(0, num_vectors, batch_size): + batch_end = min(batch_start + batch_size, num_vectors) + current_batch_size = batch_end - batch_start + + # Generate vectors + gen_start = time.time() + vectors = generate_vectors(current_batch_size, dimension, distribution, seed, batch_idx) + total_gen_time += time.time() - gen_start + + # Prepare data + ids = list(range(batch_start, batch_end)) + data = [ids, vectors.tolist()] + + # Insert + insert_start = time.time() + collection.insert(data) + total_insert_time += time.time() - insert_start + + vectors_loaded += current_batch_size + batch_idx += 1 + + # Progress reporting + if batch_idx % 100 == 0 or vectors_loaded == num_vectors: + elapsed = time.time() - start_time + rate = vectors_loaded / elapsed if elapsed > 0 else 0 + progress = (vectors_loaded / num_vectors) * 100 + mem_info = f", Mem: {get_memory_percent():.1f}%" if PSUTIL_AVAILABLE else "" + logger.info(f"Progress: {vectors_loaded:,}/{num_vectors:,} ({progress:.1f}%) - " + f"Rate: {rate:,.0f} vec/s{mem_info}") + + # Cleanup + del vectors, data + if batch_idx % 50 == 0: + gc.collect() + + total_time = time.time() - start_time + + return { + 'vectors_loaded': vectors_loaded, + 'total_time': total_time, + 'generation_time': total_gen_time, + 'insertion_time': total_insert_time, + 'batches': batch_idx, + 'rate': vectors_loaded / total_time if total_time > 0 else 0, + } + + +def insert_data_adaptive(collection, num_vectors: int, dimension: int, + distribution: str, batch_size: int, + memory_budget: int = 0, + seed: Optional[int] = None) -> Dict[str, Any]: + """ + Adaptive vector loading with memory-aware batch sizing. + + Monitors memory pressure and adjusts batch sizes dynamically. + """ + logger.info(f"Loading {num_vectors:,} vectors with adaptive batch sizing") + if memory_budget > 0: + logger.info(f"Memory budget: {memory_budget / (1024**3):.1f} GB") + + # Initialize adaptive controller + controller = AdaptiveBatchController( + initial_batch_size=batch_size, + min_batch_size=max(100, batch_size // 20), + max_batch_size=min(100000, batch_size * 5) + ) + + start_time = time.time() + vectors_loaded = 0 + batch_idx = 0 + total_gen_time = 0.0 + total_insert_time = 0.0 + errors = 0 + + while vectors_loaded < num_vectors: + current_batch_size = controller.get_batch_size() + remaining = num_vectors - vectors_loaded + current_batch_size = min(current_batch_size, remaining) + + try: + # Generate vectors + gen_start = time.time() + vectors = generate_vectors(current_batch_size, dimension, distribution, seed, batch_idx) + total_gen_time += time.time() - gen_start + + # Prepare data + ids = list(range(vectors_loaded, vectors_loaded + current_batch_size)) + data = [ids, vectors.tolist()] + + # Insert + insert_start = time.time() + collection.insert(data) + total_insert_time += time.time() - insert_start + + vectors_loaded += current_batch_size + batch_idx += 1 + + except Exception as e: + logger.error(f"Error in batch {batch_idx}: {e}") + errors += 1 + controller.force_scale_down() + continue + + # Progress reporting + if batch_idx % 100 == 0 or vectors_loaded >= num_vectors: + elapsed = time.time() - start_time + rate = vectors_loaded / elapsed if elapsed > 0 else 0 + progress = (vectors_loaded / num_vectors) * 100 + logger.info(f"Progress: {vectors_loaded:,}/{num_vectors:,} ({progress:.1f}%) - " + f"Rate: {rate:,.0f} vec/s, Batch: {controller.current_batch_size:,}, " + f"Mem: {get_memory_percent():.1f}%") + + # Cleanup + del vectors, data + if batch_idx % 50 == 0: + gc.collect() + + total_time = time.time() - start_time + + return { + 'vectors_loaded': vectors_loaded, + 'total_time': total_time, + 'generation_time': total_gen_time, + 'insertion_time': total_insert_time, + 'batches': batch_idx, + 'rate': vectors_loaded / total_time if total_time > 0 else 0, + 'batch_adjustments': controller.adjustment_count, + 'errors': errors, + } + + +def insert_data_disk_backed(collection, num_vectors: int, dimension: int, + distribution: str, batch_size: int, + temp_dir: Optional[str] = None, + seed: Optional[int] = None) -> Dict[str, Any]: + """ + Disk-backed vector loading for memory-constrained systems. + + Two-phase approach: + 1. Generate all vectors to disk (memory-mapped file) + 2. Stream from disk to database + """ + logger.info(f"Loading {num_vectors:,} vectors using disk-backed buffer") + + start_time = time.time() + + with DiskBackedBuffer(dimension, num_vectors, temp_dir) as disk_buffer: + # Phase 1: Generate to disk + logger.info("Phase 1/2: Generating vectors to disk...") + gen_start = time.time() + vectors_generated = 0 + batch_idx = 0 + + while vectors_generated < num_vectors: + remaining = num_vectors - vectors_generated + current_batch_size = min(batch_size, remaining) + + vectors = generate_vectors(current_batch_size, dimension, distribution, seed, batch_idx) + disk_buffer.write_batch(vectors, vectors_generated) + + vectors_generated += current_batch_size + batch_idx += 1 + + if batch_idx % 100 == 0: + progress = (vectors_generated / num_vectors) * 100 + logger.info(f"Generation: {vectors_generated:,}/{num_vectors:,} ({progress:.1f}%)") + + del vectors + if batch_idx % 50 == 0: + gc.collect() + + gen_time = time.time() - gen_start + logger.info(f"Phase 1 complete: {vectors_generated:,} vectors in {gen_time:.1f}s") + + # Phase 2: Load from disk to database + logger.info("Phase 2/2: Loading vectors to database...") + insert_start = time.time() + vectors_loaded = 0 + insert_batch_idx = 0 + + for start_id in range(0, num_vectors, batch_size): + count = min(batch_size, num_vectors - start_id) + vectors = disk_buffer.read_batch(start_id, count) + + ids = list(range(start_id, start_id + count)) + data = [ids, vectors.tolist()] + + collection.insert(data) + vectors_loaded += count + insert_batch_idx += 1 + + if insert_batch_idx % 100 == 0 or vectors_loaded >= num_vectors: + progress = (vectors_loaded / num_vectors) * 100 + logger.info(f"Loading: {vectors_loaded:,}/{num_vectors:,} ({progress:.1f}%)") + + insert_time = time.time() - insert_start + + total_time = time.time() - start_time + + return { + 'vectors_loaded': vectors_loaded, + 'total_time': total_time, + 'generation_time': gen_time, + 'insertion_time': insert_time, + 'batches': batch_idx + insert_batch_idx, + 'rate': vectors_loaded / total_time if total_time > 0 else 0, + } + + +# ============================================================================= +# Collection Operations +# ============================================================================= + def flush_collection(collection): - # Flush the collection + """Flush the collection""" flush_start = time.time() collection.flush() flush_time = time.time() - flush_start @@ -248,9 +728,40 @@ def create_index(collection, index_params): return False +# ============================================================================= +# Main Entry Point +# ============================================================================= + def main(): args = parse_args() + # Determine loading mode + if args.disk_backed: + mode = "disk-backed" + elif args.adaptive: + mode = "adaptive" + else: + mode = "standard" + + memory_budget = parse_memory_string(args.memory_budget) + + # Print configuration summary + logger.info("=" * 60) + logger.info("VDB Benchmark - Vector Loader") + logger.info("=" * 60) + logger.info(f"Collection: {args.collection_name}") + logger.info(f"Vectors: {args.num_vectors:,}") + logger.info(f"Dimension: {args.dimension}") + logger.info(f"Distribution: {args.distribution}") + logger.info(f"Batch size: {args.batch_size:,}") + logger.info(f"Shards: {args.num_shards}") + logger.info(f"Mode: {mode}") + if args.seed: + logger.info(f"Seed: {args.seed}") + if PSUTIL_AVAILABLE: + logger.info(f"Available RAM: {get_available_memory() / (1024**3):.1f} GB") + logger.info("=" * 60) + # Connect to Milvus if not connect_to_milvus(args.host, args.port): logger.error("Failed to connect to Milvus.") @@ -308,46 +819,76 @@ def main(): if not create_index(collection, index_params): return 1 - # Generate vectors - logger.info( - f"Generating {args.num_vectors} vectors with {args.dimension} dimensions using {args.distribution} distribution") + # Load vectors based on mode + logger.info(f"Starting vector generation and insertion using {mode} mode") start_gen_time = time.time() - # Split vector generation into chunks if num_vectors is large - if args.num_vectors > args.chunk_size: - logger.info(f"Large vector count detected. Generating in chunks of {args.chunk_size:,} vectors") - vectors = [] - remaining = args.num_vectors - chunks_processed = 0 - - while remaining > 0: - chunk_size = min(args.chunk_size, remaining) - logger.info(f"Generating chunk {chunks_processed+1}: {chunk_size:,} vectors") - chunk_start = time.time() - chunk_vectors = generate_vectors(chunk_size, args.dimension, args.distribution) - chunk_time = time.time() - chunk_start - - logger.info(f"Generated chunk {chunks_processed} ({chunk_size:,} vectors) in {chunk_time:.2f} seconds. " - f"Progress: {(args.num_vectors - remaining):,}/{args.num_vectors:,} vectors " - f"({(args.num_vectors - remaining) / args.num_vectors * 100:.1f}%)") - - # Insert data - logger.info(f"Inserting {args.num_vectors} vectors into collection '{args.collection_name}'") - total_inserted, insert_time = insert_data(collection, chunk_vectors, args.batch_size) - logger.info(f"Inserted {total_inserted} vectors in {insert_time:.2f} seconds") - - remaining -= chunk_size - chunks_processed += 1 + if mode == "disk-backed": + result = insert_data_disk_backed( + collection, args.num_vectors, args.dimension, args.distribution, + args.batch_size, temp_dir=args.temp_dir, seed=args.seed + ) + elif mode == "adaptive": + result = insert_data_adaptive( + collection, args.num_vectors, args.dimension, args.distribution, + args.batch_size, memory_budget=memory_budget, seed=args.seed + ) else: - # For smaller vector counts, generate all at once - vectors = generate_vectors(args.num_vectors, args.dimension, args.distribution) - # Insert data - logger.info(f"Inserting {args.num_vectors} vectors into collection '{args.collection_name}'") - total_inserted, insert_time = insert_data(collection, vectors, args.batch_size) - logger.info(f"Inserted {total_inserted} vectors in {insert_time:.2f} seconds") + # Standard mode - use chunk-based approach for large datasets + if args.num_vectors > args.chunk_size: + logger.info(f"Large vector count detected. Generating in chunks of {args.chunk_size:,} vectors") + total_inserted = 0 + remaining = args.num_vectors + chunks_processed = 0 + total_gen_time = 0.0 + total_insert_time = 0.0 + + while remaining > 0: + chunk_size = min(args.chunk_size, remaining) + logger.info(f"Generating chunk {chunks_processed+1}: {chunk_size:,} vectors") + chunk_start = time.time() + chunk_vectors = generate_vectors(chunk_size, args.dimension, args.distribution, + args.seed, chunks_processed) + chunk_gen_time = time.time() - chunk_start + total_gen_time += chunk_gen_time + + logger.info(f"Generated chunk {chunks_processed+1} ({chunk_size:,} vectors) in {chunk_gen_time:.2f} seconds. " + f"Progress: {(args.num_vectors - remaining):,}/{args.num_vectors:,} vectors " + f"({(args.num_vectors - remaining) / args.num_vectors * 100:.1f}%)") + + # Insert data + logger.info(f"Inserting {chunk_size:,} vectors into collection '{args.collection_name}'") + insert_start = time.time() + inserted, insert_time = insert_data(collection, chunk_vectors, args.batch_size, + start_id=args.num_vectors - remaining) + total_insert_time += insert_time + total_inserted += inserted + logger.info(f"Inserted {inserted:,} vectors in {insert_time:.2f} seconds") + + remaining -= chunk_size + chunks_processed += 1 + + # Cleanup after each chunk + del chunk_vectors + gc.collect() + + result = { + 'vectors_loaded': total_inserted, + 'total_time': time.time() - start_gen_time, + 'generation_time': total_gen_time, + 'insertion_time': total_insert_time, + 'batches': chunks_processed, + 'rate': total_inserted / (time.time() - start_gen_time) if (time.time() - start_gen_time) > 0 else 0, + } + else: + # For smaller vector counts, use the standard insertion function + result = insert_data_standard( + collection, args.num_vectors, args.dimension, args.distribution, + args.batch_size, seed=args.seed + ) gen_time = time.time() - start_gen_time - logger.info(f"Generated all {args.num_vectors:,} vectors in {gen_time:.2f} seconds") + logger.info(f"Completed loading {result['vectors_loaded']:,} vectors in {gen_time:.2f} seconds") flush_collection(collection) @@ -362,6 +903,21 @@ def main(): logger.info(f"Collection '{args.collection_name}' compacted successfully.") # Summary + logger.info("=" * 60) + logger.info("Loading Summary") + logger.info("=" * 60) + logger.info(f"Vectors loaded: {result['vectors_loaded']:,}") + logger.info(f"Total time: {result['total_time']:.1f}s") + logger.info(f"Throughput: {result['rate']:,.0f} vectors/sec") + logger.info(f"Generation time: {result['generation_time']:.1f}s") + logger.info(f"Insertion time: {result['insertion_time']:.1f}s") + logger.info(f"Batches: {result['batches']:,}") + if 'batch_adjustments' in result and result['batch_adjustments'] > 0: + logger.info(f"Batch adjustments: {result['batch_adjustments']}") + if 'errors' in result and result['errors'] > 0: + logger.info(f"Errors: {result['errors']}") + logger.info("=" * 60) + logger.info("Benchmark completed successfully!") return 0 From f9ab288979c4ad76bf4056abab8cc7cb08f419a4 Mon Sep 17 00:00:00 2001 From: Devasena Inupakutika Date: Tue, 10 Feb 2026 07:46:23 -0800 Subject: [PATCH 2/2] added recall metric implementatin to vdb benchmark script --- vdb_benchmark/vdbbench/simple_bench.py | 835 +++++++++++++++++++++++-- 1 file changed, 798 insertions(+), 37 deletions(-) diff --git a/vdb_benchmark/vdbbench/simple_bench.py b/vdb_benchmark/vdbbench/simple_bench.py index b679cd11..f04fa26b 100644 --- a/vdb_benchmark/vdbbench/simple_bench.py +++ b/vdb_benchmark/vdbbench/simple_bench.py @@ -1,9 +1,23 @@ -#!/usr/bin/env python3 """ -Milvus Vector Database Benchmark Script - -This script executes random vector queries against a Milvus collection using multiple processes. -It measures and reports query latency statistics. +simple_bench.py - Vector Database Benchmark with Recall Metrics + +Benchmarks vector search performance (throughput, latency, disk I/O) and +measures recall accuracy by comparing ANN index results against brute-force +(FLAT) ground truth. + +Recall metric design (addresses review comments on PR #7): + 1. Ground truth is pre-computed OUTSIDE the timed benchmark using a + duplicate FLAT-indexed collection (brute-force exact search). + This avoids doubling queries (which halved throughput in the original + approach) and avoids circular ANN-validating-ANN logic. + 2. batch_end timing is placed BEFORE any recall-related result capture, + so performance numbers (QPS, latency) only reflect the primary search. + 3. Query vectors are pre-generated with a fixed seed before the benchmark + to ensure identical queries hit both FLAT (ground truth) and ANN + (benchmark) collections. + +Recall calculation follows the VectorDBBench methodology: + recall@k = |ANN_top_k ∩ FLAT_top_k| / k """ import argparse @@ -25,7 +39,7 @@ from vdbbench.list_collections import get_collection_info try: - from pymilvus import connections, Collection, utility + from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility except ImportError: print("Error: pymilvus package not found. Please install it with 'pip install pymilvus'") sys.exit(1) @@ -47,6 +61,514 @@ ] +# =========================================================================== +# Recall metric calculation (following VectorDBBench methodology) +# =========================================================================== + +def calc_recall( + ann_results: Dict[int, List[int]], + ground_truth: Dict[int, List[int]], + k: int, +) -> Dict[str, Any]: + """ + Calculate recall@k by comparing ANN search results against ground truth. + + Follows the VectorDBBench approach: + recall@k = |ANN_top_k ∩ GT_top_k| / k + + Ground truth comes from a FLAT (brute-force) index which guarantees exact + nearest neighbor results — NOT from the ANN index itself. + + Args: + ann_results: Dict mapping query_index -> list of IDs from ANN search. + ground_truth: Dict mapping query_index -> list of true nearest neighbor + IDs from FLAT index search. + k: Number of top results to evaluate. + + Returns: + Dict with recall statistics (mean, min, max, percentiles). + """ + per_query_recall = [] + + for query_idx in sorted(ann_results.keys()): + if query_idx not in ground_truth: + continue + + ann_ids = set(ann_results[query_idx][:k]) + gt_ids = set(ground_truth[query_idx][:k]) + + if len(gt_ids) == 0: + continue + + # recall = size of intersection / k + intersection_size = len(ann_ids & gt_ids) + recall_value = intersection_size / k + per_query_recall.append(recall_value) + + if not per_query_recall: + return { + "recall_at_k": 0.0, + "num_queries_evaluated": 0, + "k": k, + "min_recall": 0.0, + "max_recall": 0.0, + "mean_recall": 0.0, + "median_recall": 0.0, + "p95_recall": 0.0, + "p99_recall": 0.0, + } + + recalls_arr = np.array(per_query_recall) + return { + "recall_at_k": float(np.mean(recalls_arr)), + "num_queries_evaluated": len(per_query_recall), + "k": k, + "min_recall": float(np.min(recalls_arr)), + "max_recall": float(np.max(recalls_arr)), + "mean_recall": float(np.mean(recalls_arr)), + "median_recall": float(np.median(recalls_arr)), + "p95_recall": float(np.percentile(recalls_arr, 95)), + "p99_recall": float(np.percentile(recalls_arr, 99)), + } + + +# =========================================================================== +# Ground truth pre-computation using FLAT index +# =========================================================================== + +def _detect_schema_fields(collection: Collection) -> Tuple[str, str, DataType]: + """ + Detect primary key and vector field names from a collection's schema. + + Returns: + (pk_field_name, vector_field_name, pk_dtype) tuple. + + Raises: + ValueError if required fields cannot be detected. + """ + pk_field = None + pk_dtype = None + vec_field = None + for field in collection.schema.fields: + if field.is_primary: + pk_field = field.name + pk_dtype = field.dtype + if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR, + DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR): + vec_field = field.name + + if pk_field is None: + raise ValueError(f"Cannot detect primary key field in collection " + f"'{collection.name}'. Schema: {collection.schema}") + if vec_field is None: + raise ValueError(f"Cannot detect vector field in collection " + f"'{collection.name}'. Schema: {collection.schema}") + + return pk_field, vec_field, pk_dtype + + +def create_flat_collection( + host: str, + port: str, + source_collection_name: str, + flat_collection_name: str, + vector_dim: int, + metric_type: str = "COSINE", +) -> bool: + """ + Create a duplicate collection with FLAT index for ground truth computation. + + FLAT index performs brute-force exact search which gives true nearest + neighbors — unlike ANN indexes (DiskANN, HNSW, IVF) which approximate. + + CRITICAL: The FLAT collection preserves the source collection's primary + key values (auto_id=False). This ensures that the IDs returned by FLAT + search match the IDs returned by the ANN search on the source collection, + so the recall set-intersection calculation works correctly. + + Uses query_iterator() to avoid the Milvus maxQueryResultWindow offset + limit (default 16384) that breaks offset-based pagination on collections + larger than ~16K vectors. + + Args: + host: Milvus server host. + port: Milvus server port. + source_collection_name: Name of the original ANN-indexed collection. + flat_collection_name: Name for the new FLAT-indexed collection. + vector_dim: Vector dimension. + metric_type: Distance metric (COSINE, L2, IP). + + Returns: + True if the FLAT collection is ready, False on failure. + """ + conn_alias = "flat_setup" + try: + connections.connect(alias=conn_alias, host=host, port=port) + except Exception as e: + print(f"Failed to connect for FLAT collection setup: {e}") + return False + + try: + # Check if FLAT collection already exists and is populated + if utility.has_collection(flat_collection_name, using=conn_alias): + flat_coll = Collection(flat_collection_name, using=conn_alias) + source_coll = Collection(source_collection_name, using=conn_alias) + if flat_coll.num_entities > 0 and flat_coll.num_entities == source_coll.num_entities: + print(f"FLAT collection '{flat_collection_name}' already exists " + f"with {flat_coll.num_entities} vectors, reusing it.") + flat_coll.load() + return True + else: + print(f"FLAT collection exists but has {flat_coll.num_entities} vs " + f"{source_coll.num_entities} vectors. Dropping and recreating...") + utility.drop_collection(flat_collection_name, using=conn_alias) + + print(f"Creating FLAT collection '{flat_collection_name}' " + f"from source '{source_collection_name}'...") + + # Get source collection and detect field names + PK type from schema + source_coll = Collection(source_collection_name, using=conn_alias) + source_coll.load() + # Flush to ensure num_entities is up-to-date (unflushed collections + # can return 0 which makes the copy loop never run) + source_coll.flush() + total_vectors = source_coll.num_entities + if total_vectors == 0: + print(f"ERROR: Source collection '{source_collection_name}' " + f"reports 0 vectors after flush. Cannot create ground truth.") + return False + + src_pk_field, src_vec_field, src_pk_dtype = _detect_schema_fields(source_coll) + print(f"Source schema: pk_field='{src_pk_field}' ({src_pk_dtype.name}), " + f"vec_field='{src_vec_field}', vectors={total_vectors}") + + # Define schema for FLAT collection. + # CRITICAL: auto_id=False — we copy the source PK values so that + # IDs from FLAT search match IDs from ANN search on source. + pk_kwargs = {"max_length": 256} if src_pk_dtype == DataType.VARCHAR else {} + fields = [ + FieldSchema(name="pk", dtype=src_pk_dtype, + is_primary=True, auto_id=False, **pk_kwargs), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, + dim=vector_dim), + ] + schema = CollectionSchema( + fields, description="FLAT index ground truth collection") + flat_coll = Collection(flat_collection_name, schema, using=conn_alias) + + # Copy vectors AND PK values from source to FLAT collection. + # We try query_iterator (pymilvus >=2.3) first, then fall back to + # pk-cursor pagination which works on any version and avoids the + # offset+limit > maxQueryResultWindow (default 16384) error. + copy_batch_size = 5000 + print(f"Copying {total_vectors} vectors to FLAT collection " + f"(batch_size={copy_batch_size})...") + + copied = 0 + use_iterator = hasattr(source_coll, 'query_iterator') + + if use_iterator: + # pymilvus >= 2.3: use built-in iterator + try: + iterator = source_coll.query_iterator( + batch_size=copy_batch_size, + output_fields=[src_pk_field, src_vec_field], + ) + while True: + batch = iterator.next() + if not batch: + break + pk_values = [row[src_pk_field] for row in batch] + vectors = [row[src_vec_field] for row in batch] + flat_coll.insert([pk_values, vectors]) + copied += len(vectors) + if copied % (copy_batch_size * 20) < copy_batch_size: + print(f" Copied {copied}/{total_vectors} vectors " + f"({100.0 * copied / total_vectors:.1f}%)") + iterator.close() + except Exception as iter_err: + print(f" query_iterator failed ({iter_err}), " + f"falling back to pk-cursor pagination...") + use_iterator = False + copied = 0 + # Drop and recreate if partial data was inserted + utility.drop_collection(flat_collection_name, using=conn_alias) + flat_coll = Collection(flat_collection_name, schema, using=conn_alias) + + if not use_iterator: + # Fallback: pk-cursor pagination + search-based vector retrieval. + # query() cannot return vector fields on many Milvus versions + # (MilvusException: vector field not supported in query output). + # Instead: query PKs only, then search filtered by those PKs + # with output_fields to retrieve vectors. search() always + # supports vector output. + is_int_pk = src_pk_dtype in (DataType.INT64, DataType.INT32, + DataType.INT16, DataType.INT8) + last_pk = -2**63 if is_int_pk else "" + page_limit = min(copy_batch_size, 16384) # stay under Milvus limit + + # Need a dummy vector for search calls + dummy_vec = np.random.random(vector_dim).astype(np.float32) + dummy_vec = (dummy_vec / np.linalg.norm(dummy_vec)).tolist() + + while copied < total_vectors: + if is_int_pk: + expr = f"{src_pk_field} > {last_pk}" + else: + expr = f'{src_pk_field} > "{last_pk}"' + + # Step A: query PKs only (works on all Milvus versions) + try: + pk_batch = source_coll.query( + expr=expr, + output_fields=[src_pk_field], + limit=page_limit, + ) + except Exception as qe: + print(f" query() failed: {qe}") + break + if not pk_batch: + break + + # Sort by PK so cursor advances correctly + if is_int_pk: + pk_batch.sort(key=lambda r: r[src_pk_field]) + else: + pk_batch.sort(key=lambda r: str(r[src_pk_field])) + last_pk = pk_batch[-1][src_pk_field] + + pk_values_batch = [row[src_pk_field] for row in pk_batch] + + # Step B: retrieve vectors via search filtered to these PKs. + # search() supports output_fields with vector data on all + # Milvus versions (unlike query()). + if is_int_pk: + pk_filter = f"{src_pk_field} in {pk_values_batch}" + else: + escaped = [str(v).replace('"', '\\"') for v in pk_values_batch] + pk_filter = f'{src_pk_field} in [' + ','.join(f'"{v}"' for v in escaped) + ']' + + try: + search_results = source_coll.search( + data=[dummy_vec], + anns_field=src_vec_field, + param={"metric_type": metric_type, "params": {}}, + limit=len(pk_values_batch), + expr=pk_filter, + output_fields=[src_vec_field], + ) + except Exception as se: + print(f" search() for vector retrieval failed: {se}") + break + + # Build pk -> vector map from search results + pk_vec_map = {} + if search_results: + for hit in search_results[0]: + hit_pk = hit.id + hit_vec = hit.entity.get(src_vec_field) + if hit_vec is not None: + pk_vec_map[hit_pk] = hit_vec + + # Insert matched pk+vector pairs + insert_pks = [] + insert_vecs = [] + for pk_val in pk_values_batch: + if pk_val in pk_vec_map: + insert_pks.append(pk_val) + insert_vecs.append(pk_vec_map[pk_val]) + + if insert_pks: + flat_coll.insert([insert_pks, insert_vecs]) + copied += len(insert_pks) + else: + # If search returned no vectors, try direct query with + # vector output as last resort (works on pymilvus >= 2.3) + try: + vec_batch = source_coll.query( + expr=pk_filter, + output_fields=[src_pk_field, src_vec_field], + limit=len(pk_values_batch), + ) + if vec_batch: + pks = [row[src_pk_field] for row in vec_batch] + vecs = [row[src_vec_field] for row in vec_batch] + flat_coll.insert([pks, vecs]) + copied += len(pks) + except Exception: + print(f" WARNING: Could not retrieve vectors for " + f"{len(pk_values_batch)} PKs, skipping batch.") + continue + + if copied % (page_limit * 20) < page_limit: + pct = min(100.0, 100.0 * copied / total_vectors) + print(f" Copied {copied}/{total_vectors} vectors " + f"({pct:.1f}%)") + + print(f" Copied {copied}/{total_vectors} vectors (100.0%)") + flat_coll.flush() + + # Wait for entity count to stabilize after flush — Milvus can + # take a moment before num_entities reflects the flushed data. + for attempt in range(10): + actual_count = flat_coll.num_entities + if actual_count >= copied: + break + time.sleep(1) + print(f" Waiting for flush to complete " + f"({actual_count}/{copied} visible)...") + + if actual_count < copied: + print(f" WARNING: Only {actual_count}/{copied} vectors visible " + f"after flush. Proceeding anyway.") + + # Create FLAT index (brute-force, exact results) + print("Building FLAT index...") + flat_coll.create_index( + field_name="vector", + index_params={ + "index_type": "FLAT", + "metric_type": metric_type, + "params": {}, + }, + ) + flat_coll.load() + print(f"FLAT collection '{flat_collection_name}' ready with " + f"{flat_coll.num_entities} vectors.") + return True + + except Exception as e: + print(f"Error creating FLAT collection: {e}") + import traceback + traceback.print_exc() + return False + finally: + try: + connections.disconnect(conn_alias) + except: + pass + + +def precompute_ground_truth( + host: str, + port: str, + flat_collection_name: str, + query_vectors: List[List[float]], + top_k: int, + metric_type: str = "COSINE", +) -> Dict[int, List[int]]: + """ + Pre-compute ground truth by running queries against the FLAT collection. + + This runs OUTSIDE the timed benchmark so it has zero impact on + performance measurements. + + Args: + host: Milvus host. + port: Milvus port. + flat_collection_name: Name of the FLAT-indexed collection. + query_vectors: List of query vectors. + top_k: Number of nearest neighbors to retrieve. + metric_type: Distance metric. + + Returns: + Dict mapping query_index -> list of ground truth nearest neighbor IDs. + """ + conn_alias = "gt_compute" + try: + connections.connect(alias=conn_alias, host=host, port=port) + except Exception as e: + print(f"Failed to connect for ground truth computation: {e}") + return {} + + try: + flat_coll = Collection(flat_collection_name, using=conn_alias) + flat_coll.load() + + # Cap top_k to collection size to avoid Milvus search errors + entity_count = flat_coll.num_entities + effective_top_k = min(top_k, entity_count) if entity_count > 0 else top_k + if effective_top_k != top_k: + print(f" NOTE: top_k capped from {top_k} to {effective_top_k} " + f"(collection has {entity_count} vectors)") + # Milvus also enforces a max topk (typically 16384) + effective_top_k = min(effective_top_k, 16384) + + ground_truth: Dict[int, List[int]] = {} + gt_batch_size = 100 # Process queries in batches for efficiency + + print(f"Pre-computing ground truth for {len(query_vectors)} queries " + f"using FLAT index (top_k={effective_top_k})...") + + gt_start = time.time() + + for batch_start in range(0, len(query_vectors), gt_batch_size): + batch_end_idx = min(batch_start + gt_batch_size, len(query_vectors)) + batch_vectors = query_vectors[batch_start:batch_end_idx] + + results = flat_coll.search( + data=batch_vectors, + anns_field="vector", + param={"metric_type": metric_type, "params": {}}, + limit=effective_top_k, + ) + + for i, hits in enumerate(results): + query_idx = batch_start + i + ground_truth[query_idx] = [hit.id for hit in hits] + + gt_elapsed = time.time() - gt_start + print(f"Ground truth pre-computation complete: " + f"{len(ground_truth)} queries in {gt_elapsed:.2f}s") + + return ground_truth + + except Exception as e: + print(f"Error computing ground truth: {e}") + import traceback + traceback.print_exc() + return {} + finally: + try: + connections.disconnect(conn_alias) + except: + pass + + +def generate_query_vectors( + num_queries: int, + dimension: int, + seed: int = 42, +) -> List[List[float]]: + """ + Pre-generate a fixed set of query vectors. + + Pre-generating ensures: + - Consistent queries between ANN and FLAT searches + - Ground truth can be computed before the timed benchmark + - No random generation overhead during the benchmark + + Args: + num_queries: Number of query vectors to generate. + dimension: Vector dimension. + seed: Random seed for reproducibility. + + Returns: + List of normalized query vectors. + """ + rng = np.random.RandomState(seed) + vectors = rng.random((num_queries, dimension)).astype(np.float32) + # Normalize for cosine similarity + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + norms[norms == 0] = 1.0 + vectors = vectors / norms + return vectors.tolist() + + +# =========================================================================== +# Utility functions +# =========================================================================== + def signal_handler(sig, frame): """Handle interrupt signals to gracefully shut down worker processes""" print("\nReceived interrupt signal. Shutting down workers gracefully...") @@ -134,11 +656,28 @@ def connect_to_milvus(host: str, port: str) -> connections: return False +# =========================================================================== +# Benchmark worker — always captures ANN result IDs for recall +# =========================================================================== + def execute_batch_queries(process_id: int, host: str, port: str, collection_name: str, vector_dim: int, batch_size: int, report_count: int, max_queries: Optional[int], runtime_seconds: Optional[int], output_dir: str, - shutdown_flag: mp.Value) -> None: + shutdown_flag: mp.Value, + pre_generated_queries: List[List[float]] = None, + ann_results_dict: dict = None, + search_limit: int = 10, + search_ef: int = 200, + anns_field: str = "vector") -> None: """ - Execute batches of vector queries and log results to disk + Execute batches of vector queries and log results to disk. + + Always uses pre-generated query vectors and captures ANN result IDs + for post-hoc recall calculation. + + CRITICAL TIMING NOTE (Review Comment #2): + batch_end is measured IMMEDIATELY after collection.search() returns. + ANN result ID capture happens AFTER batch_end, so performance + numbers only reflect the primary ANN search. Args: process_id: ID of the current process @@ -147,10 +686,16 @@ def execute_batch_queries(process_id: int, host: str, port: str, collection_name collection_name: Name of the collection to query vector_dim: Dimension of vectors batch_size: Number of queries to execute in each batch + report_count: Number of batches between progress reports max_queries: Maximum number of queries to execute (None for unlimited) runtime_seconds: Maximum runtime in seconds (None for unlimited) output_dir: Directory to save results shutdown_flag: Shared value to signal process termination + pre_generated_queries: Pre-generated query vectors (deterministic, seed-based). + ann_results_dict: Shared dict to capture ANN result IDs for recall. + search_limit: Number of results per query (top-k). + search_ef: Search ef parameter. + anns_field: Name of the vector field in the collection (auto-detected from schema). """ print(f'Process {process_id} initialized') # Connect to Milvus @@ -174,6 +719,9 @@ def execute_batch_queries(process_id: int, host: str, port: str, collection_name # Create output directory if it doesn't exist os.makedirs(os.path.dirname(output_file), exist_ok=True) + # Pre-generated query count for cycling + num_pre_generated = len(pre_generated_queries) if pre_generated_queries else 0 + # Track execution start_time = time.time() query_count = 0 @@ -202,26 +750,44 @@ def execute_batch_queries(process_id: int, host: str, port: str, collection_name if max_queries is not None and query_count >= max_queries: break - # Generate batch of query vectors - batch_vectors = [generate_random_vector(vector_dim) for _ in range(batch_size)] + # Build batch from pre-generated queries (cycle deterministically) + batch_vectors = [] + batch_query_indices = [] + for b in range(batch_size): + idx = (query_count + b) % num_pre_generated + batch_vectors.append(pre_generated_queries[idx]) + batch_query_indices.append(idx) - # Execute batch and measure time + # ---- TIMED SECTION: Only the primary ANN search ---- batch_start = time.time() try: - search_params = {"metric_type": "COSINE", "params": {"ef": 200}} + search_params = {"metric_type": "COSINE", "params": {"ef": search_ef}} results = collection.search( data=batch_vectors, - anns_field="vector", + anns_field=anns_field, param=search_params, - limit=10, - output_fields=["id"] + limit=search_limit, ) + # CRITICAL (Review Comment #2): batch_end is placed HERE, + # BEFORE any recall result capture below. batch_end = time.time() batch_success = True except Exception as e: print(f"Process {process_id}: Search error: {e}") batch_end = time.time() batch_success = False + results = None + # ---- END TIMED SECTION ---- + + # Capture ANN result IDs for post-hoc recall (NOT timed). + # Review Comment #1: this capture is outside the timed section. + if results is not None and ann_results_dict is not None: + for i, hits in enumerate(results): + global_query_idx = batch_query_indices[i] + result_ids = [hit.id for hit in hits] + key = f"{process_id}_{global_query_idx}" + if key not in ann_results_dict: + ann_results_dict[key] = result_ids # Record batch results batch_time = batch_end - batch_start @@ -261,16 +827,30 @@ def execute_batch_queries(process_id: int, host: str, port: str, collection_name f"Process {process_id}: Finished. Executed {query_count} queries in {time.time() - start_time:.2f} seconds", flush=True) -def calculate_statistics(results_dir: str) -> Dict[str, Union[str, int, float, Dict[str, int]]]: - """Calculate statistics from benchmark results""" +# =========================================================================== +# Statistics calculation — always includes recall +# =========================================================================== + +def calculate_statistics(results_dir: str, + recall_stats: Dict[str, Any] = None, + ) -> Dict[str, Union[str, int, float, Dict[str, int]]]: + """Calculate statistics from benchmark results. + + Args: + results_dir: Directory containing per-process CSV result files. + recall_stats: Recall metrics dict from calc_recall(). + + Returns: + Dict with latency, batch, throughput, and recall statistics. + """ import pandas as pd - + # Find all result files file_paths = list(Path(results_dir).glob("milvus_benchmark_p*.csv")) - + if not file_paths: return {"error": "No benchmark result files found"} - + # Read and concatenate all CSV files into a single DataFrame dfs = [] for file_path in file_paths: @@ -280,10 +860,10 @@ def calculate_statistics(results_dir: str) -> Dict[str, Union[str, int, float, D dfs.append(df) except Exception as e: print(f"Error reading result file {file_path}: {e}") - + if not dfs: return {"error": "No valid data found in benchmark result files"} - + # Concatenate all dataframes all_data = pd.concat(dfs, ignore_index=True) all_data.sort_values('timestamp', inplace=True) @@ -298,15 +878,15 @@ def calculate_statistics(results_dir: str) -> Dict[str, Union[str, int, float, D for _, row in all_data.iterrows(): query_time_ms = row['avg_query_time_seconds'] * 1000 all_latencies.extend([query_time_ms] * row['batch_size']) - + # Convert batch times to milliseconds batch_times_ms = all_data['batch_time_seconds'] * 1000 - + # Calculate statistics latencies = np.array(all_latencies) batch_times = np.array(batch_times_ms) total_queries = len(latencies) - + stats = { "total_queries": total_queries, "total_time_seconds": total_time_seconds, @@ -329,12 +909,19 @@ def calculate_statistics(results_dir: str) -> Dict[str, Union[str, int, float, D "p95_batch_time_ms": float(np.percentile(batch_times, 95)) if len(batch_times) > 0 else 0, "p99_batch_time_ms": float(np.percentile(batch_times, 99)) if len(batch_times) > 0 else 0, "p999_batch_time_ms": float(np.percentile(batch_times, 99.9)) if len(batch_times) > 0 else 0, - "p9999_batch_time_ms": float(np.percentile(batch_times, 99.99)) if len(batch_times) > 0 else 0 + "p9999_batch_time_ms": float(np.percentile(batch_times, 99.99)) if len(batch_times) > 0 else 0, + + # Recall statistics — always present + "recall": recall_stats, } return stats +# =========================================================================== +# Database loading +# =========================================================================== + def load_database(host: str, port: str, collection_name: str, reload=False) -> Union[dict, None]: print(f'Connecting to Milvus server at {host}:{port}...', flush=True) connections = connect_to_milvus(host, port) @@ -393,6 +980,10 @@ def load_database(host: str, port: str, collection_name: str, reload=False) -> U return collection_info +# =========================================================================== +# Main entry point +# =========================================================================== + def main(): parser = argparse.ArgumentParser(description="Milvus Vector Database Benchmark") @@ -409,6 +1000,12 @@ def main(): parser.add_argument("--port", type=str, default="19530", help="Milvus server port") parser.add_argument("--collection-name", type=str, help="Collection name to query") + # Search parameters + parser.add_argument("--search-limit", type=int, default=10, + help="Number of results per query (top-k)") + parser.add_argument("--search-ef", type=int, default=200, + help="Search ef parameter (search_list_size)") + # Termination conditions (at least one must be specified) termination_group = parser.add_argument_group("termination conditions (at least one required)") termination_group.add_argument("--runtime", type=int, help="Maximum runtime in seconds") @@ -418,6 +1015,17 @@ def main(): parser.add_argument("--output-dir", type=str, help="Directory to save benchmark results") parser.add_argument("--json-output", action="store_true", help="Print benchmark results as JSON document") + # Recall parameters (always active — recall is a standard metric) + parser.add_argument("--gt-collection", type=str, default=None, + help="Name for FLAT ground truth collection " + "(default: _flat_gt)") + parser.add_argument("--num-query-vectors", type=int, default=1000, + help="Number of pre-generated query vectors for recall " + "(default: 1000)") + parser.add_argument("--recall-k", type=int, default=None, + help="K value for recall@k calculation " + "(default: same as --search-limit)") + args = parser.parse_args() # Validate termination conditions @@ -448,7 +1056,10 @@ def main(): os.makedirs(output_dir, exist_ok=True) - # Save benchmark configuration + # Preliminary recall_k (will be capped after collection loads) + recall_k = args.recall_k if args.recall_k else args.search_limit + + # Save benchmark configuration (after recall_k capping below) config = { "timestamp": datetime.now().isoformat(), "processes": args.processes, @@ -459,13 +1070,14 @@ def main(): "port": args.port, "collection_name": args.collection_name, "runtime_seconds": args.runtime, - "total_queries": args.queries + "total_queries": args.queries, + "search_limit": args.search_limit, + "search_ef": args.search_ef, + "gt_collection": args.gt_collection, + "num_query_vectors": args.num_query_vectors, } print(f"Results will be saved to: {output_dir}") - print(f'Writing configuration to {output_dir}/config.json') - with open(os.path.join(output_dir, "config.json"), 'w') as f: - json.dump(config, f, indent=2) print("") print("=" * 50) @@ -482,6 +1094,104 @@ def main(): print("Unable to load the specified collection") sys.exit(1) + # Cap recall_k to collection vector count and Milvus topk hard limit. + # Must happen AFTER load_database so collection_info is available. + vec_count = collection_info.get("row_count", 0) + if isinstance(vec_count, str): + try: + vec_count = int(vec_count) + except ValueError: + vec_count = 0 + if vec_count > 0 and recall_k > vec_count: + print(f"NOTE: recall_k capped from {recall_k} to {vec_count} " + f"(collection vector count)") + recall_k = vec_count + recall_k = min(recall_k, 16384) # Milvus topk hard limit + + # Now save config with the actual capped recall_k + config["recall_k"] = recall_k + print(f'Writing configuration to {output_dir}/config.json') + with open(os.path.join(output_dir, "config.json"), 'w') as f: + json.dump(config, f, indent=2) + + # ================================================================== + # RECALL SETUP: Always pre-compute ground truth OUTSIDE the benchmark + # (Review Comment #1: ground truth computation is completely + # separated from the timed benchmark portion) + # ================================================================== + print("") + print("=" * 50) + print("RECALL SETUP (outside benchmark timing)", flush=True) + print("=" * 50) + print("Ground truth is pre-computed using a FLAT (brute-force) index.") + print("This does NOT affect performance measurements.\n") + + # Determine metric type from collection info + metric_type = "COSINE" + if collection_info and collection_info.get("index_info"): + mt = collection_info["index_info"][0].get("metric_type") + if mt: + metric_type = mt + print(f"Using metric type: {metric_type}") + + # Detect the source collection's vector field name for search calls. + # We connect briefly to read the schema, then disconnect before fork. + source_vec_field = "vector" # default fallback + try: + conn_detect = connect_to_milvus(args.host, args.port) + if conn_detect: + _src_coll = Collection(args.collection_name) + _, source_vec_field, _ = _detect_schema_fields(_src_coll) + connections.disconnect("default") + print(f"Detected source vector field: '{source_vec_field}'") + except Exception as e: + print(f"Could not detect vector field, using default '{source_vec_field}': {e}") + + # Step 1: Pre-generate deterministic query vectors + print(f"\nGenerating {args.num_query_vectors} query vectors " + f"(dim={args.vector_dim}, seed=42)...") + pre_generated_queries = generate_query_vectors( + args.num_query_vectors, args.vector_dim, seed=42 + ) + print(f"Generated {len(pre_generated_queries)} query vectors.") + + # Step 2: Create or reuse FLAT ground truth collection + gt_collection_name = args.gt_collection or f"{args.collection_name}_flat_gt" + print(f"\nSetting up FLAT collection: {gt_collection_name}") + + flat_ok = create_flat_collection( + host=args.host, + port=args.port, + source_collection_name=args.collection_name, + flat_collection_name=gt_collection_name, + vector_dim=args.vector_dim, + metric_type=metric_type, + ) + + if not flat_ok: + print("ERROR: FLAT collection setup failed. Cannot compute recall.") + sys.exit(1) + + # Step 3: Pre-compute ground truth + ground_truth = precompute_ground_truth( + host=args.host, + port=args.port, + flat_collection_name=gt_collection_name, + query_vectors=pre_generated_queries, + top_k=recall_k, + metric_type=metric_type, + ) + + if not ground_truth: + print("ERROR: Ground truth computation failed. Cannot compute recall.") + sys.exit(1) + + print(f"Ground truth ready: {len(ground_truth)} queries pre-computed.") + + # Create shared dict for workers to store ANN result IDs + manager = mp.Manager() + ann_results_dict = manager.dict() + # Read initial disk stats print(f'\nCollecting initial disk statistics...') start_disk_stats = read_disk_stats() @@ -505,6 +1215,8 @@ def main(): print(f"Starting benchmark with {args.processes} processes and {max_queries_per_process} queries per process") else: print(f'Starting benchmark with {args.processes} processes and running for {args.runtime} seconds') + print(f"Recall measurement: using {len(pre_generated_queries)} pre-generated queries, recall@{recall_k}") + print(f"NOTE: batch_end timing is placed BEFORE recall capture — performance is unaffected.") if args.processes > 1: print(f"Staggering benchmark execution by {stagger_interval_secs} seconds between processes") try: @@ -529,7 +1241,12 @@ def main(): process_max_queries, args.runtime, output_dir, - shutdown_flag + shutdown_flag, + pre_generated_queries, + ann_results_dict, + args.search_limit, + args.search_ef, + source_vec_field, ) ) print(f'Starting process {i}...') @@ -554,7 +1271,9 @@ def main(): else: print(f'Running single process benchmark...') execute_batch_queries(0, args.host, args.port, args.collection_name, args.vector_dim, args.batch_size, - args.report_count, args.queries, args.runtime, output_dir, shutdown_flag) + args.report_count, args.queries, args.runtime, output_dir, shutdown_flag, + pre_generated_queries, ann_results_dict, + args.search_limit, args.search_ef, source_vec_field) # Read final disk stats print('Reading final disk statistics...') @@ -563,9 +1282,38 @@ def main(): # Calculate disk I/O during benchmark disk_io_diff = calculate_disk_io_diff(start_disk_stats, end_disk_stats) - # Calculate and print statistics - print("\nCalculating benchmark statistics...") - stats = calculate_statistics(output_dir) + # ================================================================== + # RECALL CALCULATION (post-hoc, OUTSIDE benchmark timing) + # Review Comment #1: recall is computed from captured results after + # the benchmark completes, not during the timed search loop. + # ================================================================== + print("\nCalculating recall from captured ANN results...") + + # Deduplicate: for each query index, take the first worker's result + ann_results_by_query: Dict[int, List[int]] = {} + for key, ids in ann_results_dict.items(): + # key format: "workerID_queryIdx" + parts = str(key).rsplit("_", 1) + if len(parts) == 2: + try: + query_idx = int(parts[1]) + if query_idx not in ann_results_by_query: + ann_results_by_query[query_idx] = list(ids) + except ValueError: + continue + + recall_stats = calc_recall(ann_results_by_query, ground_truth, recall_k) + + # Save recall details to separate file + recall_output_file = os.path.join(output_dir, "recall_stats.json") + with open(recall_output_file, 'w') as f: + json.dump(recall_stats, f, indent=2) + + # ================================================================== + # Calculate and aggregate all statistics + # ================================================================== + print("Calculating benchmark statistics...") + stats = calculate_statistics(output_dir, recall_stats=recall_stats) # Add disk I/O statistics to the stats dictionary if disk_io_diff: @@ -638,6 +1386,18 @@ def main(): print(f"Max Batch Time: {stats.get('max_batch_time_ms', 0):.2f} ms") print(f"Batch Throughput: {1000 / stats.get('mean_batch_time_ms', float('inf')):.2f} batches/second") + # Print recall statistics — always shown + r = stats["recall"] + print(f"\nRECALL STATISTICS (recall@{r['k']})") + print("-" * 50) + print(f"Mean Recall: {r['mean_recall']:.4f}") + print(f"Median Recall: {r['median_recall']:.4f}") + print(f"Min Recall: {r['min_recall']:.4f}") + print(f"Max Recall: {r['max_recall']:.4f}") + print(f"P95 Recall: {r['p95_recall']:.4f}") + print(f"P99 Recall: {r['p99_recall']:.4f}") + print(f"Queries Evaluated: {r['num_queries_evaluated']}") + # Print disk I/O statistics print("\nDISK I/O DURING BENCHMARK") print("-" * 50) @@ -661,8 +1421,9 @@ def main(): print("Disk I/O statistics not available") print("\nDetailed results saved to:", output_dir) + print(f"Recall details saved to: {recall_output_file}") print("=" * 50) if __name__ == "__main__": - main() \ No newline at end of file + main()