From 9d990e03dcd0beb2d1db8f24bd7ea383080bf197 Mon Sep 17 00:00:00 2001 From: LucaCappelletti94 Date: Wed, 11 Feb 2026 02:34:09 +0100 Subject: [PATCH 1/2] Added support for Arbitrary for fuzzing --- Cargo.lock | 21 ++++++ Cargo.toml | 2 + src/protocol.rs | 177 ++++++++++++++++++++++++++++++++++++++++++++++++ src/types.rs | 28 ++++++++ 4 files changed, 228 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index c0b1379..7d0053a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,15 @@ dependencies = [ "libc", ] +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" +dependencies = [ + "derive_arbitrary", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -143,6 +152,17 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "either" version = "1.15.0" @@ -488,6 +508,7 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" name = "pg_walstream" version = "0.3.0" dependencies = [ + "arbitrary", "bytes", "chrono", "flate2", diff --git a/Cargo.toml b/Cargo.toml index 30ddb84..c79f745 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,9 +22,11 @@ bytes = "1.11.0" tracing = "0.1.44" libpq-sys = "0.8" thiserror = "2.0.17" +arbitrary = { version = "1", features = ["derive"], optional = true } [features] default = [] +arbitrary = ["dep:arbitrary"] [dev-dependencies] tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/protocol.rs b/src/protocol.rs index f2ff668..2325cd4 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -75,6 +75,38 @@ pub enum MessageType { StreamPrepare = message_types::STREAM_PREPARE, } +#[cfg(feature = "arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for MessageType { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let variant = u.int_in_range(0..=18)?; + Ok(match variant { + 0 => MessageType::Begin, + 1 => MessageType::Commit, + 2 => MessageType::Origin, + 3 => MessageType::Relation, + 4 => MessageType::Type, + 5 => MessageType::Insert, + 6 => MessageType::Update, + 7 => MessageType::Delete, + 8 => MessageType::Truncate, + 9 => MessageType::Message, + 10 => MessageType::StreamStart, + 11 => MessageType::StreamStop, + 12 => MessageType::StreamCommit, + 13 => MessageType::StreamAbort, + 14 => MessageType::BeginPrepare, + 15 => MessageType::Prepare, + 16 => MessageType::CommitPrepared, + 17 => MessageType::RollbackPrepared, + _ => MessageType::StreamPrepare, + }) + } + + fn size_hint(_depth: usize) -> (usize, Option) { + (1, Some(1)) + } +} + /// Unified logical replication message enum #[derive(Debug, Clone)] pub enum LogicalReplicationMessage { @@ -222,8 +254,128 @@ pub enum LogicalReplicationMessage { }, } +#[cfg(feature = "arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for LogicalReplicationMessage { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let variant = u.int_in_range(0..=18)?; + Ok(match variant { + 0 => LogicalReplicationMessage::Begin { + final_lsn: u.arbitrary()?, + timestamp: u.arbitrary()?, + xid: u.arbitrary()?, + }, + 1 => LogicalReplicationMessage::Commit { + flags: u.arbitrary()?, + commit_lsn: u.arbitrary()?, + end_lsn: u.arbitrary()?, + timestamp: u.arbitrary()?, + }, + 2 => LogicalReplicationMessage::Relation { + relation_id: u.arbitrary()?, + namespace: u.arbitrary()?, + relation_name: u.arbitrary()?, + replica_identity: u.arbitrary()?, + columns: u.arbitrary()?, + }, + 3 => LogicalReplicationMessage::Insert { + relation_id: u.arbitrary()?, + tuple: u.arbitrary()?, + }, + 4 => LogicalReplicationMessage::Update { + relation_id: u.arbitrary()?, + old_tuple: u.arbitrary()?, + new_tuple: u.arbitrary()?, + key_type: u.arbitrary()?, + }, + 5 => LogicalReplicationMessage::Delete { + relation_id: u.arbitrary()?, + old_tuple: u.arbitrary()?, + key_type: u.arbitrary()?, + }, + 6 => LogicalReplicationMessage::Truncate { + relation_ids: u.arbitrary()?, + flags: u.arbitrary()?, + }, + 7 => LogicalReplicationMessage::Type { + type_id: u.arbitrary()?, + namespace: u.arbitrary()?, + type_name: u.arbitrary()?, + }, + 8 => LogicalReplicationMessage::Origin { + origin_lsn: u.arbitrary()?, + origin_name: u.arbitrary()?, + }, + 9 => LogicalReplicationMessage::Message { + flags: u.arbitrary()?, + lsn: u.arbitrary()?, + prefix: u.arbitrary()?, + content: u.arbitrary()?, + }, + 10 => LogicalReplicationMessage::StreamStart { + xid: u.arbitrary()?, + first_segment: u.arbitrary()?, + }, + 11 => LogicalReplicationMessage::StreamStop, + 12 => LogicalReplicationMessage::StreamCommit { + xid: u.arbitrary()?, + flags: u.arbitrary()?, + commit_lsn: u.arbitrary()?, + end_lsn: u.arbitrary()?, + timestamp: u.arbitrary()?, + }, + 13 => LogicalReplicationMessage::StreamAbort { + xid: u.arbitrary()?, + subtransaction_xid: u.arbitrary()?, + abort_lsn: u.arbitrary()?, + abort_timestamp: u.arbitrary()?, + }, + 14 => LogicalReplicationMessage::BeginPrepare { + prepare_lsn: u.arbitrary()?, + end_lsn: u.arbitrary()?, + timestamp: u.arbitrary()?, + xid: u.arbitrary()?, + gid: u.arbitrary()?, + }, + 15 => LogicalReplicationMessage::Prepare { + flags: u.arbitrary()?, + prepare_lsn: u.arbitrary()?, + end_lsn: u.arbitrary()?, + timestamp: u.arbitrary()?, + xid: u.arbitrary()?, + gid: u.arbitrary()?, + }, + 16 => LogicalReplicationMessage::CommitPrepared { + flags: u.arbitrary()?, + commit_lsn: u.arbitrary()?, + end_lsn: u.arbitrary()?, + timestamp: u.arbitrary()?, + xid: u.arbitrary()?, + gid: u.arbitrary()?, + }, + 17 => LogicalReplicationMessage::RollbackPrepared { + flags: u.arbitrary()?, + prepare_end_lsn: u.arbitrary()?, + rollback_end_lsn: u.arbitrary()?, + prepare_timestamp: u.arbitrary()?, + rollback_timestamp: u.arbitrary()?, + xid: u.arbitrary()?, + gid: u.arbitrary()?, + }, + _ => LogicalReplicationMessage::StreamPrepare { + flags: u.arbitrary()?, + prepare_lsn: u.arbitrary()?, + end_lsn: u.arbitrary()?, + timestamp: u.arbitrary()?, + xid: u.arbitrary()?, + gid: u.arbitrary()?, + }, + }) + } +} + /// Column information in a relation #[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct ColumnInfo { /// Column flags (bit 0 = key column) pub flags: u8, @@ -255,6 +407,7 @@ impl ColumnInfo { /// Tuple (row) data #[derive(Debug, Clone)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct TupleData { pub columns: Vec, } @@ -414,8 +567,32 @@ impl ColumnData { } } +#[cfg(feature = "arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for ColumnData { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let variant = u.int_in_range(0..=3)?; + Ok(match variant { + 0 => ColumnData::null(), + 1 => ColumnData::unchanged(), + 2 => { + let data: Vec = u.arbitrary()?; + ColumnData::text(data) + } + _ => { + let data: Vec = u.arbitrary()?; + ColumnData::binary(data) + } + }) + } + + fn size_hint(depth: usize) -> (usize, Option) { + as arbitrary::Arbitrary>::size_hint(depth) + } +} + /// Information about a relation (table) #[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct RelationInfo { pub relation_id: Oid, pub namespace: String, diff --git a/src/types.rs b/src/types.rs index 1e55944..ad6c1ee 100644 --- a/src/types.rs +++ b/src/types.rs @@ -91,6 +91,17 @@ impl std::ops::DerefMut for CachePadded { } } +#[cfg(feature = "arbitrary")] +impl<'a, T: arbitrary::Arbitrary<'a>> arbitrary::Arbitrary<'a> for CachePadded { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + Ok(CachePadded::new(T::arbitrary(u)?)) + } + + fn size_hint(depth: usize) -> (usize, Option) { + T::size_hint(depth) + } +} + /// Convert SystemTime to PostgreSQL timestamp format (microseconds since 2000-01-01) /// /// PostgreSQL uses a different epoch than Unix (2000-01-01 vs 1970-01-01). @@ -309,6 +320,23 @@ impl ReplicaIdentity { } } +#[cfg(feature = "arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for ReplicaIdentity { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let variant = u.int_in_range(0..=3)?; + Ok(match variant { + 0 => ReplicaIdentity::Default, + 1 => ReplicaIdentity::Nothing, + 2 => ReplicaIdentity::Full, + _ => ReplicaIdentity::Index, + }) + } + + fn size_hint(_depth: usize) -> (usize, Option) { + (1, Some(1)) + } +} + /// Replication slot type #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SlotType { From d76d42e789d700fa30ce0fb07ba2485e233afe4f Mon Sep 17 00:00:00 2001 From: LucaCappelletti94 Date: Wed, 11 Feb 2026 02:44:23 +0100 Subject: [PATCH 2/2] Adressed review points and added tests --- src/protocol.rs | 432 ++++++++++++++++++++++++++++++++---------------- src/types.rs | 114 +++++++++++-- 2 files changed, 388 insertions(+), 158 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 2325cd4..25ae368 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -78,28 +78,28 @@ pub enum MessageType { #[cfg(feature = "arbitrary")] impl<'a> arbitrary::Arbitrary<'a> for MessageType { fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - let variant = u.int_in_range(0..=18)?; - Ok(match variant { - 0 => MessageType::Begin, - 1 => MessageType::Commit, - 2 => MessageType::Origin, - 3 => MessageType::Relation, - 4 => MessageType::Type, - 5 => MessageType::Insert, - 6 => MessageType::Update, - 7 => MessageType::Delete, - 8 => MessageType::Truncate, - 9 => MessageType::Message, - 10 => MessageType::StreamStart, - 11 => MessageType::StreamStop, - 12 => MessageType::StreamCommit, - 13 => MessageType::StreamAbort, - 14 => MessageType::BeginPrepare, - 15 => MessageType::Prepare, - 16 => MessageType::CommitPrepared, - 17 => MessageType::RollbackPrepared, - _ => MessageType::StreamPrepare, - }) + const VARIANTS: &[MessageType] = &[ + MessageType::Begin, + MessageType::Commit, + MessageType::Origin, + MessageType::Relation, + MessageType::Type, + MessageType::Insert, + MessageType::Update, + MessageType::Delete, + MessageType::Truncate, + MessageType::Message, + MessageType::StreamStart, + MessageType::StreamStop, + MessageType::StreamCommit, + MessageType::StreamAbort, + MessageType::BeginPrepare, + MessageType::Prepare, + MessageType::CommitPrepared, + MessageType::RollbackPrepared, + MessageType::StreamPrepare, + ]; + u.choose(VARIANTS).copied() } fn size_hint(_depth: usize) -> (usize, Option) { @@ -109,6 +109,7 @@ impl<'a> arbitrary::Arbitrary<'a> for MessageType { /// Unified logical replication message enum #[derive(Debug, Clone)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub enum LogicalReplicationMessage { /// Begin transaction Begin { @@ -254,125 +255,6 @@ pub enum LogicalReplicationMessage { }, } -#[cfg(feature = "arbitrary")] -impl<'a> arbitrary::Arbitrary<'a> for LogicalReplicationMessage { - fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - let variant = u.int_in_range(0..=18)?; - Ok(match variant { - 0 => LogicalReplicationMessage::Begin { - final_lsn: u.arbitrary()?, - timestamp: u.arbitrary()?, - xid: u.arbitrary()?, - }, - 1 => LogicalReplicationMessage::Commit { - flags: u.arbitrary()?, - commit_lsn: u.arbitrary()?, - end_lsn: u.arbitrary()?, - timestamp: u.arbitrary()?, - }, - 2 => LogicalReplicationMessage::Relation { - relation_id: u.arbitrary()?, - namespace: u.arbitrary()?, - relation_name: u.arbitrary()?, - replica_identity: u.arbitrary()?, - columns: u.arbitrary()?, - }, - 3 => LogicalReplicationMessage::Insert { - relation_id: u.arbitrary()?, - tuple: u.arbitrary()?, - }, - 4 => LogicalReplicationMessage::Update { - relation_id: u.arbitrary()?, - old_tuple: u.arbitrary()?, - new_tuple: u.arbitrary()?, - key_type: u.arbitrary()?, - }, - 5 => LogicalReplicationMessage::Delete { - relation_id: u.arbitrary()?, - old_tuple: u.arbitrary()?, - key_type: u.arbitrary()?, - }, - 6 => LogicalReplicationMessage::Truncate { - relation_ids: u.arbitrary()?, - flags: u.arbitrary()?, - }, - 7 => LogicalReplicationMessage::Type { - type_id: u.arbitrary()?, - namespace: u.arbitrary()?, - type_name: u.arbitrary()?, - }, - 8 => LogicalReplicationMessage::Origin { - origin_lsn: u.arbitrary()?, - origin_name: u.arbitrary()?, - }, - 9 => LogicalReplicationMessage::Message { - flags: u.arbitrary()?, - lsn: u.arbitrary()?, - prefix: u.arbitrary()?, - content: u.arbitrary()?, - }, - 10 => LogicalReplicationMessage::StreamStart { - xid: u.arbitrary()?, - first_segment: u.arbitrary()?, - }, - 11 => LogicalReplicationMessage::StreamStop, - 12 => LogicalReplicationMessage::StreamCommit { - xid: u.arbitrary()?, - flags: u.arbitrary()?, - commit_lsn: u.arbitrary()?, - end_lsn: u.arbitrary()?, - timestamp: u.arbitrary()?, - }, - 13 => LogicalReplicationMessage::StreamAbort { - xid: u.arbitrary()?, - subtransaction_xid: u.arbitrary()?, - abort_lsn: u.arbitrary()?, - abort_timestamp: u.arbitrary()?, - }, - 14 => LogicalReplicationMessage::BeginPrepare { - prepare_lsn: u.arbitrary()?, - end_lsn: u.arbitrary()?, - timestamp: u.arbitrary()?, - xid: u.arbitrary()?, - gid: u.arbitrary()?, - }, - 15 => LogicalReplicationMessage::Prepare { - flags: u.arbitrary()?, - prepare_lsn: u.arbitrary()?, - end_lsn: u.arbitrary()?, - timestamp: u.arbitrary()?, - xid: u.arbitrary()?, - gid: u.arbitrary()?, - }, - 16 => LogicalReplicationMessage::CommitPrepared { - flags: u.arbitrary()?, - commit_lsn: u.arbitrary()?, - end_lsn: u.arbitrary()?, - timestamp: u.arbitrary()?, - xid: u.arbitrary()?, - gid: u.arbitrary()?, - }, - 17 => LogicalReplicationMessage::RollbackPrepared { - flags: u.arbitrary()?, - prepare_end_lsn: u.arbitrary()?, - rollback_end_lsn: u.arbitrary()?, - prepare_timestamp: u.arbitrary()?, - rollback_timestamp: u.arbitrary()?, - xid: u.arbitrary()?, - gid: u.arbitrary()?, - }, - _ => LogicalReplicationMessage::StreamPrepare { - flags: u.arbitrary()?, - prepare_lsn: u.arbitrary()?, - end_lsn: u.arbitrary()?, - timestamp: u.arbitrary()?, - xid: u.arbitrary()?, - gid: u.arbitrary()?, - }, - }) - } -} - /// Column information in a relation #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] @@ -2041,3 +1923,271 @@ mod tests { assert!(result.is_err()); } } + +#[cfg(all(test, feature = "arbitrary"))] +mod arbitrary_tests { + use super::*; + use arbitrary::{Arbitrary, Unstructured}; + + // Helper to create Unstructured with enough entropy + fn make_unstructured(size: usize) -> Vec { + (0..size).map(|i| i as u8).collect() + } + + #[test] + fn test_message_type_arbitrary() { + let data = make_unstructured(100); + let mut u = Unstructured::new(&data); + + // Generate multiple instances + let mut seen_types = std::collections::HashSet::new(); + for _ in 0..50 { + if let Ok(msg_type) = MessageType::arbitrary(&mut u) { + seen_types.insert(msg_type as u8); + } + } + + // Should see multiple different message types + assert!( + seen_types.len() > 1, + "Should generate diverse message types" + ); + } + + #[test] + fn test_message_type_all_variants_valid() { + // Verify all MessageType variants have valid discriminant values + let all_types = [ + MessageType::Begin, + MessageType::Commit, + MessageType::Origin, + MessageType::Relation, + MessageType::Type, + MessageType::Insert, + MessageType::Update, + MessageType::Delete, + MessageType::Truncate, + MessageType::Message, + MessageType::StreamStart, + MessageType::StreamStop, + MessageType::StreamCommit, + MessageType::StreamAbort, + MessageType::BeginPrepare, + MessageType::Prepare, + MessageType::CommitPrepared, + MessageType::RollbackPrepared, + MessageType::StreamPrepare, + ]; + + for msg_type in all_types { + // Each variant should have a unique byte value + let byte_val = msg_type as u8; + assert!( + byte_val > 0, + "Message type should have non-zero discriminant" + ); + } + } + + #[test] + fn test_column_info_arbitrary() { + let data = make_unstructured(200); + let mut u = Unstructured::new(&data); + + for _ in 0..5 { + let result = ColumnInfo::arbitrary(&mut u); + assert!(result.is_ok()); + let col_info = result.unwrap(); + + // Verify the struct is usable + let _is_key = col_info.is_key(); + assert!(!col_info.name.is_empty() || col_info.name.is_empty()); // Just verify accessible + } + } + + #[test] + fn test_column_data_arbitrary() { + let data = make_unstructured(200); + let mut u = Unstructured::new(&data); + + let mut saw_null = false; + let mut saw_text = false; + let mut saw_binary = false; + let mut saw_unchanged = false; + + for _ in 0..20 { + if let Ok(col_data) = ColumnData::arbitrary(&mut u) { + if col_data.is_null() { + saw_null = true; + } + if col_data.is_text() { + saw_text = true; + } + if col_data.is_binary() { + saw_binary = true; + } + if col_data.is_unchanged() { + saw_unchanged = true; + } + } + } + + // With enough iterations, we should see at least some variants + assert!( + saw_null || saw_text || saw_binary || saw_unchanged, + "Should generate at least one column data variant" + ); + } + + #[test] + fn test_column_data_arbitrary_text_has_data() { + // Create data that will produce text variant with content + let data: Vec = vec![ + 2, // variant selector for text + 3, // length of vec + b'a', b'b', b'c', // content + ]; + let mut u = Unstructured::new(&data); + + let result = ColumnData::arbitrary(&mut u); + assert!(result.is_ok()); + let col_data = result.unwrap(); + + if col_data.is_text() { + // Verify we can access the data + let bytes = col_data.as_bytes(); + assert!(!bytes.is_empty() || bytes.is_empty()); // Just verify accessible + } + } + + #[test] + fn test_tuple_data_arbitrary() { + let data = make_unstructured(500); + let mut u = Unstructured::new(&data); + + for _ in 0..5 { + let result = TupleData::arbitrary(&mut u); + assert!(result.is_ok()); + let tuple = result.unwrap(); + + // Verify the struct is usable + let _count = tuple.column_count(); + for i in 0..tuple.column_count() { + let _col = tuple.get_column(i); + } + } + } + + #[test] + fn test_relation_info_arbitrary() { + let data = make_unstructured(1000); + let mut u = Unstructured::new(&data); + + for _ in 0..3 { + let result = RelationInfo::arbitrary(&mut u); + assert!(result.is_ok()); + let rel_info = result.unwrap(); + + // Verify the struct is usable + let _full_name = rel_info.full_name(); + let _key_cols = rel_info.get_key_columns(); + } + } + + #[test] + fn test_logical_replication_message_arbitrary() { + let data = make_unstructured(2000); + let mut u = Unstructured::new(&data); + + let mut generated_count = 0; + for _ in 0..10 { + if LogicalReplicationMessage::arbitrary(&mut u).is_ok() { + generated_count += 1; + } + } + + assert!( + generated_count > 0, + "Should successfully generate at least one LogicalReplicationMessage" + ); + } + + #[test] + fn test_logical_replication_message_variants() { + let data = make_unstructured(5000); + let mut u = Unstructured::new(&data); + + let mut variant_names = std::collections::HashSet::new(); + + for _ in 0..50 { + if let Ok(msg) = LogicalReplicationMessage::arbitrary(&mut u) { + let name = match msg { + LogicalReplicationMessage::Begin { .. } => "Begin", + LogicalReplicationMessage::Commit { .. } => "Commit", + LogicalReplicationMessage::Relation { .. } => "Relation", + LogicalReplicationMessage::Insert { .. } => "Insert", + LogicalReplicationMessage::Update { .. } => "Update", + LogicalReplicationMessage::Delete { .. } => "Delete", + LogicalReplicationMessage::Truncate { .. } => "Truncate", + LogicalReplicationMessage::Type { .. } => "Type", + LogicalReplicationMessage::Origin { .. } => "Origin", + LogicalReplicationMessage::Message { .. } => "Message", + LogicalReplicationMessage::StreamStart { .. } => "StreamStart", + LogicalReplicationMessage::StreamStop => "StreamStop", + LogicalReplicationMessage::StreamCommit { .. } => "StreamCommit", + LogicalReplicationMessage::StreamAbort { .. } => "StreamAbort", + LogicalReplicationMessage::BeginPrepare { .. } => "BeginPrepare", + LogicalReplicationMessage::Prepare { .. } => "Prepare", + LogicalReplicationMessage::CommitPrepared { .. } => "CommitPrepared", + LogicalReplicationMessage::RollbackPrepared { .. } => "RollbackPrepared", + LogicalReplicationMessage::StreamPrepare { .. } => "StreamPrepare", + }; + variant_names.insert(name); + } + } + + // With enough entropy, we should see multiple different variants + assert!( + variant_names.len() > 1, + "Should generate diverse LogicalReplicationMessage variants, got: {:?}", + variant_names + ); + } + + #[test] + fn test_arbitrary_column_data_roundtrip() { + let data = make_unstructured(500); + let mut u = Unstructured::new(&data); + + for _ in 0..10 { + if let Ok(col_data) = ColumnData::arbitrary(&mut u) { + // Verify all accessor methods work without panic + let _is_null = col_data.is_null(); + let _is_text = col_data.is_text(); + let _is_binary = col_data.is_binary(); + let _is_unchanged = col_data.is_unchanged(); + let _bytes = col_data.as_bytes(); + + if col_data.is_text() || col_data.is_binary() { + let _str_opt = col_data.as_str(); + let _string_opt = col_data.clone().as_string(); + } + } + } + } + + #[test] + fn test_tuple_data_to_hashmap_with_arbitrary() { + let data = make_unstructured(2000); + let mut u = Unstructured::new(&data); + + // Generate arbitrary RelationInfo and TupleData + if let (Ok(rel_info), Ok(tuple)) = ( + RelationInfo::arbitrary(&mut u), + TupleData::arbitrary(&mut u), + ) { + // Should not panic even with mismatched column counts + let _map = tuple.to_hash_map(&rel_info); + } + } +} diff --git a/src/types.rs b/src/types.rs index ad6c1ee..174a21a 100644 --- a/src/types.rs +++ b/src/types.rs @@ -243,6 +243,7 @@ pub fn format_lsn(lsn: XLogRecPtr) -> String { /// PostgreSQL replica identity settings #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub enum ReplicaIdentity { /// Default replica identity (primary key) Default, @@ -320,23 +321,6 @@ impl ReplicaIdentity { } } -#[cfg(feature = "arbitrary")] -impl<'a> arbitrary::Arbitrary<'a> for ReplicaIdentity { - fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - let variant = u.int_in_range(0..=3)?; - Ok(match variant { - 0 => ReplicaIdentity::Default, - 1 => ReplicaIdentity::Nothing, - 2 => ReplicaIdentity::Full, - _ => ReplicaIdentity::Index, - }) - } - - fn size_hint(_depth: usize) -> (usize, Option) { - (1, Some(1)) - } -} - /// Replication slot type #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SlotType { @@ -1113,3 +1097,99 @@ mod tests { assert_eq!(parsed, lsn); } } + +#[cfg(all(test, feature = "arbitrary"))] +mod arbitrary_tests { + use super::*; + use arbitrary::{Arbitrary, Unstructured}; + + #[test] + fn test_cache_padded_arbitrary() { + let data: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; + let mut u = Unstructured::new(data); + + // Generate multiple instances to verify consistency + for _ in 0..5 { + let result = CachePadded::::arbitrary(&mut u); + assert!(result.is_ok()); + let padded = result.unwrap(); + // Verify we can dereference the value + let _value: u32 = *padded; + } + } + + #[test] + fn test_cache_padded_arbitrary_with_string() { + let data: &[u8] = &[5, b'h', b'e', b'l', b'l', b'o']; + let mut u = Unstructured::new(data); + + let result = CachePadded::::arbitrary(&mut u); + assert!(result.is_ok()); + } + + #[test] + fn test_replica_identity_arbitrary() { + let data: &[u8] = &[0, 1, 2, 3, 0, 1, 2, 3]; + let mut u = Unstructured::new(data); + + // Generate multiple instances and verify all variants can be produced + let mut seen_variants = std::collections::HashSet::new(); + for _ in 0..10 { + if let Ok(identity) = ReplicaIdentity::arbitrary(&mut u) { + seen_variants.insert(identity.to_byte()); + } + } + + // With enough entropy, we should see at least one variant + assert!(!seen_variants.is_empty()); + } + + #[test] + fn test_replica_identity_arbitrary_generates_valid_variants() { + // Test with various entropy sources to cover different variants + let test_data: Vec> = vec![ + (0..50).collect(), + (50..100).collect(), + (100..150).collect(), + (150..200).collect(), + ]; + + let mut seen = std::collections::HashSet::new(); + for data in &test_data { + let mut u = Unstructured::new(data); + for _ in 0..10 { + if let Ok(identity) = ReplicaIdentity::arbitrary(&mut u) { + // Verify the generated variant is valid + let byte = identity.to_byte(); + assert!( + byte == b'd' || byte == b'n' || byte == b'f' || byte == b'i', + "Generated invalid ReplicaIdentity byte: {}", + byte + ); + seen.insert(byte); + } + } + } + + // Should see at least some variants + assert!( + !seen.is_empty(), + "Should generate at least one ReplicaIdentity variant" + ); + } + + #[test] + fn test_replica_identity_roundtrip() { + let data: &[u8] = &[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]; + let mut u = Unstructured::new(data); + + for _ in 0..4 { + if let Ok(identity) = ReplicaIdentity::arbitrary(&mut u) { + // Verify roundtrip through byte conversion + let byte = identity.to_byte(); + let recovered = ReplicaIdentity::from_byte(byte); + assert_eq!(recovered, Some(identity)); + } + } + } +}