From f9b5ae272844eb0d8dae6dd4a823f4780d26f005 Mon Sep 17 00:00:00 2001 From: Tanmay Arya Date: Wed, 28 Jan 2026 19:00:00 +0530 Subject: [PATCH 1/4] add(snapshots): implement index snapshots for kd tree and flat index --- Cargo.lock | 297 +++++++++++++++++++++- Cargo.toml | 2 + crates/api/Cargo.toml | 1 + crates/api/src/lib.rs | 203 ++++++++++++++- crates/defs/src/error.rs | 7 + crates/defs/src/lib.rs | 6 + crates/defs/src/types.rs | 2 + crates/index/Cargo.toml | 3 + crates/index/src/flat.rs | 270 -------------------- crates/index/src/flat/index.rs | 71 ++++++ crates/index/src/flat/mod.rs | 9 + crates/index/src/flat/serialize.rs | 108 ++++++++ crates/index/src/flat/tests.rs | 238 +++++++++++++++++ crates/index/src/hnsw/mod.rs | 1 + crates/index/src/hnsw/serialize.rs | 21 ++ crates/index/src/hnsw/tests.rs | 2 +- crates/index/src/kd_tree/mod.rs | 5 + crates/index/src/kd_tree/serialize.rs | 204 +++++++++++++++ crates/index/src/kd_tree/tests.rs | 33 ++- crates/index/src/kd_tree/types.rs | 20 +- crates/index/src/lib.rs | 25 +- crates/snapshot/Cargo.toml | 21 ++ crates/snapshot/README.md | 0 crates/snapshot/src/constants.rs | 5 + crates/snapshot/src/engine/mod.rs | 172 +++++++++++++ crates/snapshot/src/lib.rs | 294 +++++++++++++++++++++ crates/snapshot/src/manifest.rs | 47 ++++ crates/snapshot/src/metadata.rs | 114 +++++++++ crates/snapshot/src/registry/constants.rs | 1 + crates/snapshot/src/registry/local.rs | 271 ++++++++++++++++++++ crates/snapshot/src/registry/mod.rs | 32 +++ crates/snapshot/src/util.rs | 148 +++++++++++ crates/storage/Cargo.toml | 3 + crates/storage/src/checkpoint.rs | 51 ++++ crates/storage/src/in_memory.rs | 17 +- crates/storage/src/lib.rs | 15 +- crates/storage/src/rocks_db.rs | 224 ++++++++++++++-- 37 files changed, 2634 insertions(+), 309 deletions(-) delete mode 100644 crates/index/src/flat.rs create mode 100644 crates/index/src/flat/index.rs create mode 100644 crates/index/src/flat/mod.rs create mode 100644 crates/index/src/flat/serialize.rs create mode 100644 crates/index/src/flat/tests.rs create mode 100644 crates/index/src/hnsw/serialize.rs create mode 100644 crates/index/src/kd_tree/serialize.rs create mode 100644 crates/snapshot/Cargo.toml create mode 100644 crates/snapshot/README.md create mode 100644 crates/snapshot/src/constants.rs create mode 100644 crates/snapshot/src/engine/mod.rs create mode 100644 crates/snapshot/src/lib.rs create mode 100644 crates/snapshot/src/manifest.rs create mode 100644 crates/snapshot/src/metadata.rs create mode 100644 crates/snapshot/src/registry/constants.rs create mode 100644 crates/snapshot/src/registry/local.rs create mode 100644 crates/snapshot/src/registry/mod.rs create mode 100644 crates/snapshot/src/util.rs create mode 100644 crates/storage/src/checkpoint.rs diff --git a/Cargo.lock b/Cargo.lock index 7b9a6a4..31cbb4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,15 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anyhow" version = "1.0.100" @@ -44,6 +53,7 @@ version = "0.1.0" dependencies = [ "defs", "index", + "snapshot", "storage", "tempfile", "uuid", @@ -66,6 +76,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + [[package]] name = "axum" version = "0.8.8" @@ -199,6 +215,15 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.1" @@ -263,6 +288,20 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "chrono" +version = "0.4.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -330,6 +369,24 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossterm" version = "0.27.0" @@ -355,6 +412,22 @@ dependencies = [ "winapi", ] +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "data-encoding" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" + [[package]] name = "defs" version = "0.1.0" @@ -363,6 +436,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -427,6 +510,17 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "filetime" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" +dependencies = [ + "cfg-if", + "libc", + "libredox", +] + [[package]] name = "find-msvc-tools" version = "0.1.6" @@ -439,6 +533,16 @@ version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" +[[package]] +name = "flate2" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -475,6 +579,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -525,6 +639,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -777,6 +901,30 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "iana-time-zone" +version = "0.1.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "2.1.1" @@ -889,8 +1037,11 @@ checksum = "964de6e86d545b246d84badc0fef527924ace5134f30641c203ef52ba83f58d5" name = "index" version = "0.1.0" dependencies = [ + "bincode", "defs", "rand", + "serde", + "storage", "uuid", ] @@ -1001,6 +1152,17 @@ dependencies = [ "windows-link", ] +[[package]] +name = "libredox" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +dependencies = [ + "bitflags 2.10.0", + "libc", + "redox_syscall 0.7.0", +] + [[package]] name = "librocksdb-sys" version = "0.11.0+8.1.1" @@ -1115,6 +1277,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -1182,6 +1345,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "object" version = "0.37.3" @@ -1265,7 +1437,7 @@ checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.18", "smallvec", "windows-link", ] @@ -1520,6 +1692,15 @@ dependencies = [ "bitflags 2.10.0", ] +[[package]] +name = "redox_syscall" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27" +dependencies = [ + "bitflags 2.10.0", +] + [[package]] name = "regex" version = "1.12.2" @@ -1730,6 +1911,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" @@ -1819,6 +2006,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1865,6 +2063,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + [[package]] name = "slab" version = "0.4.11" @@ -1877,6 +2081,26 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "snapshot" +version = "0.1.0" +dependencies = [ + "chrono", + "data-encoding", + "defs", + "flate2", + "fs2", + "index", + "semver", + "serde", + "serde_json", + "sha2", + "storage", + "tar", + "tempfile", + "uuid", +] + [[package]] name = "socket2" version = "0.6.1" @@ -1915,7 +2139,10 @@ version = "0.1.0" dependencies = [ "bincode", "defs", + "flate2", "rocksdb", + "serde", + "tar", "tempfile", "uuid", ] @@ -2000,6 +2227,17 @@ dependencies = [ "libc", ] +[[package]] +name = "tar" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "tempfile" version = "3.24.0" @@ -2314,6 +2552,12 @@ dependencies = [ "uuid", ] +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicase" version = "2.8.1" @@ -2397,6 +2641,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "want" version = "0.3.1" @@ -2511,6 +2761,41 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" @@ -2780,6 +3065,16 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + [[package]] name = "yoke" version = "0.8.1" diff --git a/Cargo.toml b/Cargo.toml index 2958714..ca5b979 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/http", "crates/tui", "crates/grpc", + "crates/snapshot", ] [workspace.package] @@ -51,4 +52,5 @@ http = { path = "crates/http" } index = { path = "crates/index" } server = { path = "crates/server" } storage = { path = "crates/storage" } +snapshot = { path = "crates/snapshot" } tui = { path = "crates/tui" } diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index 8ade9d8..fe40536 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -12,3 +12,4 @@ index.workspace = true storage.workspace = true tempfile.workspace = true uuid.workspace = true +snapshot.workspace = true diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 6ca6849..3395b26 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -1,13 +1,15 @@ -use defs::{DbError, Dimension, IndexedVector, Similarity}; - +use defs::{DbError, Dimension, IndexedVector, Similarity, SnapshottableDb}; use defs::{DenseVector, Payload, Point, PointId}; use index::hnsw::HnswIndex; -use std::path::PathBuf; +use index::kd_tree::index::KDTree; +use std::path::{Path, PathBuf}; +use tempfile::tempdir; // use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, RwLock}; -use index::flat::FlatIndex; +use index::flat::index::FlatIndex; use index::{IndexType, VectorIndex}; +use snapshot::Snapshot; use storage::rocks_db::RocksDbStorage; use storage::{StorageEngine, StorageType, VectorPage}; @@ -132,6 +134,31 @@ impl VectorDb { } } +impl SnapshottableDb for VectorDb { + fn create_snapshot(&self, dir_path: &Path) -> Result { + if !dir_path.is_dir() { + return Err(DbError::SnapshotError(format!( + "Invalid path: {}", + dir_path.display() + ))); + } + + let index_snapshot = self + .index + .read() + .map_err(|_| DbError::LockError)? + .snapshot()?; + + let tempdir = tempdir().unwrap(); + let storage_checkpoint = self.storage.checkpoint_at(tempdir.path())?; + + let snapshot = Snapshot::new(index_snapshot, storage_checkpoint, self.dimension)?; + let snapshot_path = snapshot.save(dir_path)?; + + Ok(snapshot_path) + } +} + #[derive(Debug)] pub struct DbConfig { pub storage_type: StorageType, @@ -141,6 +168,28 @@ pub struct DbConfig { pub similarity: Similarity, } +#[derive(Debug)] +pub struct DbRestoreConfig { + pub data_path: PathBuf, + pub snapshot_path: PathBuf, +} + +impl DbRestoreConfig { + pub fn new(data_path: PathBuf, snapshot_path: PathBuf) -> Self { + Self { + data_path, + snapshot_path, + } + } +} + +pub fn restore_from_snapshot(config: &DbRestoreConfig) -> Result { + // restore the index from the snapshot + let (storage_engine, index, dimensions) = + Snapshot::load(&config.snapshot_path, &config.data_path)?; + Ok(VectorDb::_new(storage_engine, index, dimensions)) +} + pub fn init_api(config: DbConfig) -> Result { // Initialize the storage engine let storage = match config.storage_type { @@ -151,11 +200,11 @@ pub fn init_api(config: DbConfig) -> Result { // Initialize the vector index let index: Arc> = match config.index_type { IndexType::Flat => Arc::new(RwLock::new(FlatIndex::new())), + IndexType::KDTree => Arc::new(RwLock::new(KDTree::build_empty(config.dimension))), IndexType::HNSW => Arc::new(RwLock::new(HnswIndex::new( config.similarity, config.dimension, ))), - _ => Arc::new(RwLock::new(FlatIndex::new())), }; // Init the db @@ -172,8 +221,11 @@ mod tests { // TODO: Add more exhaustive tests + use std::sync::Mutex; + use super::*; use defs::ContentType; + use snapshot::{engine::SnapshotEngine, registry::local::LocalRegistry}; use tempfile::{TempDir, tempdir}; // Helper function to create a test database @@ -307,7 +359,7 @@ mod tests { // Search with limit 3 let query = vec![0.0, 0.0, 0.0]; - let results = db.search(query, Similarity::Euclidean, 3).unwrap(); + let results = db.search(query, Similarity::Cosine, 3).unwrap(); assert_eq!(results.len(), 3); } @@ -377,4 +429,143 @@ mod tests { let inserted = db.build_index().unwrap(); assert_eq!(inserted, 10); } + + #[test] + fn test_create_and_load_snapshot() { + let (old_db, temp_dir) = create_test_db(); + + let v1 = vec![0.0, 1.0, 2.0]; + let v2 = vec![3.0, 4.0, 5.0]; + let v3 = vec![6.0, 7.0, 8.0]; + + let id1 = old_db + .insert( + v1.clone(), + Payload { + content_type: ContentType::Text, + content: "test".to_string(), + }, + ) + .unwrap(); + + let id2 = old_db + .insert( + v2.clone(), + Payload { + content_type: ContentType::Text, + content: "test".to_string(), + }, + ) + .unwrap(); + + let temp_snapshot_dir = tempdir().unwrap(); + let snapshot_path = old_db.create_snapshot(temp_snapshot_dir.path()).unwrap(); + + // insert v3 after snapshot + let id3 = old_db + .insert( + v3.clone(), + Payload { + content_type: ContentType::Text, + content: "test".to_string(), + }, + ) + .unwrap(); + + let reload_config = DbRestoreConfig { + data_path: temp_dir.path().to_path_buf(), + snapshot_path, + }; + + std::mem::drop(old_db); + let loaded_db = restore_from_snapshot(&reload_config).unwrap(); + + assert!(loaded_db.get(id1).unwrap_or(None).is_some()); + assert!(loaded_db.get(id2).unwrap_or(None).is_some()); + assert!(loaded_db.get(id3).unwrap_or(None).is_none()); // v3 was inserted after snapshot was taken + + // vector restore check + assert!(loaded_db.get(id1).unwrap().unwrap().vector.unwrap() == v1); + assert!(loaded_db.get(id2).unwrap().unwrap().vector.unwrap() == v2); + } + + #[test] + fn test_snapshot_engine() { + let (_db, _temp_dir) = create_test_db(); + let db = Arc::new(Mutex::new(_db)); + + let registry_tempdir = tempdir().unwrap(); + + let registry = Arc::new(Mutex::new( + LocalRegistry::new(registry_tempdir.path()).unwrap(), + )); + + let last_k = 4; + let mut se = SnapshotEngine::new(last_k, db.clone(), registry.clone()); + + let v1 = vec![0.0, 1.0, 2.0]; + let v2 = vec![3.0, 4.0, 5.0]; + let v3 = vec![6.0, 7.0, 8.0]; + + let test_vectors = vec![v1.clone(), v2.clone(), v3.clone()]; + let mut inserted_ids = Vec::new(); + + for (i, vector) in test_vectors.clone().into_iter().enumerate() { + se.snapshot().unwrap(); + let id = db + .lock() + .unwrap() + .insert( + vector.clone(), + Payload { + content_type: ContentType::Text, + content: format!("{}", i), + }, + ) + .unwrap(); + inserted_ids.push(id); + } + se.snapshot().unwrap(); + let snapshots = se.list_alive_snapshots().unwrap(); + + // asserting these cases: + // snapshot 0 : no vectors + // snapshot 1 : v1 + // snapshot 2 : v1, v2 + // snapshot 3 : v1, v2, v3 + + std::mem::drop(db); + std::mem::drop(se); + + for (i, snapshot) in snapshots.iter().enumerate() { + let temp_dir = tempdir().unwrap(); + let db = restore_from_snapshot(&DbRestoreConfig { + data_path: temp_dir.path().to_path_buf(), + snapshot_path: snapshot.path.clone(), + }) + .unwrap(); + for j in 0..i { + // test if point is present + assert!(db.get(inserted_ids[j]).unwrap_or(None).is_some()); + // test vector restore + assert!( + db.get(inserted_ids[j]).unwrap().unwrap().vector.unwrap() == test_vectors[j] + ); + // test payload restore + assert!( + db.get(inserted_ids[j]) + .unwrap() + .unwrap() + .payload + .unwrap() + .content + == format!("{}", j) + ); + } + for absent_id in inserted_ids.iter().skip(i) { + assert!(db.get(*absent_id).unwrap_or(None).is_none()); + } + std::mem::drop(db); + } + } } diff --git a/crates/defs/src/error.rs b/crates/defs/src/error.rs index 89116ce..14ba7be 100644 --- a/crates/defs/src/error.rs +++ b/crates/defs/src/error.rs @@ -12,6 +12,13 @@ pub enum DbError { IndexInitError, //TODO: Change this UnsupportedSimilarity, DimensionMismatch, + SnapshotError(String), + StorageInitializationError, + StorageCheckpointError(String), + InvalidMagicBytes(String), + VectorNotFound(uuid::Uuid), + SnapshotRegistryError(String), + StorageEngineError(String), InvalidDimension { expected: Dimension, got: Dimension }, PointAlreadyExists { id: PointId }, PointNotFound { id: PointId }, diff --git a/crates/defs/src/lib.rs b/crates/defs/src/lib.rs index c2a79bf..cf3b3f3 100644 --- a/crates/defs/src/lib.rs +++ b/crates/defs/src/lib.rs @@ -3,4 +3,10 @@ pub mod types; // Without re-exports, users would need to write defs::types::SomeType instead of just defs::SomeType. Re-exports simplify the API by flattening the module hierarchy. The * means "everything public" from that module. pub use error::*; +use std::path::{Path, PathBuf}; pub use types::*; + +// hoisted trait so it can be used by the snapshots crate +pub trait SnapshottableDb: Send + Sync { + fn create_snapshot(&self, dir_path: &Path) -> Result; +} diff --git a/crates/defs/src/types.rs b/crates/defs/src/types.rs index 65c861a..03f3f47 100644 --- a/crates/defs/src/types.rs +++ b/crates/defs/src/types.rs @@ -15,6 +15,8 @@ pub type Dimension = usize; // Sparse vector implementation not supported yet. Refer lib/sparse/src/common/sparse_vector.rs pub type DenseVector = Vec; +pub type Magic = [u8; 4]; + pub enum StoredVector { Dense(DenseVector), } diff --git a/crates/index/Cargo.toml b/crates/index/Cargo.toml index 8fe2733..1e92124 100644 --- a/crates/index/Cargo.toml +++ b/crates/index/Cargo.toml @@ -10,3 +10,6 @@ license.workspace = true defs.workspace = true rand.workspace = true uuid.workspace = true +bincode.workspace = true +serde.workspace = true +storage.workspace = true diff --git a/crates/index/src/flat.rs b/crates/index/src/flat.rs deleted file mode 100644 index c0910e3..0000000 --- a/crates/index/src/flat.rs +++ /dev/null @@ -1,270 +0,0 @@ -use defs::{DbError, DenseVector, DistanceOrderedVector, IndexedVector, PointId, Similarity}; - -use crate::{VectorIndex, distance}; - -pub struct FlatIndex { - index: Vec, -} - -impl FlatIndex { - pub fn new() -> Self { - Self { index: Vec::new() } - } - - pub fn build(vectors: Vec) -> Self { - FlatIndex { index: vectors } - } -} - -impl Default for FlatIndex { - fn default() -> Self { - Self::new() - } -} - -impl VectorIndex for FlatIndex { - fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> { - self.index.push(vector); - Ok(()) - } - - fn delete(&mut self, point_id: PointId) -> Result { - if let Some(pos) = self.index.iter().position(|vector| vector.id == point_id) { - self.index.remove(pos); - Ok(true) - } else { - Ok(false) - } - } - - fn search( - &self, - query_vector: DenseVector, - similarity: Similarity, - k: usize, - ) -> Result, DbError> { - let scores = self - .index - .iter() - .map(|point| DistanceOrderedVector { - distance: distance(&point.vector, &query_vector, similarity), - query_vector: &query_vector, - point_id: Some(point.id), - }) - .collect::>(); - - // select k smallest elements in scores using a max heap - let mut heap = std::collections::BinaryHeap::::new(); - for score in scores { - if heap.len() < k { - heap.push(score); - } else if score < *heap.peek().unwrap() { - heap.pop(); - heap.push(score); - } - } - Ok(heap - .into_sorted_vec() - .into_iter() - .map(|v| v.point_id.unwrap()) - .collect()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use uuid::Uuid; - - #[test] - fn test_flat_index_new() { - let index = FlatIndex::new(); - assert_eq!(index.index.len(), 0); - } - - #[test] - fn test_flat_index_build() { - let vectors = vec![ - IndexedVector { - id: Uuid::new_v4(), - vector: vec![1.0, 2.0, 3.0], - }, - IndexedVector { - id: Uuid::new_v4(), - vector: vec![4.0, 5.0, 6.0], - }, - ]; - let index = FlatIndex::build(vectors.clone()); - assert_eq!(index.index, vectors); - } - - #[test] - fn test_insert() { - let mut index = FlatIndex::new(); - let vector = IndexedVector { - id: Uuid::new_v4(), - vector: vec![1.0, 2.0, 3.0], - }; - - assert!(index.insert(vector.clone()).is_ok()); - assert_eq!(index.index.len(), 1); - assert_eq!(index.index[0], vector); - } - - #[test] - fn test_delete_existing() { - let mut index = FlatIndex::new(); - let existing_id = Uuid::new_v4(); - let vector = IndexedVector { - id: existing_id, - vector: vec![1.0, 2.0, 3.0], - }; - index.insert(vector).unwrap(); - - let result = index.delete(existing_id).unwrap(); - assert!(result); - assert_eq!(index.index.len(), 0); - } - - #[test] - fn test_delete_non_existing() { - let mut index = FlatIndex::new(); - let vector = IndexedVector { - id: Uuid::new_v4(), - vector: vec![1.0, 2.0, 3.0], - }; - index.insert(vector).unwrap(); - - let result = index.delete(Uuid::new_v4()).unwrap(); - assert!(!result); - assert_eq!(index.index.len(), 1); - } - - #[test] - fn test_search_euclidean() { - let mut index = FlatIndex::new(); - let id1 = Uuid::new_v4(); - let id2 = Uuid::new_v4(); - let id3 = Uuid::new_v4(); - index - .insert(IndexedVector { - id: id1, - vector: vec![1.0, 1.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id2, - vector: vec![2.0, 2.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id3, - vector: vec![10.0, 10.0], - }) - .unwrap(); - - let results = index - .search(vec![0.0, 0.0], Similarity::Euclidean, 2) - .unwrap(); - assert_eq!(results, vec![id1, id2]); - } - - #[test] - fn test_search_cosine() { - let mut index = FlatIndex::new(); - let id1 = Uuid::new_v4(); - let id2 = Uuid::new_v4(); - let id3 = Uuid::new_v4(); - index - .insert(IndexedVector { - id: id1, - vector: vec![1.0, 0.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id2, - vector: vec![0.5, 0.5], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id3, - vector: vec![0.0, 1.0], - }) - .unwrap(); - - let results = index.search(vec![1.0, 1.0], Similarity::Cosine, 2).unwrap(); - assert_eq!(results, vec![id2, id1]); - } - - #[test] - fn test_search_manhattan() { - let mut index = FlatIndex::new(); - let id1 = Uuid::new_v4(); - let id2 = Uuid::new_v4(); - let id3 = Uuid::new_v4(); - index - .insert(IndexedVector { - id: id1, - vector: vec![1.0, 1.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id2, - vector: vec![2.0, 2.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id3, - vector: vec![5.0, 5.0], - }) - .unwrap(); - - let results = index - .search(vec![0.0, 0.0], Similarity::Manhattan, 2) - .unwrap(); - assert_eq!(results, vec![id1, id2]); - } - - #[test] - fn test_search_hamming() { - let mut index = FlatIndex::new(); - let id1 = Uuid::new_v4(); - let id2 = Uuid::new_v4(); - let id3 = Uuid::new_v4(); - index - .insert(IndexedVector { - id: id1, - vector: vec![1.0, 0.0, 1.0, 1.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id2, - vector: vec![1.0, 0.0, 0.0, 0.0], - }) - .unwrap(); - index - .insert(IndexedVector { - id: id3, - vector: vec![0.0, 0.0, 0.0, 0.0], - }) - .unwrap(); - - let results = index - .search(vec![1.0, 0.0, 0.0, 0.0], Similarity::Hamming, 2) - .unwrap(); - assert_eq!(results, vec![id2, id3]); - } - - #[test] - fn test_default() { - let index = FlatIndex::default(); - assert_eq!(index.index.len(), 0); - } -} diff --git a/crates/index/src/flat/index.rs b/crates/index/src/flat/index.rs new file mode 100644 index 0000000..87f814f --- /dev/null +++ b/crates/index/src/flat/index.rs @@ -0,0 +1,71 @@ +use crate::{VectorIndex, distance}; +use defs::{DbError, DenseVector, DistanceOrderedVector, IndexedVector, PointId, Similarity}; + +pub struct FlatIndex { + pub index: Vec, +} + +impl FlatIndex { + pub fn new() -> Self { + Self { index: Vec::new() } + } + + pub fn build(vectors: Vec) -> Self { + FlatIndex { index: vectors } + } +} + +impl Default for FlatIndex { + fn default() -> Self { + Self::new() + } +} + +impl VectorIndex for FlatIndex { + fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError> { + self.index.push(vector); + Ok(()) + } + + fn delete(&mut self, point_id: PointId) -> Result { + if let Some(pos) = self.index.iter().position(|vector| vector.id == point_id) { + self.index.remove(pos); + Ok(true) + } else { + Ok(false) + } + } + + fn search( + &self, + query_vector: DenseVector, + similarity: Similarity, + k: usize, + ) -> Result, DbError> { + let scores = self + .index + .iter() + .map(|point| DistanceOrderedVector { + distance: distance(&point.vector, &query_vector, similarity), + query_vector: &query_vector, + point_id: Some(point.id), + }) + .collect::>(); + + // select k smallest elements in scores using a max heap + let mut heap = std::collections::BinaryHeap::::new(); + for score in scores { + if heap.len() < k { + heap.push(score); + } else if score < *heap.peek().unwrap() { + heap.pop(); + heap.push(score); + } + } + Ok(heap + .into_sorted_vec() + .into_iter() + .map(|v| v.point_id.unwrap()) + .collect()) + } +} diff --git a/crates/index/src/flat/mod.rs b/crates/index/src/flat/mod.rs new file mode 100644 index 0000000..5e3f726 --- /dev/null +++ b/crates/index/src/flat/mod.rs @@ -0,0 +1,9 @@ +use defs::Magic; + +pub mod index; +mod serialize; + +#[cfg(test)] +mod tests; + +pub const FLAT_MAGIC_BYTES: Magic = [0x00, 0x00, 0x00, 0x01]; diff --git a/crates/index/src/flat/serialize.rs b/crates/index/src/flat/serialize.rs new file mode 100644 index 0000000..32af97e --- /dev/null +++ b/crates/index/src/flat/serialize.rs @@ -0,0 +1,108 @@ +use super::FLAT_MAGIC_BYTES; +use crate::IndexType; +use crate::flat::index::FlatIndex; +use crate::{IndexSnapshot, SerializableIndex}; +use defs::{DbError, IndexedVector}; +use serde::{Deserialize, Serialize}; +use std::io::{Cursor, Read}; +use storage::StorageEngine; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlatIndexMetadata { + total_points: usize, +} + +impl FlatIndex { + pub fn deserialize( + IndexSnapshot { + index_type, + magic, + topology_b, + metadata_b, + }: &IndexSnapshot, + ) -> Result { + if index_type != &IndexType::Flat { + return Err(DbError::SerializationError( + "Invalid index type".to_string(), + )); + } + + if magic != &FLAT_MAGIC_BYTES { + return Err(DbError::SerializationError( + "Invalid magic bytes".to_string(), + )); + } + + let metadata: FlatIndexMetadata = bincode::deserialize(metadata_b).map_err(|e| { + DbError::SerializationError(format!("Failed to deserialize FlatIndex Metadata: {}", e)) + })?; + let total_points = metadata.total_points; + + let mut cursor = Cursor::new(topology_b); + let mut vectors = Vec::new(); + + for _ in 0..total_points { + let mut uuid_slice = [0u8; 16]; + cursor.read_exact(&mut uuid_slice).map_err(|e| { + DbError::SerializationError(format!( + "Failed to deserialize FlatIndex Topology: {}", + e + )) + })?; + let id = Uuid::from_bytes_le(uuid_slice); + vectors.push(IndexedVector { + id, + vector: Vec::new(), + }); + } + + Ok(FlatIndex { index: vectors }) + } +} + +impl SerializableIndex for FlatIndex { + fn serialize_topology(&self) -> Result, DbError> { + let mut buffer: Vec = Vec::new(); + for point in &self.index { + buffer.extend_from_slice(&point.id.to_bytes_le()); + } + + Ok(buffer) + } + + fn serialize_metadata(&self) -> Result, DbError> { + let mut buffer: Vec = Vec::new(); + let metadata = FlatIndexMetadata { + total_points: self.index.len(), + }; + + let metadata_bytes = bincode::serialize(&metadata).map_err(|e| { + DbError::SerializationError(format!("Failed to serialize FlatIndex Metadata: {}", e)) + })?; + + buffer.extend_from_slice(&metadata_bytes); + Ok(buffer) + } + + fn populate_vectors(&mut self, storage: &dyn StorageEngine) -> Result<(), DbError> { + for item in &mut self.index { + item.vector = storage + .get_vector(item.id)? + .ok_or(DbError::VectorNotFound(item.id))?; + } + Ok(()) + } + + fn snapshot(&self) -> Result { + let topology = self.serialize_topology()?; + let metadata = self.serialize_metadata()?; + + Ok(IndexSnapshot { + metadata_b: metadata, + topology_b: topology, + magic: FLAT_MAGIC_BYTES, + index_type: IndexType::Flat, + }) + } +} diff --git a/crates/index/src/flat/tests.rs b/crates/index/src/flat/tests.rs new file mode 100644 index 0000000..6d43c3d --- /dev/null +++ b/crates/index/src/flat/tests.rs @@ -0,0 +1,238 @@ +use super::index::FlatIndex; +use crate::{SerializableIndex, VectorIndex}; +use defs::{IndexedVector, Similarity}; +use uuid::Uuid; + +#[test] +fn test_flat_index_new() { + let index = FlatIndex::new(); + assert_eq!(index.index.len(), 0); +} + +#[test] +fn test_flat_index_build() { + let vectors = vec![ + IndexedVector { + id: Uuid::new_v4(), + vector: vec![1.0, 2.0, 3.0], + }, + IndexedVector { + id: Uuid::new_v4(), + vector: vec![4.0, 5.0, 6.0], + }, + ]; + let index = FlatIndex::build(vectors.clone()); + assert_eq!(index.index, vectors); +} + +#[test] +fn test_insert() { + let mut index = FlatIndex::new(); + let vector = IndexedVector { + id: Uuid::new_v4(), + vector: vec![1.0, 2.0, 3.0], + }; + + assert!(index.insert(vector.clone()).is_ok()); + assert_eq!(index.index.len(), 1); + assert_eq!(index.index[0], vector); +} + +#[test] +fn test_delete_existing() { + let mut index = FlatIndex::new(); + let existing_id = Uuid::new_v4(); + let vector = IndexedVector { + id: existing_id, + vector: vec![1.0, 2.0, 3.0], + }; + index.insert(vector).unwrap(); + + let result = index.delete(existing_id).unwrap(); + assert!(result); + assert_eq!(index.index.len(), 0); +} + +#[test] +fn test_delete_non_existing() { + let mut index = FlatIndex::new(); + let vector = IndexedVector { + id: Uuid::new_v4(), + vector: vec![1.0, 2.0, 3.0], + }; + index.insert(vector).unwrap(); + + let result = index.delete(Uuid::new_v4()).unwrap(); + assert!(!result); + assert_eq!(index.index.len(), 1); +} + +#[test] +fn test_search_euclidean() { + let mut index = FlatIndex::new(); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + index + .insert(IndexedVector { + id: id1, + vector: vec![1.0, 1.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id2, + vector: vec![2.0, 2.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id3, + vector: vec![10.0, 10.0], + }) + .unwrap(); + + let results = index + .search(vec![0.0, 0.0], Similarity::Euclidean, 2) + .unwrap(); + assert_eq!(results, vec![id1, id2]); +} + +#[test] +fn test_search_cosine() { + let mut index = FlatIndex::new(); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + index + .insert(IndexedVector { + id: id1, + vector: vec![1.0, 0.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id2, + vector: vec![0.5, 0.5], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id3, + vector: vec![0.0, 1.0], + }) + .unwrap(); + + let results = index.search(vec![1.0, 1.0], Similarity::Cosine, 2).unwrap(); + assert_eq!(results, vec![id2, id1]); +} + +#[test] +fn test_search_manhattan() { + let mut index = FlatIndex::new(); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + index + .insert(IndexedVector { + id: id1, + vector: vec![1.0, 1.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id2, + vector: vec![2.0, 2.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id3, + vector: vec![5.0, 5.0], + }) + .unwrap(); + + let results = index + .search(vec![0.0, 0.0], Similarity::Manhattan, 2) + .unwrap(); + assert_eq!(results, vec![id1, id2]); +} + +#[test] +fn test_search_hamming() { + let mut index = FlatIndex::new(); + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + index + .insert(IndexedVector { + id: id1, + vector: vec![1.0, 0.0, 1.0, 1.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id2, + vector: vec![1.0, 0.0, 0.0, 0.0], + }) + .unwrap(); + index + .insert(IndexedVector { + id: id3, + vector: vec![0.0, 0.0, 0.0, 0.0], + }) + .unwrap(); + + let results = index + .search(vec![1.0, 0.0, 0.0, 0.0], Similarity::Hamming, 2) + .unwrap(); + assert_eq!(results, vec![id2, id3]); +} + +#[test] +fn test_default() { + let index = FlatIndex::default(); + assert_eq!(index.index.len(), 0); +} + +#[test] +fn test_serialize_and_deserialize_topo() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let id4 = Uuid::new_v4(); + + let v1 = IndexedVector { + id: id1, + vector: vec![0.0, 0.0, 0.0, 0.0], + }; + let v2 = IndexedVector { + id: id2, + vector: vec![1.0, 0.0, 0.0, 0.0], + }; + let v3 = IndexedVector { + id: id3, + vector: vec![2.0, 0.0, 0.0, 0.0], + }; + let v4 = IndexedVector { + id: id4, + vector: vec![3.0, 0.0, 0.0, 0.0], + }; + + let vectors = vec![v1.clone(), v2.clone(), v3.clone(), v4.clone()]; + let mut index_before = FlatIndex::build(vectors); + index_before.insert(v4.clone()).unwrap(); + + index_before.delete(id1).unwrap(); + + let snapshot = index_before.snapshot().unwrap(); + + let idx = FlatIndex::deserialize(&snapshot).unwrap(); + + assert_eq!(idx.index.len(), 4); + assert!(!idx.index.iter().any(|v| v.id == id1)); + assert!(idx.index.iter().any(|v| v.id == id2)); + assert!(idx.index.iter().any(|v| v.id == id3)); + assert!(idx.index.iter().any(|v| v.id == id3)); + assert!(idx.index.iter().any(|v| v.id == id4)); +} diff --git a/crates/index/src/hnsw/mod.rs b/crates/index/src/hnsw/mod.rs index dfbc7ae..35ae15f 100644 --- a/crates/index/src/hnsw/mod.rs +++ b/crates/index/src/hnsw/mod.rs @@ -3,6 +3,7 @@ pub mod index; pub mod search; +pub mod serialize; pub mod types; pub use index::HnswIndex; diff --git a/crates/index/src/hnsw/serialize.rs b/crates/index/src/hnsw/serialize.rs new file mode 100644 index 0000000..f632c82 --- /dev/null +++ b/crates/index/src/hnsw/serialize.rs @@ -0,0 +1,21 @@ +use defs::DbError; +use storage::StorageEngine; + +use crate::{IndexSnapshot, SerializableIndex, hnsw::HnswIndex}; + +impl SerializableIndex for HnswIndex { + fn serialize_topology(&self) -> Result, DbError> { + return Err(DbError::SerializationError("not implemented".to_string())); + } + fn serialize_metadata(&self) -> Result, DbError> { + return Err(DbError::SerializationError("not implemented".to_string())); + } + + fn snapshot(&self) -> Result { + return Err(DbError::SerializationError("not implemented".to_string())); + } + + fn populate_vectors(&mut self, _storage: &dyn StorageEngine) -> Result<(), DbError> { + return Err(DbError::SerializationError("not implemented".to_string())); + } +} diff --git a/crates/index/src/hnsw/tests.rs b/crates/index/src/hnsw/tests.rs index 3f0e4f3..605c8f7 100644 --- a/crates/index/src/hnsw/tests.rs +++ b/crates/index/src/hnsw/tests.rs @@ -1,6 +1,6 @@ use super::*; use crate::VectorIndex; -use crate::flat::FlatIndex; +use crate::flat::index::FlatIndex; use defs::{IndexedVector, Similarity}; use uuid::Uuid; diff --git a/crates/index/src/kd_tree/mod.rs b/crates/index/src/kd_tree/mod.rs index fa8a23d..52aa7b9 100644 --- a/crates/index/src/kd_tree/mod.rs +++ b/crates/index/src/kd_tree/mod.rs @@ -1,6 +1,11 @@ +use defs::Magic; + pub mod helpers; pub mod index; +mod serialize; pub mod types; #[cfg(test)] mod tests; + +pub const KD_TREE_MAGIC_BYTES: Magic = [0x00, 0x00, 0x00, 0x00]; diff --git a/crates/index/src/kd_tree/serialize.rs b/crates/index/src/kd_tree/serialize.rs new file mode 100644 index 0000000..5d8519f --- /dev/null +++ b/crates/index/src/kd_tree/serialize.rs @@ -0,0 +1,204 @@ +use std::collections::HashSet; +use std::io::{Cursor, Read, Write}; + +use super::KD_TREE_MAGIC_BYTES; +use super::index::KDTree; +use super::types::KDTreeNode; +use crate::{IndexSnapshot, IndexType, SerializableIndex}; +use bincode; +use defs::{DbError, IndexedVector, PointId}; +use serde::{Deserialize, Serialize}; +use storage::StorageEngine; +use uuid::Uuid; + +#[derive(Serialize, Deserialize)] +pub struct KDTreeMetadata { + pub dim: usize, + pub total_nodes: usize, + pub deleted_count: usize, +} + +impl SerializableIndex for KDTree { + fn serialize_topology(&self) -> Result, DbError> { + let mut buffer = Vec::new(); + let mut cursor = Cursor::new(&mut buffer); + serialize_topology_recursive(&self.root, &mut cursor)?; + Ok(buffer) + } + + fn serialize_metadata(&self) -> Result, DbError> { + let mut buffer = Vec::new(); + let km = KDTreeMetadata { + dim: self.dim, + total_nodes: self.total_nodes, + deleted_count: self.deleted_count, + }; + let metadata_bytes = bincode::serialize(&km).map_err(|e| { + DbError::SerializationError(format!("Failed to serailize KD Tree Metadata: {}", e)) + })?; + buffer.extend_from_slice(metadata_bytes.as_slice()); + Ok(buffer) + } + + fn snapshot(&self) -> Result { + let topology_bytes = self.serialize_topology()?; + let metadata_bytes = self.serialize_metadata()?; + Ok(IndexSnapshot { + index_type: crate::IndexType::KDTree, + magic: KD_TREE_MAGIC_BYTES, + topology_b: topology_bytes, + metadata_b: metadata_bytes, + }) + } + + fn populate_vectors(&mut self, storage: &dyn StorageEngine) -> Result<(), DbError> { + populate_vectors_recursive(&mut self.root, storage)?; + Ok(()) + } +} + +const NODE_MARKER_BYTE: u8 = 1u8; +const SKIP_MARKER_BYTE: u8 = 0u8; + +const DELETED_MASK: u8 = 2u8; + +impl KDTree { + pub fn deserialize( + IndexSnapshot { + index_type, + magic, + topology_b, + metadata_b, + }: &IndexSnapshot, + ) -> Result { + if index_type != &IndexType::KDTree { + return Err(DbError::SerializationError( + "Invalid index type".to_string(), + )); + } + + if magic != &KD_TREE_MAGIC_BYTES { + return Err(DbError::SerializationError( + "Invalid magic bytes".to_string(), + )); + } + + let metadata: KDTreeMetadata = + bincode::deserialize(metadata_b.as_slice()).map_err(|e| { + DbError::SerializationError(format!( + "Failed to deserailize KD Tree Metadata: {}", + e + )) + })?; + + let mut buf = Cursor::new(topology_b); + let mut non_deleted = HashSet::new(); + let root = deserialize_topology_recursive(metadata.dim, 0, &mut buf, &mut non_deleted)?; + + Ok(KDTree { + dim: metadata.dim, + root, + point_ids: non_deleted, + total_nodes: metadata.total_nodes, + deleted_count: metadata.deleted_count, + }) + } +} + +// helper functions + +fn serialize_topology_recursive( + current_opt: &Option>, + buffer: &mut Cursor<&mut Vec>, +) -> Result<(), DbError> { + if let Some(current) = current_opt { + let mut marker = NODE_MARKER_BYTE; + if current.is_deleted { + marker |= DELETED_MASK; + } + buffer + .write_all(&[marker]) + .map_err(|e| DbError::SerializationError(e.to_string()))?; + + let uuid_bytes = current.indexed_vector.id.to_bytes_le(); + buffer + .write_all(&uuid_bytes) + .map_err(|e| DbError::SerializationError(e.to_string()))?; + + // serialize left subtree topology + serialize_topology_recursive(¤t.left, buffer)?; + // serialize right subtree topology + serialize_topology_recursive(¤t.right, buffer)?; + } else { + buffer + .write_all(&[SKIP_MARKER_BYTE]) + .map_err(|e| DbError::SerializationError(e.to_string()))?; + } + Ok(()) +} + +fn populate_vectors_recursive( + node: &mut Option>, + storage: &dyn StorageEngine, +) -> Result<(), DbError> { + if let Some(node) = node { + let vector = storage + .get_vector(node.indexed_vector.id)? + .ok_or(DbError::VectorNotFound(node.indexed_vector.id))?; + node.indexed_vector.vector = vector; + + populate_vectors_recursive(&mut node.left, storage)?; + populate_vectors_recursive(&mut node.right, storage)?; + } + Ok(()) +} + +fn deserialize_topology_recursive( + dimensions: usize, + depth: usize, + buffer: &mut Cursor<&Vec>, + non_deleted: &mut HashSet, +) -> Result>, DbError> { + let mut current_marker: [u8; 1] = [0u8; 1]; + buffer.read_exact(&mut current_marker).map_err(|e| { + DbError::SerializationError(format!("Failed to deserialize KD Topology: {}", e)) + })?; + + if current_marker[0] == SKIP_MARKER_BYTE { + return Ok(None); + } + + let mut uuid_bytes = [0u8; 16]; + buffer.read_exact(&mut uuid_bytes).map_err(|e| { + DbError::SerializationError(format!("Failed to deserialize KD Topology: {}", e)) + })?; + let uuid = Uuid::from_bytes_le(uuid_bytes); + let indexed_vector = IndexedVector { + id: uuid, + vector: Vec::new(), + }; + + let is_deleted = current_marker[0] & DELETED_MASK == DELETED_MASK; + if !is_deleted { + non_deleted.insert(uuid); + } + + // pre order deserialization + let lower_dim = (depth + 1) % dimensions; + let left_node = deserialize_topology_recursive(dimensions, lower_dim, buffer, non_deleted)?; + let right_node = deserialize_topology_recursive(dimensions, lower_dim, buffer, non_deleted)?; + + let left_size = left_node.as_ref().map_or(0, |n| n.subtree_size); + let right_size = right_node.as_ref().map_or(0, |n| n.subtree_size); + + let current_node = KDTreeNode { + indexed_vector, + left: left_node, + right: right_node, + is_deleted, + axis: depth, + subtree_size: left_size + right_size + 1, + }; + + Ok(Some(Box::new(current_node))) +} diff --git a/crates/index/src/kd_tree/tests.rs b/crates/index/src/kd_tree/tests.rs index faefae5..5b30952 100644 --- a/crates/index/src/kd_tree/tests.rs +++ b/crates/index/src/kd_tree/tests.rs @@ -1,7 +1,8 @@ use super::index::KDTree; +use crate::SerializableIndex; use crate::VectorIndex; use crate::distance; -use crate::flat::FlatIndex; +use crate::flat::index::FlatIndex; use defs::{DbError, IndexedVector, Similarity}; use std::collections::HashSet; use uuid::Uuid; @@ -701,3 +702,33 @@ fn test_kdtree_vs_flat_euclidean_5d() { } } } + +#[test] +fn test_serialize_and_deserialize_topo() { + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + let id3 = Uuid::new_v4(); + let id4 = Uuid::new_v4(); + + let vectors = vec![ + make_vector_with_id(id1, vec![1.0, 2.0, 3.0]), + make_vector_with_id(id2, vec![4.0, 5.0, 6.0]), + make_vector_with_id(id3, vec![7.0, 8.0, 9.0]), + ]; + let mut tree_before = KDTree::build(vectors).unwrap(); + tree_before + .insert(make_vector_with_id(id4, vec![10.0, 11.0, 12.0])) + .unwrap(); + tree_before.delete(id1).unwrap(); + + let snapshot = tree_before.snapshot().unwrap(); + let tree = KDTree::deserialize(&snapshot).unwrap(); + + assert!(tree.root.is_some()); + assert_eq!(tree.dim, 3); + assert_eq!(tree.total_nodes, 4); + assert!(!tree.point_ids.contains(&id1)); + assert!(tree.point_ids.contains(&id2)); + assert!(tree.point_ids.contains(&id3)); + assert!(tree.point_ids.contains(&id3)); +} diff --git a/crates/index/src/kd_tree/types.rs b/crates/index/src/kd_tree/types.rs index aca187c..06d2234 100644 --- a/crates/index/src/kd_tree/types.rs +++ b/crates/index/src/kd_tree/types.rs @@ -1,4 +1,5 @@ use defs::{IndexedVector, OrdF32, PointId}; +use std::cmp::Ordering; // the node which will be the part of the KD Tree pub struct KDTreeNode { @@ -12,8 +13,25 @@ pub struct KDTreeNode { // The struct definition which is present in max heap while search // distance is first for correct Ord derivation (primary sort key) -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq)] pub struct Neighbor { pub distance: OrdF32, pub id: PointId, } + +impl Eq for Neighbor {} + +// Custom Ord implementation for the max-heap +impl Ord for Neighbor { + fn cmp(&self, other: &Self) -> Ordering { + self.distance + .partial_cmp(&other.distance) + .unwrap_or(Ordering::Equal) + } +} + +impl PartialOrd for Neighbor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} diff --git a/crates/index/src/lib.rs b/crates/index/src/lib.rs index 4e59ced..e166f3b 100644 --- a/crates/index/src/lib.rs +++ b/crates/index/src/lib.rs @@ -1,10 +1,11 @@ -use defs::{DbError, DenseVector, IndexedVector, PointId, Similarity}; - +use defs::{DbError, DenseVector, IndexedVector, Magic, PointId, Similarity}; +use serde::{Deserialize, Serialize}; +use storage::StorageEngine; pub mod flat; pub mod hnsw; pub mod kd_tree; -pub trait VectorIndex: Send + Sync { +pub trait VectorIndex: Send + Sync + SerializableIndex { fn insert(&mut self, vector: IndexedVector) -> Result<(), DbError>; // Returns true if point id existed and is deleted, else returns false @@ -60,9 +61,25 @@ pub fn distance(a: &DenseVector, b: &DenseVector, dist_type: Similarity) -> f32 } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] pub enum IndexType { Flat, KDTree, HNSW, } + +pub struct IndexSnapshot { + pub index_type: IndexType, + pub magic: Magic, + pub topology_b: Vec, + pub metadata_b: Vec, +} + +pub trait SerializableIndex { + fn serialize_topology(&self) -> Result, DbError>; + fn serialize_metadata(&self) -> Result, DbError>; + + fn snapshot(&self) -> Result; + + fn populate_vectors(&mut self, storage: &dyn StorageEngine) -> Result<(), DbError>; +} diff --git a/crates/snapshot/Cargo.toml b/crates/snapshot/Cargo.toml new file mode 100644 index 0000000..10328f1 --- /dev/null +++ b/crates/snapshot/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "snapshot" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +chrono.workspace = true +data-encoding = "2.9.0" +defs.workspace = true +flate2 = "1.1.5" +fs2 = "0.4.3" +index.workspace = true +semver = "1.0.27" +serde.workspace = true +serde_json.workspace = true +sha2 = "0.10.9" +storage.workspace = true +tar = "0.4.44" +tempfile.workspace = true +uuid.workspace = true diff --git a/crates/snapshot/README.md b/crates/snapshot/README.md new file mode 100644 index 0000000..e69de29 diff --git a/crates/snapshot/src/constants.rs b/crates/snapshot/src/constants.rs new file mode 100644 index 0000000..3dd46d3 --- /dev/null +++ b/crates/snapshot/src/constants.rs @@ -0,0 +1,5 @@ +use semver::Version; + +pub const SNAPSHOT_PARSER_VER: Version = Version::new(0, 1, 0); +pub const SMALL_ID_LEN: usize = 8; +pub const MANIFEST_FILE: &str = "manifest.json"; diff --git a/crates/snapshot/src/engine/mod.rs b/crates/snapshot/src/engine/mod.rs new file mode 100644 index 0000000..fe138f8 --- /dev/null +++ b/crates/snapshot/src/engine/mod.rs @@ -0,0 +1,172 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Condvar, Mutex}, + time::Duration, +}; + +use defs::{DbError, SnapshottableDb}; + +use crate::{metadata::Metadata, registry::SnapshotRegistry}; + +pub struct SnapshotEngine { + last_k: usize, // only retain the last k snapshots on disk. old/stale snapshots are marked as dead on the registry + snapshot_queue: Arc>>, + db: Arc>, + registry: Arc>, + worker_cv: Arc, + worker_running: Arc>, +} +impl SnapshotEngine { + pub fn new( + last_k: usize, + db: Arc>, + registry: Arc>, + ) -> Self { + Self { + last_k, + snapshot_queue: Arc::new(Mutex::new(VecDeque::new())), + db, + registry, + worker_cv: Arc::new(Condvar::new()), + worker_running: Arc::new(Mutex::new(false)), + } + } + + pub fn stop_worker(&mut self) -> Result<(), DbError> { + // acquire lock for worker_running + let mut worker_running = self.worker_running.lock().map_err(|_| DbError::LockError)?; + if !*worker_running { + return Err(DbError::StorageEngineError( + "Worker thread not running".to_string(), + )); + } + *worker_running = false; + self.worker_cv.notify_one(); + Ok(()) + } + + // notify the worker thread to take a snapshot now + pub fn worker_snapshot(&mut self) -> Result<(), DbError> { + // acquire lock for worker_running + let worker_running = self.worker_running.lock().map_err(|_| DbError::LockError)?; + if !*worker_running { + return Err(DbError::StorageEngineError( + "Worker thread not running".to_string(), + )); + } + self.worker_cv.notify_one(); + Ok(()) + } + + // take a snapshot on the callers thread + pub fn snapshot(&mut self) -> Result<(), DbError> { + Self::take_snapshot( + &mut self.db, + &mut self.registry, + &mut self.snapshot_queue, + self.last_k, + ) + } + + pub fn list_alive_snapshots(&mut self) -> Result, DbError> { + Ok(self + .snapshot_queue + .lock() + .map_err(|_| DbError::LockError)? + .iter() + .cloned() + .collect()) + } + + pub fn start_worker(&mut self, interval: i64) -> Result<(), DbError> { + // acquire lock for worker_running + let mut worker_running = self.worker_running.lock().map_err(|_| DbError::LockError)?; + if *worker_running { + return Err(DbError::StorageEngineError( + "Worker thread already running".to_string(), + )); + } + *worker_running = true; + + let worker_running_clone = Arc::clone(&self.worker_running); + let db_clone = Arc::clone(&self.db); + let registry_clone = Arc::clone(&self.registry); + let worker_cv_clone = Arc::clone(&self.worker_cv); + let snapshot_queue_clone = Arc::clone(&self.snapshot_queue); + let last_k_clone = self.last_k; + + let dur_interval = Duration::from_secs(interval as u64); + let _ = std::thread::spawn(move || { + Self::worker( + dur_interval, + last_k_clone, + worker_running_clone, + db_clone, + registry_clone, + worker_cv_clone, + snapshot_queue_clone, + ); + }); + Ok(()) + } + + // helper function to take snapshot + fn take_snapshot( + db: &mut Arc>, + registry: &mut Arc>, + snapshot_queue: &mut Arc>>, + last_k: usize, + ) -> Result<(), DbError> { + let snapshot_path = db + .lock() + .unwrap() + .create_snapshot(registry.lock().unwrap().dir().as_path()) + .unwrap(); + let snapshot_metadata = Metadata::parse(&snapshot_path).unwrap(); + + // add the snapshot to registry + registry + .lock() + .unwrap() + .add_snapshot(&snapshot_path) + .unwrap(); + + { + let mut queue = snapshot_queue.lock().unwrap(); + queue.push_back(snapshot_metadata); + + while queue.len() > last_k { + let old = queue.pop_front().unwrap(); + registry.lock().unwrap().mark_dead(old.small_id).unwrap(); + } + // drop queue lock + } + Ok(()) + } + + // TODO: fix sync issues if any (i dont think there are any) + fn worker( + interval: Duration, + last_k: usize, + worker_running: Arc>, + mut db: Arc>, + mut registry: Arc>, + worker_cv: Arc, + mut snapshot_queue: Arc>>, + ) { + loop { + // acquire the lock and exit if its false + let worker_running = worker_running + .lock() + .map_err(|_| DbError::LockError) + .unwrap(); + if !*worker_running { + break; + } + + Self::take_snapshot(&mut db, &mut registry, &mut snapshot_queue, last_k).unwrap(); + + let _ = worker_cv.wait_timeout(worker_running, interval).unwrap(); + } + } +} diff --git a/crates/snapshot/src/lib.rs b/crates/snapshot/src/lib.rs new file mode 100644 index 0000000..32311d2 --- /dev/null +++ b/crates/snapshot/src/lib.rs @@ -0,0 +1,294 @@ +pub mod constants; +pub mod engine; +pub mod manifest; +pub mod metadata; +pub mod registry; +mod util; + +use crate::{ + constants::{MANIFEST_FILE, SNAPSHOT_PARSER_VER}, + manifest::Manifest, + util::{compress_archive, save_index_metadata, save_topology}, +}; + +use chrono::{DateTime, Local}; +use defs::DbError; +use flate2::read::GzDecoder; +use index::{ + IndexSnapshot, IndexType, VectorIndex, flat::index::FlatIndex, kd_tree::index::KDTree, +}; +use semver::Version; +use std::{ + fs::File, + path::{Path, PathBuf}, + sync::{Arc, RwLock}, + time::SystemTime, +}; +use storage::{ + StorageEngine, StorageType, checkpoint::StorageCheckpoint, rocks_db::RocksDbStorage, +}; +use tar::Archive; +use tempfile::tempdir; +use uuid::Uuid; + +type VectorDbRestore = (Arc, Arc>, usize); + +pub struct Snapshot { + pub id: Uuid, + pub date: SystemTime, + pub sem_ver: Version, + pub index_snapshot: IndexSnapshot, + pub storage_snapshot: StorageCheckpoint, + pub dimensions: usize, +} + +impl Snapshot { + pub fn new( + index_snapshot: IndexSnapshot, + storage_snapshot: StorageCheckpoint, + dimensions: usize, + ) -> Result { + let id = Uuid::new_v4(); + let date = SystemTime::now(); + + Ok(Snapshot { + id, + date, + sem_ver: SNAPSHOT_PARSER_VER, + index_snapshot, + storage_snapshot, + dimensions, + }) + } + + pub fn save(&self, dir_path: &Path) -> Result { + if !dir_path.is_dir() { + return Err(DbError::SnapshotError(format!( + "Invalid path: {}", + dir_path.display() + ))); + } + + let temp_dir = tempdir().map_err(|e| DbError::SnapshotError(e.to_string()))?; + + // save index snapshots + let index_metadata_path = save_index_metadata( + temp_dir.path(), + self.id, + &self.index_snapshot.metadata_b, + &self.index_snapshot.magic, + )?; + + let topology_path = save_topology( + temp_dir.path(), + self.id, + &self.index_snapshot.topology_b, + &self.index_snapshot.magic, + )?; + + // take checksums + let index_metadata_checksum = util::sha256_digest(&index_metadata_path) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + let index_topo_checksum = util::sha256_digest(&topology_path) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + let storage_checkpoint_checksum = util::sha256_digest(&self.storage_snapshot.path) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + + let dt_now_local: DateTime = self.date.into(); + + // need this for manifest + let storage_checkpoint_filename = self + .storage_snapshot + .path + .file_name() + .ok_or(DbError::SnapshotError( + "Storage checkpoint was not properly made".to_string(), + ))? + .to_str() + .unwrap() + .to_string(); + + // create manifest file + let manifest = Manifest { + id: self.id, + date: dt_now_local.timestamp(), + sem_ver: constants::SNAPSHOT_PARSER_VER.to_string(), + index_metadata_checksum, + index_topo_checksum, + storage_checkpoint_checksum, + storage_type: self.storage_snapshot.storage_type, + index_type: self.index_snapshot.index_type, + dimensions: self.dimensions, + storage_checkpoint_filename, + }; + + let manifest_path = manifest + .save(temp_dir.path()) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + + let tar_filename = format!( + "{}.tar.gz", + metadata::Metadata::new( + self.id, + self.date, + index_metadata_path.clone(), + constants::SNAPSHOT_PARSER_VER + ) + ); + let tar_gz_path = dir_path.join(tar_filename); + + compress_archive( + &tar_gz_path, + &[ + &index_metadata_path, + &topology_path, + &self.storage_snapshot.path, + &manifest_path, + ], + ) + .map_err(|e| DbError::SnapshotError(e.to_string()))?; + Ok(tar_gz_path.to_path_buf()) + } + + pub fn load(path: &Path, storage_data_path: &Path) -> Result { + let tar_gz = File::open(path) + .map_err(|e| DbError::SnapshotError(format!("Couldn't open snapshot: {}", e)))?; + + let tar = GzDecoder::new(tar_gz); + let mut archive = Archive::new(tar); + + let snapshot_filename = path.file_name().ok_or(DbError::SnapshotError( + "Invalid snapshot filename".to_string(), + ))?; + let temp_dir = std::env::temp_dir().join(snapshot_filename); + + // remove any existing data + if temp_dir.exists() && !temp_dir.is_dir() { + std::fs::remove_file(temp_dir.clone()).map_err(|e| { + DbError::SnapshotError(format!("Couldn't remove existing file: {}", e)) + })?; + } else if temp_dir.is_dir() { + std::fs::remove_dir_all(temp_dir.clone()).map_err(|e| { + DbError::SnapshotError(format!("Couldn't remove existing directory: {}", e)) + })?; + } + + std::fs::create_dir(temp_dir.clone()).map_err(|e| { + DbError::SnapshotError(format!("Couldn't create temporary directory: {}", e)) + })?; + + archive + .unpack(temp_dir.clone()) + .map_err(|e| DbError::SnapshotError(format!("Couldn't unpack archive: {}", e)))?; + + // read manifest and validate + let manifest_path = temp_dir.join(MANIFEST_FILE); + if !manifest_path.is_file() { + return Err(DbError::SnapshotError( + "Manifest file not found".to_string(), + )); + } + + let manifest = Manifest::load(&manifest_path) + .map_err(|e| DbError::SnapshotError(format!("Couldn't load manifest: {}", e)))?; + + if manifest.sem_ver != SNAPSHOT_PARSER_VER.to_string() { + return Err(DbError::SnapshotError( + "Incompatible snapshot version".to_string(), + )); + } + + // only rocksdb is supported for snapshots as of now + let mut storage_engine: Box = match manifest.storage_type { + StorageType::RocksDb => Box::new(RocksDbStorage::new(storage_data_path)?), + _ => { + return Err(DbError::SnapshotError( + "Unsupported storage type".to_string(), + )); + } + }; + + let id = manifest.id; + let index_metadata_path = temp_dir.join(util::metadata_filename(&id)); + let topology_path = temp_dir.join(util::topology_filename(&id)); + let storage_checkpoint_path = temp_dir.join(manifest.storage_checkpoint_filename); + + if !index_metadata_path.exists() + || !topology_path.exists() + || !storage_checkpoint_path.exists() + { + return Err(DbError::SnapshotError(format!( + "Missing snapshot files {} , {}, {}", + index_metadata_path.display(), + topology_path.display(), + storage_checkpoint_path.display() + ))); + } + + // match checksums + if util::sha256_digest(&index_metadata_path).map_err(|_| { + DbError::SnapshotError("Could not calculate index metadata hash".to_string()) + })? != manifest.index_metadata_checksum + { + return Err(DbError::SnapshotError( + "Index metadata hash mismatch".to_string(), + )); + } + if util::sha256_digest(&topology_path) + .map_err(|_| DbError::SnapshotError("Could not calculate topology hash".to_string()))? + != manifest.index_topo_checksum + { + return Err(DbError::SnapshotError("Topology hash mismatch".to_string())); + } + if util::sha256_digest(&storage_checkpoint_path).map_err(|_| { + DbError::SnapshotError("Could not calculate storage checkpoint hash".to_string()) + })? != manifest.storage_checkpoint_checksum + { + return Err(DbError::SnapshotError( + "Storage checkpoint hash mismatch".to_string(), + )); + } + + let (mgmeta, meta_bytes) = util::read_index_metadata(&index_metadata_path) + .map_err(|_| DbError::SnapshotError("Could not read metadata".to_string()))?; + let (mgtopo, topo_bytes) = util::read_index_topology(&topology_path) + .map_err(|_| DbError::SnapshotError("Could not read topology".to_string()))?; + + if mgtopo != mgmeta { + return Err(DbError::InvalidMagicBytes( + "Magic bytes don't match".to_string(), + )); + } + + // validates if manifest storage type matches that in the filename of storage checkpoint + let storage_checkpoint = StorageCheckpoint::open(storage_checkpoint_path.as_path())?; + if storage_checkpoint.storage_type != manifest.storage_type { + return Err(DbError::SnapshotError( + "Storage type mismatch from manifest and checkpoint".to_string(), + )); + } + + storage_engine.restore_checkpoint(&storage_checkpoint)?; + + let index_snapshot = IndexSnapshot { + index_type: manifest.index_type, + magic: mgmeta, + metadata_b: meta_bytes, + topology_b: topo_bytes, + }; + + // dynamic dispatch based on index type + let vector_index: Arc> = match manifest.index_type { + IndexType::Flat => Arc::new(RwLock::new(FlatIndex::deserialize(&index_snapshot)?)), + IndexType::KDTree => Arc::new(RwLock::new(KDTree::deserialize(&index_snapshot)?)), + _ => return Err(DbError::SnapshotError("Unsupported index type".to_string())), + }; + + vector_index + .write() + .map_err(|_| DbError::LockError)? + .populate_vectors(&*storage_engine)?; + + Ok((storage_engine.into(), vector_index, manifest.dimensions)) + } +} diff --git a/crates/snapshot/src/manifest.rs b/crates/snapshot/src/manifest.rs new file mode 100644 index 0000000..0993435 --- /dev/null +++ b/crates/snapshot/src/manifest.rs @@ -0,0 +1,47 @@ +use index::IndexType; +use serde::{Deserialize, Serialize}; +use std::path::Path; +use std::{ + io::{BufReader, BufWriter, Error, Write}, + path::PathBuf, +}; +use storage::StorageType; +use uuid::Uuid; + +use crate::constants::MANIFEST_FILE; + +type UnixTimestamp = i64; + +#[derive(Serialize, Deserialize)] +pub struct Manifest { + pub id: Uuid, + pub date: UnixTimestamp, + pub sem_ver: String, + pub index_metadata_checksum: String, + pub index_topo_checksum: String, + pub storage_checkpoint_checksum: String, + pub index_type: IndexType, + pub storage_type: StorageType, + pub dimensions: usize, + pub storage_checkpoint_filename: String, +} + +impl Manifest { + pub fn save(&self, path: &Path) -> Result { + let manifest_path = path.join(MANIFEST_FILE); + + let file = std::fs::File::create(manifest_path.clone())?; + let mut writer = BufWriter::new(file); + serde_json::to_writer(&mut writer, self)?; + writer.flush()?; + + Ok(manifest_path) + } + + pub fn load(path: &Path) -> Result { + let file = std::fs::File::open(path)?; + let mut reader = BufReader::new(file); + let manifest: Manifest = serde_json::from_reader(&mut reader)?; + Ok(manifest) + } +} diff --git a/crates/snapshot/src/metadata.rs b/crates/snapshot/src/metadata.rs new file mode 100644 index 0000000..cd73185 --- /dev/null +++ b/crates/snapshot/src/metadata.rs @@ -0,0 +1,114 @@ +use crate::constants::SMALL_ID_LEN; +use chrono::DateTime; +use chrono::Local; +use defs::DbError; +use semver::Version; +use std::{fmt::Display, path::PathBuf, time::SystemTime}; +use std::{fs, path::Path}; +use uuid::Uuid; + +pub type SmallID = String; + +// Metadata is the data that can be parsed from the snapshot filename +#[derive(Debug, Clone)] +pub struct Metadata { + pub small_id: SmallID, + pub date: SystemTime, + pub path: PathBuf, + pub sem_ver: Version, +} + +const FILENAME_METADATA_SEPARATOR: &str = "-x"; + +impl Metadata { + pub fn new(id: Uuid, date: SystemTime, path: PathBuf, sem_ver: Version) -> Self { + Metadata { + small_id: id.to_string()[..SMALL_ID_LEN].to_string(), + date, + path, + sem_ver, + } + } + + pub fn parse(path: &Path) -> Result { + if !path.is_file() { + return Err(DbError::SnapshotError("File not found".to_string())); + } + let filename = path + .file_name() + .ok_or(DbError::SnapshotError("No filename".to_string()))? + .to_str() + .ok_or(DbError::SnapshotError( + "Invalid UTF-8 in filename".to_string(), + ))? + .strip_suffix(".tar.gz") + .ok_or(DbError::SnapshotError( + "Snapshot filename doesnt end with .tar.gz".to_string(), + ))?; + + let parts = filename + .split(FILENAME_METADATA_SEPARATOR) + .collect::>(); + + if parts.len() != 3 { + return Err(DbError::SnapshotError("Invalid filename".to_string())); + } + + let id = parts[1]; + if id.len() != SMALL_ID_LEN { + return Err(DbError::SnapshotError("Invalid UUID".to_string())); + } + + let date = chrono::DateTime::parse_from_rfc3339(parts[0]) + .map_err(|_| DbError::SnapshotError("Invalid date".to_string()))?; + let version = Version::parse(parts[2]) + .map_err(|_| DbError::SnapshotError("Invalid version".to_string()))?; + + Ok(Metadata { + small_id: id.to_string(), + date: date.into(), + path: path.to_path_buf(), + sem_ver: version, + }) + } + + pub fn snapshot_dir_metadata(path: &Path) -> Result, DbError> { + if !path.is_dir() { + return Err(DbError::SnapshotError( + "Path is not a directory".to_string(), + )); + } + + let mut metadata_vec = Vec::new(); + + for item in fs::read_dir(path).map_err(|_| { + DbError::SnapshotError(format!("Cannot read directory: {}", path.display())) + })? { + let entry = item.map_err(|_| { + DbError::SnapshotError(format!("Invalid entry: {}", path.display())) + })?; + let path = entry.path(); + if path.is_file() + && let Ok(metadata) = Self::parse(&path) + { + metadata_vec.push(metadata); + } + } + Ok(metadata_vec) + } +} + +impl Display for Metadata { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let dt_now_local: DateTime = self.date.into(); + write!( + f, + "{}{}{}{}{}", + dt_now_local.to_rfc3339_opts(chrono::SecondsFormat::Secs, true), + FILENAME_METADATA_SEPARATOR, + self.small_id, + FILENAME_METADATA_SEPARATOR, + self.sem_ver + ) + } +} diff --git a/crates/snapshot/src/registry/constants.rs b/crates/snapshot/src/registry/constants.rs new file mode 100644 index 0000000..9138454 --- /dev/null +++ b/crates/snapshot/src/registry/constants.rs @@ -0,0 +1 @@ +pub const LOCAL_REGISTRY_LOCKFILE: &str = "LOCKFILE"; diff --git a/crates/snapshot/src/registry/local.rs b/crates/snapshot/src/registry/local.rs new file mode 100644 index 0000000..8f79a73 --- /dev/null +++ b/crates/snapshot/src/registry/local.rs @@ -0,0 +1,271 @@ +use std::{ + collections::HashMap, + fs, + path::{Path, PathBuf}, +}; + +use crate::registry::{INFINITY_LIMIT, NO_OFFSET, SnapshotRegistry}; +use crate::registry::{SnapshotMetaPage, constants::LOCAL_REGISTRY_LOCKFILE}; +use crate::{ + Snapshot, VectorDbRestore, + metadata::{Metadata, SmallID}, +}; +use defs::DbError; +use fs2::FileExt; + +pub struct LocalRegistry { + pub dir: PathBuf, + filename_cache: HashMap, +} + +impl LocalRegistry { + pub fn new(dir: &Path) -> Result { + fs::create_dir_all(dir).map_err(|e| DbError::SnapshotRegistryError(e.to_string()))?; + let lock_file_path = dir.join(LOCAL_REGISTRY_LOCKFILE); + let lock_file = if !lock_file_path.exists() { + fs::File::create(&lock_file_path).map_err(|e| { + DbError::SnapshotRegistryError(format!("Couldn't create LOCKFILE : {}", e)) + })? + } else { + fs::OpenOptions::new() + .read(true) + .write(true) + .open(&lock_file_path) + .map_err(|e| { + DbError::SnapshotRegistryError(format!("Couldn't open LOCKFILE : {}", e)) + })? + }; + + // try to acquire lockfile + lock_file + .try_lock_exclusive() + .map_err(|_| DbError::SnapshotRegistryError("Couldn't acquire LOCKFILE".to_string()))?; + + Ok(LocalRegistry { + dir: dir.to_path_buf(), + filename_cache: HashMap::new(), + }) + } +} + +impl SnapshotRegistry for LocalRegistry { + fn add_snapshot(&mut self, snapshot_path: &Path) -> Result { + // move the snapshot file to the directory and cache its metadata + + let filename = snapshot_path + .file_name() + .ok_or(DbError::SnapshotRegistryError( + "Invalid snapshot path".to_string(), + ))?; + let final_snapshot_path = self.dir.join(filename); + + // if the snapshot is already in the managed directory then do nothing + if snapshot_path != final_snapshot_path.as_path() { + fs::rename(snapshot_path, final_snapshot_path.clone()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Failed to move snapshot: {}", e)) + })?; + } + + let metadata = Metadata::parse(final_snapshot_path.as_path())?; + self.filename_cache.insert( + metadata.small_id.clone(), + filename.to_string_lossy().to_string(), + ); + Ok(metadata) + } + + fn list_snapshots(&mut self, limit: usize, offset: usize) -> Result { + let mut res = Vec::new(); + let filtered_files = fs::read_dir(self.dir.as_path()) + .map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? + .skip(offset) + .take(limit); + + for file in filtered_files { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + + if let Ok(metadata) = Metadata::parse(file_path.as_path()) { + let filename = file_path + .file_name() + .ok_or(DbError::SnapshotRegistryError( + "Could not load filename of snapshot".to_string(), + ))? + .to_string_lossy(); + self.filename_cache + .insert(metadata.small_id.clone(), filename.to_string()); + + res.push(metadata); + } + } + Ok(res) + } + + fn get_latest_snapshot(&mut self) -> Result { + let mut latest_record: Option = None; + for file in fs::read_dir(self.dir.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + + if let Ok(metadata) = Metadata::parse(file_path.as_path()) { + let filename = file_path + .file_name() + .ok_or(DbError::SnapshotRegistryError( + "Could not load filename of snapshot".to_string(), + ))? + .to_string_lossy(); + self.filename_cache + .insert(metadata.small_id.clone(), filename.to_string()); + + latest_record = match latest_record { + None => Some(metadata), + Some(existing) => { + if metadata.date > existing.date { + Some(metadata) + } else { + Some(existing) + } + } + }; + } + } + match latest_record { + Some(metadata) => Ok(metadata), + None => Err(DbError::SnapshotRegistryError( + "No snapshots found".to_string(), + )), + } + } + + fn list_alive_snapshots(&mut self) -> Result { + self.list_snapshots(INFINITY_LIMIT, NO_OFFSET) + } + + fn remove_snapshot(&mut self, small_id: SmallID) -> Result { + if let Some(filename) = self.filename_cache.get(&small_id) { + let snapshot_filepath = self.dir.join(filename); + + let metadata = Metadata::parse(snapshot_filepath.as_path())?; + fs::remove_file(snapshot_filepath.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Failed to remove snapshot: {}", e)) + })?; + self.filename_cache.remove_entry(&small_id); + Ok(metadata) + } else { + for file in fs::read_dir(self.dir.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + if let Ok(metadata) = Metadata::parse(file_path.as_path()) + && metadata.small_id == small_id + { + fs::remove_file(metadata.path.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Failed to remove snapshot: {}", e)) + })?; + return Ok(metadata); + } + } + Err(DbError::SnapshotRegistryError( + "Snapshot not found".to_string(), + )) + } + } + + fn get_metadata(&mut self, small_id: SmallID) -> Result { + if let Some(filename) = self.filename_cache.get(&small_id) { + let snapshot_filepath = self.dir.join(filename); + let metadata = Metadata::parse(snapshot_filepath.as_path())?; + Ok(metadata) + } else { + for file in fs::read_dir(self.dir.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + if let Ok(metadata) = Metadata::parse(file_path.as_path()) + && metadata.small_id == small_id + { + return Ok(metadata); + } + } + Err(DbError::SnapshotRegistryError( + "Snapshot not found".to_string(), + )) + } + } + + fn mark_dead(&mut self, small_id: String) -> Result { + self.remove_snapshot(small_id) + } + + fn load( + &mut self, + small_id: String, + storage_data_path: &Path, + ) -> Result { + if let Some(filename) = self.filename_cache.get(&small_id) { + let snapshot_filepath = self.dir.join(filename); + Snapshot::load(snapshot_filepath.as_path(), storage_data_path) + } else { + for file in fs::read_dir(self.dir.as_path()).map_err(|e| { + DbError::SnapshotRegistryError(format!("Cannot read local registry dir: {}", e)) + })? { + let file = match file { + Ok(file) => file, + Err(_) => continue, + }; + let file_path = file.path(); + let metadata = Metadata::parse(file_path.as_path())?; + let filename = file_path + .file_name() + .ok_or(DbError::SnapshotRegistryError( + "Could not load filename of snapshot".to_string(), + ))? + .to_string_lossy(); + self.filename_cache + .insert(metadata.small_id.clone(), filename.to_string()); + if metadata.small_id == small_id { + return Snapshot::load(file_path.as_path(), storage_data_path); + } + } + Err(DbError::SnapshotRegistryError( + "Snapshot not found".to_string(), + )) + } + } + + fn dir(&self) -> PathBuf { + self.dir.clone() + } +} + +impl Drop for LocalRegistry { + fn drop(&mut self) { + // remove exclusive lock on lockfile + let lock_file_path = self.dir.join(LOCAL_REGISTRY_LOCKFILE); + if let Ok(lock_file) = fs::OpenOptions::new() + .read(true) + .write(true) + .open(&lock_file_path) + { + let _ = lock_file.unlock(); + } + } +} diff --git a/crates/snapshot/src/registry/mod.rs b/crates/snapshot/src/registry/mod.rs new file mode 100644 index 0000000..6513d97 --- /dev/null +++ b/crates/snapshot/src/registry/mod.rs @@ -0,0 +1,32 @@ +use std::path::{Path, PathBuf}; + +use defs::DbError; +pub mod constants; +pub mod local; +use crate::{VectorDbRestore, metadata::Metadata}; + +pub type SnapshotMetaPage = Vec; + +pub const INFINITY_LIMIT: usize = 100000; +pub const NO_OFFSET: usize = 0; + +pub trait SnapshotRegistry: Send + Sync { + fn add_snapshot(&mut self, snapshot_path: &Path) -> Result; + + fn list_snapshots(&mut self, limit: usize, offset: usize) -> Result; + fn get_latest_snapshot(&mut self) -> Result; + + fn get_metadata(&mut self, small_id: String) -> Result; + fn remove_snapshot(&mut self, small_id: String) -> Result; + + fn load( + &mut self, + small_id: String, + storage_data_path: &Path, + ) -> Result; + fn dir(&self) -> PathBuf; + + // in the future this could be used to maybe move an old/stale snapshot to cold storage or to a remote registry + fn mark_dead(&mut self, small_id: String) -> Result; // current behaviour is to call remove_snapshot; + fn list_alive_snapshots(&mut self) -> Result; // current behaviour is to call list_snapshots; +} diff --git a/crates/snapshot/src/util.rs b/crates/snapshot/src/util.rs new file mode 100644 index 0000000..7d0510e --- /dev/null +++ b/crates/snapshot/src/util.rs @@ -0,0 +1,148 @@ +use data_encoding::HEXLOWER; +use sha2::{Digest, Sha256}; +use std::fs::File; +use std::io::{BufReader, Error, Read}; +use std::path::PathBuf; + +use defs::{DbError, Magic}; +use flate2::{Compression, write::GzEncoder}; +use std::{io::Write, path::Path}; +use tar::Builder; +use uuid::Uuid; + +type BinFileContent = (Magic, Vec); + +#[inline] +pub fn metadata_filename(id: &Uuid) -> String { + format!("{}-index-meta.bin", id) +} + +#[inline] +pub fn topology_filename(id: &Uuid) -> String { + format!("{}-index-topo.bin", id) +} + +// source: https://stackoverflow.com/questions/69787906/how-to-hash-a-binary-file-in-rust +pub fn sha256_digest(path: &PathBuf) -> Result { + let input = File::open(path)?; + let mut reader = BufReader::new(input); + + let digest = { + let mut hasher = Sha256::new(); + let mut buffer = [0; 1024]; + loop { + let count = reader.read(&mut buffer)?; + if count == 0 { + break; + } + hasher.update(&buffer[..count]); + } + hasher.finalize() + }; + Ok(HEXLOWER.encode(digest.as_ref())) +} + +pub fn save_index_metadata( + path: &Path, + uuid: Uuid, + bytes: &[u8], + magic: &Magic, +) -> Result { + let file_name = metadata_filename(&uuid); + let metadata_file_path = path.join(file_name); + + let mut file = std::fs::File::create(metadata_file_path.clone()) + .map_err(|e| DbError::SnapshotError(format!("Could not create metadata file: {}", e)))?; + + file.write_all(magic) + .map_err(|e| DbError::SnapshotError(format!("Could not write metadata file: {}", e)))?; + file.write_all(&bytes.len().to_le_bytes()) + .map_err(|e| DbError::SnapshotError(format!("Could not write metadata file: {}", e)))?; + file.write_all(bytes) + .map_err(|e| DbError::SnapshotError(format!("Could not write metadata file: {}", e)))?; + + Ok(metadata_file_path) +} + +pub fn save_topology( + path: &Path, + uuid: Uuid, + bytes: &[u8], + magic: &Magic, +) -> Result { + let file_name = topology_filename(&uuid); + let topology_file_path = path.join(file_name); + + let mut file = std::fs::File::create(topology_file_path.clone()) + .map_err(|e| DbError::SnapshotError(format!("Could not create topology file: {}", e)))?; + + file.write_all(magic) + .map_err(|e| DbError::SnapshotError(format!("Could not write topology file: {}", e)))?; + file.write_all(&bytes.len().to_le_bytes()) + .map_err(|e| DbError::SnapshotError(format!("Could not write topology file: {}", e)))?; + file.write_all(bytes) + .map_err(|e| DbError::SnapshotError(format!("Could not write topology file: {}", e)))?; + + Ok(topology_file_path) +} + +pub fn compress_archive(path: &Path, files: &[&Path]) -> Result<(), Error> { + let tar_gz = File::create(path)?; + let enc = GzEncoder::new(tar_gz, Compression::default()); + let mut tar = Builder::new(enc); + + for file in files { + let rel_path = file.file_name().unwrap(); + let mut f = File::open(file)?; + tar.append_file(rel_path, &mut f)?; + } + + tar.into_inner()?; + Ok(()) +} + +pub fn read_index_topology(path: &Path) -> Result { + let mut file = File::open(path) + .map_err(|e| DbError::SnapshotError(format!("Couldn't open topology file: {}", e)))?; + + let mut magic = Magic::default(); + file.read_exact(&mut magic).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read magic from topology file: {}", e)) + })?; + + let mut len_bytes = [0u8; size_of::()]; + file.read_exact(&mut len_bytes).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read length from topology file: {}", e)) + })?; + let len = usize::from_le_bytes(len_bytes); + + let mut bytes = vec![0u8; len]; + file.read_exact(&mut bytes).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read bytes from topology file: {}", e)) + })?; + + Ok((magic, bytes)) +} + +pub fn read_index_metadata(path: &Path) -> Result { + let mut file = File::open(path) + .map_err(|e| DbError::SnapshotError(format!("Couldn't open metadata file: {}", e)))?; + + let mut magic = Magic::default(); + file.read_exact(&mut magic).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read magic from metadata file: {}", e)) + })?; + + let mut len_bytes = [0u8; size_of::()]; + file.read_exact(&mut len_bytes).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read length from metadata file: {}", e)) + })?; + + let len = usize::from_le_bytes(len_bytes); + let mut bytes = vec![0u8; len]; + file.read_exact(&mut bytes).map_err(|e| { + DbError::SnapshotError(format!("Couldn't read bytes from metadata file: {}", e)) + })?; + + Ok((magic, bytes)) +} diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index c786373..4d2afb3 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -9,6 +9,9 @@ license.workspace = true [dependencies] bincode.workspace = true defs.workspace = true +flate2 = "1.1.5" rocksdb.workspace = true +serde.workspace = true +tar = "0.4.44" tempfile.workspace = true uuid.workspace = true diff --git a/crates/storage/src/checkpoint.rs b/crates/storage/src/checkpoint.rs new file mode 100644 index 0000000..09827fd --- /dev/null +++ b/crates/storage/src/checkpoint.rs @@ -0,0 +1,51 @@ +use crate::StorageType; +use crate::in_memory::INMEMORY_CHECKPOINT_FILENAME_MARKER; +use crate::rocks_db::ROCKSDB_CHECKPOINT_FILENAME_MARKER; +use defs::DbError; +use std::path::{Path, PathBuf}; + +impl StorageType { + #[inline] + pub fn checkpoint_filename_marker(&self) -> &str { + match self { + StorageType::InMemory => INMEMORY_CHECKPOINT_FILENAME_MARKER, + StorageType::RocksDb => ROCKSDB_CHECKPOINT_FILENAME_MARKER, + } + } +} + +pub struct StorageCheckpoint { + pub path: PathBuf, + pub storage_type: StorageType, +} + +impl StorageCheckpoint { + pub fn open(path: &Path) -> Result { + let filename = path + .file_name() + .ok_or_else(|| DbError::StorageCheckpointError("Invalid filename".to_string()))? + .to_str() + .ok_or_else(|| { + DbError::StorageCheckpointError("Invalid UTF-8 in filename".to_string()) + })? + .to_owned(); + let marker = filename + .split_once("-") + .ok_or_else(|| DbError::StorageCheckpointError("Invalid filename".to_string()))? + .0; + + let storage_type = match marker { + ROCKSDB_CHECKPOINT_FILENAME_MARKER => StorageType::RocksDb, + _ => { + return Err(DbError::StorageCheckpointError( + "Invalid storage type".to_string(), + )); + } + }; + + Ok(StorageCheckpoint { + path: path.to_path_buf(), + storage_type, + }) + } +} diff --git a/crates/storage/src/in_memory.rs b/crates/storage/src/in_memory.rs index 5190082..647627d 100644 --- a/crates/storage/src/in_memory.rs +++ b/crates/storage/src/in_memory.rs @@ -1,5 +1,9 @@ -use crate::{StorageEngine, VectorPage}; +use crate::StorageType; +use crate::{StorageEngine, VectorPage, checkpoint::StorageCheckpoint}; use defs::{DbError, DenseVector, Payload, PointId}; +use std::path::{Path, PathBuf}; + +pub const INMEMORY_CHECKPOINT_FILENAME_MARKER: &str = "inmemory"; pub struct MemoryStorage { // define here how MemoryStorage will be defined @@ -41,4 +45,15 @@ impl StorageEngine for MemoryStorage { fn list_vectors(&self, _offset: PointId, _limit: usize) -> Result, DbError> { Ok(None) } + + fn checkpoint_at(&self, _path: &Path) -> Result { + Ok(StorageCheckpoint { + path: PathBuf::default(), + storage_type: StorageType::InMemory, + }) + } + + fn restore_checkpoint(&mut self, _checkpoint: &StorageCheckpoint) -> Result<(), DbError> { + Ok(()) + } } diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index f7c067e..8228f72 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -1,8 +1,9 @@ +use crate::rocks_db::RocksDbStorage; use defs::{DbError, DenseVector, Payload, PointId}; -use std::path::PathBuf; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; use std::sync::Arc; - -use crate::rocks_db::RocksDbStorage; +pub mod checkpoint; pub type VectorPage = (Vec<(PointId, DenseVector)>, PointId); @@ -18,12 +19,18 @@ pub trait StorageEngine: Send + Sync { fn delete_point(&self, id: PointId) -> Result<(), DbError>; fn contains_point(&self, id: PointId) -> Result; fn list_vectors(&self, offset: PointId, limit: usize) -> Result, DbError>; + + fn checkpoint_at(&self, path: &Path) -> Result; + fn restore_checkpoint( + &mut self, + checkpoint: &checkpoint::StorageCheckpoint, + ) -> Result<(), DbError>; } pub mod in_memory; pub mod rocks_db; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] pub enum StorageType { InMemory, RocksDb, diff --git a/crates/storage/src/rocks_db.rs b/crates/storage/src/rocks_db.rs index c85f7a7..7d184d4 100644 --- a/crates/storage/src/rocks_db.rs +++ b/crates/storage/src/rocks_db.rs @@ -1,26 +1,44 @@ // Rewrite needed -use crate::{StorageEngine, VectorPage}; +use crate::StorageType; +use crate::{StorageEngine, VectorPage, checkpoint::StorageCheckpoint}; use bincode::{deserialize, serialize}; use defs::{DbError, DenseVector, Payload, Point, PointId}; +use flate2::{Compression, read::GzDecoder, write::GzEncoder}; use rocksdb::{DB, Error, Options}; -use std::path::PathBuf; +use std::{ + fs::File, + path::{Path, PathBuf}, +}; +use tar::{Archive, Builder}; +use tempfile::tempdir; //TODO: Implement RocksDbStorage with necessary fields and implementations //TODO: Optimize the basic design pub struct RocksDbStorage { pub path: PathBuf, - pub db: DB, + pub db: Option, } pub enum RocksDBStorageError { RocksDBError(Error), - SerializationError, } +pub const ROCKSDB_CHECKPOINT_FILENAME_MARKER: &str = "rocksdb"; + impl RocksDbStorage { // Creates new db or switches to existing db pub fn new(path: impl Into) -> Result { + let converted_path = path.into(); + let db = Self::initialize_db(&converted_path)?; + + Ok(RocksDbStorage { + path: converted_path, + db: Some(db), + }) + } + + fn initialize_db(path: &Path) -> Result { // Initialize a db at the given location let mut options = Options::default(); @@ -30,15 +48,8 @@ impl RocksDbStorage { options.create_if_missing(true); - let converted_path = path.into(); - - let db = DB::open(&options, converted_path.clone()) - .map_err(|e| DbError::StorageError(e.into_string()))?; - - Ok(RocksDbStorage { - path: converted_path, - db, - }) + let db = DB::open(&options, path).map_err(|e| DbError::StorageError(e.into_string()))?; + Ok(db) } pub fn get_current_path(&self) -> PathBuf { @@ -60,7 +71,12 @@ impl StorageEngine for RocksDbStorage { payload, }; let value = serialize(&point).map_err(|e| DbError::SerializationError(e.to_string()))?; - match self.db.put(key.as_bytes(), value.as_slice()) { + match self + .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .put(key.as_bytes(), value.as_slice()) + { Ok(_) => Ok(()), Err(e) => Err(DbError::StorageError(e.into_string())), } @@ -69,9 +85,16 @@ impl StorageEngine for RocksDbStorage { fn contains_point(&self, id: PointId) -> Result { // Efficient lookup inspired from https://github.com/facebook/rocksdb/issues/11586#issuecomment-1890429488 let key = id.to_string(); - if self.db.key_may_exist(key.clone()) { + if self + .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .key_may_exist(key.clone()) + { let key_exist = self .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? .get(key) .map_err(|e| DbError::StorageError(e.into_string()))? .is_some(); @@ -84,6 +107,8 @@ impl StorageEngine for RocksDbStorage { fn delete_point(&self, id: PointId) -> Result<(), DbError> { let key = id.to_string(); self.db + .as_ref() + .ok_or(DbError::StorageInitializationError)? .delete(key) .map_err(|e| DbError::StorageError(e.into_string()))?; @@ -94,6 +119,8 @@ impl StorageEngine for RocksDbStorage { let key = id.to_string(); let Some(value_serialized) = self .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? .get(key) .map_err(|e| DbError::StorageError(e.into_string()))? else { @@ -110,6 +137,8 @@ impl StorageEngine for RocksDbStorage { let key = id.to_string(); let Some(value_serialized) = self .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? .get(key) .map_err(|e| DbError::StorageError(e.into_string()))? else { @@ -128,10 +157,14 @@ impl StorageEngine for RocksDbStorage { } let mut result = Vec::with_capacity(limit); - let iter = self.db.iterator(rocksdb::IteratorMode::From( - offset.to_string().as_bytes(), - rocksdb::Direction::Forward, - )); + let iter = self + .db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .iterator(rocksdb::IteratorMode::From( + offset.to_string().as_bytes(), + rocksdb::Direction::Forward, + )); let mut last_id = offset; for item in iter { @@ -152,6 +185,124 @@ impl StorageEngine for RocksDbStorage { } Ok(Some((result, last_id))) } + + fn checkpoint_at(&self, path: &Path) -> Result { + // flush db first for durability + self.db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .flush() + .map_err(|e| { + DbError::StorageCheckpointError(format!( + "Failed to flush database: {}", + e.into_string() + )) + })?; + + // filename is rocksdb-{uuid}.tar.gz + let checkpoint_filename = format!( + "{}-{}.tar.gz", + ROCKSDB_CHECKPOINT_FILENAME_MARKER, + uuid::Uuid::new_v4() + ); + let checkpoint_path = path.join(checkpoint_filename); + + let temp_dir_parent = tempdir().unwrap(); + let temp_dir = temp_dir_parent.path().join("checkpoint"); + + let db_ref = self + .db + .as_ref() + .ok_or(DbError::StorageInitializationError)?; + let checkpoint = rocksdb::checkpoint::Checkpoint::new(db_ref) + .map_err(|e| DbError::StorageCheckpointError(e.into_string()))?; + checkpoint + .create_checkpoint(temp_dir.clone()) + .map_err(|e| DbError::StorageCheckpointError(e.into_string()))?; + + // compress the checkpoint into an archive + let tar_gz = File::create(checkpoint_path.clone()).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't create tar archive file: {}", e)) + })?; + let enc = GzEncoder::new(tar_gz, Compression::default()); + let mut archive = Builder::new(enc); + + archive.append_dir_all("", temp_dir).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't append directory to archive: {}", e)) + })?; + + let enc = archive.into_inner().map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't compress tar archive: {}", e)) + })?; + + enc.finish().map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't compress tar archive: {}", e)) + })?; + + Ok(StorageCheckpoint { + path: checkpoint_path, + storage_type: crate::StorageType::RocksDb, + }) + } + + fn restore_checkpoint(&mut self, checkpoint: &StorageCheckpoint) -> Result<(), DbError> { + // enforce storage type + if checkpoint.storage_type != StorageType::RocksDb { + return Err(DbError::StorageCheckpointError( + "Invalid storage type".to_string(), + )); + } + // enforce filename marker - should have been enforced during StoraegCheckpoint::open anyway + let checkpoint_filename = checkpoint + .path + .file_name() + .ok_or(DbError::StorageCheckpointError( + "Could not read checkpoint filename".to_string(), + ))? + .to_str() + .ok_or(DbError::StorageCheckpointError( + "Could not read checkpoint filename".to_string(), + ))?; + if !checkpoint_filename.ends_with(".tar.gz") + || !checkpoint_filename.starts_with(ROCKSDB_CHECKPOINT_FILENAME_MARKER) + { + return Err(DbError::StorageCheckpointError( + "Invalid filename".to_string(), + )); + } + + let tar_gz = File::open(&checkpoint.path).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't open rocksdb checkpoint: {}", e)) + })?; + let tar = GzDecoder::new(tar_gz); + let mut archive = Archive::new(tar); + + // remove existing stuff in data path + self.db + .as_ref() + .ok_or(DbError::StorageInitializationError)? + .cancel_all_background_work(true); + // drop db early + self.db = None; + + std::fs::remove_dir_all(&self.path).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't remove existing data: {}", e)) + })?; + + // create new data path + std::fs::create_dir_all(&self.path).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't create data path: {}", e)) + })?; + + archive.unpack(&self.path).map_err(|e| { + DbError::StorageCheckpointError(format!("Couldn't unpack tar.gz archive: {}", e)) + })?; + + // reinitialize db + self.db = Some(Self::initialize_db(&self.path)?); + + Ok(()) + } } #[cfg(test)] @@ -172,7 +323,7 @@ mod tests { #[test] fn test_new_rocksdb_storage() { let (db, temp_dir) = create_test_db(); - assert_eq!(db.get_current_path(), PathBuf::from(temp_dir.path())); + assert_eq!(db.get_current_path(), temp_dir.path()); } #[test] @@ -264,4 +415,37 @@ mod tests { assert_eq!(db.get_payload(id).unwrap(), None); } + + #[test] + fn test_create_and_load_checkpoint() { + let (mut db, temp_dir) = create_test_db(); + + let id1 = Uuid::new_v4(); + let id2 = Uuid::new_v4(); + + let vector = Some(vec![0.1, 0.2, 0.3]); + let payload = Some(Payload { + content_type: ContentType::Text, + content: "Test".to_string(), + }); + + assert!( + db.insert_point(id1, vector.clone(), payload.clone()) + .is_ok() + ); + + let checkpoint = db + .checkpoint_at(temp_dir.path()) + .expect("Failed to create checkpoint"); + + assert!( + db.insert_point(id2, vector.clone(), payload.clone()) + .is_ok() + ); + + db.restore_checkpoint(&checkpoint).unwrap(); + + assert!(db.contains_point(id1).unwrap()); + assert!(!db.contains_point(id2).unwrap()); + } } From 44c8db7f41a613825790c818074f8877dc101e86 Mon Sep 17 00:00:00 2001 From: Tanmay Arya Date: Wed, 28 Jan 2026 20:01:45 +0530 Subject: [PATCH 2/4] add(snapshots): implement snapshots for hnsw index --- crates/api/src/lib.rs | 2 +- crates/defs/src/types.rs | 2 +- crates/index/src/hnsw/index.rs | 2 +- crates/index/src/hnsw/mod.rs | 4 +- crates/index/src/hnsw/serialize.rs | 139 +++++++++++++++++++++++++++-- crates/index/src/hnsw/types.rs | 2 + crates/index/src/kd_tree/mod.rs | 2 +- crates/snapshot/src/lib.rs | 3 +- 8 files changed, 143 insertions(+), 13 deletions(-) diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 3395b26..ea42a3b 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -233,7 +233,7 @@ mod tests { let temp_dir = tempdir().unwrap(); let config = DbConfig { storage_type: StorageType::RocksDb, - index_type: IndexType::Flat, + index_type: IndexType::HNSW, data_path: temp_dir.path().to_path_buf(), dimension: 3, similarity: Similarity::Cosine, diff --git a/crates/defs/src/types.rs b/crates/defs/src/types.rs index 03f3f47..f6b7c40 100644 --- a/crates/defs/src/types.rs +++ b/crates/defs/src/types.rs @@ -47,7 +47,7 @@ pub struct IndexedVector { pub vector: DenseVector, } -#[derive(Debug, Deserialize, Copy, Clone)] +#[derive(Debug, Serialize , Deserialize, Copy, Clone)] pub enum Similarity { Euclidean, Manhattan, diff --git a/crates/index/src/hnsw/index.rs b/crates/index/src/hnsw/index.rs index c781d7e..80aeb1f 100644 --- a/crates/index/src/hnsw/index.rs +++ b/crates/index/src/hnsw/index.rs @@ -18,7 +18,7 @@ pub struct HnswIndex { // Default query beam width (ef); recommended ef ≥ k at query time pub ef: usize, // In-memory vector cache owned by the index - cache: HashMap, + pub cache: HashMap, // Fixed metric for this index; used consistently in insert and search pub similarity: Similarity, } diff --git a/crates/index/src/hnsw/mod.rs b/crates/index/src/hnsw/mod.rs index 35ae15f..0cedc17 100644 --- a/crates/index/src/hnsw/mod.rs +++ b/crates/index/src/hnsw/mod.rs @@ -5,8 +5,10 @@ pub mod index; pub mod search; pub mod serialize; pub mod types; - +use defs::Magic; pub use index::HnswIndex; +pub const HNSW_MAGIC_BYTES: Magic = [0x02, 0x01, 0x03, 0x00]; + #[cfg(test)] mod tests; diff --git a/crates/index/src/hnsw/serialize.rs b/crates/index/src/hnsw/serialize.rs index f632c82..b4706ab 100644 --- a/crates/index/src/hnsw/serialize.rs +++ b/crates/index/src/hnsw/serialize.rs @@ -1,21 +1,146 @@ -use defs::DbError; +use std::collections::HashMap; + +use defs::{DbError, Dimension, PointId, Similarity}; +use serde::{Deserialize, Serialize}; use storage::StorageEngine; -use crate::{IndexSnapshot, SerializableIndex, hnsw::HnswIndex}; +use crate::{IndexSnapshot, IndexType, SerializableIndex, hnsw::{HNSW_MAGIC_BYTES, HnswIndex, types::{LevelGenerator, Node, PointIndexation}}}; + +#[repr(packed)] +#[derive(Serialize, Deserialize)] +pub struct HnswMetadataPack{ + pub ef_construction: usize, + pub data_dimension: Dimension, + pub ef: usize, + pub similarity: Similarity, +} + +#[derive(Serialize, Deserialize)] +pub struct HnswIndexPack { + pub max_connections: usize, + pub max_connections_0: usize, + pub max_layer: usize, + pub points_by_layer: Vec>, + pub nodes: Vec, + pub entry_point: Option, + pub level_scale: f64, +} + impl SerializableIndex for HnswIndex { fn serialize_topology(&self) -> Result, DbError> { - return Err(DbError::SerializationError("not implemented".to_string())); + let mut buffer = Vec::new(); + + let nodes: Vec = self.index.nodes.values().cloned().collect(); + let index_pack = HnswIndexPack { + max_connections: self.index.max_connections, + max_connections_0: self.index.max_connections_0, + max_layer: self.index.max_layer, + points_by_layer: self.index.points_by_layer.clone(), + nodes, + entry_point: self.index.entry_point, + level_scale: self.index.level_generator.level_scale, + }; + + let index_bytes = bincode::serialize(&index_pack).map_err(|e| DbError::SerializationError(e.to_string()))?; + buffer.extend(index_bytes); + + return Ok(buffer); } + fn serialize_metadata(&self) -> Result, DbError> { - return Err(DbError::SerializationError("not implemented".to_string())); + let mut buffer = Vec::new(); + let index_pack = HnswMetadataPack { + ef_construction: self.ef_construction, + data_dimension: self.data_dimension, + ef: self.ef, + similarity: self.similarity + }; + + let metadata_bytes = bincode::serialize(&index_pack).map_err(|e| DbError::SerializationError(e.to_string()))?; + buffer.extend(metadata_bytes); + return Ok(buffer); } fn snapshot(&self) -> Result { - return Err(DbError::SerializationError("not implemented".to_string())); + let topology_bytes = self.serialize_topology()?; + let metadata_bytes = self.serialize_metadata()?; + Ok(IndexSnapshot { + index_type: crate::IndexType::HNSW, + magic: HNSW_MAGIC_BYTES, + topology_b: topology_bytes, + metadata_b: metadata_bytes, + }) + } + + fn populate_vectors(&mut self, storage: &dyn StorageEngine) -> Result<(), DbError> { + // assumes index topology is restored + for id in self.index.nodes.keys() { + let vec = storage.get_vector(*id)?.ok_or(DbError::SerializationError(format!("Failed to locate vector for id: {} in storage", id)))?; + self.cache.insert(*id, vec); + } + Ok(()) } +} + + + +impl HnswIndex { + pub fn deserialize( + IndexSnapshot { + index_type, + magic, + topology_b, + metadata_b, + }: &IndexSnapshot, + ) -> Result { + if index_type != &IndexType::HNSW { + return Err(DbError::SerializationError( + "Invalid index type".to_string(), + )); + } + + if magic != &HNSW_MAGIC_BYTES { + return Err(DbError::SerializationError( + "Invalid magic bytes".to_string(), + )); + } + + let metadata: HnswMetadataPack = bincode::deserialize(metadata_b).map_err(|e| { + DbError::SerializationError(format!("Failed to deserialize HNSW Metadata: {}", e)) + })?; + + let index_pack : HnswIndexPack = bincode::deserialize(topology_b).map_err(|e| { + DbError::SerializationError(format!("Failed to deserialize HNSW Index: {}", e)) + })?; + + let mut hnsw_index_restored = PointIndexation { + max_connections: index_pack.max_connections, + max_connections_0: index_pack.max_connections_0, + max_layer: index_pack.max_layer, + points_by_layer: index_pack.points_by_layer, + entry_point: index_pack.entry_point, + nodes: HashMap::new(), + level_generator: LevelGenerator{ + level_scale: index_pack.level_scale + } + }; + + // restore nodes hashmap + for i in index_pack.nodes { + hnsw_index_restored.nodes.insert(i.id, i); + } + + + let hnsw = HnswIndex { + ef_construction : metadata.ef_construction, + data_dimension: metadata.data_dimension, + ef: metadata.ef, + cache: HashMap::new(), + similarity: metadata.similarity, + index: hnsw_index_restored + }; - fn populate_vectors(&mut self, _storage: &dyn StorageEngine) -> Result<(), DbError> { - return Err(DbError::SerializationError("not implemented".to_string())); + Ok(hnsw) } } diff --git a/crates/index/src/hnsw/types.rs b/crates/index/src/hnsw/types.rs index c16e04b..c52d4ff 100644 --- a/crates/index/src/hnsw/types.rs +++ b/crates/index/src/hnsw/types.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use defs::PointId; use rand::Rng; +use serde::{Deserialize, Serialize}; // Compact storage for layered points and adjacency used by `HnswIndex`. pub struct PointIndexation { @@ -22,6 +23,7 @@ pub struct PointIndexation { } // Node with highest level and per-level neighbor lists +#[derive(Serialize, Deserialize, Clone)] pub struct Node { pub id: PointId, // Highest level (0-based; level 0 is the base layer) diff --git a/crates/index/src/kd_tree/mod.rs b/crates/index/src/kd_tree/mod.rs index 52aa7b9..4a29b61 100644 --- a/crates/index/src/kd_tree/mod.rs +++ b/crates/index/src/kd_tree/mod.rs @@ -8,4 +8,4 @@ pub mod types; #[cfg(test)] mod tests; -pub const KD_TREE_MAGIC_BYTES: Magic = [0x00, 0x00, 0x00, 0x00]; +pub const KD_TREE_MAGIC_BYTES: Magic = [0x00, 0x01, 0x02, 0x00]; diff --git a/crates/snapshot/src/lib.rs b/crates/snapshot/src/lib.rs index 32311d2..c8392e6 100644 --- a/crates/snapshot/src/lib.rs +++ b/crates/snapshot/src/lib.rs @@ -15,7 +15,7 @@ use chrono::{DateTime, Local}; use defs::DbError; use flate2::read::GzDecoder; use index::{ - IndexSnapshot, IndexType, VectorIndex, flat::index::FlatIndex, kd_tree::index::KDTree, + IndexSnapshot, IndexType, VectorIndex, flat::index::FlatIndex, hnsw::HnswIndex, kd_tree::index::KDTree }; use semver::Version; use std::{ @@ -281,6 +281,7 @@ impl Snapshot { let vector_index: Arc> = match manifest.index_type { IndexType::Flat => Arc::new(RwLock::new(FlatIndex::deserialize(&index_snapshot)?)), IndexType::KDTree => Arc::new(RwLock::new(KDTree::deserialize(&index_snapshot)?)), + IndexType::HNSW => Arc::new(RwLock::new(HnswIndex::deserialize(&index_snapshot)?)), _ => return Err(DbError::SnapshotError("Unsupported index type".to_string())), }; From 9f25ffd3dafea5b73b2a06a176ca0c3a1d720bf8 Mon Sep 17 00:00:00 2001 From: Tanmay Arya Date: Wed, 28 Jan 2026 20:27:11 +0530 Subject: [PATCH 3/4] format code --- Cargo.toml | 2 +- crates/api/Cargo.toml | 2 +- crates/defs/src/types.rs | 2 +- crates/index/Cargo.toml | 4 +-- crates/index/src/hnsw/serialize.rs | 41 ++++++++++++++++++------------ crates/snapshot/src/lib.rs | 3 ++- 6 files changed, 32 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ca5b979..3c94654 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,6 @@ grpc = { path = "crates/grpc" } http = { path = "crates/http" } index = { path = "crates/index" } server = { path = "crates/server" } -storage = { path = "crates/storage" } snapshot = { path = "crates/snapshot" } +storage = { path = "crates/storage" } tui = { path = "crates/tui" } diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index fe40536..b6e3e74 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -9,7 +9,7 @@ license.workspace = true [dependencies] defs.workspace = true index.workspace = true +snapshot.workspace = true storage.workspace = true tempfile.workspace = true uuid.workspace = true -snapshot.workspace = true diff --git a/crates/defs/src/types.rs b/crates/defs/src/types.rs index f6b7c40..525df26 100644 --- a/crates/defs/src/types.rs +++ b/crates/defs/src/types.rs @@ -47,7 +47,7 @@ pub struct IndexedVector { pub vector: DenseVector, } -#[derive(Debug, Serialize , Deserialize, Copy, Clone)] +#[derive(Debug, Serialize, Deserialize, Copy, Clone)] pub enum Similarity { Euclidean, Manhattan, diff --git a/crates/index/Cargo.toml b/crates/index/Cargo.toml index 1e92124..b776b7a 100644 --- a/crates/index/Cargo.toml +++ b/crates/index/Cargo.toml @@ -7,9 +7,9 @@ edition.workspace = true license.workspace = true [dependencies] +bincode.workspace = true defs.workspace = true rand.workspace = true -uuid.workspace = true -bincode.workspace = true serde.workspace = true storage.workspace = true +uuid.workspace = true diff --git a/crates/index/src/hnsw/serialize.rs b/crates/index/src/hnsw/serialize.rs index b4706ab..153e913 100644 --- a/crates/index/src/hnsw/serialize.rs +++ b/crates/index/src/hnsw/serialize.rs @@ -4,11 +4,17 @@ use defs::{DbError, Dimension, PointId, Similarity}; use serde::{Deserialize, Serialize}; use storage::StorageEngine; -use crate::{IndexSnapshot, IndexType, SerializableIndex, hnsw::{HNSW_MAGIC_BYTES, HnswIndex, types::{LevelGenerator, Node, PointIndexation}}}; +use crate::{ + IndexSnapshot, IndexType, SerializableIndex, + hnsw::{ + HNSW_MAGIC_BYTES, HnswIndex, + types::{LevelGenerator, Node, PointIndexation}, + }, +}; #[repr(packed)] #[derive(Serialize, Deserialize)] -pub struct HnswMetadataPack{ +pub struct HnswMetadataPack { pub ef_construction: usize, pub data_dimension: Dimension, pub ef: usize, @@ -26,7 +32,6 @@ pub struct HnswIndexPack { pub level_scale: f64, } - impl SerializableIndex for HnswIndex { fn serialize_topology(&self) -> Result, DbError> { let mut buffer = Vec::new(); @@ -42,7 +47,8 @@ impl SerializableIndex for HnswIndex { level_scale: self.index.level_generator.level_scale, }; - let index_bytes = bincode::serialize(&index_pack).map_err(|e| DbError::SerializationError(e.to_string()))?; + let index_bytes = bincode::serialize(&index_pack) + .map_err(|e| DbError::SerializationError(e.to_string()))?; buffer.extend(index_bytes); return Ok(buffer); @@ -54,10 +60,11 @@ impl SerializableIndex for HnswIndex { ef_construction: self.ef_construction, data_dimension: self.data_dimension, ef: self.ef, - similarity: self.similarity + similarity: self.similarity, }; - let metadata_bytes = bincode::serialize(&index_pack).map_err(|e| DbError::SerializationError(e.to_string()))?; + let metadata_bytes = bincode::serialize(&index_pack) + .map_err(|e| DbError::SerializationError(e.to_string()))?; buffer.extend(metadata_bytes); return Ok(buffer); } @@ -76,15 +83,18 @@ impl SerializableIndex for HnswIndex { fn populate_vectors(&mut self, storage: &dyn StorageEngine) -> Result<(), DbError> { // assumes index topology is restored for id in self.index.nodes.keys() { - let vec = storage.get_vector(*id)?.ok_or(DbError::SerializationError(format!("Failed to locate vector for id: {} in storage", id)))?; + let vec = storage + .get_vector(*id)? + .ok_or(DbError::SerializationError(format!( + "Failed to locate vector for id: {} in storage", + id + )))?; self.cache.insert(*id, vec); } Ok(()) } } - - impl HnswIndex { pub fn deserialize( IndexSnapshot { @@ -110,7 +120,7 @@ impl HnswIndex { DbError::SerializationError(format!("Failed to deserialize HNSW Metadata: {}", e)) })?; - let index_pack : HnswIndexPack = bincode::deserialize(topology_b).map_err(|e| { + let index_pack: HnswIndexPack = bincode::deserialize(topology_b).map_err(|e| { DbError::SerializationError(format!("Failed to deserialize HNSW Index: {}", e)) })?; @@ -121,9 +131,9 @@ impl HnswIndex { points_by_layer: index_pack.points_by_layer, entry_point: index_pack.entry_point, nodes: HashMap::new(), - level_generator: LevelGenerator{ - level_scale: index_pack.level_scale - } + level_generator: LevelGenerator { + level_scale: index_pack.level_scale, + }, }; // restore nodes hashmap @@ -131,14 +141,13 @@ impl HnswIndex { hnsw_index_restored.nodes.insert(i.id, i); } - let hnsw = HnswIndex { - ef_construction : metadata.ef_construction, + ef_construction: metadata.ef_construction, data_dimension: metadata.data_dimension, ef: metadata.ef, cache: HashMap::new(), similarity: metadata.similarity, - index: hnsw_index_restored + index: hnsw_index_restored, }; Ok(hnsw) diff --git a/crates/snapshot/src/lib.rs b/crates/snapshot/src/lib.rs index c8392e6..0833057 100644 --- a/crates/snapshot/src/lib.rs +++ b/crates/snapshot/src/lib.rs @@ -15,7 +15,8 @@ use chrono::{DateTime, Local}; use defs::DbError; use flate2::read::GzDecoder; use index::{ - IndexSnapshot, IndexType, VectorIndex, flat::index::FlatIndex, hnsw::HnswIndex, kd_tree::index::KDTree + IndexSnapshot, IndexType, VectorIndex, flat::index::FlatIndex, hnsw::HnswIndex, + kd_tree::index::KDTree, }; use semver::Version; use std::{ From 756766a2232ab72b271b7f715b9e3aefcbeee224 Mon Sep 17 00:00:00 2001 From: Tanmay Arya Date: Wed, 28 Jan 2026 21:39:52 +0530 Subject: [PATCH 4/4] remove repr(packed) because bincode already packs structs --- crates/api/src/lib.rs | 4 ++-- crates/index/src/hnsw/serialize.rs | 5 ++--- crates/snapshot/src/lib.rs | 1 - 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index ea42a3b..2e07719 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -233,7 +233,7 @@ mod tests { let temp_dir = tempdir().unwrap(); let config = DbConfig { storage_type: StorageType::RocksDb, - index_type: IndexType::HNSW, + index_type: IndexType::Flat, data_path: temp_dir.path().to_path_buf(), dimension: 3, similarity: Similarity::Cosine, @@ -359,7 +359,7 @@ mod tests { // Search with limit 3 let query = vec![0.0, 0.0, 0.0]; - let results = db.search(query, Similarity::Cosine, 3).unwrap(); + let results = db.search(query, Similarity::Euclidean, 3).unwrap(); assert_eq!(results.len(), 3); } diff --git a/crates/index/src/hnsw/serialize.rs b/crates/index/src/hnsw/serialize.rs index 153e913..4bf8b43 100644 --- a/crates/index/src/hnsw/serialize.rs +++ b/crates/index/src/hnsw/serialize.rs @@ -12,7 +12,6 @@ use crate::{ }, }; -#[repr(packed)] #[derive(Serialize, Deserialize)] pub struct HnswMetadataPack { pub ef_construction: usize, @@ -51,7 +50,7 @@ impl SerializableIndex for HnswIndex { .map_err(|e| DbError::SerializationError(e.to_string()))?; buffer.extend(index_bytes); - return Ok(buffer); + Ok(buffer) } fn serialize_metadata(&self) -> Result, DbError> { @@ -66,7 +65,7 @@ impl SerializableIndex for HnswIndex { let metadata_bytes = bincode::serialize(&index_pack) .map_err(|e| DbError::SerializationError(e.to_string()))?; buffer.extend(metadata_bytes); - return Ok(buffer); + Ok(buffer) } fn snapshot(&self) -> Result { diff --git a/crates/snapshot/src/lib.rs b/crates/snapshot/src/lib.rs index 0833057..bb1c7ab 100644 --- a/crates/snapshot/src/lib.rs +++ b/crates/snapshot/src/lib.rs @@ -283,7 +283,6 @@ impl Snapshot { IndexType::Flat => Arc::new(RwLock::new(FlatIndex::deserialize(&index_snapshot)?)), IndexType::KDTree => Arc::new(RwLock::new(KDTree::deserialize(&index_snapshot)?)), IndexType::HNSW => Arc::new(RwLock::new(HnswIndex::deserialize(&index_snapshot)?)), - _ => return Err(DbError::SnapshotError("Unsupported index type".to_string())), }; vector_index