diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index de6802b..7fba1d9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -21,6 +21,9 @@ jobs: with: components: clippy, llvm-tools-preview + - name: Install Protobuf Compiler + run: sudo apt-get update && sudo apt-get install -y protobuf-compiler + - name: Cache dependencies uses: Swatinem/rust-cache@v2 diff --git a/Cargo.lock b/Cargo.lock index f38271e..ebf8689 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -97,6 +97,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "anyhow" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" + [[package]] name = "assert_cmd" version = "2.0.16" @@ -164,6 +170,9 @@ dependencies = [ "futures", "once_cell", "predicates", + "prost", + "prost-build", + "prost-types", "rand", "rmp-serde", "serde", @@ -522,6 +531,16 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "eyre" version = "0.6.12" @@ -532,6 +551,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fixedbitset" version = "0.4.2" @@ -994,6 +1019,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" + [[package]] name = "litemap" version = "0.7.5" @@ -1051,6 +1082,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "multimap" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" + [[package]] name = "normalize-line-endings" version = "0.3.0" @@ -1239,6 +1276,16 @@ dependencies = [ "termtree", ] +[[package]] +name = "prettyplease" +version = "0.2.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.94" @@ -1248,6 +1295,58 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" +dependencies = [ + "heck", + "itertools", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +dependencies = [ + "prost", +] + [[package]] name = "quote" version = "1.0.40" @@ -1394,6 +1493,19 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustix" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + [[package]] name = "rustversion" version = "1.0.20" @@ -1604,6 +1716,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tempfile" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "termtree" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 41d44c2..3207ec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,8 @@ color-eyre = "0.6.3" dashmap = "6.1.0" futures = "0.3.31" once_cell = "1.21.3" +prost = "0.13.5" +prost-types = "0.13.5" rand = "0.9.0" rmp-serde = "1.3.0" serde = { version = "1.0.219", features = ["serde_derive"] } @@ -48,6 +50,7 @@ built = { version = "0.7", features = [ "chrono", "semver", ] } +prost-build = "0.13.5" # Workspace configuration [lib] diff --git a/benches/bonka_benchmark.rs b/benches/bonka_benchmark.rs index e17930c..2d8aa5b 100644 --- a/benches/bonka_benchmark.rs +++ b/benches/bonka_benchmark.rs @@ -1,12 +1,13 @@ use bonka::kv::KeyValueStore; use bonka::kv::Value; -use bonka::protocol::{Command, Request, Response}; +use bonka::proto; +use bonka::proto::bonka::{CommandType, ResultType}; use bytes::Bytes; use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use futures::future::join_all; use futures::{SinkExt, StreamExt}; +use prost::Message; use rand::{Rng, distr::Alphanumeric}; -use serde::Serialize; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::net::TcpStream; use tokio_util::codec::{Framed, LengthDelimitedCodec}; @@ -114,20 +115,24 @@ fn bench_protocol(c: &mut Criterion) { // Benchmark request serialization group.bench_function("serialize_request", |b| { b.iter(|| { - let request = Request { + // Create a proto Value + let proto_value = proto::bonka::Value { + value: Some(proto::bonka::value::Value::StringValue( + "test-value".to_string(), + )), + }; + + let request = proto::bonka::Request { id: Some(1), timestamp: get_timestamp(), - command: Command::Set( - "test-key".to_string(), - Value::String("test-value".to_string()), - ), - metadata: None, + command_type: CommandType::CommandSet as i32, + key: Some("test-key".to_string()), + value: Some(proto_value), + metadata: Default::default(), }; - let mut buf = Vec::new(); - request - .serialize(&mut rmp_serde::Serializer::new(&mut buf)) - .unwrap(); + // Serialize using Protocol Buffers + let buf = request.encode_to_vec(); black_box(buf); }) }); @@ -135,39 +140,51 @@ fn bench_protocol(c: &mut Criterion) { // Benchmark response serialization group.bench_function("serialize_response", |b| { b.iter(|| { - let response = Response { + // Create a proto Value + let proto_value = proto::bonka::Value { + value: Some(proto::bonka::value::Value::StringValue( + "test-value".to_string(), + )), + }; + + let response = proto::bonka::Response { id: Some(1), timestamp: get_timestamp(), - result: bonka::protocol::Result::Value(Some(Value::String( - "test-value".to_string(), - ))), - metadata: None, + result_type: ResultType::ResultValue as i32, + value: Some(proto_value), + keys: vec![], + error: None, + metadata: Default::default(), }; - let mut buf = Vec::new(); - response - .serialize(&mut rmp_serde::Serializer::new(&mut buf)) - .unwrap(); + // Serialize using Protocol Buffers + let buf = response.encode_to_vec(); black_box(buf); }) }); // Create a sample serialized request for deserialization benchmark - let sample_request = Request { + let proto_value = proto::bonka::Value { + value: Some(proto::bonka::value::Value::StringValue( + "test-value".to_string(), + )), + }; + + let sample_request = proto::bonka::Request { id: Some(1), timestamp: get_timestamp(), - command: Command::Get("test-key".to_string()), - metadata: None, + command_type: CommandType::CommandGet as i32, + key: Some("test-key".to_string()), + value: Some(proto_value), + metadata: Default::default(), }; - let mut request_buf = Vec::new(); - sample_request - .serialize(&mut rmp_serde::Serializer::new(&mut request_buf)) - .unwrap(); + + let request_buf = sample_request.encode_to_vec(); // Benchmark request deserialization group.bench_function("deserialize_request", |b| { b.iter(|| { - let request: Request = rmp_serde::from_slice(&request_buf).unwrap(); + let request = proto::bonka::Request::decode(request_buf.as_slice()).unwrap(); black_box(request); }) }); @@ -215,24 +232,27 @@ async fn server_benchmark() -> Result> { let key = format!("client{}-key{}", client_id, i); let value = format!("value{}-{}", client_id, i); - // Create a Set command - let request = Request { + // Create a Set command with protobuf + let proto_value = proto::bonka::Value { + value: Some(proto::bonka::value::Value::StringValue(value)), + }; + + let request = proto::bonka::Request { id: Some((client_id * ops_per_client + i) as u64), timestamp: get_timestamp(), - command: Command::Set(key.clone(), Value::String(value)), - metadata: None, + command_type: CommandType::CommandSet as i32, + key: Some(key), + value: Some(proto_value), + metadata: Default::default(), }; // Serialize and send - let mut buf = Vec::new(); - request - .serialize(&mut rmp_serde::Serializer::new(&mut buf)) - .unwrap(); + let buf = request.encode_to_vec(); framed.send(Bytes::from(buf)).await.unwrap(); // Receive response let bytes = framed.next().await.unwrap().unwrap(); - let _response: Response = rmp_serde::from_slice(&bytes).unwrap(); + let _response = proto::bonka::Response::decode(bytes.as_ref()).unwrap(); } }) }) diff --git a/build.rs b/build.rs index 4306bb7..9bee755 100644 --- a/build.rs +++ b/build.rs @@ -1,3 +1,8 @@ fn main() { - built::write_built_file().expect("Failed to acquire build-time information") + built::write_built_file().expect("Failed to acquire build-time information"); + + // Compile protocol buffers + println!("cargo:rerun-if-changed=protocol/bonka.proto"); + prost_build::compile_protos(&["protocol/bonka.proto"], &["protocol"]) + .expect("Failed to compile protobuf definitions"); } diff --git a/protocol/bonka.proto b/protocol/bonka.proto new file mode 100644 index 0000000..b9d55a5 --- /dev/null +++ b/protocol/bonka.proto @@ -0,0 +1,65 @@ +syntax = "proto3"; + +package bonka; + +// Value types that can be stored +message Value { + oneof value { + string string_value = 1; + bytes bytes_value = 2; + int64 int_value = 3; + uint64 uint_value = 4; + double float_value = 5; + bool bool_value = 6; + // Null is represented by not setting any value + } +} + +// Commands that can be executed +enum CommandType { + COMMAND_UNSPECIFIED = 0; // Default value + COMMAND_GET = 1; + COMMAND_SET = 2; + COMMAND_DELETE = 3; + COMMAND_LIST = 4; + COMMAND_EXIT = 5; +} + +// Request message +message Request { + optional uint64 id = 1; + uint64 timestamp = 2; + + // Command details + CommandType command_type = 3; + optional string key = 4; + optional Value value = 5; + + // Optional metadata as key-value pairs + map metadata = 6; +} + +// Response result types +enum ResultType { + RESULT_UNSPECIFIED = 0; // Default value + RESULT_VALUE = 1; + RESULT_SUCCESS = 2; + RESULT_KEYS = 3; + RESULT_ERROR = 4; + RESULT_EXIT = 5; +} + +// Response message +message Response { + optional uint64 id = 1; + uint64 timestamp = 2; + + // Result details + ResultType result_type = 3; + optional Value value = 4; + repeated string keys = 5; // For LIST command + optional string error = 6; // For ERROR result + + // Optional metadata + map metadata = 7; +} \ No newline at end of file diff --git a/src/kv.rs b/src/kv.rs index 93e3ab0..07990e0 100644 --- a/src/kv.rs +++ b/src/kv.rs @@ -1,6 +1,8 @@ use dashmap::DashMap; use serde::{Deserialize, Serialize}; +use crate::proto; + /// The `Value` enum represents the different types of values that can be stored in the key-value store. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Value { @@ -13,6 +15,48 @@ pub enum Value { Null, } +/// Convert the `Value` enum into the protobuf format. +impl From for proto::bonka::Value { + fn from(value: Value) -> Self { + match value { + Value::String(s) => proto::bonka::Value { + value: Some(proto::bonka::value::Value::StringValue(s)), + }, + Value::Bytes(b) => proto::bonka::Value { + value: Some(proto::bonka::value::Value::BytesValue(b.to_vec())), + }, + Value::Int(i) => proto::bonka::Value { + value: Some(proto::bonka::value::Value::IntValue(i)), + }, + Value::UInt(u) => proto::bonka::Value { + value: Some(proto::bonka::value::Value::UintValue(u)), + }, + Value::Float(f) => proto::bonka::Value { + value: Some(proto::bonka::value::Value::FloatValue(f)), + }, + Value::Bool(b) => proto::bonka::Value { + value: Some(proto::bonka::value::Value::BoolValue(b)), + }, + Value::Null => proto::bonka::Value { value: None }, + } + } +} + +/// Convert the protobuf format into the `Value` enum. +impl From for Value { + fn from(value: proto::bonka::Value) -> Self { + match value.value { + Some(proto::bonka::value::Value::StringValue(s)) => Value::String(s), + Some(proto::bonka::value::Value::BytesValue(b)) => Value::Bytes(b.into_boxed_slice()), + Some(proto::bonka::value::Value::IntValue(i)) => Value::Int(i), + Some(proto::bonka::value::Value::UintValue(u)) => Value::UInt(u), + Some(proto::bonka::value::Value::FloatValue(f)) => Value::Float(f), + Some(proto::bonka::value::Value::BoolValue(b)) => Value::Bool(b), + None => Value::Null, + } + } +} + /// Simple key-value store using [`DashMap`](https://docs.rs/dashmap/6.1.0/dashmap/). pub struct KeyValueStore { data: DashMap, diff --git a/src/lib.rs b/src/lib.rs index c2ee241..2063078 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,6 @@ pub mod cli; pub mod constants; pub mod kv; pub mod log; -pub mod protocol; +pub mod proto; pub mod server; pub mod session; diff --git a/src/proto.rs b/src/proto.rs new file mode 100644 index 0000000..45af77d --- /dev/null +++ b/src/proto.rs @@ -0,0 +1,3 @@ +pub mod bonka { + include!(concat!(env!("OUT_DIR"), "/bonka.rs")); +} diff --git a/src/protocol.rs b/src/protocol.rs deleted file mode 100644 index ffbfdb3..0000000 --- a/src/protocol.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::collections::HashMap; - -use serde::{Deserialize, Serialize}; - -use crate::kv::Value; - -pub type Id = u64; -pub type Timestamp = u64; - -/// The `Command` enum represents the different commands that can be executed against the key-value store. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum Command { - Get(String), - Set(String, Value), - Delete(String), - List, - Exit, -} - -/// The `Result` enum represents the different types of results that can be returned from executing a command. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum Result { - Value(Option), - Success, - Keys(Vec), - Error(String), - Exit, -} - -/// The `Metadata` type is a key-value store for additional information associated with a request or response. -pub type Metadata = HashMap; - -/// The `Request` struct represents a request to the key-value store. -/// It contains an optional ID (to correlate with the request), a timestamp, the result of the command execution, -/// and optional metadata. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct Request { - pub id: Option, - pub timestamp: Timestamp, - pub command: Command, - pub metadata: Option, -} - -/// The `Response` struct represents a response from the key-value store. -/// It contains an optional ID (to correlate with the request), a timestamp, the result of the command execution, -/// and optional metadata. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct Response { - pub id: Option, - pub timestamp: Timestamp, - pub result: Result, - pub metadata: Option, -} - -#[cfg(test)] -mod tests { - use super::*; - use rmp_serde::{Deserializer, Serializer}; - use std::io::Cursor; - - #[test] - fn msgpack_serde() { - let request = Request { - id: Some(1), - timestamp: 1234567890, - command: Command::Set("key".to_string(), Value::String("value".to_string())), - metadata: Some({ - let mut meta = HashMap::new(); - meta.insert("author".to_string(), Value::String("bonka".to_string())); - meta.insert("version".to_string(), Value::UInt(1)); - meta - }), - }; - - // Serialize the request - let mut buf = Vec::new(); - request.serialize(&mut Serializer::new(&mut buf)).unwrap(); - - // Deserialize the request - let mut de = Deserializer::new(Cursor::new(buf)); - let deserialized_request: Request = Deserialize::deserialize(&mut de).unwrap(); - - assert_eq!(request, deserialized_request); - } -} diff --git a/src/server.rs b/src/server.rs index fd0080d..9f28690 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,27 +1,23 @@ use bytes::Bytes; use color_eyre::eyre::{self, Report}; use futures::{SinkExt, StreamExt}; -use serde::Serialize; +use prost::Message; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::net::{TcpListener, TcpStream}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; -// Import your session manager and protocol messages -use crate::kv::KeyValueStore; +use crate::kv::{self, KeyValueStore}; use crate::log; +use crate::proto::bonka::{CommandType, Request, Response, ResultType}; use crate::session::SessionManager; -// Import protocol messages -use crate::protocol::{Command, Request, Response, Result as ProtocolResult}; - -// Server state struct ServerState { session_manager: SessionManager, kv_store: KeyValueStore, } -// Get current timestamp +#[inline(always)] fn get_timestamp() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) @@ -29,6 +25,107 @@ fn get_timestamp() -> u64 { .as_secs() } +// =============================================== +// Helpers for creating responses / handling commands +// =============================================== + +#[inline(always)] +fn create_success_response(id: Option) -> Response { + Response { + id, + timestamp: get_timestamp(), + result_type: ResultType::ResultSuccess as i32, + ..Default::default() + } +} + +#[inline(always)] +fn create_value_response(id: Option, value: Option) -> Response { + Response { + id, + timestamp: get_timestamp(), + result_type: ResultType::ResultValue as i32, + value: value.map(|v| v.into()), + ..Default::default() + } +} + +#[inline(always)] +fn create_error_response(id: Option, message: String) -> Response { + Response { + id, + timestamp: get_timestamp(), + result_type: ResultType::ResultError as i32, + error: Some(message), + ..Default::default() + } +} + +#[inline(always)] +fn create_keys_response(id: Option, keys: Vec) -> Response { + Response { + id, + timestamp: get_timestamp(), + result_type: ResultType::ResultKeys as i32, + keys, + ..Default::default() + } +} + +#[inline(always)] +fn create_exit_response(id: Option) -> Response { + Response { + id, + timestamp: get_timestamp(), + result_type: ResultType::ResultExit as i32, + ..Default::default() + } +} + +#[inline(always)] +fn handle_get_command(request: &Request, kv_store: &KeyValueStore) -> Response { + match &request.key { + Some(key) => create_value_response(request.id, kv_store.get(key)), + None => create_error_response(request.id, "Key not provided".to_string()), + } +} + +#[inline(always)] +fn handle_set_command(request: &Request, kv_store: &KeyValueStore) -> Response { + match (&request.key, &request.value) { + (Some(key), Some(value)) => { + kv_store.set(key.clone(), value.clone().into()); + create_success_response(request.id) + } + (None, _) => create_error_response(request.id, "Key not provided".to_string()), + (_, None) => create_error_response(request.id, "Value not provided".to_string()), + } +} + +#[inline(always)] +fn handle_delete_command(request: &Request, kv_store: &KeyValueStore) -> Response { + match &request.key { + Some(key) => { + if kv_store.delete(key) { + create_success_response(request.id) + } else { + create_error_response(request.id, format!("Key '{}' not found", key)) + } + } + None => create_error_response(request.id, "Key not provided".to_string()), + } +} + +#[inline(always)] +fn handle_list_command(request: &Request, kv_store: &KeyValueStore) -> Response { + let keys = kv_store.list(); + create_keys_response(request.id, keys) +} + +// =============================================== +// End response helpers +// =============================================== + /// Run the server pub async fn run(host: impl Into, port: u16) -> Result<(), Report> { // Get the address to bind to @@ -105,7 +202,7 @@ async fn handle_client( match result { Ok(bytes) => { // Deserialize request using MessagePack - let request: Request = match rmp_serde::from_slice(&bytes) { + let request: Request = match Request::decode(bytes.as_ref()) { Ok(req) => req, Err(e) => { log::error!("Failed to deserialize request: {}", e); @@ -114,8 +211,9 @@ async fn handle_client( let error_response = Response { id: None, timestamp: get_timestamp(), - result: ProtocolResult::Error("Invalid request format".to_string()), - metadata: None, + result_type: ResultType::ResultError as i32, + error: Some("Invalid request format".to_string()), + ..Default::default() }; send_response(&mut framed, &error_response).await?; @@ -136,7 +234,7 @@ async fn handle_client( send_response(&mut framed, &response).await?; // Check if client is exiting - if matches!(response.result, ProtocolResult::Exit) { + if response.result_type() == ResultType::ResultExit { log::info!("Client {} requested exit", addr); break; } @@ -165,34 +263,13 @@ async fn process_command(request: Request, state: &Arc>) -> R let server_state = state.lock().unwrap(); let kv_store = &server_state.kv_store; - let result = match request.command { - Command::Get(key) => { - let value = kv_store.get(&key); - ProtocolResult::Value(value) - } - Command::Set(key, value) => { - kv_store.set(key, value); - ProtocolResult::Success - } - Command::Delete(key) => { - if kv_store.delete(&key) { - ProtocolResult::Success - } else { - ProtocolResult::Error(format!("Key '{}' not found", key)) - } - } - Command::List => { - let keys = kv_store.list(); - ProtocolResult::Keys(keys) - } - Command::Exit => ProtocolResult::Exit, - }; - - Response { - id: request.id, // Echo back the request ID for correlation - timestamp: get_timestamp(), - result, - metadata: None, // We could add server metadata here if needed + match request.command_type() { + CommandType::CommandGet => handle_get_command(&request, kv_store), + CommandType::CommandSet => handle_set_command(&request, kv_store), + CommandType::CommandDelete => handle_delete_command(&request, kv_store), + CommandType::CommandList => handle_list_command(&request, kv_store), + CommandType::CommandExit => create_exit_response(request.id), + _ => create_error_response(request.id, "Unknown command".to_string()), } } @@ -203,20 +280,19 @@ async fn send_response( framed: &mut Framed, response: &Response, ) -> Result<(), Box> { - // Serialize response using MessagePack - let mut buf = Vec::new(); - response.serialize(&mut rmp_serde::Serializer::new(&mut buf))?; - + // Serialize response using protobuf + let encoded = response.encode_to_vec(); // Send the response - framed.send(Bytes::from(buf)).await?; + framed.send(Bytes::from(encoded)).await?; Ok(()) } #[cfg(test)] mod tests { use super::*; - - use crate::kv::Value; + use crate::kv; + use crate::proto::bonka::{self, CommandType, ResultType}; + use prost::Message; // Test server setup and teardown struct TestServer { @@ -278,27 +354,33 @@ mod tests { } } + // Helper function to create a protobuf Value from a kv::Value + fn create_proto_value(value: kv::Value) -> bonka::Value { + value.into() + } + // Helper function to send a command and get response async fn send_command( framed: &mut Framed, - command: Command, - ) -> Response { - let request = Request { - id: None, + command_type: CommandType, + key: Option, + value: Option, + ) -> bonka::Response { + let request = bonka::Request { + id: Some(1), // Use a test ID timestamp: get_timestamp(), - command, - metadata: None, + command_type: command_type as i32, + key, + value, + metadata: Default::default(), }; - // Serialize request - let mut buf = Vec::new(); - request - .serialize(&mut rmp_serde::Serializer::new(&mut buf)) - .expect("Failed to serialize request"); + // Serialize request using protobuf + let encoded = request.encode_to_vec(); // Send request framed - .send(Bytes::from(buf)) + .send(Bytes::from(encoded)) .await .expect("Failed to send request"); @@ -310,7 +392,7 @@ mod tests { .expect("Failed to receive response"); // Deserialize response - rmp_serde::from_slice(&bytes).expect("Failed to deserialize response") + bonka::Response::decode(bytes.as_ref()).expect("Failed to deserialize response") } #[tokio::test] @@ -322,24 +404,39 @@ mod tests { let mut client = connect_client(&server.host, server.port).await; // Set a key - let set_cmd = Command::Set( - "test-key".to_string(), - Value::String("test-value".to_string()), - ); - let set_response = send_command(&mut client, set_cmd).await; + let set_response = send_command( + &mut client, + CommandType::CommandSet, + Some("test-key".to_string()), + Some(create_proto_value(kv::Value::String( + "test-value".to_string(), + ))), + ) + .await; // Check set was successful - assert!(matches!(set_response.result, ProtocolResult::Success)); + assert_eq!(set_response.result_type(), ResultType::ResultSuccess); // Get the key - let get_cmd = Command::Get("test-key".to_string()); - let get_response = send_command(&mut client, get_cmd).await; + let get_response = send_command( + &mut client, + CommandType::CommandGet, + Some("test-key".to_string()), + None, + ) + .await; // Check the value is correct - if let ProtocolResult::Value(Some(Value::String(value))) = get_response.result { + assert_eq!(get_response.result_type(), ResultType::ResultValue); + assert!(get_response.value.is_some()); + let proto_value = get_response.value.unwrap(); + + // Convert to a kv::Value and check + let kv_value: kv::Value = proto_value.into(); + if let kv::Value::String(value) = kv_value { assert_eq!(value, "test-value"); } else { - panic!("Expected Value::String, got {:?}", get_response.result); + panic!("Expected String value, got {:?}", kv_value); } // Clean up @@ -355,29 +452,38 @@ mod tests { let mut client = connect_client(&server.host, server.port).await; // Set a key - let set_cmd = Command::Set("delete-key".to_string(), Value::String("value".to_string())); - let _ = send_command(&mut client, set_cmd).await; + let _ = send_command( + &mut client, + CommandType::CommandSet, + Some("delete-key".to_string()), + Some(create_proto_value(kv::Value::String("value".to_string()))), + ) + .await; // Delete the key - let delete_cmd = Command::Delete("delete-key".to_string()); - let delete_response = send_command(&mut client, delete_cmd).await; + let delete_response = send_command( + &mut client, + CommandType::CommandDelete, + Some("delete-key".to_string()), + None, + ) + .await; // Check delete was successful - assert!(matches!(delete_response.result, ProtocolResult::Success)); + assert_eq!(delete_response.result_type(), ResultType::ResultSuccess); // Try to get the deleted key - let get_cmd = Command::Get("delete-key".to_string()); - let get_response = send_command(&mut client, get_cmd).await; + let get_response = send_command( + &mut client, + CommandType::CommandGet, + Some("delete-key".to_string()), + None, + ) + .await; // Key should not exist - if let ProtocolResult::Value(value) = get_response.result { - assert!(value.is_none()); - } else { - panic!( - "Expected ProtocolResult::Value(None), got {:?}", - get_response.result - ); - } + assert_eq!(get_response.result_type(), ResultType::ResultValue); + assert!(get_response.value.is_none()); // Clean up server.stop().await; @@ -392,18 +498,18 @@ mod tests { let mut client = connect_client(&server.host, server.port).await; // Try to delete a non-existent key - let delete_cmd = Command::Delete("nonexistent-key".to_string()); - let delete_response = send_command(&mut client, delete_cmd).await; + let delete_response = send_command( + &mut client, + CommandType::CommandDelete, + Some("nonexistent-key".to_string()), + None, + ) + .await; // Should get an error - if let ProtocolResult::Error(err) = delete_response.result { - assert!(err.contains("not found")); - } else { - panic!( - "Expected ProtocolResult::Error, got {:?}", - delete_response.result - ); - } + assert_eq!(delete_response.result_type(), ResultType::ResultError); + assert!(delete_response.error.is_some()); + assert!(delete_response.error.unwrap().contains("not found")); // Clean up server.stop().await; @@ -420,28 +526,31 @@ mod tests { // Add multiple keys let keys = vec!["key1", "key2", "key3"]; for key in &keys { - let set_cmd = Command::Set(key.to_string(), Value::String(format!("value-{}", key))); - let _ = send_command(&mut client, set_cmd).await; + let _ = send_command( + &mut client, + CommandType::CommandSet, + Some(key.to_string()), + Some(create_proto_value(kv::Value::String(format!( + "value-{}", + key + )))), + ) + .await; } // List all keys - let list_cmd = Command::List; - let list_response = send_command(&mut client, list_cmd).await; + let list_response = send_command(&mut client, CommandType::CommandList, None, None).await; // Check that all our keys are listed - if let ProtocolResult::Keys(response_keys) = list_response.result { - // Convert Vec to Vec<&str> for easier comparison - let response_keys_str: Vec<&str> = response_keys.iter().map(|s| s.as_str()).collect(); + assert_eq!(list_response.result_type(), ResultType::ResultKeys); + assert!(!list_response.keys.is_empty()); - // Check each key is present - for key in keys { - assert!(response_keys_str.contains(&key)); - } - } else { - panic!( - "Expected ProtocolResult::Keys, got {:?}", - list_response.result - ); + // Convert Vec to Vec<&str> for easier comparison + let response_keys_str: Vec<&str> = list_response.keys.iter().map(|s| s.as_str()).collect(); + + // Check each key is present + for key in keys { + assert!(response_keys_str.contains(&key)); } // Clean up @@ -456,36 +565,46 @@ mod tests { // Connect client let mut client = connect_client(&server.host, server.port).await; - // Test different value types + // Setup test cases with KV values let test_values = vec![ - ("string-key", Value::String("string-value".to_string())), - ("int-key", Value::Int(42)), - ("float-key", Value::Float(3.14)), - ("bool-key", Value::Bool(true)), - ("null-key", Value::Null), - // You could add more complex types like arrays and maps if your Value enum supports them + ("string-key", kv::Value::String("string-value".to_string())), + ("int-key", kv::Value::Int(42)), + ("float-key", kv::Value::Float(1.337)), + ("bool-key", kv::Value::Bool(true)), + ("null-key", kv::Value::Null), ]; // Set each value for (key, value) in &test_values { - let set_cmd = Command::Set(key.to_string(), value.clone()); - let set_response = send_command(&mut client, set_cmd).await; - assert!(matches!(set_response.result, ProtocolResult::Success)); + let proto_value = create_proto_value(value.clone()); + + let set_response = send_command( + &mut client, + CommandType::CommandSet, + Some(key.to_string()), + Some(proto_value), + ) + .await; + + assert_eq!(set_response.result_type(), ResultType::ResultSuccess); } // Get and verify each value for (key, expected_value) in test_values { - let get_cmd = Command::Get(key.to_string()); - let get_response = send_command(&mut client, get_cmd).await; - - if let ProtocolResult::Value(Some(value)) = get_response.result { - assert_eq!(value, expected_value); - } else { - panic!( - "Expected value for key {}, got {:?}", - key, get_response.result - ); - } + let get_response = send_command( + &mut client, + CommandType::CommandGet, + Some(key.to_string()), + None, + ) + .await; + + assert_eq!(get_response.result_type(), ResultType::ResultValue); + assert!(get_response.value.is_some()); + + // Convert to a kv::Value and compare + let kv_value: kv::Value = get_response.value.unwrap().into(); + assert_eq!(kv_value, expected_value); } // Clean up @@ -516,20 +635,29 @@ mod tests { let value = format!("value{}-{}", client_id, i); // Set a key - let set_cmd = Command::Set(key.clone(), Value::String(value.clone())); - let set_response = send_command(&mut client, set_cmd).await; - assert!(matches!(set_response.result, ProtocolResult::Success)); + let set_response = send_command( + &mut client, + CommandType::CommandSet, + Some(key.clone()), + Some(create_proto_value(kv::Value::String(value.clone()))), + ) + .await; + + assert_eq!(set_response.result_type(), ResultType::ResultSuccess); // Get the key back - let get_cmd = Command::Get(key); - let get_response = send_command(&mut client, get_cmd).await; + let get_response = + send_command(&mut client, CommandType::CommandGet, Some(key), None) + .await; - if let ProtocolResult::Value(Some(Value::String(response_value))) = - get_response.result - { + assert_eq!(get_response.result_type(), ResultType::ResultValue); + assert!(get_response.value.is_some()); + + let kv_value: kv::Value = get_response.value.unwrap().into(); + if let kv::Value::String(response_value) = kv_value { assert_eq!(response_value, value); } else { - panic!("Expected Value::String, got {:?}", get_response.result); + panic!("Expected String value, got {:?}", kv_value); } } }) @@ -543,24 +671,26 @@ mod tests { // Connect a new client to verify all keys are present let mut verification_client = connect_client(&server.host, server.port).await; - let list_cmd = Command::List; - let list_response = send_command(&mut verification_client, list_cmd).await; - - if let ProtocolResult::Keys(keys) = list_response.result { - assert_eq!(keys.len(), client_count * operations_per_client); + let list_response = send_command( + &mut verification_client, + CommandType::CommandList, + None, + None, + ) + .await; + + assert_eq!(list_response.result_type(), ResultType::ResultKeys); + assert_eq!( + list_response.keys.len() as u32, + client_count * operations_per_client + ); - // Verify each expected key exists - for client_id in 0..client_count { - for i in 0..operations_per_client { - let expected_key = format!("client{}-key{}", client_id, i); - assert!(keys.contains(&expected_key)); - } + // Verify each expected key exists + for client_id in 0..client_count { + for i in 0..operations_per_client { + let expected_key = format!("client{}-key{}", client_id, i); + assert!(list_response.keys.contains(&expected_key)); } - } else { - panic!( - "Expected ProtocolResult::Keys, got {:?}", - list_response.result - ); } // Clean up @@ -576,11 +706,10 @@ mod tests { let mut client = connect_client(&server.host, server.port).await; // Send exit command - let exit_cmd = Command::Exit; - let exit_response = send_command(&mut client, exit_cmd).await; + let exit_response = send_command(&mut client, CommandType::CommandExit, None, None).await; // Check exit response - assert!(matches!(exit_response.result, ProtocolResult::Exit)); + assert_eq!(exit_response.result_type(), ResultType::ResultExit); // Try to send another command, should fail as connection should be closed let buf = client.next().await; @@ -598,7 +727,7 @@ mod tests { // Connect client let mut client = connect_client(&server.host, server.port).await; - // Send invalid data (not a valid MessagePack serialized Request) + // Send invalid data (not a valid protobuf Request) let invalid_data = Bytes::from(vec![0, 1, 2, 3]); client .send(invalid_data) @@ -611,10 +740,11 @@ mod tests { .await .expect("No response received") .expect("Failed to receive response"); - let response: Response = - rmp_serde::from_slice(&response_bytes).expect("Failed to deserialize error response"); - assert!(matches!(response.result, ProtocolResult::Error(_))); + let response = bonka::Response::decode(response_bytes.as_ref()) + .expect("Failed to deserialize error response"); + + assert_eq!(response.result_type(), ResultType::ResultError); // Clean up server.stop().await; @@ -629,38 +759,60 @@ mod tests { let mut client1 = connect_client(&server.host, server.port).await; // Set a key using first client - let set_cmd = Command::Set( - "session-key".to_string(), - Value::String("session-value".to_string()), - ); - let set_response = send_command(&mut client1, set_cmd).await; - assert!(matches!(set_response.result, ProtocolResult::Success)); + let set_response = send_command( + &mut client1, + CommandType::CommandSet, + Some("session-key".to_string()), + Some(create_proto_value(kv::Value::String( + "session-value".to_string(), + ))), + ) + .await; + + assert_eq!(set_response.result_type(), ResultType::ResultSuccess); // Connect second client let mut client2 = connect_client(&server.host, server.port).await; // Get the key using second client (should be visible to all clients) - let get_cmd = Command::Get("session-key".to_string()); - let get_response = send_command(&mut client2, get_cmd).await; - - if let ProtocolResult::Value(Some(Value::String(value))) = get_response.result { + let get_response = send_command( + &mut client2, + CommandType::CommandGet, + Some("session-key".to_string()), + None, + ) + .await; + + assert_eq!(get_response.result_type(), ResultType::ResultValue); + assert!(get_response.value.is_some()); + + let kv_value: kv::Value = get_response.value.unwrap().into(); + if let kv::Value::String(value) = kv_value { assert_eq!(value, "session-value"); } else { - panic!("Expected Value::String, got {:?}", get_response.result); + panic!("Expected String value, got {:?}", kv_value); } // Close first client with Exit command - let exit_cmd = Command::Exit; - let _ = send_command(&mut client1, exit_cmd).await; + let _ = send_command(&mut client1, CommandType::CommandExit, None, None).await; // Key should still be accessible from second client - let get_cmd = Command::Get("session-key".to_string()); - let get_response = send_command(&mut client2, get_cmd).await; - - if let ProtocolResult::Value(Some(Value::String(value))) = get_response.result { + let get_response = send_command( + &mut client2, + CommandType::CommandGet, + Some("session-key".to_string()), + None, + ) + .await; + + assert_eq!(get_response.result_type(), ResultType::ResultValue); + assert!(get_response.value.is_some()); + + let kv_value: kv::Value = get_response.value.unwrap().into(); + if let kv::Value::String(value) = kv_value { assert_eq!(value, "session-value"); } else { - panic!("Expected Value::String, got {:?}", get_response.result); + panic!("Expected String value, got {:?}", kv_value); } // Clean up