Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 296 additions & 1 deletion Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ members = [
"crates/http",
"crates/tui",
"crates/grpc",
"crates/snapshot",
]

[workspace.package]
Expand Down Expand Up @@ -50,5 +51,6 @@ grpc = { path = "crates/grpc" }
http = { path = "crates/http" }
index = { path = "crates/index" }
server = { path = "crates/server" }
snapshot = { path = "crates/snapshot" }
storage = { path = "crates/storage" }
tui = { path = "crates/tui" }
1 change: 1 addition & 0 deletions crates/api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ license.workspace = true
[dependencies]
defs.workspace = true
index.workspace = true
snapshot.workspace = true
storage.workspace = true
tempfile.workspace = true
uuid.workspace = true
201 changes: 196 additions & 5 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use defs::{DbError, Dimension, IndexedVector, Similarity};

use defs::{DbError, Dimension, IndexedVector, Similarity, SnapshottableDb};
use defs::{DenseVector, Payload, Point, PointId};
use index::hnsw::HnswIndex;
use std::path::PathBuf;
use index::kd_tree::index::KDTree;
use std::path::{Path, PathBuf};
use tempfile::tempdir;
// use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};

use index::flat::FlatIndex;
use index::flat::index::FlatIndex;
use index::{IndexType, VectorIndex};
use snapshot::Snapshot;
use storage::rocks_db::RocksDbStorage;
use storage::{StorageEngine, StorageType, VectorPage};

Expand Down Expand Up @@ -132,6 +134,31 @@ impl VectorDb {
}
}

impl SnapshottableDb for VectorDb {
fn create_snapshot(&self, dir_path: &Path) -> Result<PathBuf, DbError> {
if !dir_path.is_dir() {
return Err(DbError::SnapshotError(format!(
"Invalid path: {}",
dir_path.display()
)));
}

let index_snapshot = self
.index
.read()
.map_err(|_| DbError::LockError)?
.snapshot()?;

let tempdir = tempdir().unwrap();
let storage_checkpoint = self.storage.checkpoint_at(tempdir.path())?;

let snapshot = Snapshot::new(index_snapshot, storage_checkpoint, self.dimension)?;
let snapshot_path = snapshot.save(dir_path)?;

Ok(snapshot_path)
}
}

