diff --git a/Cargo.lock b/Cargo.lock index 0fe4362..6b41df9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,6 +13,7 @@ dependencies = [ "chrono", "clap", "colour", + "crc64", "env_logger", "failure", "jemallocator", @@ -192,6 +193,12 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" +[[package]] +name = "crc64" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2707e3afba5e19b75d582d88bc79237418f2a2a2d673d01cf9b03633b46e98f3" + [[package]] name = "crossbeam-channel" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index cd116b6..d0e41b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,3 +47,4 @@ once_cell = "1.10.0" jemallocator = { path = "components/jemallocator-0.3.2"} spin = "0.9.2" sysinfo = "0.23.5" +crc64 = "2.0.0" diff --git a/src/lib.rs b/src/lib.rs index 3216288..7b8be0c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -170,6 +170,8 @@ pub enum CstError { ReplicateCommandsLost(String), #[fail(display = "replica nodeid already exist")] ReplicaNodeAlreadyExist, + #[fail(display = "invalid checksum of snapshot")] + InvalidSnapshotChecksum, } impl From for CstError { diff --git a/src/server.rs b/src/server.rs index 019a35d..79040a8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -96,7 +96,7 @@ impl Server { let addr = format!("{}:{}", c.ip, c.port).parse::().unwrap(); let socket = TcpSocket::new_v4()?; socket.set_reuseaddr(true)?; - socket.set_reuseport(true)?; + socket.set_reuseport(false)?; socket.bind(addr)?; let listener = socket.listen(c.tcp_backlog)?; let server_c = server.clone(); diff --git a/src/snapshot.rs b/src/snapshot.rs index 92c9017..2d87809 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -1,42 +1,43 @@ +use crate::object::Object; +use crate::{Bytes, CstError}; +use crc64::Crc64; +use std::fmt; +use std::io; use std::io::{BufWriter, Write}; use tokio::io::AsyncReadExt; -use crate::{Bytes, CstError}; -use crate::object::Object; - pub struct SnapshotWriter { io: BufWriter, wrote_size: usize, - checksum: u64, + checksum_writter: Crc64, } impl SnapshotWriter { pub fn new(size: usize, io: W) -> Self { - SnapshotWriter{ + SnapshotWriter { io: BufWriter::with_capacity(size, io), wrote_size: 0, - checksum: 0, + checksum_writter: Crc64::new(), } } #[inline] pub fn write_integer(&mut self, i: i64) -> std::io::Result<&mut Self> { - if i < 1<<6 { + if i < 1 << 6 { self.write_bytes([i as u8].as_ref()) - } else if i < 1<<14 { - self.write_bytes(i16::to_be_bytes((i as i16) | 1<<14).as_ref()) - } else if i < 1<<30 { + } else if i < 1 << 14 { + self.write_bytes(i16::to_be_bytes((i as i16) | 1 << 14).as_ref()) + } else if i < 1 << 30 { //self.write_bytes([(1<<7|i>>24) as u8, ((i>>16)&255) as u8, ((i>>8)&255) as u8, (i&255) as u8].as_ref()) - self.write_bytes(i32::to_be_bytes((i as i32) | 1<<31).as_ref()) + self.write_bytes(i32::to_be_bytes((i as i32) | 1 << 31).as_ref()) } else { - self.write_bytes([3<<6 as u8].as_ref())?; + self.write_bytes([3 << 6 as u8].as_ref())?; self.write_bytes(i.to_be_bytes().as_ref()) } } pub fn write_bytes(&mut self, src: &[u8]) -> std::io::Result<&mut Self> { - // TODO need to update checksum - self.checksum = 0; + self.checksum_writter.write(src)?; self.io.write(src).map(|x| { self.wrote_size += x; x @@ -51,16 +52,15 @@ impl SnapshotWriter { } pub fn write_byte(&mut self, d: u8) -> std::io::Result<&mut Self> { - self.write_bytes([d;1].as_ref()) + self.write_bytes([d; 1].as_ref()) } pub fn total_wrote(&self) -> usize { self.wrote_size } - // TODO - pub fn checksum(&self) -> u64 { - self.checksum + pub fn checksum(&mut self) -> u64 { + self.checksum_writter.get() } pub fn flush(&mut self) -> Result<(), CstError> { @@ -70,21 +70,50 @@ impl SnapshotWriter { pub type FileSnapshotLoader = SnapshotLoader; +struct DebugCrc64 { + checksum_writter: Crc64, +} + +impl fmt::Debug for DebugCrc64 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "checksum: {}", self.checksum_writter.get()) + } +} + +impl DebugCrc64 { + pub fn new() -> Self { + DebugCrc64 { + checksum_writter: Crc64::new(), + } + } + + pub fn get(&self) -> u64 { + self.checksum_writter.get() + } + + fn write(&mut self, buf: &[u8]) -> io::Result { + self.checksum_writter.write(buf) + } +} + #[derive(Debug)] pub struct SnapshotLoader { stat: SnapshotLoadProgress, io: T, read_size: usize, + checksum_writter: DebugCrc64, } impl SnapshotLoader - where T: tokio::io::AsyncRead + std::marker::Unpin, +where + T: tokio::io::AsyncRead + std::marker::Unpin, { pub fn new(f: T) -> Self { - SnapshotLoader{ + SnapshotLoader { stat: SnapshotLoadProgress::Begin, io: f, read_size: 0, + checksum_writter: DebugCrc64::new(), } } @@ -100,38 +129,54 @@ impl SnapshotLoader if version.len() != 4 { return Err(CstError::InvalidSnapshot(7)); } - let v = format!("{}.{}.{}.{}", version[0], version[1], version[2], version[3]).into(); + let v = format!( + "{}.{}.{}.{}", + version[0], version[1], version[2], version[3] + ) + .into(); self.stat = SnapshotLoadProgress::Node; return Ok(Some(SnapshotEntry::Version(v))); } SnapshotLoadProgress::Node => { let nodid = self.read_integer().await? as u64; let alias_len = self.read_integer().await? as usize; - let alias = String::from_utf8(self.read_bytes(alias_len).await?).map_err(|_| CstError::InvalidSnapshot(self.read_size)).unwrap(); + let alias = String::from_utf8(self.read_bytes(alias_len).await?) + .map_err(|_| CstError::InvalidSnapshot(self.read_size)) + .unwrap(); let addr_len = self.read_integer().await? as usize; - let addr = String::from_utf8(self.read_bytes(addr_len).await?).map_err(|_| CstError::InvalidSnapshot(self.read_size)).unwrap(); + let addr = String::from_utf8(self.read_bytes(addr_len).await?) + .map_err(|_| CstError::InvalidSnapshot(self.read_size)) + .unwrap(); let uuid = self.read_integer().await? as u64; self.convert_stat().await?; return Ok(Some(SnapshotEntry::Node(nodid, alias, addr, uuid))); - }, + } SnapshotLoadProgress::Replicas(true) => { let add_time = self.read_integer().await? as u64; let nodid = self.read_integer().await? as u64; let alias_len = self.read_integer().await? as usize; - let alias = String::from_utf8(self.read_bytes(alias_len).await?).map_err(|_| CstError::InvalidSnapshot(self.read_size)).unwrap(); + let alias = String::from_utf8(self.read_bytes(alias_len).await?) + .map_err(|_| CstError::InvalidSnapshot(self.read_size)) + .unwrap(); let addr_len = self.read_integer().await? as usize; - let addr = String::from_utf8(self.read_bytes(addr_len).await?).map_err(|_| CstError::InvalidSnapshot(self.read_size)).unwrap(); + let addr = String::from_utf8(self.read_bytes(addr_len).await?) + .map_err(|_| CstError::InvalidSnapshot(self.read_size)) + .unwrap(); let uuid = self.read_integer().await? as u64; self.convert_stat().await?; - return Ok(Some(SnapshotEntry::ReplicaAdd(add_time, nodid, alias, addr, uuid))); - }, + return Ok(Some(SnapshotEntry::ReplicaAdd( + add_time, nodid, alias, addr, uuid, + ))); + } SnapshotLoadProgress::Replicas(false) => { let addr_len = self.read_integer().await? as usize; - let addr = String::from_utf8(self.read_bytes(addr_len).await?).map_err(|_| CstError::InvalidSnapshot(self.read_size)).unwrap(); + let addr = String::from_utf8(self.read_bytes(addr_len).await?) + .map_err(|_| CstError::InvalidSnapshot(self.read_size)) + .unwrap(); let t = self.read_integer().await? as u64; self.convert_stat().await?; return Ok(Some(SnapshotEntry::ReplicaDel(addr, t))); - }, + } SnapshotLoadProgress::Datas(size, current) => { if *current < *size { *current += 1; @@ -140,7 +185,7 @@ impl SnapshotLoader } else { self.convert_stat().await?; } - }, + } SnapshotLoadProgress::Deletes(size, current) => { if *current < *size { *current += 1; @@ -149,7 +194,7 @@ impl SnapshotLoader } else { self.convert_stat().await?; } - }, + } SnapshotLoadProgress::Expires(size, current) => { if *current < *size { *current += 1; @@ -158,12 +203,15 @@ impl SnapshotLoader } else { self.convert_stat().await?; } - }, + } SnapshotLoadProgress::Checksum => { let _checksum = self.read_integer().await?; - // TODO need compare the checksum contained in snapshot with the one we calculated + let checksum = self.checksum_writter.get(); + if (_checksum as u64) != checksum { + return Err(CstError::InvalidSnapshotChecksum); + } self.stat = SnapshotLoadProgress::Finish; - }, + } SnapshotLoadProgress::Finish => { return Ok(None); } @@ -175,13 +223,19 @@ impl SnapshotLoader self.stat = match self.read_byte().await? { SNAPSHOT_FLAG_REPLICA_ADD => SnapshotLoadProgress::Replicas(true), SNAPSHOT_FLAG_REPLICA_REM => SnapshotLoadProgress::Replicas(false), - SNAPSHOT_FLAG_DATAS => SnapshotLoadProgress::Datas(self.read_integer().await? as usize, 0), - SNAPSHOT_FLAG_DELETES => SnapshotLoadProgress::Deletes(self.read_integer().await? as usize, 0), - SNAPSHOT_FLAG_EXPIRES => SnapshotLoadProgress::Expires(self.read_integer().await? as usize, 0), + SNAPSHOT_FLAG_DATAS => { + SnapshotLoadProgress::Datas(self.read_integer().await? as usize, 0) + } + SNAPSHOT_FLAG_DELETES => { + SnapshotLoadProgress::Deletes(self.read_integer().await? as usize, 0) + } + SNAPSHOT_FLAG_EXPIRES => { + SnapshotLoadProgress::Expires(self.read_integer().await? as usize, 0) + } SNAPSHOT_FLAG_CHECKSUM => SnapshotLoadProgress::Checksum, _ => { return Err(CstError::InvalidSnapshot(self.read_size)); - }, + } }; Ok(()) } @@ -190,24 +244,32 @@ impl SnapshotLoader pub async fn read_integer(&mut self) -> Result { let flag = self.read_byte().await?; match (flag >> 6) & 3 { - 0 => Ok(((flag<<2)>>2) as i64), - 1 => Ok(i16::from_be_bytes([flag & ((1<<6)-1), self.read_byte().await?]) as i64), + 0 => Ok(((flag << 2) >> 2) as i64), + 1 => Ok(i16::from_be_bytes([flag & ((1 << 6) - 1), self.read_byte().await?]) as i64), 2 => { let bytes = self.read_bytes(3).await?; - Ok(i32::from_be_bytes([flag & ((1<<6)-1), bytes[0], bytes[1], bytes[2]]) as i64) + Ok( + i32::from_be_bytes([flag & ((1 << 6) - 1), bytes[0], bytes[1], bytes[2]]) + as i64, + ) } 3 => { let bytes = self.read_bytes(8).await?; - Ok(i64::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]])) + Ok(i64::from_be_bytes([ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + ])) } - _ => unreachable!() + _ => unreachable!(), } } pub async fn read_bytes(&mut self, size: usize) -> Result, CstError> { let mut datas = Vec::with_capacity(size); - unsafe { datas.set_len(size); } + unsafe { + datas.set_len(size); + } self.io.read_exact(&mut datas).await?; + self.checksum_writter.write(&datas)?; Ok(datas) } @@ -243,7 +305,7 @@ pub enum SnapshotEntry { Version(Bytes), Node(u64, String, String, u64), // node_id, node_alias, addr, uuid_he_sent ReplicaAdd(u64, u64, String, String, u64), // (add_time, node_id, node_alias, addr, uuid_he_sent) - ReplicaDel(String, u64), // (addr, del_time) + ReplicaDel(String, u64), // (addr, del_time) Data(Bytes, Object), Expires(Bytes, u64), Deletes(Bytes, u64), @@ -290,30 +352,41 @@ mod test { async fn test_snapshot_bytes() { { - let f = std::fs::OpenOptions::new().create(true).read(true).write(true).truncate(true).open("test_spapshot_bytes").unwrap(); + let f = std::fs::OpenOptions::new() + .create(true) + .read(true) + .write(true) + .truncate(true) + .open("test_spapshot_bytes") + .unwrap(); let mut w = SnapshotWriter::new(2048, f); w.write_bytes(b"CONST"); w.write_bytes(b"DB"); w.write_integer(1); w.write_integer(2); - w.write_integer(1<<13); - w.write_integer(1<<20); - w.write_integer(1<<26); - w.write_integer(1<<30); - w.write_integer(1<<31); + w.write_integer(1 << 13); + w.write_integer(1 << 20); + w.write_integer(1 << 26); + w.write_integer(1 << 30); + w.write_integer(1 << 31); + assert_eq!(9519382692141102896, w.checksum()); } { - let f = tokio::fs::OpenOptions::new().read(true).open("test_spapshot_bytes").await.unwrap(); + let f = tokio::fs::OpenOptions::new() + .read(true) + .open("test_spapshot_bytes") + .await + .unwrap(); let mut r = SnapshotLoader::new(f); assert_eq!(r.read_bytes(5).await.unwrap(), b"CONST"); assert_eq!(r.read_bytes(2).await.unwrap(), b"DB"); assert_eq!(r.read_integer().await.unwrap(), 1); assert_eq!(r.read_integer().await.unwrap(), 2); - assert_eq!(r.read_integer().await.unwrap(), 1<<13); - assert_eq!(r.read_integer().await.unwrap(), 1<<20); - assert_eq!(r.read_integer().await.unwrap(), 1<<26); - assert_eq!(r.read_integer().await.unwrap(), 1<<30); - assert_eq!(r.read_integer().await.unwrap(), 1<<31); + assert_eq!(r.read_integer().await.unwrap(), 1 << 13); + assert_eq!(r.read_integer().await.unwrap(), 1 << 20); + assert_eq!(r.read_integer().await.unwrap(), 1 << 26); + assert_eq!(r.read_integer().await.unwrap(), 1 << 30); + assert_eq!(r.read_integer().await.unwrap(), 1 << 31); } } -} \ No newline at end of file +}