#[derive(Debug)]
pub struct DbConfig {
pub storage_type: StorageType,
Expand All @@ -141,6 +168,28 @@ pub struct DbConfig {
pub similarity: Similarity,
}

#[derive(Debug)]
pub struct DbRestoreConfig {
pub data_path: PathBuf,
pub snapshot_path: PathBuf,
}

impl DbRestoreConfig {
pub fn new(data_path: PathBuf, snapshot_path: PathBuf) -> Self {
Self {
data_path,
snapshot_path,
}
}
}

pub fn restore_from_snapshot(config: &DbRestoreConfig) -> Result<VectorDb, DbError> {
// restore the index from the snapshot
let (storage_engine, index, dimensions) =
Snapshot::load(&config.snapshot_path, &config.data_path)?;
Ok(VectorDb::_new(storage_engine, index, dimensions))
}

pub fn init_api(config: DbConfig) -> Result<VectorDb, DbError> {
// Initialize the storage engine
let storage = match config.storage_type {
Expand All @@ -151,11 +200,11 @@ pub fn init_api(config: DbConfig) -> Result<VectorDb, DbError> {
// Initialize the vector index
let index: Arc<RwLock<dyn VectorIndex>> = match config.index_type {
IndexType::Flat => Arc::new(RwLock::new(FlatIndex::new())),
IndexType::KDTree => Arc::new(RwLock::new(KDTree::build_empty(config.dimension))),
IndexType::HNSW => Arc::new(RwLock::new(HnswIndex::new(
config.similarity,
config.dimension,
))),
_ => Arc::new(RwLock::new(FlatIndex::new())),
};

// Init the db
Expand All @@ -172,8 +221,11 @@ mod tests {

// TODO: Add more exhaustive tests

use std::sync::Mutex;

use super::*;
use defs::ContentType;
use snapshot::{engine::SnapshotEngine, registry::local::LocalRegistry};
use tempfile::{TempDir, tempdir};

// Helper function to create a test database
Expand Down Expand Up @@ -377,4 +429,143 @@ mod tests {
let inserted = db.build_index().unwrap();
assert_eq!(inserted, 10);
}

#[test]
fn test_create_and_load_snapshot() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these tests give an overview of how the snapshots api is to be used

let (old_db, temp_dir) = create_test_db();

let v1 = vec![0.0, 1.0, 2.0];
let v2 = vec![3.0, 4.0, 5.0];
let v3 = vec![6.0, 7.0, 8.0];

let id1 = old_db
.insert(
v1.clone(),
Payload {
content_type: ContentType::Text,
content: "test".to_string(),
},
)
.unwrap();

let id2 = old_db
.insert(
v2.clone(),
Payload {
content_type: ContentType::Text,
content: "test".to_string(),
},
)
.unwrap();

let temp_snapshot_dir = tempdir().unwrap();
let snapshot_path = old_db.create_snapshot(temp_snapshot_dir.path()).unwrap();

// insert v3 after snapshot
let id3 = old_db
.insert(
v3.clone(),
Payload {
content_type: ContentType::Text,
content: "test".to_string(),
},
)
.unwrap();

let reload_config = DbRestoreConfig {
data_path: temp_dir.path().to_path_buf(),
snapshot_path,
};

std::mem::drop(old_db);
let loaded_db = restore_from_snapshot(&reload_config).unwrap();

assert!(loaded_db.get(id1).unwrap_or(None).is_some());
assert!(loaded_db.get(id2).unwrap_or(None).is_some());
assert!(loaded_db.get(id3).unwrap_or(None).is_none()); // v3 was inserted after snapshot was taken

// vector restore check
assert!(loaded_db.get(id1).unwrap().unwrap().vector.unwrap() == v1);
assert!(loaded_db.get(id2).unwrap().unwrap().vector.unwrap() == v2);
}

#[test]
fn test_snapshot_engine() {
let (_db, _temp_dir) = create_test_db();
let db = Arc::new(Mutex::new(_db));

let registry_tempdir = tempdir().unwrap();

let registry = Arc::new(Mutex::new(
LocalRegistry::new(registry_tempdir.path()).unwrap(),
));

let last_k = 4;
let mut se = SnapshotEngine::new(last_k, db.clone(), registry.clone());

let v1 = vec![0.0, 1.0, 2.0];
let v2 = vec![3.0, 4.0, 5.0];
let v3 = vec![6.0, 7.0, 8.0];

let test_vectors = vec![v1.clone(), v2.clone(), v3.clone()];
let mut inserted_ids = Vec::new();

for (i, vector) in test_vectors.clone().into_iter().enumerate() {
se.snapshot().unwrap();
let id = db
.lock()
.unwrap()
.insert(
vector.clone(),
Payload {
content_type: ContentType::Text,
content: format!("{}", i),
},
)
.unwrap();
inserted_ids.push(id);
}
se.snapshot().unwrap();
let snapshots = se.list_alive_snapshots().unwrap();

// asserting these cases:
// snapshot 0 : no vectors
// snapshot 1 : v1
// snapshot 2 : v1, v2
// snapshot 3 : v1, v2, v3

std::mem::drop(db);
std::mem::drop(se);

for (i, snapshot) in snapshots.iter().enumerate() {
let temp_dir = tempdir().unwrap();
let db = restore_from_snapshot(&DbRestoreConfig {
data_path: temp_dir.path().to_path_buf(),
snapshot_path: snapshot.path.clone(),
})
.unwrap();
for j in 0..i {
// test if point is present
assert!(db.get(inserted_ids[j]).unwrap_or(None).is_some());
// test vector restore
assert!(
db.get(inserted_ids[j]).unwrap().unwrap().vector.unwrap() == test_vectors[j]
);
// test payload restore
assert!(
db.get(inserted_ids[j])
.unwrap()
.unwrap()
.payload
.unwrap()
.content
== format!("{}", j)
);
}
for absent_id in inserted_ids.iter().skip(i) {
assert!(db.get(*absent_id).unwrap_or(None).is_none());
}
std::mem::drop(db);
}
}
}
7 changes: 7 additions & 0 deletions crates/defs/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ pub enum DbError {
IndexInitError, //TODO: Change this
UnsupportedSimilarity,
DimensionMismatch,
SnapshotError(String),
StorageInitializationError,
StorageCheckpointError(String),
InvalidMagicBytes(String),
VectorNotFound(uuid::Uuid),
SnapshotRegistryError(String),
StorageEngineError(String),
InvalidDimension { expected: Dimension, got: Dimension },
PointAlreadyExists { id: PointId },
PointNotFound { id: PointId },
Expand Down
6 changes: 6 additions & 0 deletions crates/defs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,10 @@ pub mod types;

// Without re-exports, users would need to write defs::types::SomeType instead of just defs::SomeType. Re-exports simplify the API by flattening the module hierarchy. The * means "everything public" from that module.
pub use error::*;
use std::path::{Path, PathBuf};
pub use types::*;

// hoisted trait so it can be used by the snapshots crate
pub trait SnapshottableDb: Send + Sync {
fn create_snapshot(&self, dir_path: &Path) -> Result<PathBuf, DbError>;
}
4 changes: 3 additions & 1 deletion crates/defs/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub type Dimension = usize;
// Sparse vector implementation not supported yet. Refer lib/sparse/src/common/sparse_vector.rs
pub type DenseVector = Vec<Element>;

pub type Magic = [u8; 4];

pub enum StoredVector {
Dense(DenseVector),
}
Expand Down Expand Up @@ -45,7 +47,7 @@ pub struct IndexedVector {
pub vector: DenseVector,
}

#[derive(Debug, Deserialize, Copy, Clone)]
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
pub enum Similarity {
Euclidean,
Manhattan,
Expand Down
3 changes: 3 additions & 0 deletions crates/index/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ edition.workspace = true
license.workspace = true

[dependencies]
bincode.workspace = true
defs.workspace = true
rand.workspace = true
serde.workspace = true
storage.workspace = true
uuid.workspace = true
Loading
Loading