diff --git a/Cargo.toml b/Cargo.toml index e77815e..9ab7fa3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,17 @@ members = [ "examples/syscall-cycles", "examples/std-smoke", "examples/c-smoke/rust", + "examples/parallel-keccak", + "examples/batched-signatures", + "examples/merkle-tree", + "examples/parallel-mergesort", + "examples/matrix-multiply", + "examples/prefix-sum", + "examples/fft", + "examples/wavelet-transform", + "examples/graph-bfs", + "examples/block-compression", + "examples/polynomial-eval", ] resolver = "2" diff --git a/examples/batched-signatures/Cargo.toml b/examples/batched-signatures/Cargo.toml new file mode 100644 index 0000000..5308f69 --- /dev/null +++ b/examples/batched-signatures/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "batched-signatures" +publish = false +version.workspace = true +edition.workspace = true +description = "Batched Ed25519 signature verification demo" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/batched-signatures/src/lib.rs b/examples/batched-signatures/src/lib.rs new file mode 100644 index 0000000..6b5235c --- /dev/null +++ b/examples/batched-signatures/src/lib.rs @@ -0,0 +1,166 @@ +//! Simplified Ed25519-like signature verification. +//! +//! This is a toy implementation for demonstration purposes. +//! In production, use a proper cryptographic library. +//! +//! The implementation focuses on exercising the computation patterns +//! without full cryptographic security. + +#![no_std] + +/// A simplified "public key" (32 bytes) +pub type PublicKey = [u8; 32]; + +/// A simplified "signature" (64 bytes) +pub type Signature = [u8; 64]; + +/// A message to verify +pub type Message<'a> = &'a [u8]; + +/// Verification result +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VerifyResult { + Valid, + Invalid, +} + +/// Simple hash function for demonstration (not cryptographically secure!) +fn simple_hash(data: &[u8]) -> [u8; 32] { + let mut hash = [0u8; 32]; + let mut acc: u64 = 0x5555555555555555; + + for (i, &byte) in data.iter().enumerate() { + acc = acc.wrapping_mul(31).wrapping_add(byte as u64); + acc ^= acc.rotate_left(13); + hash[i % 32] ^= (acc & 0xFF) as u8; + acc = acc.wrapping_add((i as u64).wrapping_mul(17)); + } + + // Final mixing + for i in 0..32 { + acc = acc.wrapping_mul(0x5851F42D4C957F2D); + acc ^= acc >> 33; + hash[i] ^= (acc & 0xFF) as u8; + } + + hash +} + +/// Generate a deterministic "signature" for testing. +/// This is NOT real Ed25519 - just a demo to exercise computation patterns. +pub fn sign_message(secret_key: &[u8; 32], message: &[u8]) -> Signature { + let mut sig = [0u8; 64]; + + // First 32 bytes: hash of secret_key || message + let mut combined = [0u8; 64]; + combined[..32].copy_from_slice(secret_key); + let msg_len = core::cmp::min(message.len(), 32); + combined[32..32 + msg_len].copy_from_slice(&message[..msg_len]); + + let r = simple_hash(&combined); + sig[..32].copy_from_slice(&r); + + // Second 32 bytes: hash of r || public_key || message + let public_key = derive_public_key(secret_key); + let mut combined2 = [0u8; 96]; + combined2[..32].copy_from_slice(&r); + combined2[32..64].copy_from_slice(&public_key); + let msg_len2 = core::cmp::min(message.len(), 32); + combined2[64..64 + msg_len2].copy_from_slice(&message[..msg_len2]); + + let s = simple_hash(&combined2); + sig[32..].copy_from_slice(&s); + + sig +} + +/// Derive "public key" from secret key (simplified) +pub fn derive_public_key(secret_key: &[u8; 32]) -> PublicKey { + simple_hash(secret_key) +} + +/// Verify a signature against a public key and message. +/// Returns Valid if the signature matches, Invalid otherwise. +pub fn verify_signature( + public_key: &PublicKey, + message: &[u8], + signature: &Signature, +) -> VerifyResult { + // Reconstruct expected signature components + let r = &signature[..32]; + let s = &signature[32..]; + + // Recompute s' = hash(r || public_key || message) + let mut combined = [0u8; 96]; + combined[..32].copy_from_slice(r); + combined[32..64].copy_from_slice(public_key); + let msg_len = core::cmp::min(message.len(), 32); + combined[64..64 + msg_len].copy_from_slice(&message[..msg_len]); + + let expected_s = simple_hash(&combined); + + // Check if s matches expected + if s == expected_s { + VerifyResult::Valid + } else { + VerifyResult::Invalid + } +} + +/// Batch verify multiple signatures (single-threaded baseline) +pub fn batch_verify( + public_keys: &[PublicKey], + messages: &[&[u8]], + signatures: &[Signature], + results: &mut [VerifyResult], +) { + let n = core::cmp::min( + core::cmp::min(public_keys.len(), messages.len()), + core::cmp::min(signatures.len(), results.len()), + ); + + for i in 0..n { + results[i] = verify_signature(&public_keys[i], messages[i], &signatures[i]); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sign_and_verify() { + let secret_key = [0x42u8; 32]; + let public_key = derive_public_key(&secret_key); + let message = b"hello world"; + + let signature = sign_message(&secret_key, message); + let result = verify_signature(&public_key, message, &signature); + + assert_eq!(result, VerifyResult::Valid); + } + + #[test] + fn test_invalid_signature() { + let secret_key = [0x42u8; 32]; + let public_key = derive_public_key(&secret_key); + let message = b"hello world"; + let wrong_message = b"wrong message"; + + let signature = sign_message(&secret_key, message); + let result = verify_signature(&public_key, wrong_message, &signature); + + assert_eq!(result, VerifyResult::Invalid); + } + + #[test] + fn test_deterministic() { + let secret_key = [0x42u8; 32]; + let message = b"test message"; + + let sig1 = sign_message(&secret_key, message); + let sig2 = sign_message(&secret_key, message); + + assert_eq!(sig1, sig2); + } +} diff --git a/examples/batched-signatures/src/main.rs b/examples/batched-signatures/src/main.rs new file mode 100644 index 0000000..2b9fcad --- /dev/null +++ b/examples/batched-signatures/src/main.rs @@ -0,0 +1,96 @@ +//! Batched signature verification demo. +//! +//! Demonstrates verifying multiple signatures using ZeroOS. +//! Each signature verification can be handled by a separate thread. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use batched_signatures::{ + batch_verify, derive_public_key, sign_message, verify_signature, + PublicKey, Signature, VerifyResult, +}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +/// Number of signatures to verify in the batch +const BATCH_SIZE: usize = 8; + +/// Generate test keypairs deterministically +fn generate_test_keypair(seed: u8) -> ([u8; 32], PublicKey) { + let mut secret = [0u8; 32]; + for i in 0..32 { + secret[i] = seed.wrapping_add(i as u8).wrapping_mul(17); + } + let public = derive_public_key(&secret); + (secret, public) +} + +#[no_mangle] +fn main() -> ! { + debug::writeln!("[batched-signatures] Starting signature verification demo"); + + // Generate test data + let mut secret_keys = [[0u8; 32]; BATCH_SIZE]; + let mut public_keys = [[0u8; 32]; BATCH_SIZE]; + let mut signatures = [[0u8; 64]; BATCH_SIZE]; + let mut results = [VerifyResult::Invalid; BATCH_SIZE]; + + // Messages to sign/verify + let messages: [&[u8]; BATCH_SIZE] = [ + b"transaction_0_transfer_100", + b"transaction_1_approve_token", + b"transaction_2_stake_amount", + b"transaction_3_withdraw_eth", + b"transaction_4_swap_tokens", + b"transaction_5_add_liquidity", + b"transaction_6_vote_proposal", + b"transaction_7_claim_reward", + ]; + + // Generate keypairs and sign messages + debug::writeln!("[batched-signatures] Generating {} signatures", BATCH_SIZE); + for i in 0..BATCH_SIZE { + let (secret, public) = generate_test_keypair(i as u8); + secret_keys[i] = secret; + public_keys[i] = public; + signatures[i] = sign_message(&secret, messages[i]); + } + + // Batch verify all signatures + debug::writeln!("[batched-signatures] Verifying signatures..."); + batch_verify(&public_keys, &messages, &signatures, &mut results); + + // Report results + let mut valid_count = 0; + for (i, result) in results.iter().enumerate() { + let status = match result { + VerifyResult::Valid => { + valid_count += 1; + "VALID" + } + VerifyResult::Invalid => "INVALID", + }; + println!("sig[{}] = {}", i, status); + } + + println!("Verified {}/{} signatures as valid", valid_count, BATCH_SIZE); + + // Test with an invalid signature (wrong message) + debug::writeln!("[batched-signatures] Testing invalid signature detection..."); + let invalid_result = verify_signature(&public_keys[0], b"wrong_message", &signatures[0]); + if invalid_result == VerifyResult::Invalid { + println!("Invalid signature correctly rejected"); + } else { + println!("ERROR: Invalid signature was accepted!"); + } + + debug::writeln!("[batched-signatures] Demo complete!"); + platform::exit(0) +} diff --git a/examples/block-compression/Cargo.toml b/examples/block-compression/Cargo.toml new file mode 100644 index 0000000..b2bc4a1 --- /dev/null +++ b/examples/block-compression/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "block-compression" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel block-wise LZ/RLE compression" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/block-compression/src/lib.rs b/examples/block-compression/src/lib.rs new file mode 100644 index 0000000..61ac9c7 --- /dev/null +++ b/examples/block-compression/src/lib.rs @@ -0,0 +1,406 @@ +//! Parallel Block Compression Implementation +//! +//! Block-wise LZ77 and RLE compression for parallel execution. +//! Each block can be compressed independently. + +#![no_std] + +/// Maximum block size for compression +pub const BLOCK_SIZE: usize = 256; +/// Maximum output size (worst case: slight expansion) +pub const MAX_OUTPUT: usize = BLOCK_SIZE + 64; +/// Maximum match length for LZ77 +pub const MAX_MATCH_LEN: usize = 15; +/// Maximum look-back distance +pub const MAX_DISTANCE: usize = 255; + +/// Compression token types +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Token { + /// Literal byte + Literal(u8), + /// Match reference: (distance back, length) + Match { distance: u8, length: u8 }, +} + +/// RLE-compressed block +#[derive(Clone)] +pub struct RleBlock { + /// Compressed data: (count, value) pairs + pub data: [(u8, u8); MAX_OUTPUT], + /// Number of pairs used + pub len: usize, + /// Original uncompressed size + pub original_size: usize, +} + +impl RleBlock { + pub fn new() -> Self { + Self { + data: [(0, 0); MAX_OUTPUT], + len: 0, + original_size: 0, + } + } + + /// Compression ratio (compressed / original) + pub fn ratio(&self) -> f32 { + if self.original_size == 0 { + return 1.0; + } + (self.len * 2) as f32 / self.original_size as f32 + } +} + +impl Default for RleBlock { + fn default() -> Self { + Self::new() + } +} + +/// LZ77-compressed block +#[derive(Clone)] +pub struct Lz77Block { + /// Compressed tokens + pub tokens: [Token; MAX_OUTPUT], + /// Number of tokens + pub len: usize, + /// Original size + pub original_size: usize, +} + +impl Lz77Block { + pub fn new() -> Self { + Self { + tokens: [Token::Literal(0); MAX_OUTPUT], + len: 0, + original_size: 0, + } + } + + /// Approximate compressed size in bytes + pub fn compressed_size(&self) -> usize { + let mut size = 0; + for token in &self.tokens[..self.len] { + match token { + Token::Literal(_) => size += 2, // flag + byte + Token::Match { .. } => size += 2, // distance + length + } + } + size + } +} + +impl Default for Lz77Block { + fn default() -> Self { + Self::new() + } +} + +/// Run-Length Encoding compression. +/// Simple but effective for data with repeated values. +pub fn rle_compress(input: &[u8]) -> RleBlock { + let mut result = RleBlock::new(); + result.original_size = input.len(); + + if input.is_empty() { + return result; + } + + let mut i = 0; + while i < input.len() { + let value = input[i]; + let mut count = 1u8; + + // Count consecutive identical bytes + while (i + count as usize) < input.len() + && input[i + count as usize] == value + && count < 255 + { + count += 1; + } + + result.data[result.len] = (count, value); + result.len += 1; + i += count as usize; + } + + result +} + +/// RLE decompression. +pub fn rle_decompress(compressed: &RleBlock, output: &mut [u8]) -> usize { + let mut pos = 0; + + for i in 0..compressed.len { + let (count, value) = compressed.data[i]; + for _ in 0..count { + if pos < output.len() { + output[pos] = value; + pos += 1; + } + } + } + + pos +} + +/// Find longest match in sliding window for LZ77. +fn find_match(data: &[u8], pos: usize, window_start: usize) -> Option<(u8, u8)> { + if pos >= data.len() { + return None; + } + + let mut best_distance = 0u8; + let mut best_length = 0u8; + + // Search backwards in window + let search_start = if pos > MAX_DISTANCE { + pos - MAX_DISTANCE + } else { + window_start + }; + + for match_pos in search_start..pos { + let mut length = 0usize; + + // Count matching bytes + while pos + length < data.len() + && data[match_pos + length] == data[pos + length] + && length < MAX_MATCH_LEN + { + length += 1; + } + + if length > best_length as usize && length >= 3 { + best_distance = (pos - match_pos) as u8; + best_length = length as u8; + } + } + + if best_length >= 3 { + Some((best_distance, best_length)) + } else { + None + } +} + +/// LZ77 compression with sliding window. +pub fn lz77_compress(input: &[u8]) -> Lz77Block { + let mut result = Lz77Block::new(); + result.original_size = input.len(); + + let mut pos = 0; + + while pos < input.len() { + if let Some((distance, length)) = find_match(input, pos, 0) { + result.tokens[result.len] = Token::Match { distance, length }; + pos += length as usize; + } else { + result.tokens[result.len] = Token::Literal(input[pos]); + pos += 1; + } + result.len += 1; + } + + result +} + +/// LZ77 decompression. +pub fn lz77_decompress(compressed: &Lz77Block, output: &mut [u8]) -> usize { + let mut pos = 0; + + for i in 0..compressed.len { + match compressed.tokens[i] { + Token::Literal(byte) => { + if pos < output.len() { + output[pos] = byte; + pos += 1; + } + } + Token::Match { distance, length } => { + let start = pos - distance as usize; + for j in 0..length as usize { + if pos < output.len() { + output[pos] = output[start + j]; + pos += 1; + } + } + } + } + } + + pos +} + +/// Block-based compression (parallel-friendly). +/// Each block is compressed independently. +pub fn compress_blocks( + input: &[u8], + block_size: usize, + rle_blocks: &mut [RleBlock], + lz77_blocks: &mut [Lz77Block], +) -> usize { + let num_blocks = (input.len() + block_size - 1) / block_size; + + // Each block can be compressed independently by a different thread + for i in 0..num_blocks { + let start = i * block_size; + let end = core::cmp::min(start + block_size, input.len()); + let block = &input[start..end]; + + rle_blocks[i] = rle_compress(block); + lz77_blocks[i] = lz77_compress(block); + } + + num_blocks +} + +/// Simple byte histogram (useful for entropy estimation). +pub fn histogram(data: &[u8], hist: &mut [u32; 256]) { + for h in hist.iter_mut() { + *h = 0; + } + for &byte in data { + hist[byte as usize] += 1; + } +} + +/// Estimate entropy (bits per byte) from histogram. +/// Lower entropy = more compressible. +pub fn estimate_entropy(hist: &[u32; 256], total: usize) -> u32 { + if total == 0 { + return 0; + } + + // Fixed-point entropy calculation (scaled by 1000) + let mut entropy: u64 = 0; + let total_u64 = total as u64; + + for &count in hist.iter() { + if count > 0 { + let p = (count as u64 * 1000) / total_u64; + if p > 0 { + // Approximate -p*log2(p) using lookup or approximation + // log2(p/1000) ≈ log2(p) - 10 + // Simple approximation: entropy contribution ~ p * (10 - log2(p)) + let log_approx = 10 - (64 - p.leading_zeros()) as u64; + entropy += p * log_approx; + } + } + } + + (entropy / 1000) as u32 +} + +/// Delta encoding (for sequences with gradual changes). +/// Encodes differences between consecutive values. +pub fn delta_encode(input: &[u8], output: &mut [u8]) -> usize { + if input.is_empty() { + return 0; + } + + output[0] = input[0]; + for i in 1..input.len() { + output[i] = input[i].wrapping_sub(input[i - 1]); + } + input.len() +} + +/// Delta decoding. +pub fn delta_decode(input: &[u8], output: &mut [u8]) -> usize { + if input.is_empty() { + return 0; + } + + output[0] = input[0]; + for i in 1..input.len() { + output[i] = output[i - 1].wrapping_add(input[i]); + } + input.len() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rle_simple() { + let input = [1, 1, 1, 1, 2, 2, 3]; + let compressed = rle_compress(&input); + + assert_eq!(compressed.len, 3); + assert_eq!(compressed.data[0], (4, 1)); + assert_eq!(compressed.data[1], (2, 2)); + assert_eq!(compressed.data[2], (1, 3)); + } + + #[test] + fn test_rle_roundtrip() { + let input = [5, 5, 5, 10, 10, 15, 15, 15, 15]; + let compressed = rle_compress(&input); + + let mut output = [0u8; 16]; + let len = rle_decompress(&compressed, &mut output); + + assert_eq!(&output[..len], &input); + } + + #[test] + fn test_lz77_literal() { + // No repeats = all literals + let input = [1, 2, 3, 4, 5]; + let compressed = lz77_compress(&input); + + assert_eq!(compressed.len, 5); + for (i, token) in compressed.tokens[..5].iter().enumerate() { + assert_eq!(*token, Token::Literal(input[i])); + } + } + + #[test] + fn test_lz77_match() { + // Repeated pattern should create matches + let input = [1, 2, 3, 1, 2, 3, 1, 2, 3]; + let compressed = lz77_compress(&input); + + // First 3 bytes are literals, then matches + assert!(compressed.len < input.len()); + } + + #[test] + fn test_lz77_roundtrip() { + let input = [1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 5, 6, 5, 6]; + let compressed = lz77_compress(&input); + + let mut output = [0u8; 32]; + let len = lz77_decompress(&compressed, &mut output); + + assert_eq!(&output[..len], &input); + } + + #[test] + fn test_delta_encoding() { + let input = [10, 12, 15, 14, 16]; + let mut encoded = [0u8; 5]; + let mut decoded = [0u8; 5]; + + delta_encode(&input, &mut encoded); + assert_eq!(encoded, [10, 2, 3, 255, 2]); // 255 = -1 as u8 + + delta_decode(&encoded, &mut decoded); + assert_eq!(decoded, input); + } + + #[test] + fn test_histogram() { + let data = [1, 1, 1, 2, 2, 3]; + let mut hist = [0u32; 256]; + histogram(&data, &mut hist); + + assert_eq!(hist[1], 3); + assert_eq!(hist[2], 2); + assert_eq!(hist[3], 1); + assert_eq!(hist[0], 0); + } +} diff --git a/examples/block-compression/src/main.rs b/examples/block-compression/src/main.rs new file mode 100644 index 0000000..e1f36cc --- /dev/null +++ b/examples/block-compression/src/main.rs @@ -0,0 +1,193 @@ +//! Block Compression Example +//! +//! Demonstrates parallel block-wise compression. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use block_compression::{ + compress_blocks, delta_decode, delta_encode, estimate_entropy, histogram, + lz77_compress, lz77_decompress, rle_compress, rle_decompress, + Lz77Block, RleBlock, BLOCK_SIZE, +}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +#[unsafe(no_mangle)] +fn main() -> ! { + println!("=== Block Compression Example ==="); + + // Test 1: RLE with repetitive data + println!("\nTest 1: RLE Compression"); + let repetitive: [u8; 32] = [ + 1, 1, 1, 1, 1, 1, 1, 1, // 8 ones + 2, 2, 2, 2, // 4 twos + 3, 3, 3, 3, 3, 3, // 6 threes + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // 14 fours + ]; + + let rle = rle_compress(&repetitive); + println!(" Original size: {} bytes", repetitive.len()); + println!(" RLE pairs: {}", rle.len); + println!(" Compressed size: {} bytes (pairs * 2)", rle.len * 2); + + // Verify roundtrip + let mut rle_output = [0u8; 64]; + let rle_len = rle_decompress(&rle, &mut rle_output); + let rle_match = rle_output[..rle_len] == repetitive; + println!(" Roundtrip: {}", if rle_match { "PASS" } else { "FAIL" }); + + // Test 2: RLE with non-repetitive data + println!("\nTest 2: RLE on Random Data"); + let random: [u8; 16] = [1, 5, 2, 8, 3, 7, 4, 6, 9, 0, 1, 2, 3, 4, 5, 6]; + + let rle_random = rle_compress(&random); + println!(" Original: {} bytes", random.len()); + println!(" RLE pairs: {} (expansion!)", rle_random.len); + println!(" RLE not effective for random data"); + + // Test 3: LZ77 with repeating pattern + println!("\nTest 3: LZ77 Compression"); + let pattern: [u8; 24] = [ + 1, 2, 3, 4, 5, 6, + 1, 2, 3, 4, 5, 6, // repeat + 1, 2, 3, 4, 5, 6, // repeat + 7, 8, 9, 10, 11, 12, + ]; + + let lz77 = lz77_compress(&pattern); + println!(" Original size: {} bytes", pattern.len()); + println!(" LZ77 tokens: {}", lz77.len); + println!(" Approx compressed: {} bytes", lz77.compressed_size()); + + // Count match vs literal tokens + let mut matches = 0; + let mut literals = 0; + for token in &lz77.tokens[..lz77.len] { + match token { + block_compression::Token::Match { .. } => matches += 1, + block_compression::Token::Literal(_) => literals += 1, + } + } + println!(" Literals: {}, Matches: {}", literals, matches); + + // Verify roundtrip + let mut lz77_output = [0u8; 64]; + let lz77_len = lz77_decompress(&lz77, &mut lz77_output); + let lz77_match = lz77_output[..lz77_len] == pattern; + println!(" Roundtrip: {}", if lz77_match { "PASS" } else { "FAIL" }); + + // Test 4: Delta encoding for gradual data + println!("\nTest 4: Delta Encoding"); + let gradual: [u8; 8] = [100, 102, 105, 107, 110, 108, 112, 115]; + let mut delta_encoded = [0u8; 8]; + let mut delta_decoded = [0u8; 8]; + + delta_encode(&gradual, &mut delta_encoded); + println!(" Original: {:?}", &gradual[..4]); + println!(" Delta encoded: {:?}", &delta_encoded[..4]); + + // Delta encoding reduces magnitude, making RLE/LZ77 more effective + let delta_rle = rle_compress(&delta_encoded); + println!(" RLE after delta: {} pairs", delta_rle.len); + + delta_decode(&delta_encoded, &mut delta_decoded); + let delta_match = delta_decoded == gradual; + println!(" Roundtrip: {}", if delta_match { "PASS" } else { "FAIL" }); + + // Test 5: Histogram and entropy + println!("\nTest 5: Entropy Analysis"); + + // Low entropy (repetitive) + let mut hist = [0u32; 256]; + histogram(&repetitive, &mut hist); + let entropy_rep = estimate_entropy(&hist, repetitive.len()); + println!(" Repetitive data entropy: ~{} bits/byte", entropy_rep); + + // Higher entropy (more varied) + histogram(&random, &mut hist); + let entropy_rand = estimate_entropy(&hist, random.len()); + println!(" Random data entropy: ~{} bits/byte", entropy_rand); + + // Test 6: Block-based compression (parallel-friendly) + println!("\nTest 6: Block-Based Compression"); + + // Create data with different characteristics per block + let mut multiblock = [0u8; 128]; + // Block 0: repetitive + for i in 0..32 { + multiblock[i] = 5; + } + // Block 1: pattern + for i in 0..32 { + multiblock[32 + i] = (i % 4) as u8; + } + // Block 2: gradual + for i in 0..32 { + multiblock[64 + i] = i as u8; + } + // Block 3: mixed + for i in 0..32 { + multiblock[96 + i] = ((i * 7) % 256) as u8; + } + + let mut rle_blocks = [RleBlock::new(), RleBlock::new(), RleBlock::new(), RleBlock::new()]; + let mut lz77_blocks = [Lz77Block::new(), Lz77Block::new(), Lz77Block::new(), Lz77Block::new()]; + + let num_blocks = compress_blocks(&multiblock, 32, &mut rle_blocks, &mut lz77_blocks); + println!(" Total blocks: {}", num_blocks); + + for i in 0..num_blocks { + println!(" Block {}: RLE={} pairs, LZ77={} tokens", + i, rle_blocks[i].len, lz77_blocks[i].len); + } + + // Test 7: Compression comparison + println!("\nTest 7: Algorithm Comparison"); + + // Test different data patterns + let patterns: [(&str, [u8; 16]); 4] = [ + ("All zeros", [0; 16]), + ("Ascending", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + ("Repeating", [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]), + ("Mixed", [1, 1, 1, 2, 3, 4, 4, 4, 5, 6, 6, 7, 8, 8, 8, 8]), + ]; + + for (name, data) in &patterns { + let rle = rle_compress(data); + let lz77 = lz77_compress(data); + println!(" {}: RLE={} pairs, LZ77={} tokens", + name, rle.len, lz77.len); + } + + // Test 8: Verify block independence + println!("\nTest 8: Block Independence Verification"); + let block_a: [u8; 8] = [1, 1, 1, 1, 2, 2, 2, 2]; + let block_b: [u8; 8] = [3, 3, 3, 3, 4, 4, 4, 4]; + + // Compress separately + let rle_a = rle_compress(&block_a); + let rle_b = rle_compress(&block_b); + + // Compress together + let mut combined = [0u8; 16]; + combined[..8].copy_from_slice(&block_a); + combined[8..].copy_from_slice(&block_b); + + let mut rle_combined = [RleBlock::new(), RleBlock::new()]; + let mut lz_combined = [Lz77Block::new(), Lz77Block::new()]; + compress_blocks(&combined, 8, &mut rle_combined, &mut lz_combined); + + let independent = rle_a.len == rle_combined[0].len && rle_b.len == rle_combined[1].len; + println!(" Blocks compressed independently: {}", if independent { "PASS" } else { "FAIL" }); + + println!("\n=== Block Compression Example Complete ==="); + + platform::exit(0) +} diff --git a/examples/fft/Cargo.toml b/examples/fft/Cargo.toml new file mode 100644 index 0000000..887c576 --- /dev/null +++ b/examples/fft/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "fft" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel FFT with stage partitioning (Cooley-Tukey)" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/fft/src/lib.rs b/examples/fft/src/lib.rs new file mode 100644 index 0000000..fc8600d --- /dev/null +++ b/examples/fft/src/lib.rs @@ -0,0 +1,325 @@ +//! Parallel FFT Implementation (Cooley-Tukey) +//! +//! Stage-partitioned FFT for parallel execution. +//! Each stage operates on independent butterfly pairs. + +#![no_std] + +/// Complex number representation using fixed-point arithmetic. +/// Uses Q16.16 format for deterministic computation. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Complex { + pub re: i32, // Q16.16 fixed-point + pub im: i32, // Q16.16 fixed-point +} + +impl Complex { + pub const SCALE: i32 = 1 << 16; + + pub const fn new(re: i32, im: i32) -> Self { + Self { re, im } + } + + pub const fn from_int(re: i32) -> Self { + Self { + re: re << 16, + im: 0, + } + } + + /// Fixed-point multiplication + pub fn mul(self, other: Self) -> Self { + let re = ((self.re as i64 * other.re as i64) >> 16) + - ((self.im as i64 * other.im as i64) >> 16); + let im = ((self.re as i64 * other.im as i64) >> 16) + + ((self.im as i64 * other.re as i64) >> 16); + Self { + re: re as i32, + im: im as i32, + } + } + + pub fn add(self, other: Self) -> Self { + Self { + re: self.re + other.re, + im: self.im + other.im, + } + } + + pub fn sub(self, other: Self) -> Self { + Self { + re: self.re - other.re, + im: self.im - other.im, + } + } + + /// Approximate magnitude squared (for validation) + pub fn mag_squared(self) -> i64 { + (self.re as i64 * self.re as i64 + self.im as i64 * self.im as i64) >> 16 + } +} + +/// Pre-computed twiddle factors for FFT stages. +/// For N-point FFT: W_N^k = exp(-2πik/N) +pub struct TwiddleTable { + pub factors: [Complex; N], +} + +impl TwiddleTable { + /// Create twiddle factor table using integer approximation. + /// Uses pre-computed sine table for determinism. + pub fn new() -> Self { + let mut factors = [Complex::new(0, 0); N]; + + // Pre-computed sin/cos values for common angles + // sin(2πk/N) approximated using Taylor series or lookup + for k in 0..N { + // For small N, use lookup table approach + // cos(2πk/N) and -sin(2πk/N) in Q16.16 + let (cos_val, sin_val) = trig_lookup(k, N); + factors[k] = Complex::new(cos_val, -sin_val); + } + + Self { factors } + } + + pub fn get(&self, k: usize) -> Complex { + self.factors[k % N] + } +} + +impl Default for TwiddleTable { + fn default() -> Self { + Self::new() + } +} + +/// Integer trigonometry lookup using Cordic-style approximation. +/// Returns (cos, sin) in Q16.16 format. +fn trig_lookup(k: usize, n: usize) -> (i32, i32) { + if n == 0 { + return (Complex::SCALE, 0); + } + + // Normalize angle to [0, 4) representing quadrants + let angle_frac = (k * 4) / n; + let sub_angle = (k * 4) % n; + + // Base sin/cos for quarter rotation + // Using simple linear interpolation for small FFTs + let t = (sub_angle as i64 * Complex::SCALE as i64) / n as i64; + let t = t as i32; + + // Approximate cos and sin for angle in [0, π/2] + // cos(x) ≈ 1 - x²/2, sin(x) ≈ x (for small x) + // For better accuracy, use quadrant-aware computation + let (base_cos, base_sin) = match angle_frac % 4 { + 0 => { + // First quadrant: angle = t * π/2 + let c = Complex::SCALE - ((t as i64 * t as i64) >> 18) as i32; + let s = (t as i64 * 102944 >> 16) as i32; // π/2 * scale ≈ 102944 + (c, s) + } + 1 => { + // Second quadrant + let c = -((t as i64 * 102944 >> 16) as i32); + let s = Complex::SCALE - ((t as i64 * t as i64) >> 18) as i32; + (c, s) + } + 2 => { + // Third quadrant + let c = -(Complex::SCALE - ((t as i64 * t as i64) >> 18) as i32); + let s = -((t as i64 * 102944 >> 16) as i32); + (c, s) + } + 3 => { + // Fourth quadrant + let c = (t as i64 * 102944 >> 16) as i32; + let s = -(Complex::SCALE - ((t as i64 * t as i64) >> 18) as i32); + (c, s) + } + _ => unreachable!(), + }; + + (base_cos, base_sin) +} + +/// Bit-reverse permutation index for FFT input reordering. +pub fn bit_reverse(mut x: usize, bits: u32) -> usize { + let mut result = 0; + for _ in 0..bits { + result = (result << 1) | (x & 1); + x >>= 1; + } + result +} + +/// In-place bit-reversal permutation of the input array. +pub fn bit_reverse_permute(data: &mut [Complex]) { + let n = data.len(); + // Use integer log2 via trailing_zeros (n is power of 2) + let bits = n.trailing_zeros(); + + for i in 0..n { + let j = bit_reverse(i, bits); + if i < j { + data.swap(i, j); + } + } +} + +/// Single butterfly operation: the core FFT building block. +#[inline] +pub fn butterfly(a: &mut Complex, b: &mut Complex, twiddle: Complex) { + let t = twiddle.mul(*b); + let new_a = a.add(t); + let new_b = a.sub(t); + *a = new_a; + *b = new_b; +} + +/// FFT stage computation. +/// Each stage processes butterflies with a specific stride. +/// This is the parallelizable unit - different groups within a stage are independent. +pub fn fft_stage(data: &mut [Complex], stage: u32, twiddles: &[Complex]) { + let n = data.len(); + let butterflies_per_group = 1 << stage; + let group_size = butterflies_per_group * 2; + let num_groups = n / group_size; + + // Each group can be processed independently (parallel-friendly) + for group in 0..num_groups { + let group_start = group * group_size; + + for k in 0..butterflies_per_group { + let i = group_start + k; + let j = i + butterflies_per_group; + + // Twiddle factor index: k * (N / group_size) + let twiddle_idx = k * (n / group_size); + let twiddle = twiddles[twiddle_idx % twiddles.len()]; + + // Use split_at_mut to get non-overlapping mutable references + let (left, right) = data.split_at_mut(j); + butterfly(&mut left[i], &mut right[0], twiddle); + } + } +} + +/// Complete FFT computation using Cooley-Tukey decimation-in-time. +/// Parallel-friendly: each stage can be partitioned across workers. +pub fn fft(data: &mut [Complex], twiddles: &[Complex]) { + let n = data.len(); + assert!(n.is_power_of_two(), "FFT size must be power of 2"); + + // Use integer log2 via trailing_zeros (n is power of 2) + let num_stages = n.trailing_zeros(); + + // Step 1: Bit-reverse permutation + bit_reverse_permute(data); + + // Step 2: Process each stage + // TODO: With threading, parallelize within each stage + for stage in 0..num_stages { + fft_stage(data, stage, twiddles); + } +} + +/// Inverse FFT (IFFT) - conjugate input, FFT, conjugate output, scale. +pub fn ifft(data: &mut [Complex], twiddles: &[Complex]) { + let n = data.len(); + + // Conjugate input + for x in data.iter_mut() { + x.im = -x.im; + } + + // Forward FFT + fft(data, twiddles); + + // Conjugate and scale output + for x in data.iter_mut() { + x.im = -x.im; + x.re /= n as i32; + x.im /= n as i32; + } +} + +/// Batch FFT processing - multiple independent FFTs. +/// Each FFT is completely independent (embarrassingly parallel). +pub fn batch_fft( + batches: &mut [[Complex; N]], + twiddles: &TwiddleTable, +) { + // Each batch can be processed by a different thread + for batch in batches.iter_mut() { + fft(batch, &twiddles.factors); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_complex_arithmetic() { + let a = Complex::from_int(3); + let b = Complex::from_int(4); + + let sum = a.add(b); + assert_eq!(sum.re, 7 << 16); + assert_eq!(sum.im, 0); + + let diff = a.sub(b); + assert_eq!(diff.re, -1 << 16); + } + + #[test] + fn test_bit_reverse() { + assert_eq!(bit_reverse(0b000, 3), 0b000); + assert_eq!(bit_reverse(0b001, 3), 0b100); + assert_eq!(bit_reverse(0b010, 3), 0b010); + assert_eq!(bit_reverse(0b011, 3), 0b110); + assert_eq!(bit_reverse(0b100, 3), 0b001); + } + + #[test] + fn test_small_fft() { + // 4-point FFT of [1, 1, 1, 1] should give [4, 0, 0, 0] + let mut data = [ + Complex::from_int(1), + Complex::from_int(1), + Complex::from_int(1), + Complex::from_int(1), + ]; + + let twiddles = TwiddleTable::<4>::new(); + fft(&mut data, &twiddles.factors); + + // First element should be sum = 4 + assert!((data[0].re - (4 << 16)).abs() < 1000); + // Other elements should be near zero + assert!(data[1].mag_squared() < 10000); + assert!(data[2].mag_squared() < 10000); + assert!(data[3].mag_squared() < 10000); + } + + #[test] + fn test_impulse_response() { + // FFT of impulse [1, 0, 0, 0] should give all ones + let mut data = [ + Complex::from_int(1), + Complex::from_int(0), + Complex::from_int(0), + Complex::from_int(0), + ]; + + let twiddles = TwiddleTable::<4>::new(); + fft(&mut data, &twiddles.factors); + + // All elements should be approximately 1 + for c in &data { + assert!((c.re - Complex::SCALE).abs() < 5000); + } + } +} diff --git a/examples/fft/src/main.rs b/examples/fft/src/main.rs new file mode 100644 index 0000000..4914722 --- /dev/null +++ b/examples/fft/src/main.rs @@ -0,0 +1,135 @@ +//! FFT Example - Parallel Fast Fourier Transform +//! +//! Demonstrates stage-partitioned Cooley-Tukey FFT. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use fft::{batch_fft, Complex, TwiddleTable}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +#[unsafe(no_mangle)] +fn main() -> ! { + println!("=== FFT Example ==="); + + // Create twiddle factor table for 16-point FFT + let twiddles = TwiddleTable::<16>::new(); + + // Test 1: Simple DC signal (all ones) + println!("\nTest 1: DC Signal [1,1,1,...,1]"); + let mut dc_signal: [Complex; 16] = [Complex::from_int(1); 16]; + + fft::fft(&mut dc_signal, &twiddles.factors); + + println!(" Result[0] (DC): re={}", dc_signal[0].re >> 16); + println!(" Expected: 16 (sum of inputs)"); + + // Test 2: Impulse signal + println!("\nTest 2: Impulse Signal [1,0,0,...,0]"); + let mut impulse: [Complex; 16] = [Complex::from_int(0); 16]; + impulse[0] = Complex::from_int(1); + + fft::fft(&mut impulse, &twiddles.factors); + + println!(" All bins should be ~1 (flat spectrum)"); + let mut all_ones = true; + for (i, c) in impulse.iter().enumerate() { + let mag = (c.re >> 16).abs(); + if mag < 1 { + all_ones = false; + } + if i < 4 { + println!(" Result[{}]: re={}", i, c.re >> 16); + } + } + println!(" Flat spectrum: {}", if all_ones { "PASS" } else { "FAIL" }); + + // Test 3: Simple sinusoid (alternating pattern = Nyquist) + println!("\nTest 3: Nyquist Signal [1,-1,1,-1,...]"); + let mut nyquist: [Complex; 16] = [Complex::from_int(0); 16]; + for i in 0..16 { + nyquist[i] = if i % 2 == 0 { + Complex::from_int(1) + } else { + Complex::from_int(-1) + }; + } + + fft::fft(&mut nyquist, &twiddles.factors); + + println!(" Result[0] (DC): {}", nyquist[0].re >> 16); + println!(" Result[8] (Nyquist): {}", nyquist[8].re >> 16); + println!(" Expected: DC=0, Nyquist=16"); + + // Test 4: Batch FFT (parallel-friendly) + println!("\nTest 4: Batch FFT (4 independent transforms)"); + + let twiddles_8 = TwiddleTable::<8>::new(); + let mut batches: [[Complex; 8]; 4] = [[Complex::from_int(0); 8]; 4]; + + // Initialize each batch differently + for (batch_idx, batch) in batches.iter_mut().enumerate() { + for (i, c) in batch.iter_mut().enumerate() { + *c = Complex::from_int(((batch_idx + 1) * (i + 1)) as i32); + } + } + + batch_fft(&mut batches, &twiddles_8); + + for (batch_idx, batch) in batches.iter().enumerate() { + println!( + " Batch {}: DC component = {}", + batch_idx, + batch[0].re >> 16 + ); + } + + // Test 5: Round-trip FFT -> IFFT + println!("\nTest 5: FFT -> IFFT Round Trip"); + let original = [ + Complex::from_int(1), + Complex::from_int(2), + Complex::from_int(3), + Complex::from_int(4), + Complex::from_int(5), + Complex::from_int(6), + Complex::from_int(7), + Complex::from_int(8), + ]; + + let mut data = original; + let twiddles_8 = TwiddleTable::<8>::new(); + + fft::fft(&mut data, &twiddles_8.factors); + println!(" After FFT, DC = {}", data[0].re >> 16); + + fft::ifft(&mut data, &twiddles_8.factors); + + println!(" After IFFT:"); + let mut round_trip_ok = true; + for i in 0..8 { + let expected = (i + 1) as i32; + let actual = data[i].re >> 16; + if (actual - expected).abs() > 1 { + round_trip_ok = false; + } + if i < 4 { + println!(" [{}]: expected={}, actual={}", i, expected, actual); + } + } + println!( + " Round-trip: {}", + if round_trip_ok { "PASS" } else { "FAIL" } + ); + + println!("\n=== FFT Example Complete ==="); + + platform::exit(0) +} diff --git a/examples/graph-bfs/Cargo.toml b/examples/graph-bfs/Cargo.toml new file mode 100644 index 0000000..eb34486 --- /dev/null +++ b/examples/graph-bfs/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "graph-bfs" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel graph BFS with frontier slicing" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/graph-bfs/src/lib.rs b/examples/graph-bfs/src/lib.rs new file mode 100644 index 0000000..af78662 --- /dev/null +++ b/examples/graph-bfs/src/lib.rs @@ -0,0 +1,396 @@ +//! Parallel Graph BFS Implementation +//! +//! Frontier-sliced BFS for parallel execution. +//! Each slice of the frontier can be processed independently. + +#![no_std] + +/// Maximum number of vertices supported +pub const MAX_VERTICES: usize = 64; +/// Maximum number of edges supported +pub const MAX_EDGES: usize = 256; + +/// Compact graph representation using adjacency lists. +/// Edges are stored contiguously, with offsets per vertex. +pub struct Graph { + /// Number of vertices + pub num_vertices: usize, + /// Number of edges + pub num_edges: usize, + /// Offset into edges array for each vertex (CSR format) + pub offsets: [usize; MAX_VERTICES + 1], + /// Edge destinations (packed adjacency lists) + pub edges: [usize; MAX_EDGES], +} + +impl Graph { + pub const fn new() -> Self { + Self { + num_vertices: 0, + num_edges: 0, + offsets: [0; MAX_VERTICES + 1], + edges: [0; MAX_EDGES], + } + } + + /// Add a directed edge from `from` to `to`. + pub fn add_edge(&mut self, from: usize, to: usize) { + assert!(from < MAX_VERTICES && to < MAX_VERTICES); + + // This simple implementation requires edges to be added in order + // For a more robust version, build from edge list at the end + self.edges[self.num_edges] = to; + self.num_edges += 1; + } + + /// Get neighbors of a vertex. + pub fn neighbors(&self, v: usize) -> &[usize] { + let start = self.offsets[v]; + let end = self.offsets[v + 1]; + &self.edges[start..end] + } + + /// Build graph from edge list (simpler API). + pub fn from_edges(num_vertices: usize, edges: &[(usize, usize)]) -> Self { + let mut graph = Self::new(); + graph.num_vertices = num_vertices; + + // Count edges per vertex + let mut counts = [0usize; MAX_VERTICES]; + for &(from, _to) in edges { + counts[from] += 1; + } + + // Compute offsets (prefix sum) + let mut offset = 0; + for v in 0..num_vertices { + graph.offsets[v] = offset; + offset += counts[v]; + } + graph.offsets[num_vertices] = offset; + graph.num_edges = offset; + + // Fill in edges (need to track current position per vertex) + let mut current = [0usize; MAX_VERTICES]; + for v in 0..num_vertices { + current[v] = graph.offsets[v]; + } + + for &(from, to) in edges { + let pos = current[from]; + graph.edges[pos] = to; + current[from] += 1; + } + + graph + } +} + +impl Default for Graph { + fn default() -> Self { + Self::new() + } +} + +/// BFS state and results. +pub struct BfsResult { + /// Distance from source (-1 if unreachable) + pub distance: [i32; MAX_VERTICES], + /// Parent in BFS tree (-1 for source or unreachable) + pub parent: [i32; MAX_VERTICES], + /// Number of vertices reached + pub num_reached: usize, +} + +impl BfsResult { + pub fn new() -> Self { + Self { + distance: [-1; MAX_VERTICES], + parent: [-1; MAX_VERTICES], + num_reached: 0, + } + } +} + +impl Default for BfsResult { + fn default() -> Self { + Self::new() + } +} + +/// Frontier representation for BFS. +/// Double-buffered for level-synchronous BFS. +pub struct Frontier { + /// Current frontier vertices + pub current: [usize; MAX_VERTICES], + pub current_len: usize, + /// Next frontier vertices + pub next: [usize; MAX_VERTICES], + pub next_len: usize, +} + +impl Frontier { + pub fn new() -> Self { + Self { + current: [0; MAX_VERTICES], + current_len: 0, + next: [0; MAX_VERTICES], + next_len: 0, + } + } + + pub fn add_to_next(&mut self, v: usize) { + self.next[self.next_len] = v; + self.next_len += 1; + } + + pub fn swap(&mut self) { + // Swap current and next + for i in 0..self.next_len { + self.current[i] = self.next[i]; + } + self.current_len = self.next_len; + self.next_len = 0; + } + + pub fn is_empty(&self) -> bool { + self.current_len == 0 + } +} + +impl Default for Frontier { + fn default() -> Self { + Self::new() + } +} + +/// Process a slice of the frontier (parallel-friendly). +/// Each slice can be processed by a different thread. +/// Returns vertices to add to next frontier. +pub fn process_frontier_slice( + graph: &Graph, + frontier_slice: &[usize], + visited: &[bool; MAX_VERTICES], + current_distance: i32, +) -> ([usize; MAX_VERTICES], usize, [(usize, i32, i32); MAX_VERTICES], usize) { + let mut new_vertices = [0usize; MAX_VERTICES]; + let mut new_count = 0; + let mut updates = [(0usize, 0i32, 0i32); MAX_VERTICES]; // (vertex, distance, parent) + let mut update_count = 0; + + // Process each vertex in this slice + for &v in frontier_slice { + // Visit all neighbors + for &neighbor in graph.neighbors(v) { + if !visited[neighbor] { + // Found unvisited vertex + new_vertices[new_count] = neighbor; + new_count += 1; + + updates[update_count] = (neighbor, current_distance + 1, v as i32); + update_count += 1; + } + } + } + + (new_vertices, new_count, updates, update_count) +} + +/// Standard BFS from a source vertex. +pub fn bfs(graph: &Graph, source: usize) -> BfsResult { + let mut result = BfsResult::new(); + let mut visited = [false; MAX_VERTICES]; + let mut frontier = Frontier::new(); + + // Initialize source + result.distance[source] = 0; + visited[source] = true; + frontier.current[0] = source; + frontier.current_len = 1; + result.num_reached = 1; + + let mut current_distance = 0; + + // Process level by level + while !frontier.is_empty() { + // TODO: With threading, partition frontier into slices + // Each slice can be processed independently + let (new_verts, new_count, updates, update_count) = + process_frontier_slice(graph, &frontier.current[..frontier.current_len], &visited, current_distance); + + // Apply updates (this part needs synchronization with threading) + for i in 0..update_count { + let (v, dist, parent) = updates[i]; + if !visited[v] { + visited[v] = true; + result.distance[v] = dist; + result.parent[v] = parent; + frontier.add_to_next(v); + result.num_reached += 1; + } + } + + // Avoid unused warning + let _ = (new_verts, new_count); + + frontier.swap(); + current_distance += 1; + } + + result +} + +/// Multi-source BFS (useful for connected components). +pub fn multi_source_bfs(graph: &Graph, sources: &[usize]) -> BfsResult { + let mut result = BfsResult::new(); + let mut visited = [false; MAX_VERTICES]; + let mut frontier = Frontier::new(); + + // Initialize all sources + for &source in sources { + result.distance[source] = 0; + visited[source] = true; + frontier.current[frontier.current_len] = source; + frontier.current_len += 1; + result.num_reached += 1; + } + + let mut current_distance = 0; + + while !frontier.is_empty() { + let (_, _, updates, update_count) = + process_frontier_slice(graph, &frontier.current[..frontier.current_len], &visited, current_distance); + + for i in 0..update_count { + let (v, dist, parent) = updates[i]; + if !visited[v] { + visited[v] = true; + result.distance[v] = dist; + result.parent[v] = parent; + frontier.add_to_next(v); + result.num_reached += 1; + } + } + + frontier.swap(); + current_distance += 1; + } + + result +} + +/// Batch BFS from multiple independent sources. +/// Each BFS is completely independent (embarrassingly parallel). +pub fn batch_bfs(graph: &Graph, sources: &[usize], results: &mut [BfsResult]) { + assert_eq!(sources.len(), results.len()); + + // Each BFS can be done by a different thread + for (source, result) in sources.iter().zip(results.iter_mut()) { + *result = bfs(graph, *source); + } +} + +/// Reconstruct path from source to target using BFS result. +pub fn reconstruct_path(result: &BfsResult, target: usize, path: &mut [usize; MAX_VERTICES]) -> usize { + if result.distance[target] < 0 { + return 0; // Unreachable + } + + let mut len = 0; + let mut current = target; + + // Build path backwards + while result.parent[current] >= 0 { + path[len] = current; + len += 1; + current = result.parent[current] as usize; + } + path[len] = current; // Add source + len += 1; + + // Reverse path + for i in 0..len / 2 { + path.swap(i, len - 1 - i); + } + + len +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_graph_construction() { + let edges = [(0, 1), (0, 2), (1, 3), (2, 3)]; + let graph = Graph::from_edges(4, &edges); + + assert_eq!(graph.num_vertices, 4); + assert_eq!(graph.num_edges, 4); + assert_eq!(graph.neighbors(0), &[1, 2]); + assert_eq!(graph.neighbors(1), &[3]); + } + + #[test] + fn test_simple_bfs() { + // Linear graph: 0 -> 1 -> 2 -> 3 + let edges = [(0, 1), (1, 2), (2, 3)]; + let graph = Graph::from_edges(4, &edges); + + let result = bfs(&graph, 0); + + assert_eq!(result.distance[0], 0); + assert_eq!(result.distance[1], 1); + assert_eq!(result.distance[2], 2); + assert_eq!(result.distance[3], 3); + assert_eq!(result.num_reached, 4); + } + + #[test] + fn test_bfs_tree() { + // Tree: 0 + // / \ + // 1 2 + // / \ + // 3 4 + let edges = [(0, 1), (0, 2), (1, 3), (1, 4)]; + let graph = Graph::from_edges(5, &edges); + + let result = bfs(&graph, 0); + + assert_eq!(result.distance[0], 0); + assert_eq!(result.distance[1], 1); + assert_eq!(result.distance[2], 1); + assert_eq!(result.distance[3], 2); + assert_eq!(result.distance[4], 2); + } + + #[test] + fn test_unreachable() { + // Disconnected: 0 -> 1, 2 -> 3 + let edges = [(0, 1), (2, 3)]; + let graph = Graph::from_edges(4, &edges); + + let result = bfs(&graph, 0); + + assert_eq!(result.distance[0], 0); + assert_eq!(result.distance[1], 1); + assert_eq!(result.distance[2], -1); // Unreachable + assert_eq!(result.distance[3], -1); + assert_eq!(result.num_reached, 2); + } + + #[test] + fn test_path_reconstruction() { + let edges = [(0, 1), (1, 2), (2, 3)]; + let graph = Graph::from_edges(4, &edges); + + let result = bfs(&graph, 0); + + let mut path = [0usize; MAX_VERTICES]; + let len = reconstruct_path(&result, 3, &mut path); + + assert_eq!(len, 4); + assert_eq!(&path[..len], &[0, 1, 2, 3]); + } +} diff --git a/examples/graph-bfs/src/main.rs b/examples/graph-bfs/src/main.rs new file mode 100644 index 0000000..bb82a9a --- /dev/null +++ b/examples/graph-bfs/src/main.rs @@ -0,0 +1,157 @@ +//! Graph BFS Example +//! +//! Demonstrates frontier-sliced parallel BFS. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use graph_bfs::{batch_bfs, bfs, multi_source_bfs, reconstruct_path, BfsResult, Graph, MAX_VERTICES}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +#[unsafe(no_mangle)] +fn main() -> ! { + println!("=== Graph BFS Example ==="); + + // Test 1: Simple linear graph + println!("\nTest 1: Linear Graph (0->1->2->3->4)"); + let linear_edges = [(0, 1), (1, 2), (2, 3), (3, 4)]; + let linear_graph = Graph::from_edges(5, &linear_edges); + + let result = bfs(&linear_graph, 0); + println!(" Source: 0"); + for i in 0..5 { + println!(" Distance to {}: {}", i, result.distance[i]); + } + println!(" Vertices reached: {}", result.num_reached); + + // Test 2: Binary tree graph + println!("\nTest 2: Binary Tree"); + // 0 + // / \ + // 1 2 + // / \ / \ + // 3 4 5 6 + let tree_edges = [ + (0, 1), (0, 2), + (1, 3), (1, 4), + (2, 5), (2, 6), + ]; + let tree_graph = Graph::from_edges(7, &tree_edges); + + let tree_result = bfs(&tree_graph, 0); + println!(" Level 0 (root): vertex 0, dist={}", tree_result.distance[0]); + println!(" Level 1: vertices 1,2, dist={},{}", tree_result.distance[1], tree_result.distance[2]); + println!(" Level 2: vertices 3-6, dist={},{},{},{}", + tree_result.distance[3], tree_result.distance[4], + tree_result.distance[5], tree_result.distance[6]); + + // Test 3: Path reconstruction + println!("\nTest 3: Path Reconstruction"); + let mut path = [0usize; MAX_VERTICES]; + let path_len = reconstruct_path(&tree_result, 6, &mut path); + // Build path string since we don't have print! without newline + println!(" Path from 0 to 6: {} -> {} -> {}", path[0], path[1], path[2]); + + // Test 4: Disconnected graph + println!("\nTest 4: Disconnected Graph"); + // Two components: 0-1-2 and 3-4-5 + let disconnected_edges = [ + (0, 1), (1, 2), + (3, 4), (4, 5), + ]; + let disconnected_graph = Graph::from_edges(6, &disconnected_edges); + + let disc_result = bfs(&disconnected_graph, 0); + println!(" From source 0:"); + println!(" Reachable: {} vertices", disc_result.num_reached); + println!(" Distance to 2: {}", disc_result.distance[2]); + println!(" Distance to 5: {} (unreachable)", disc_result.distance[5]); + + // Test 5: Multi-source BFS + println!("\nTest 5: Multi-Source BFS"); + let sources = [0, 3]; + let multi_result = multi_source_bfs(&disconnected_graph, &sources); + println!(" Sources: 0 and 3"); + println!(" All vertices reached: {}", multi_result.num_reached); + for i in 0..6 { + println!(" Distance to {}: {}", i, multi_result.distance[i]); + } + + // Test 6: Dense graph (complete graph K5) + println!("\nTest 6: Complete Graph K5"); + let mut complete_edges = [(0usize, 0usize); 20]; + let mut idx = 0; + for i in 0..5 { + for j in 0..5 { + if i != j { + complete_edges[idx] = (i, j); + idx += 1; + } + } + } + let complete_graph = Graph::from_edges(5, &complete_edges); + + let complete_result = bfs(&complete_graph, 0); + println!(" From any vertex, all others at distance 1:"); + let all_dist_one = (1..5).all(|i| complete_result.distance[i] == 1); + println!(" Verified: {}", if all_dist_one { "PASS" } else { "FAIL" }); + + // Test 7: Batch BFS (parallel-friendly) + println!("\nTest 7: Batch BFS (independent searches)"); + let cycle_edges = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)]; // 5-cycle + let cycle_graph = Graph::from_edges(5, &cycle_edges); + + let batch_sources = [0, 1, 2]; + let mut batch_results = [BfsResult::new(), BfsResult::new(), BfsResult::new()]; + + batch_bfs(&cycle_graph, &batch_sources, &mut batch_results); + + for (i, result) in batch_results.iter().enumerate() { + println!(" BFS from {}: max_dist={}", + batch_sources[i], + result.distance.iter().take(5).max().unwrap_or(&0)); + } + + // Test 8: Larger graph for performance + println!("\nTest 8: Grid Graph 8x8"); + // Create 8x8 grid graph (4-connected) + let mut grid_edges = [(0usize, 0usize); 256]; + let mut edge_count = 0; + for row in 0..8 { + for col in 0..8 { + let v = row * 8 + col; + // Right neighbor + if col < 7 { + grid_edges[edge_count] = (v, v + 1); + edge_count += 1; + } + // Down neighbor + if row < 7 { + grid_edges[edge_count] = (v, v + 8); + edge_count += 1; + } + } + } + let grid_graph = Graph::from_edges(64, &grid_edges[..edge_count]); + + let grid_result = bfs(&grid_graph, 0); + println!(" Source: (0,0), Target: (7,7)"); + println!(" Distance: {} (Manhattan distance)", grid_result.distance[63]); + println!(" Expected: 14 (7 right + 7 down)"); + + // Reconstruct path + let mut grid_path = [0usize; MAX_VERTICES]; + let grid_path_len = reconstruct_path(&grid_result, 63, &mut grid_path); + println!(" Path length: {} vertices", grid_path_len); + + println!("\n=== Graph BFS Example Complete ==="); + + platform::exit(0) +} diff --git a/examples/matrix-multiply/Cargo.toml b/examples/matrix-multiply/Cargo.toml new file mode 100644 index 0000000..e13b6c4 --- /dev/null +++ b/examples/matrix-multiply/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "matrix-multiply" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel matrix multiplication with row/column blocks" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/matrix-multiply/src/lib.rs b/examples/matrix-multiply/src/lib.rs new file mode 100644 index 0000000..8932711 --- /dev/null +++ b/examples/matrix-multiply/src/lib.rs @@ -0,0 +1,134 @@ +//! Matrix multiplication implementation with block-parallel structure. +//! +//! Matrices are stored in row-major order. The computation is structured +//! so that row blocks can be computed independently by different threads. + +#![no_std] + +/// Matrix dimension (NxN matrices) +pub const DIM: usize = 16; + +/// Matrix type: DIM x DIM array of i32 +pub type Matrix = [[i32; DIM]; DIM]; + +/// Initialize a matrix with zeros +pub fn zero_matrix() -> Matrix { + [[0i32; DIM]; DIM] +} + +/// Initialize a matrix with deterministic test values +pub fn init_matrix(seed: u32) -> Matrix { + let mut m = zero_matrix(); + let mut s = seed; + for i in 0..DIM { + for j in 0..DIM { + // Simple LCG for deterministic values + s = s.wrapping_mul(1103515245).wrapping_add(12345); + m[i][j] = ((s >> 16) % 100) as i32; + } + } + m +} + +/// Standard matrix multiplication: C = A * B +pub fn matmul(a: &Matrix, b: &Matrix, c: &mut Matrix) { + for i in 0..DIM { + for j in 0..DIM { + let mut sum = 0i32; + for k in 0..DIM { + sum = sum.wrapping_add(a[i][k].wrapping_mul(b[k][j])); + } + c[i][j] = sum; + } + } +} + +/// Compute a single row block of the result matrix. +/// This function computes rows [start_row, end_row) of C = A * B. +/// Can be called independently by different threads. +pub fn matmul_row_block( + a: &Matrix, + b: &Matrix, + c: &mut Matrix, + start_row: usize, + end_row: usize, +) { + let end = core::cmp::min(end_row, DIM); + for i in start_row..end { + for j in 0..DIM { + let mut sum = 0i32; + for k in 0..DIM { + sum = sum.wrapping_add(a[i][k].wrapping_mul(b[k][j])); + } + c[i][j] = sum; + } + } +} + +/// Parallel-friendly block multiplication. +/// Divides the computation into `num_blocks` row blocks. +pub fn matmul_blocked(a: &Matrix, b: &Matrix, c: &mut Matrix, num_blocks: usize) { + let rows_per_block = (DIM + num_blocks - 1) / num_blocks; + + for block in 0..num_blocks { + let start_row = block * rows_per_block; + let end_row = core::cmp::min(start_row + rows_per_block, DIM); + matmul_row_block(a, b, c, start_row, end_row); + } +} + +/// Compute checksum of a matrix (for verification) +pub fn matrix_checksum(m: &Matrix) -> i64 { + let mut sum = 0i64; + for i in 0..DIM { + for j in 0..DIM { + sum = sum.wrapping_add(m[i][j] as i64); + } + } + sum +} + +/// Check if two matrices are equal +pub fn matrices_equal(a: &Matrix, b: &Matrix) -> bool { + for i in 0..DIM { + for j in 0..DIM { + if a[i][j] != b[i][j] { + return false; + } + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_matmul_identity() { + let mut identity = zero_matrix(); + for i in 0..DIM { + identity[i][i] = 1; + } + + let a = init_matrix(42); + let mut c = zero_matrix(); + matmul(&a, &identity, &mut c); + + assert!(matrices_equal(&a, &c)); + } + + #[test] + fn test_blocked_equals_standard() { + let a = init_matrix(123); + let b = init_matrix(456); + + let mut c_std = zero_matrix(); + let mut c_blk = zero_matrix(); + + matmul(&a, &b, &mut c_std); + matmul_blocked(&a, &b, &mut c_blk, 4); + + assert!(matrices_equal(&c_std, &c_blk)); + } +} diff --git a/examples/matrix-multiply/src/main.rs b/examples/matrix-multiply/src/main.rs new file mode 100644 index 0000000..328ab6f --- /dev/null +++ b/examples/matrix-multiply/src/main.rs @@ -0,0 +1,63 @@ +//! Parallel matrix multiplication demo. +//! +//! Demonstrates block-based matrix multiplication where each row block +//! can be computed independently by a different ZeroOS thread. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use matrix_multiply::{ + init_matrix, matmul, matmul_blocked, matrices_equal, matrix_checksum, zero_matrix, DIM, +}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +/// Number of parallel blocks (each can be a separate thread) +const NUM_BLOCKS: usize = 4; + +#[no_mangle] +fn main() -> ! { + debug::writeln!("[matrix-multiply] Starting matrix multiplication demo"); + debug::writeln!("[matrix-multiply] Matrix size: {}x{}, Blocks: {}", DIM, DIM, NUM_BLOCKS); + + // Initialize matrices + let a = init_matrix(0x12345678); + let b = init_matrix(0xDEADBEEF); + + println!("Matrix A[0][0..3]: [{}, {}, {}, {}]", a[0][0], a[0][1], a[0][2], a[0][3]); + println!("Matrix B[0][0..3]: [{}, {}, {}, {}]", b[0][0], b[0][1], b[0][2], b[0][3]); + + // Compute using standard algorithm + debug::writeln!("[matrix-multiply] Computing standard matmul..."); + let mut c_std = zero_matrix(); + matmul(&a, &b, &mut c_std); + + // Compute using blocked algorithm + debug::writeln!("[matrix-multiply] Computing blocked matmul ({} blocks)...", NUM_BLOCKS); + let mut c_blk = zero_matrix(); + matmul_blocked(&a, &b, &mut c_blk, NUM_BLOCKS); + + // Verify results match + if matrices_equal(&c_std, &c_blk) { + println!("Verification: PASSED (standard == blocked)"); + debug::writeln!("[matrix-multiply] Verification PASSED"); + } else { + println!("Verification: FAILED (mismatch!)"); + debug::writeln!("[matrix-multiply] Verification FAILED"); + } + + // Print result sample and checksum + println!("Result C[0][0..3]: [{}, {}, {}, {}]", c_std[0][0], c_std[0][1], c_std[0][2], c_std[0][3]); + + let checksum = matrix_checksum(&c_std); + println!("Result checksum: {}", checksum); + + debug::writeln!("[matrix-multiply] Demo complete!"); + platform::exit(0) +} diff --git a/examples/merkle-tree/Cargo.toml b/examples/merkle-tree/Cargo.toml new file mode 100644 index 0000000..d9e8dfd --- /dev/null +++ b/examples/merkle-tree/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "merkle-tree" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel Merkle tree construction" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/merkle-tree/src/lib.rs b/examples/merkle-tree/src/lib.rs new file mode 100644 index 0000000..09511e5 --- /dev/null +++ b/examples/merkle-tree/src/lib.rs @@ -0,0 +1,372 @@ +//! Parallel Merkle Tree Construction +//! +//! Level-by-level tree construction suitable for parallel execution. +//! Each level's hash computations are independent. + +#![no_std] + +/// Maximum tree depth (supports up to 2^16 = 65536 leaves) +pub const MAX_DEPTH: usize = 16; +/// Maximum number of leaves +pub const MAX_LEAVES: usize = 1 << MAX_DEPTH; +/// Hash output size (256 bits = 32 bytes) +pub const HASH_SIZE: usize = 32; + +/// Simple hash type (32-byte array) +pub type Hash = [u8; HASH_SIZE]; + +/// Zero hash constant +pub const ZERO_HASH: Hash = [0u8; HASH_SIZE]; + +/// Simple hash function (for demonstration). +/// Uses a simplified mixing algorithm - replace with Keccak/SHA256 for production. +fn hash_pair(left: &Hash, right: &Hash) -> Hash { + let mut result = [0u8; HASH_SIZE]; + + // Simple mixing: XOR, rotate, and add + for i in 0..HASH_SIZE { + // Mix left and right with position-dependent rotation + let l = left[i]; + let r = right[(i + 7) % HASH_SIZE]; + let mixed = l.wrapping_add(r).wrapping_add(i as u8); + + // Additional mixing pass + result[i] = mixed.rotate_left(3) ^ left[(i + 13) % HASH_SIZE]; + } + + // Second pass for better avalanche + for i in 0..HASH_SIZE { + result[i] = result[i] + .wrapping_add(result[(i + 1) % HASH_SIZE]) + .rotate_right(2); + } + + result +} + +/// Hash a single leaf value. +fn hash_leaf(data: &[u8]) -> Hash { + let mut result = [0u8; HASH_SIZE]; + + // Domain separation prefix for leaves + result[0] = 0x00; + + // Simple hash of input data + for (i, &byte) in data.iter().enumerate() { + let idx = (i % (HASH_SIZE - 1)) + 1; + result[idx] = result[idx].wrapping_add(byte).rotate_left(i as u32 % 8); + } + + // Finalization mixing + for i in 0..HASH_SIZE { + result[i] = result[i] ^ result[(i + 17) % HASH_SIZE]; + } + + result +} + +/// Merkle tree node containing hash. +#[derive(Clone, Copy)] +pub struct Node { + pub hash: Hash, +} + +impl Node { + pub const fn empty() -> Self { + Self { hash: ZERO_HASH } + } + + pub fn from_hash(hash: Hash) -> Self { + Self { hash } + } + + pub fn from_data(data: &[u8]) -> Self { + Self { + hash: hash_leaf(data), + } + } +} + +impl Default for Node { + fn default() -> Self { + Self::empty() + } +} + +/// Merkle tree with level-by-level storage. +/// Supports parallel construction at each level. +pub struct MerkleTree { + /// Tree levels: level[0] = leaves, level[depth] = root + /// Each level i has N / 2^i nodes + levels: [[Node; N]; MAX_DEPTH + 1], + /// Number of leaves (must be power of 2) + pub num_leaves: usize, + /// Tree depth + pub depth: usize, +} + +impl MerkleTree { + pub fn new() -> Self { + Self { + levels: [[Node::empty(); N]; MAX_DEPTH + 1], + num_leaves: 0, + depth: 0, + } + } + + /// Build tree from leaf data. + /// Parallel-friendly: each level can be computed independently. + pub fn build(&mut self, leaves: &[Hash]) { + let n = leaves.len(); + assert!(n.is_power_of_two() && n <= N); + + self.num_leaves = n; + // Use integer log2 via trailing_zeros (n is power of 2) + self.depth = n.trailing_zeros() as usize; + + // Level 0: copy leaves + for (i, hash) in leaves.iter().enumerate() { + self.levels[0][i] = Node::from_hash(*hash); + } + + // Build each level from the previous + // TODO: Each level's hash computations are independent + let mut level_size = n; + for level in 1..=self.depth { + level_size /= 2; + + // Each pair computation is independent (parallel-friendly) + for i in 0..level_size { + let left = &self.levels[level - 1][i * 2].hash; + let right = &self.levels[level - 1][i * 2 + 1].hash; + self.levels[level][i] = Node::from_hash(hash_pair(left, right)); + } + } + } + + /// Get the root hash. + pub fn root(&self) -> Hash { + self.levels[self.depth][0].hash + } + + /// Generate Merkle proof for leaf at index. + pub fn proof(&self, leaf_index: usize) -> MerkleProof { + assert!(leaf_index < self.num_leaves); + + let mut proof = MerkleProof::new(); + proof.leaf_index = leaf_index; + proof.depth = self.depth; + + let mut idx = leaf_index; + + for level in 0..self.depth { + // Sibling is the other child of our parent + let sibling_idx = idx ^ 1; // XOR with 1 flips last bit + proof.siblings[level] = self.levels[level][sibling_idx].hash; + idx /= 2; + } + + proof + } + + /// Get node at specific position. + pub fn get_node(&self, level: usize, index: usize) -> &Node { + &self.levels[level][index] + } +} + +impl Default for MerkleTree { + fn default() -> Self { + Self::new() + } +} + +/// Merkle proof containing sibling hashes. +#[derive(Clone)] +pub struct MerkleProof { + /// Sibling hashes from leaf to root (excluding root) + pub siblings: [Hash; MAX_DEPTH], + /// Index of the leaf being proved + pub leaf_index: usize, + /// Depth of the tree + pub depth: usize, +} + +impl MerkleProof { + pub fn new() -> Self { + Self { + siblings: [ZERO_HASH; MAX_DEPTH], + leaf_index: 0, + depth: 0, + } + } + + /// Verify proof against expected root. + pub fn verify(&self, leaf: &Hash, expected_root: &Hash) -> bool { + let mut current = *leaf; + let mut idx = self.leaf_index; + + for level in 0..self.depth { + let sibling = &self.siblings[level]; + + // Order depends on whether we're left or right child + current = if idx % 2 == 0 { + hash_pair(¤t, sibling) + } else { + hash_pair(sibling, ¤t) + }; + + idx /= 2; + } + + current == *expected_root + } +} + +impl Default for MerkleProof { + fn default() -> Self { + Self::new() + } +} + +/// Build multiple Merkle trees in batch. +/// Each tree is completely independent (embarrassingly parallel). +pub fn batch_build( + leaf_sets: &[[Hash; N]], + trees: &mut [MerkleTree], + leaf_count: usize, +) { + assert_eq!(leaf_sets.len(), trees.len()); + + // Each tree can be built by a different thread + for (leaves, tree) in leaf_sets.iter().zip(trees.iter_mut()) { + tree.build(&leaves[..leaf_count]); + } +} + +/// Verify multiple proofs in batch. +/// Each verification is independent (embarrassingly parallel). +pub fn batch_verify( + proofs: &[MerkleProof], + leaves: &[Hash], + roots: &[Hash], + results: &mut [bool], +) { + assert_eq!(proofs.len(), leaves.len()); + assert_eq!(proofs.len(), roots.len()); + assert_eq!(proofs.len(), results.len()); + + // Each verification can be done by a different thread + for (i, proof) in proofs.iter().enumerate() { + results[i] = proof.verify(&leaves[i], &roots[i]); + } +} + +/// Compute hashes for a level in parallel preparation. +/// Returns array of hash pairs that can be computed independently. +pub fn prepare_level_hashes( + prev_level: &[Node; N], + level_size: usize, +) -> [(Hash, Hash); N] { + let mut pairs = [(ZERO_HASH, ZERO_HASH); N]; + + for i in 0..level_size { + pairs[i] = ( + prev_level[i * 2].hash, + prev_level[i * 2 + 1].hash, + ); + } + + pairs +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_leaf(value: u8) -> Hash { + let mut h = ZERO_HASH; + h[0] = value; + hash_leaf(&h) + } + + #[test] + fn test_hash_pair_deterministic() { + let a = [1u8; HASH_SIZE]; + let b = [2u8; HASH_SIZE]; + + let h1 = hash_pair(&a, &b); + let h2 = hash_pair(&a, &b); + + assert_eq!(h1, h2); + } + + #[test] + fn test_hash_pair_order_matters() { + let a = [1u8; HASH_SIZE]; + let b = [2u8; HASH_SIZE]; + + let h1 = hash_pair(&a, &b); + let h2 = hash_pair(&b, &a); + + assert_ne!(h1, h2); + } + + #[test] + fn test_simple_tree() { + let leaves = [make_leaf(1), make_leaf(2), make_leaf(3), make_leaf(4)]; + + let mut tree = MerkleTree::<4>::new(); + tree.build(&leaves); + + assert_eq!(tree.num_leaves, 4); + assert_eq!(tree.depth, 2); + + // Root should be non-zero + assert_ne!(tree.root(), ZERO_HASH); + } + + #[test] + fn test_proof_verification() { + let leaves = [make_leaf(1), make_leaf(2), make_leaf(3), make_leaf(4)]; + + let mut tree = MerkleTree::<4>::new(); + tree.build(&leaves); + + let root = tree.root(); + + // Verify proof for each leaf + for i in 0..4 { + let proof = tree.proof(i); + assert!(proof.verify(&leaves[i], &root)); + } + } + + #[test] + fn test_invalid_proof() { + let leaves = [make_leaf(1), make_leaf(2), make_leaf(3), make_leaf(4)]; + + let mut tree = MerkleTree::<4>::new(); + tree.build(&leaves); + + let root = tree.root(); + let proof = tree.proof(0); + + // Wrong leaf should fail + let wrong_leaf = make_leaf(99); + assert!(!proof.verify(&wrong_leaf, &root)); + } + + #[test] + fn test_tree_determinism() { + let leaves = [make_leaf(1), make_leaf(2), make_leaf(3), make_leaf(4)]; + + let mut tree1 = MerkleTree::<4>::new(); + let mut tree2 = MerkleTree::<4>::new(); + + tree1.build(&leaves); + tree2.build(&leaves); + + assert_eq!(tree1.root(), tree2.root()); + } +} diff --git a/examples/merkle-tree/src/main.rs b/examples/merkle-tree/src/main.rs new file mode 100644 index 0000000..7861a11 --- /dev/null +++ b/examples/merkle-tree/src/main.rs @@ -0,0 +1,178 @@ +//! Merkle Tree Example +//! +//! Demonstrates parallel Merkle tree construction and proof verification. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use merkle_tree::{batch_verify, Hash, MerkleProof, MerkleTree, ZERO_HASH}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +/// Create a test leaf hash from a simple value. +fn make_leaf(value: u8) -> Hash { + let mut h = ZERO_HASH; + h[0] = value; + // Simple additional mixing + for i in 1..32 { + h[i] = h[i - 1].wrapping_add(value).wrapping_mul(17); + } + h +} + +/// Print first few bytes of a hash. +fn print_hash_prefix(h: &Hash) { + println!( + " {:02x}{:02x}{:02x}{:02x}...", + h[0], h[1], h[2], h[3] + ); +} + +#[unsafe(no_mangle)] +fn main() -> ! { + println!("=== Merkle Tree Example ==="); + + // Test 1: Simple 4-leaf tree + println!("\nTest 1: 4-Leaf Tree"); + let leaves_4 = [make_leaf(1), make_leaf(2), make_leaf(3), make_leaf(4)]; + + let mut tree_4 = MerkleTree::<4>::new(); + tree_4.build(&leaves_4); + + println!(" Leaves: 4"); + println!(" Depth: {}", tree_4.depth); + println!(" Root hash:"); + print_hash_prefix(&tree_4.root()); + + // Test 2: Proof generation and verification + println!("\nTest 2: Proof Verification"); + let root = tree_4.root(); + + let mut all_valid = true; + for i in 0..4 { + let proof = tree_4.proof(i); + let valid = proof.verify(&leaves_4[i], &root); + if !valid { + all_valid = false; + } + println!(" Leaf {} proof valid: {}", i, valid); + } + println!(" All proofs valid: {}", if all_valid { "PASS" } else { "FAIL" }); + + // Test 3: Invalid proof detection + println!("\nTest 3: Invalid Proof Detection"); + let proof_0 = tree_4.proof(0); + let wrong_leaf = make_leaf(99); + let invalid = !proof_0.verify(&wrong_leaf, &root); + println!(" Wrong leaf rejected: {}", if invalid { "PASS" } else { "FAIL" }); + + // Wrong root test + let wrong_root = make_leaf(88); + let invalid_root = !proof_0.verify(&leaves_4[0], &wrong_root); + println!(" Wrong root rejected: {}", if invalid_root { "PASS" } else { "FAIL" }); + + // Test 4: Larger tree (16 leaves) + println!("\nTest 4: 16-Leaf Tree"); + let mut leaves_16 = [ZERO_HASH; 16]; + for i in 0..16 { + leaves_16[i] = make_leaf(i as u8); + } + + let mut tree_16 = MerkleTree::<16>::new(); + tree_16.build(&leaves_16); + + println!(" Leaves: 16"); + println!(" Depth: {}", tree_16.depth); + println!(" Root hash:"); + print_hash_prefix(&tree_16.root()); + + // Verify random leaf + let proof_10 = tree_16.proof(10); + let valid_10 = proof_10.verify(&leaves_16[10], &tree_16.root()); + println!(" Proof for leaf 10: {}", if valid_10 { "PASS" } else { "FAIL" }); + + // Test 5: Tree determinism + println!("\nTest 5: Determinism"); + let mut tree_16_copy = MerkleTree::<16>::new(); + tree_16_copy.build(&leaves_16); + + let deterministic = tree_16.root() == tree_16_copy.root(); + println!(" Same input -> same root: {}", if deterministic { "PASS" } else { "FAIL" }); + + // Test 6: Different leaves produce different roots + println!("\nTest 6: Collision Resistance"); + let mut leaves_different = leaves_16; + leaves_different[0] = make_leaf(255); // Change one leaf + + let mut tree_different = MerkleTree::<16>::new(); + tree_different.build(&leaves_different); + + let no_collision = tree_16.root() != tree_different.root(); + println!(" One changed leaf changes root: {}", if no_collision { "PASS" } else { "FAIL" }); + + // Test 7: Batch proof verification (parallel-friendly) + println!("\nTest 7: Batch Verification (parallel-friendly)"); + let proofs: [MerkleProof; 4] = [ + tree_16.proof(0), + tree_16.proof(5), + tree_16.proof(10), + tree_16.proof(15), + ]; + let batch_leaves = [leaves_16[0], leaves_16[5], leaves_16[10], leaves_16[15]]; + let batch_roots = [tree_16.root(), tree_16.root(), tree_16.root(), tree_16.root()]; + let mut results = [false; 4]; + + batch_verify(&proofs, &batch_leaves, &batch_roots, &mut results); + + let batch_ok = results.iter().all(|&r| r); + println!(" Verified 4 proofs in batch"); + println!(" All valid: {}", if batch_ok { "PASS" } else { "FAIL" }); + + // Test 8: Proof size analysis + println!("\nTest 8: Proof Size Analysis"); + let proof_size_bits = tree_16.depth * 256; // Each sibling is 256 bits + let proof_size_bytes = tree_16.depth * 32; + println!(" Tree depth: {}", tree_16.depth); + println!(" Proof size: {} siblings = {} bits = {} bytes", + tree_16.depth, proof_size_bits, proof_size_bytes); + + // Test 9: Level-by-level inspection + println!("\nTest 9: Tree Structure"); + println!(" Level 0 (leaves): {} nodes", tree_16.num_leaves); + let mut level_size = tree_16.num_leaves; + for level in 1..=tree_16.depth { + level_size /= 2; + println!(" Level {} (internal): {} nodes", level, level_size); + } + + // Show a path through the tree + println!("\n Path for leaf 5:"); + let mut idx = 5usize; + for level in 0..=tree_16.depth { + let node = tree_16.get_node(level, idx); + println!(" Level {}, idx {}: {:02x}{:02x}{:02x}{:02x}...", + level, idx, node.hash[0], node.hash[1], node.hash[2], node.hash[3]); + idx /= 2; + } + + // Test 10: Edge case - 2 leaves (minimum tree) + println!("\nTest 10: Minimum Tree (2 leaves)"); + let leaves_2 = [make_leaf(100), make_leaf(200)]; + let mut tree_2 = MerkleTree::<2>::new(); + tree_2.build(&leaves_2); + + println!(" Depth: {}", tree_2.depth); + let proof_min = tree_2.proof(0); + let valid_min = proof_min.verify(&leaves_2[0], &tree_2.root()); + println!(" Proof works: {}", if valid_min { "PASS" } else { "FAIL" }); + + println!("\n=== Merkle Tree Example Complete ==="); + + platform::exit(0) +} diff --git a/examples/parallel-keccak/Cargo.toml b/examples/parallel-keccak/Cargo.toml new file mode 100644 index 0000000..7f82f8b --- /dev/null +++ b/examples/parallel-keccak/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "parallel-keccak" +publish = false +version.workspace = true +edition.workspace = true +description = "Multi-threaded batch Keccak/SHA3 hasher demo" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/parallel-keccak/src/lib.rs b/examples/parallel-keccak/src/lib.rs new file mode 100644 index 0000000..214541c --- /dev/null +++ b/examples/parallel-keccak/src/lib.rs @@ -0,0 +1,160 @@ +//! Keccak-f[1600] permutation implementation. +//! +//! This is a minimal, no_std-compatible implementation of the Keccak +//! permutation used in SHA3 and other sponge constructions. + +#![no_std] + +/// Keccak-f[1600] state: 5x5 array of 64-bit lanes (1600 bits total) +pub type KeccakState = [[u64; 5]; 5]; + +/// Round constants for Keccak-f[1600] +const RC: [u64; 24] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808a, + 0x8000000080008000, + 0x000000000000808b, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008a, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000a, + 0x000000008000808b, + 0x800000000000008b, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800a, + 0x800000008000000a, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, +]; + +/// Rotation offsets for rho step +const RHO: [[u32; 5]; 5] = [ + [0, 1, 62, 28, 27], + [36, 44, 6, 55, 20], + [3, 10, 43, 25, 39], + [41, 45, 15, 21, 8], + [18, 2, 61, 56, 14], +]; + +/// Perform the Keccak-f[1600] permutation on the state +pub fn keccak_f(state: &mut KeccakState) { + for round in 0..24 { + // θ (theta) step + let mut c = [0u64; 5]; + for x in 0..5 { + c[x] = state[x][0] ^ state[x][1] ^ state[x][2] ^ state[x][3] ^ state[x][4]; + } + let mut d = [0u64; 5]; + for x in 0..5 { + d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1); + } + for x in 0..5 { + for y in 0..5 { + state[x][y] ^= d[x]; + } + } + + // ρ (rho) and π (pi) steps combined + let mut b = [[0u64; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + b[y][(2 * x + 3 * y) % 5] = state[x][y].rotate_left(RHO[x][y]); + } + } + + // χ (chi) step + for x in 0..5 { + for y in 0..5 { + state[x][y] = b[x][y] ^ ((!b[(x + 1) % 5][y]) & b[(x + 2) % 5][y]); + } + } + + // ι (iota) step + state[0][0] ^= RC[round]; + } +} + +/// Initialize state from a message block (simplified: just XOR first bytes) +pub fn absorb_block(state: &mut KeccakState, block: &[u8]) { + // Rate for SHA3-256 is 1088 bits = 136 bytes = 17 lanes + let lanes = core::cmp::min(block.len() / 8, 17); + for i in 0..lanes { + let x = i % 5; + let y = i / 5; + let mut lane_bytes = [0u8; 8]; + let start = i * 8; + let end = core::cmp::min(start + 8, block.len()); + lane_bytes[..end - start].copy_from_slice(&block[start..end]); + state[x][y] ^= u64::from_le_bytes(lane_bytes); + } +} + +/// Extract hash output from state (256 bits for SHA3-256) +pub fn squeeze_256(state: &KeccakState) -> [u8; 32] { + let mut output = [0u8; 32]; + for i in 0..4 { + let x = i % 5; + let y = i / 5; + let lane_bytes = state[x][y].to_le_bytes(); + output[i * 8..(i + 1) * 8].copy_from_slice(&lane_bytes); + } + output +} + +/// Simple SHA3-256 hash of a single block (for demo purposes) +pub fn sha3_256_simple(data: &[u8]) -> [u8; 32] { + let mut state: KeccakState = [[0u64; 5]; 5]; + + // Absorb (simplified: single block, no proper padding) + absorb_block(&mut state, data); + + // Apply permutation + keccak_f(&mut state); + + // Squeeze + squeeze_256(&state) +} + +/// Batch hash multiple messages (single-threaded baseline) +pub fn batch_hash(messages: &[&[u8]], outputs: &mut [[u8; 32]]) { + for (i, msg) in messages.iter().enumerate() { + outputs[i] = sha3_256_simple(msg); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_keccak_f_deterministic() { + let mut state1: KeccakState = [[0u64; 5]; 5]; + let mut state2: KeccakState = [[0u64; 5]; 5]; + state1[0][0] = 0x123456789ABCDEF0; + state2[0][0] = 0x123456789ABCDEF0; + + keccak_f(&mut state1); + keccak_f(&mut state2); + + assert_eq!(state1, state2); + } + + #[test] + fn test_sha3_simple() { + let data = b"hello world"; + let hash1 = sha3_256_simple(data); + let hash2 = sha3_256_simple(data); + assert_eq!(hash1, hash2); + // Hash should be non-zero + assert!(hash1.iter().any(|&b| b != 0)); + } +} diff --git a/examples/parallel-keccak/src/main.rs b/examples/parallel-keccak/src/main.rs new file mode 100644 index 0000000..7651c09 --- /dev/null +++ b/examples/parallel-keccak/src/main.rs @@ -0,0 +1,68 @@ +//! Parallel Keccak/SHA3 batch hasher demo. +//! +//! Demonstrates multi-threaded hashing using ZeroOS cooperative threads. +//! Each worker thread processes a subset of message blocks independently. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use parallel_keccak::{sha3_256_simple, KeccakState, keccak_f}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +/// Number of messages to hash in the batch +const BATCH_SIZE: usize = 8; + +/// Test messages (deterministic for reproducibility) +const TEST_MESSAGES: [&[u8]; BATCH_SIZE] = [ + b"message_0_hello_world", + b"message_1_foo_bar_baz", + b"message_2_test_data_x", + b"message_3_zkvm_rocks!", + b"message_4_zeroos_demo", + b"message_5_jolt_prover", + b"message_6_keccak_hash", + b"message_7_final_block", +]; + +#[no_mangle] +fn main() -> ! { + debug::writeln!("[parallel-keccak] Starting batch hash demo"); + + // Single-threaded batch hash for now + // TODO: Add multi-threaded version using ZeroOS threads + let mut outputs = [[0u8; 32]; BATCH_SIZE]; + + for (i, msg) in TEST_MESSAGES.iter().enumerate() { + debug::writeln!("[parallel-keccak] Hashing message {}", i); + outputs[i] = sha3_256_simple(msg); + } + + // Print results (first 8 bytes of each hash) + for (i, hash) in outputs.iter().enumerate() { + println!( + "hash[{}] = {:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}...", + i, + hash[0], hash[1], hash[2], hash[3], + hash[4], hash[5], hash[6], hash[7] + ); + } + + // Verify determinism: hash the same message twice + let check1 = sha3_256_simple(b"determinism_check"); + let check2 = sha3_256_simple(b"determinism_check"); + if check1 == check2 { + debug::writeln!("[parallel-keccak] Determinism check PASSED"); + } else { + debug::writeln!("[parallel-keccak] Determinism check FAILED"); + } + + debug::writeln!("[parallel-keccak] Demo complete!"); + platform::exit(0) +} diff --git a/examples/parallel-mergesort/Cargo.toml b/examples/parallel-mergesort/Cargo.toml new file mode 100644 index 0000000..e4d35d9 --- /dev/null +++ b/examples/parallel-mergesort/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "parallel-mergesort" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel merge sort with divide-and-conquer threading" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/parallel-mergesort/src/lib.rs b/examples/parallel-mergesort/src/lib.rs new file mode 100644 index 0000000..22b1db0 --- /dev/null +++ b/examples/parallel-mergesort/src/lib.rs @@ -0,0 +1,152 @@ +//! Merge sort implementation with parallel-friendly structure. +//! +//! The algorithm naturally decomposes into independent sub-problems, +//! making it ideal for demonstrating thread-based parallelism. + +#![no_std] + +/// Merge two sorted slices into the output buffer +pub fn merge(left: &[u32], right: &[u32], output: &mut [u32]) { + let mut i = 0; + let mut j = 0; + let mut k = 0; + + while i < left.len() && j < right.len() { + if left[i] <= right[j] { + output[k] = left[i]; + i += 1; + } else { + output[k] = right[j]; + j += 1; + } + k += 1; + } + + // Copy remaining elements from left + while i < left.len() { + output[k] = left[i]; + i += 1; + k += 1; + } + + // Copy remaining elements from right + while j < right.len() { + output[k] = right[j]; + j += 1; + k += 1; + } +} + +/// Single-threaded merge sort (in-place using auxiliary buffer) +pub fn merge_sort(arr: &mut [u32], aux: &mut [u32]) { + let n = arr.len(); + if n <= 1 { + return; + } + + let mid = n / 2; + + // Recursively sort halves + merge_sort(&mut arr[..mid], &mut aux[..mid]); + merge_sort(&mut arr[mid..], &mut aux[mid..]); + + // Merge into auxiliary buffer + merge(&arr[..mid], &arr[mid..], aux); + + // Copy back + arr.copy_from_slice(&aux[..n]); +} + +/// Sort independent segments (preparation for parallel merge) +/// Each segment can be sorted by a different thread +pub fn sort_segments(arr: &mut [u32], aux: &mut [u32], num_segments: usize) { + let n = arr.len(); + let segment_size = (n + num_segments - 1) / num_segments; + + for i in 0..num_segments { + let start = i * segment_size; + let end = core::cmp::min(start + segment_size, n); + if start < n { + let seg_len = end - start; + merge_sort(&mut arr[start..end], &mut aux[start..start + seg_len]); + } + } +} + +/// Merge sorted segments pairwise +pub fn merge_segments(arr: &mut [u32], aux: &mut [u32], num_segments: usize) { + let n = arr.len(); + let segment_size = (n + num_segments - 1) / num_segments; + + // Pairwise merge until all segments are combined + let mut current_segments = num_segments; + let mut current_size = segment_size; + + while current_segments > 1 { + let pairs = (current_segments + 1) / 2; + + for p in 0..pairs { + let left_start = p * 2 * current_size; + let left_end = core::cmp::min(left_start + current_size, n); + let right_start = left_end; + let right_end = core::cmp::min(right_start + current_size, n); + + if right_start < n { + // Merge two adjacent segments + merge( + &arr[left_start..left_end], + &arr[right_start..right_end], + &mut aux[left_start..right_end], + ); + arr[left_start..right_end].copy_from_slice(&aux[left_start..right_end]); + } + } + + current_segments = pairs; + current_size *= 2; + } +} + +/// Check if array is sorted +pub fn is_sorted(arr: &[u32]) -> bool { + for i in 1..arr.len() { + if arr[i - 1] > arr[i] { + return false; + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merge() { + let left = [1, 3, 5, 7]; + let right = [2, 4, 6, 8]; + let mut output = [0u32; 8]; + merge(&left, &right, &mut output); + assert_eq!(output, [1, 2, 3, 4, 5, 6, 7, 8]); + } + + #[test] + fn test_merge_sort() { + let mut arr = [5, 2, 8, 1, 9, 3, 7, 4, 6]; + let mut aux = [0u32; 9]; + merge_sort(&mut arr, &mut aux); + assert!(is_sorted(&arr)); + } + + #[test] + fn test_sort_segments() { + let mut arr = [8, 4, 2, 6, 1, 5, 3, 7]; + let mut aux = [0u32; 8]; + sort_segments(&mut arr, &mut aux, 4); + // Each segment of 2 should be sorted + assert!(arr[0] <= arr[1]); + assert!(arr[2] <= arr[3]); + assert!(arr[4] <= arr[5]); + assert!(arr[6] <= arr[7]); + } +} diff --git a/examples/parallel-mergesort/src/main.rs b/examples/parallel-mergesort/src/main.rs new file mode 100644 index 0000000..c8ff784 --- /dev/null +++ b/examples/parallel-mergesort/src/main.rs @@ -0,0 +1,72 @@ +//! Parallel merge sort demo. +//! +//! Demonstrates divide-and-conquer sorting using independent segment sorting, +//! which can be parallelized across ZeroOS cooperative threads. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use parallel_mergesort::{is_sorted, merge_segments, sort_segments}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +/// Array size (keep small for ~1M cycle budget) +const ARRAY_SIZE: usize = 64; + +/// Number of parallel segments +const NUM_SEGMENTS: usize = 4; + +/// Generate deterministic test data +fn generate_test_data(arr: &mut [u32]) { + let mut seed: u32 = 0x12345678; + for i in 0..arr.len() { + // Simple LCG for deterministic pseudo-random values + seed = seed.wrapping_mul(1103515245).wrapping_add(12345); + arr[i] = (seed >> 16) & 0xFFFF; + } +} + +#[no_mangle] +fn main() -> ! { + debug::writeln!("[parallel-mergesort] Starting merge sort demo"); + debug::writeln!("[parallel-mergesort] Array size: {}, Segments: {}", ARRAY_SIZE, NUM_SEGMENTS); + + // Allocate arrays + let mut arr = [0u32; ARRAY_SIZE]; + let mut aux = [0u32; ARRAY_SIZE]; + + // Generate test data + generate_test_data(&mut arr); + + // Print first few elements before sorting + println!("Before: [{}, {}, {}, {}, ...]", arr[0], arr[1], arr[2], arr[3]); + + // Phase 1: Sort independent segments (parallelizable) + debug::writeln!("[parallel-mergesort] Sorting {} segments...", NUM_SEGMENTS); + sort_segments(&mut arr, &mut aux, NUM_SEGMENTS); + + // Phase 2: Merge sorted segments + debug::writeln!("[parallel-mergesort] Merging segments..."); + merge_segments(&mut arr, &mut aux, NUM_SEGMENTS); + + // Print first few elements after sorting + println!("After: [{}, {}, {}, {}, ...]", arr[0], arr[1], arr[2], arr[3]); + + // Verify sorted + if is_sorted(&arr) { + println!("Sort verification: PASSED"); + debug::writeln!("[parallel-mergesort] Sort verification PASSED"); + } else { + println!("Sort verification: FAILED"); + debug::writeln!("[parallel-mergesort] Sort verification FAILED"); + } + + debug::writeln!("[parallel-mergesort] Demo complete!"); + platform::exit(0) +} diff --git a/examples/polynomial-eval/Cargo.toml b/examples/polynomial-eval/Cargo.toml new file mode 100644 index 0000000..3b457f8 --- /dev/null +++ b/examples/polynomial-eval/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "polynomial-eval" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel polynomial batch evaluation" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/polynomial-eval/src/lib.rs b/examples/polynomial-eval/src/lib.rs new file mode 100644 index 0000000..fd5bad0 --- /dev/null +++ b/examples/polynomial-eval/src/lib.rs @@ -0,0 +1,356 @@ +//! Parallel Polynomial Evaluation +//! +//! Batch polynomial evaluation for parallel execution. +//! Each polynomial/point evaluation is independent. +//! +//! Useful for: +//! - Polynomial commitment schemes +//! - Reed-Solomon encoding +//! - Lagrange interpolation evaluation +//! - KZG proofs + +#![no_std] + +/// Maximum polynomial degree supported +pub const MAX_DEGREE: usize = 64; +/// Maximum number of evaluation points +pub const MAX_POINTS: usize = 64; +/// Prime field modulus (small prime for demo) +pub const MODULUS: i64 = 0x7FFFFFFF; // 2^31 - 1 (Mersenne prime) + +/// Polynomial represented by coefficients. +/// p(x) = coeffs[0] + coeffs[1]*x + coeffs[2]*x^2 + ... +#[derive(Clone)] +pub struct Polynomial { + pub coeffs: [i64; MAX_DEGREE], + pub degree: usize, +} + +impl Polynomial { + pub const fn new() -> Self { + Self { + coeffs: [0; MAX_DEGREE], + degree: 0, + } + } + + /// Create polynomial from coefficient slice. + pub fn from_coeffs(coeffs: &[i64]) -> Self { + let mut poly = Self::new(); + let len = core::cmp::min(coeffs.len(), MAX_DEGREE); + for (i, &c) in coeffs.iter().take(len).enumerate() { + poly.coeffs[i] = c % MODULUS; + } + poly.degree = len.saturating_sub(1); + poly + } + + /// Evaluate polynomial at point x using Horner's method. + /// O(n) where n is degree. + pub fn eval(&self, x: i64) -> i64 { + let mut result = 0i64; + + // Horner's method: p(x) = c[0] + x*(c[1] + x*(c[2] + ...)) + for i in (0..=self.degree).rev() { + result = (result.wrapping_mul(x) + self.coeffs[i]) % MODULUS; + if result < 0 { + result += MODULUS; + } + } + + result + } + + /// Evaluate at multiple points (parallel-friendly). + /// Each point evaluation is independent. + pub fn eval_many(&self, points: &[i64], results: &mut [i64]) { + assert!(points.len() <= results.len()); + + // Each evaluation can be done by a different thread + for (i, &x) in points.iter().enumerate() { + results[i] = self.eval(x); + } + } + + /// Add two polynomials. + pub fn add(&self, other: &Self) -> Self { + let mut result = Self::new(); + let max_deg = core::cmp::max(self.degree, other.degree); + + for i in 0..=max_deg { + let a = if i <= self.degree { self.coeffs[i] } else { 0 }; + let b = if i <= other.degree { other.coeffs[i] } else { 0 }; + result.coeffs[i] = (a + b) % MODULUS; + } + result.degree = max_deg; + + result + } + + /// Multiply polynomial by scalar. + pub fn scale(&self, scalar: i64) -> Self { + let mut result = Self::new(); + result.degree = self.degree; + + for i in 0..=self.degree { + result.coeffs[i] = (self.coeffs[i].wrapping_mul(scalar)) % MODULUS; + } + + result + } +} + +impl Default for Polynomial { + fn default() -> Self { + Self::new() + } +} + +/// Batch polynomial evaluation. +/// Evaluate multiple polynomials at multiple points. +/// Each (poly, point) pair is independent (embarrassingly parallel). +pub fn batch_eval( + polys: &[Polynomial], + points: &[i64], + results: &mut [[i64; MAX_POINTS]], +) { + assert!(polys.len() <= results.len()); + assert!(points.len() <= MAX_POINTS); + + // Each polynomial can be evaluated by a different thread + for (i, poly) in polys.iter().enumerate() { + // Each point within a polynomial can also be parallelized + for (j, &x) in points.iter().enumerate() { + results[i][j] = poly.eval(x); + } + } +} + +/// Lagrange basis polynomial L_i(x) at evaluation point. +/// L_i(x) = ∏_{j≠i} (x - x_j) / (x_i - x_j) +pub fn lagrange_basis(points: &[i64], i: usize, x: i64) -> i64 { + let mut numerator: i64 = 1; + let mut denominator: i64 = 1; + + for (j, &x_j) in points.iter().enumerate() { + if j != i { + numerator = (numerator.wrapping_mul(x - x_j)) % MODULUS; + denominator = (denominator.wrapping_mul(points[i] - x_j)) % MODULUS; + } + } + + // Modular division: numerator * denominator^(-1) + // Using Fermat's little theorem: a^(-1) = a^(p-2) mod p + let inv = mod_pow(denominator, MODULUS - 2); + (numerator.wrapping_mul(inv)) % MODULUS +} + +/// Modular exponentiation using binary method. +pub fn mod_pow(base: i64, mut exp: i64) -> i64 { + let mut result = 1i64; + let mut base = base % MODULUS; + + while exp > 0 { + if exp & 1 == 1 { + result = (result.wrapping_mul(base)) % MODULUS; + } + exp >>= 1; + base = (base.wrapping_mul(base)) % MODULUS; + } + + if result < 0 { + result + MODULUS + } else { + result + } +} + +/// Lagrange interpolation at point x. +/// Given (x_i, y_i) pairs, compute p(x) where p interpolates all points. +pub fn lagrange_interpolate(xs: &[i64], ys: &[i64], x: i64) -> i64 { + assert_eq!(xs.len(), ys.len()); + + let mut result = 0i64; + + // Each term can be computed independently (parallel-friendly) + for (i, &y_i) in ys.iter().enumerate() { + let basis = lagrange_basis(xs, i, x); + result = (result + y_i.wrapping_mul(basis)) % MODULUS; + } + + if result < 0 { + result + MODULUS + } else { + result + } +} + +/// Multi-point Lagrange interpolation. +/// Interpolate at multiple evaluation points (parallel-friendly). +pub fn lagrange_interpolate_many( + xs: &[i64], + ys: &[i64], + eval_points: &[i64], + results: &mut [i64], +) { + assert_eq!(xs.len(), ys.len()); + assert!(eval_points.len() <= results.len()); + + // Each evaluation point is independent + for (i, &x) in eval_points.iter().enumerate() { + results[i] = lagrange_interpolate(xs, ys, x); + } +} + +/// Polynomial multiplication (convolution). +/// Useful for combining polynomials. +pub fn poly_mul(a: &Polynomial, b: &Polynomial) -> Polynomial { + let mut result = Polynomial::new(); + let result_degree = a.degree + b.degree; + + if result_degree >= MAX_DEGREE { + // Overflow protection + result.degree = MAX_DEGREE - 1; + } else { + result.degree = result_degree; + } + + // Standard convolution - can be parallelized + for i in 0..=a.degree { + for j in 0..=b.degree { + if i + j < MAX_DEGREE { + let product = (a.coeffs[i].wrapping_mul(b.coeffs[j])) % MODULUS; + result.coeffs[i + j] = (result.coeffs[i + j] + product) % MODULUS; + } + } + } + + result +} + +/// Reed-Solomon encoding: evaluate polynomial at consecutive powers of generator. +/// p(g^0), p(g^1), p(g^2), ..., p(g^(n-1)) +pub fn rs_encode(poly: &Polynomial, generator: i64, num_points: usize, output: &mut [i64]) { + assert!(num_points <= output.len()); + + let mut g_power = 1i64; + + // Each evaluation is independent (parallel-friendly) + for result in output.iter_mut().take(num_points) { + *result = poly.eval(g_power); + g_power = (g_power.wrapping_mul(generator)) % MODULUS; + } +} + +/// Compute polynomial derivative. +/// If p(x) = sum(c_i * x^i), then p'(x) = sum(i * c_i * x^(i-1)) +pub fn derivative(poly: &Polynomial) -> Polynomial { + let mut result = Polynomial::new(); + + if poly.degree == 0 { + return result; + } + + result.degree = poly.degree - 1; + for i in 1..=poly.degree { + result.coeffs[i - 1] = (poly.coeffs[i].wrapping_mul(i as i64)) % MODULUS; + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_constant_poly() { + let poly = Polynomial::from_coeffs(&[5]); + assert_eq!(poly.eval(0), 5); + assert_eq!(poly.eval(100), 5); + } + + #[test] + fn test_linear_poly() { + // p(x) = 2 + 3x + let poly = Polynomial::from_coeffs(&[2, 3]); + assert_eq!(poly.eval(0), 2); + assert_eq!(poly.eval(1), 5); + assert_eq!(poly.eval(10), 32); + } + + #[test] + fn test_quadratic_poly() { + // p(x) = 1 + 2x + x^2 = (1+x)^2 + let poly = Polynomial::from_coeffs(&[1, 2, 1]); + assert_eq!(poly.eval(0), 1); + assert_eq!(poly.eval(1), 4); + assert_eq!(poly.eval(2), 9); + assert_eq!(poly.eval(3), 16); + } + + #[test] + fn test_poly_add() { + let a = Polynomial::from_coeffs(&[1, 2]); // 1 + 2x + let b = Polynomial::from_coeffs(&[3, 4]); // 3 + 4x + let c = a.add(&b); // 4 + 6x + + assert_eq!(c.eval(0), 4); + assert_eq!(c.eval(1), 10); + } + + #[test] + fn test_mod_pow() { + assert_eq!(mod_pow(2, 10), 1024); + assert_eq!(mod_pow(3, 0), 1); + assert_eq!(mod_pow(5, 1), 5); + } + + #[test] + fn test_lagrange_simple() { + // Interpolate through (0,1), (1,2), (2,5) + // This is p(x) = 1 + 0.5x + 0.5x^2, but we use integer math + let xs = [0, 1, 2]; + let ys = [1, 2, 5]; + + // Verify interpolation at known points + assert_eq!(lagrange_interpolate(&xs, &ys, 0), 1); + assert_eq!(lagrange_interpolate(&xs, &ys, 1), 2); + assert_eq!(lagrange_interpolate(&xs, &ys, 2), 5); + } + + #[test] + fn test_derivative() { + // p(x) = 1 + 2x + 3x^2, p'(x) = 2 + 6x + let poly = Polynomial::from_coeffs(&[1, 2, 3]); + let deriv = derivative(&poly); + + assert_eq!(deriv.coeffs[0], 2); + assert_eq!(deriv.coeffs[1], 6); + assert_eq!(deriv.degree, 1); + } + + #[test] + fn test_batch_eval() { + let polys = [ + Polynomial::from_coeffs(&[1, 1]), // 1 + x + Polynomial::from_coeffs(&[0, 0, 1]), // x^2 + ]; + let points = [0, 1, 2, 3]; + let mut results = [[0i64; MAX_POINTS]; 2]; + + batch_eval(&polys, &points, &mut results); + + // p1: 1+x at 0,1,2,3 = 1,2,3,4 + assert_eq!(results[0][0], 1); + assert_eq!(results[0][1], 2); + assert_eq!(results[0][2], 3); + assert_eq!(results[0][3], 4); + + // p2: x^2 at 0,1,2,3 = 0,1,4,9 + assert_eq!(results[1][0], 0); + assert_eq!(results[1][1], 1); + assert_eq!(results[1][2], 4); + assert_eq!(results[1][3], 9); + } +} diff --git a/examples/polynomial-eval/src/main.rs b/examples/polynomial-eval/src/main.rs new file mode 100644 index 0000000..10ed504 --- /dev/null +++ b/examples/polynomial-eval/src/main.rs @@ -0,0 +1,172 @@ +//! Polynomial Evaluation Example +//! +//! Demonstrates parallel polynomial evaluation for zkVM applications. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use polynomial_eval::{ + batch_eval, derivative, lagrange_interpolate, lagrange_interpolate_many, + mod_pow, poly_mul, rs_encode, Polynomial, MAX_POINTS, MODULUS, +}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +#[unsafe(no_mangle)] +fn main() -> ! { + println!("=== Polynomial Evaluation Example ==="); + println!(" Using modulus: {}", MODULUS); + + // Test 1: Basic polynomial evaluation + println!("\nTest 1: Basic Evaluation"); + // p(x) = 5 + 3x + 2x^2 + let poly = Polynomial::from_coeffs(&[5, 3, 2]); + println!(" Polynomial: 5 + 3x + 2x^2"); + println!(" p(0) = {}", poly.eval(0)); + println!(" p(1) = {}", poly.eval(1)); + println!(" p(2) = {}", poly.eval(2)); + println!(" p(10) = {}", poly.eval(10)); + + // Verify: p(2) = 5 + 6 + 8 = 19 + let expected = 5 + 3 * 2 + 2 * 4; + println!(" Expected p(2) = {}: {}", expected, if poly.eval(2) == expected { "PASS" } else { "FAIL" }); + + // Test 2: Multi-point evaluation + println!("\nTest 2: Multi-Point Evaluation"); + let points = [0, 1, 2, 3, 4, 5, 6, 7]; + let mut results = [0i64; 8]; + poly.eval_many(&points, &mut results); + + println!(" Points: 0..7"); + println!(" Results: {} {} {} {} {} {} {} {}", + results[0], results[1], results[2], results[3], + results[4], results[5], results[6], results[7]); + + // Test 3: Polynomial arithmetic + println!("\nTest 3: Polynomial Arithmetic"); + let p1 = Polynomial::from_coeffs(&[1, 2]); // 1 + 2x + let p2 = Polynomial::from_coeffs(&[3, 4]); // 3 + 4x + + let sum = p1.add(&p2); // 4 + 6x + let product = poly_mul(&p1, &p2); // 3 + 10x + 8x^2 + + println!(" p1 = 1 + 2x, p2 = 3 + 4x"); + println!(" p1 + p2 at x=1: {} (expected 10)", sum.eval(1)); + println!(" p1 * p2 at x=1: {} (expected 21)", product.eval(1)); + + // Test 4: Lagrange interpolation + println!("\nTest 4: Lagrange Interpolation"); + // Interpolate through (0,1), (1,4), (2,9), (3,16) = x^2 + 2x + 1 = (x+1)^2 + let xs = [0, 1, 2, 3]; + let ys = [1, 4, 9, 16]; + + println!(" Points: (0,1), (1,4), (2,9), (3,16)"); + + // Verify at known points + let mut all_match = true; + for i in 0..4 { + let interp = lagrange_interpolate(&xs, &ys, xs[i]); + if interp != ys[i] { + all_match = false; + } + } + println!(" Interpolation at known points: {}", if all_match { "PASS" } else { "FAIL" }); + + // Evaluate at new points + let p5 = lagrange_interpolate(&xs, &ys, 5); + let expected_p5 = 36; // (5+1)^2 + println!(" p(5) = {} (expected {}): {}", + p5, expected_p5, if p5 == expected_p5 { "PASS" } else { "FAIL" }); + + // Test 5: Multi-point interpolation + println!("\nTest 5: Multi-Point Interpolation"); + let eval_points = [4, 5, 6, 7]; + let mut interp_results = [0i64; 4]; + lagrange_interpolate_many(&xs, &ys, &eval_points, &mut interp_results); + + println!(" Interpolated values:"); + for (i, &x) in eval_points.iter().enumerate() { + let expected = (x + 1) * (x + 1); + println!(" p({}) = {} (expected {})", x, interp_results[i], expected); + } + + // Test 6: Batch polynomial evaluation + println!("\nTest 6: Batch Evaluation (parallel-friendly)"); + let polys = [ + Polynomial::from_coeffs(&[1, 1]), // 1 + x + Polynomial::from_coeffs(&[0, 0, 1]), // x^2 + Polynomial::from_coeffs(&[1, 1, 1]), // 1 + x + x^2 + ]; + let batch_points = [0, 1, 2, 3]; + let mut batch_results = [[0i64; MAX_POINTS]; 4]; + + batch_eval(&polys, &batch_points, &mut batch_results); + + println!(" 3 polynomials evaluated at 4 points each:"); + for (i, poly_name) in ["1+x", "x^2", "1+x+x^2"].iter().enumerate() { + println!(" {}: {} {} {} {}", + poly_name, + batch_results[i][0], batch_results[i][1], + batch_results[i][2], batch_results[i][3]); + } + + // Test 7: Reed-Solomon encoding + println!("\nTest 7: Reed-Solomon Encoding"); + let message_poly = Polynomial::from_coeffs(&[1, 2, 3, 4]); // Message as polynomial + let generator = 3; // Primitive root + let mut codeword = [0i64; 8]; + + rs_encode(&message_poly, generator, 8, &mut codeword); + + println!(" Message polynomial: 1 + 2x + 3x^2 + 4x^3"); + println!(" Generator: {}", generator); + println!(" Codeword (8 points): {} {} {} {} {} {} {} {}", + codeword[0], codeword[1], codeword[2], codeword[3], + codeword[4], codeword[5], codeword[6], codeword[7]); + + // Test 8: Polynomial derivative + println!("\nTest 8: Polynomial Derivative"); + let p = Polynomial::from_coeffs(&[1, 2, 3, 4]); // 1 + 2x + 3x^2 + 4x^3 + let dp = derivative(&p); // 2 + 6x + 12x^2 + + println!(" p(x) = 1 + 2x + 3x^2 + 4x^3"); + println!(" p'(x) = 2 + 6x + 12x^2"); + println!(" p'(0) = {} (expected 2)", dp.eval(0)); + println!(" p'(1) = {} (expected 20)", dp.eval(1)); + println!(" p'(2) = {} (expected 62)", dp.eval(2)); + + // Test 9: Modular arithmetic + println!("\nTest 9: Modular Arithmetic"); + println!(" 2^10 mod {} = {}", MODULUS, mod_pow(2, 10)); + println!(" 3^100 mod {} = {}", MODULUS, mod_pow(3, 100)); + + // Test Fermat's little theorem: a^(p-1) = 1 mod p + let fermat = mod_pow(7, MODULUS - 1); + println!(" 7^(p-1) mod p = {} (Fermat: should be 1)", fermat); + + // Test 10: Zero polynomial + println!("\nTest 10: Edge Cases"); + let zero_poly = Polynomial::from_coeffs(&[0]); + let const_poly = Polynomial::from_coeffs(&[42]); + + println!(" Zero polynomial at x=100: {}", zero_poly.eval(100)); + println!(" Constant 42 at x=100: {}", const_poly.eval(100)); + + // High degree polynomial + let mut high_coeffs = [0i64; 32]; + for i in 0..32 { + high_coeffs[i] = (i + 1) as i64; + } + let high_poly = Polynomial::from_coeffs(&high_coeffs); + println!(" Degree-31 polynomial at x=2: {}", high_poly.eval(2)); + + println!("\n=== Polynomial Evaluation Example Complete ==="); + + platform::exit(0) +} diff --git a/examples/prefix-sum/Cargo.toml b/examples/prefix-sum/Cargo.toml new file mode 100644 index 0000000..df39093 --- /dev/null +++ b/examples/prefix-sum/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "prefix-sum" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel prefix sum (scan) with stage-based synchronization" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/prefix-sum/src/lib.rs b/examples/prefix-sum/src/lib.rs new file mode 100644 index 0000000..95e0765 --- /dev/null +++ b/examples/prefix-sum/src/lib.rs @@ -0,0 +1,138 @@ +//! Parallel prefix sum (scan) implementation. +//! +//! Implements the Blelloch scan algorithm which naturally decomposes +//! into up-sweep and down-sweep phases that can be parallelized. + +#![no_std] + +/// Compute sequential inclusive prefix sum: out[i] = sum(arr[0..=i]) +pub fn prefix_sum_sequential(arr: &[u64], out: &mut [u64]) { + if arr.is_empty() { + return; + } + + out[0] = arr[0]; + for i in 1..arr.len() { + out[i] = out[i - 1].wrapping_add(arr[i]); + } +} + +/// Compute exclusive prefix sum: out[i] = sum(arr[0..i]) +pub fn prefix_sum_exclusive(arr: &[u64], out: &mut [u64]) { + if arr.is_empty() { + return; + } + + out[0] = 0; + for i in 1..arr.len() { + out[i] = out[i - 1].wrapping_add(arr[i - 1]); + } +} + +/// Parallel-friendly prefix sum using block decomposition. +/// +/// Phase 1: Compute local prefix sums within each block (parallelizable) +/// Phase 2: Compute block totals prefix sum (sequential) +/// Phase 3: Add block offsets to each element (parallelizable) +pub fn prefix_sum_blocked(arr: &[u64], out: &mut [u64], num_blocks: usize) { + let n = arr.len(); + if n == 0 { + return; + } + + let block_size = (n + num_blocks - 1) / num_blocks; + + // Phase 1: Compute local prefix sums within each block + // This phase is embarrassingly parallel + for block in 0..num_blocks { + let start = block * block_size; + let end = core::cmp::min(start + block_size, n); + + if start < n { + // Local prefix sum for this block + out[start] = arr[start]; + for i in (start + 1)..end { + out[i] = out[i - 1].wrapping_add(arr[i]); + } + } + } + + // Phase 2: Compute prefix sum of block totals (sequential) + // block_totals[i] = sum of all elements in blocks 0..i + let mut block_offsets = [0u64; 32]; // Support up to 32 blocks + let mut running_total = 0u64; + for block in 0..num_blocks { + block_offsets[block] = running_total; + + let start = block * block_size; + let end = core::cmp::min(start + block_size, n); + if end > start { + running_total = running_total.wrapping_add(out[end - 1]); + } + } + + // Phase 3: Add block offsets to each element (parallelizable) + for block in 1..num_blocks { + let start = block * block_size; + let end = core::cmp::min(start + block_size, n); + let offset = block_offsets[block]; + + for i in start..end { + out[i] = out[i].wrapping_add(offset); + } + } +} + +/// Verify prefix sum correctness +pub fn verify_prefix_sum(arr: &[u64], prefix: &[u64]) -> bool { + if arr.is_empty() { + return prefix.is_empty(); + } + + let mut expected = arr[0]; + if prefix[0] != expected { + return false; + } + + for i in 1..arr.len() { + expected = expected.wrapping_add(arr[i]); + if prefix[i] != expected { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sequential_prefix_sum() { + let arr = [1, 2, 3, 4, 5]; + let mut out = [0u64; 5]; + prefix_sum_sequential(&arr, &mut out); + assert_eq!(out, [1, 3, 6, 10, 15]); + } + + #[test] + fn test_exclusive_prefix_sum() { + let arr = [1, 2, 3, 4, 5]; + let mut out = [0u64; 5]; + prefix_sum_exclusive(&arr, &mut out); + assert_eq!(out, [0, 1, 3, 6, 10]); + } + + #[test] + fn test_blocked_equals_sequential() { + let arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let mut out_seq = [0u64; 12]; + let mut out_blk = [0u64; 12]; + + prefix_sum_sequential(&arr, &mut out_seq); + prefix_sum_blocked(&arr, &mut out_blk, 4); + + assert_eq!(out_seq, out_blk); + } +} diff --git a/examples/prefix-sum/src/main.rs b/examples/prefix-sum/src/main.rs new file mode 100644 index 0000000..bf38650 --- /dev/null +++ b/examples/prefix-sum/src/main.rs @@ -0,0 +1,89 @@ +//! Parallel prefix sum demo. +//! +//! Demonstrates stage-based parallel prefix sum computation, +//! which can be parallelized across ZeroOS cooperative threads. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use prefix_sum::{prefix_sum_blocked, prefix_sum_sequential, verify_prefix_sum}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +/// Array size +const ARRAY_SIZE: usize = 64; + +/// Number of parallel blocks +const NUM_BLOCKS: usize = 4; + +/// Generate deterministic test data +fn generate_test_data(arr: &mut [u64]) { + for i in 0..arr.len() { + // Simple pattern: 1, 2, 3, ... + arr[i] = (i + 1) as u64; + } +} + +#[no_mangle] +fn main() -> ! { + debug::writeln!("[prefix-sum] Starting prefix sum demo"); + debug::writeln!("[prefix-sum] Array size: {}, Blocks: {}", ARRAY_SIZE, NUM_BLOCKS); + + // Allocate arrays + let mut arr = [0u64; ARRAY_SIZE]; + let mut out_seq = [0u64; ARRAY_SIZE]; + let mut out_blk = [0u64; ARRAY_SIZE]; + + // Generate test data + generate_test_data(&mut arr); + + println!("Input: [{}, {}, {}, {}, ...]", arr[0], arr[1], arr[2], arr[3]); + + // Compute sequential prefix sum + debug::writeln!("[prefix-sum] Computing sequential prefix sum..."); + prefix_sum_sequential(&arr, &mut out_seq); + + // Compute blocked prefix sum + debug::writeln!("[prefix-sum] Computing blocked prefix sum ({} blocks)...", NUM_BLOCKS); + prefix_sum_blocked(&arr, &mut out_blk, NUM_BLOCKS); + + println!("Sequential: [{}, {}, {}, {}, ...]", + out_seq[0], out_seq[1], out_seq[2], out_seq[3]); + println!("Blocked: [{}, {}, {}, {}, ...]", + out_blk[0], out_blk[1], out_blk[2], out_blk[3]); + + // Verify sequential result + let seq_valid = verify_prefix_sum(&arr, &out_seq); + + // Verify blocked matches sequential + let mut blk_matches = true; + for i in 0..ARRAY_SIZE { + if out_seq[i] != out_blk[i] { + blk_matches = false; + break; + } + } + + if seq_valid && blk_matches { + println!("Verification: PASSED"); + debug::writeln!("[prefix-sum] Verification PASSED"); + } else { + println!("Verification: FAILED (seq_valid={}, blk_matches={})", + seq_valid, blk_matches); + debug::writeln!("[prefix-sum] Verification FAILED"); + } + + // Print final sum (should be n*(n+1)/2 for input 1,2,3,...,n) + let expected_sum = (ARRAY_SIZE * (ARRAY_SIZE + 1) / 2) as u64; + let actual_sum = out_seq[ARRAY_SIZE - 1]; + println!("Final sum: {} (expected: {})", actual_sum, expected_sum); + + debug::writeln!("[prefix-sum] Demo complete!"); + platform::exit(0) +} diff --git a/examples/wavelet-transform/Cargo.toml b/examples/wavelet-transform/Cargo.toml new file mode 100644 index 0000000..f91d0b4 --- /dev/null +++ b/examples/wavelet-transform/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "wavelet-transform" +publish = false +version.workspace = true +edition.workspace = true +description = "Parallel wavelet/Haar transform with level partitioning" + +[dependencies] +platform.workspace = true +debug.workspace = true +cfg-if.workspace = true + +[features] +default = ["with-spike"] + +debug = ["platform/debug"] + +std = [ + "platform/std", + "platform/vfs-device-console", + "platform/memory", + "platform/bounds-checks", +] + +with-spike = ["platform/with-spike"] +# with-jolt = ["platform/with-jolt"] # Uncomment when jolt-platform exists + +[target.'cfg(target_os = "none")'.dependencies] +platform = { workspace = true, features = ["memory"] } diff --git a/examples/wavelet-transform/src/lib.rs b/examples/wavelet-transform/src/lib.rs new file mode 100644 index 0000000..27322b0 --- /dev/null +++ b/examples/wavelet-transform/src/lib.rs @@ -0,0 +1,271 @@ +//! Parallel Wavelet Transform Implementation +//! +//! Haar wavelet transform with level-based partitioning. +//! Each level's coefficient pairs can be computed independently. + +#![no_std] + +/// Haar wavelet coefficients at a single level. +/// Average and detail coefficients. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct HaarCoeff { + pub average: i32, + pub detail: i32, +} + +impl HaarCoeff { + pub fn new(average: i32, detail: i32) -> Self { + Self { average, detail } + } +} + +/// Single Haar transform step on two adjacent values. +/// Returns (average, detail) = ((a+b)/2, (a-b)/2) +#[inline] +pub fn haar_step(a: i32, b: i32) -> HaarCoeff { + // Use integer arithmetic to avoid floating point + // Scale up to preserve precision, then scale down + HaarCoeff { + average: (a + b) / 2, + detail: (a - b) / 2, + } +} + +/// Inverse Haar step: reconstruct original values from coefficients. +#[inline] +pub fn haar_inverse_step(coeff: HaarCoeff) -> (i32, i32) { + let a = coeff.average + coeff.detail; + let b = coeff.average - coeff.detail; + (a, b) +} + +/// Single level of Haar transform. +/// Processes pairs of values independently (parallel-friendly). +pub fn haar_level(input: &[i32], averages: &mut [i32], details: &mut [i32]) { + assert_eq!(input.len(), averages.len() * 2); + assert_eq!(averages.len(), details.len()); + + // Each pair is independent - can be parallelized + for i in 0..averages.len() { + let coeff = haar_step(input[i * 2], input[i * 2 + 1]); + averages[i] = coeff.average; + details[i] = coeff.detail; + } +} + +/// Multi-level Haar wavelet transform. +/// Returns all detail coefficients at each level plus final average. +pub struct HaarTransform { + /// Detail coefficients at each level (level 0 = finest) + pub details: [[i32; N]; 8], // Support up to 256-element input + /// Final averages at coarsest level + pub averages: [i32; N], + /// Number of levels computed + pub num_levels: usize, +} + +impl HaarTransform { + pub fn new() -> Self { + Self { + details: [[0; N]; 8], + averages: [0; N], + num_levels: 0, + } + } + + /// Compute full Haar transform. + /// Each level halves the data size. + pub fn transform(&mut self, input: &[i32]) { + let n = input.len(); + assert!(n.is_power_of_two() && n <= N); + + // Copy input to working buffer + for (i, &val) in input.iter().enumerate() { + self.averages[i] = val; + } + + let mut current_len = n; + let mut level = 0; + + // Process each level until we reach single value + while current_len > 1 { + let half_len = current_len / 2; + + // Compute this level's coefficients + // TODO: With threading, each pair in this level is independent + for i in 0..half_len { + let coeff = haar_step(self.averages[i * 2], self.averages[i * 2 + 1]); + self.details[level][i] = coeff.detail; + // Store averages for next level (in-place) + self.averages[i] = coeff.average; + } + + current_len = half_len; + level += 1; + } + + self.num_levels = level; + } + + /// Reconstruct original signal from transform coefficients. + pub fn inverse(&self, output: &mut [i32], original_len: usize) { + assert!(original_len.is_power_of_two() && original_len <= N); + + // Start with final average(s) + output[0] = self.averages[0]; + let mut current_len = 1; + + // Reconstruct each level from coarsest to finest + for level in (0..self.num_levels).rev() { + // Each reconstruction step doubles the data + // TODO: With threading, each pair reconstruction is independent + for i in (0..current_len).rev() { + let coeff = HaarCoeff::new(output[i], self.details[level][i]); + let (a, b) = haar_inverse_step(coeff); + output[i * 2] = a; + output[i * 2 + 1] = b; + } + current_len *= 2; + } + } +} + +impl Default for HaarTransform { + fn default() -> Self { + Self::new() + } +} + +/// Batch transform multiple signals. +/// Each signal is completely independent (embarrassingly parallel). +pub fn batch_transform( + inputs: &[[i32; N]], + transforms: &mut [HaarTransform], +) { + assert_eq!(inputs.len(), transforms.len()); + + // Each transform is independent + for (input, transform) in inputs.iter().zip(transforms.iter_mut()) { + transform.transform(input); + } +} + +/// Compute energy at each decomposition level. +/// Useful for signal analysis. +pub fn level_energy(details: &[i32], len: usize) -> i64 { + let mut energy: i64 = 0; + for &d in details.iter().take(len) { + energy += (d as i64) * (d as i64); + } + energy +} + +/// Simple thresholding for denoising. +/// Zeros out detail coefficients below threshold. +pub fn threshold_details(details: &mut [i32], len: usize, threshold: i32) { + for d in details.iter_mut().take(len) { + if d.abs() < threshold { + *d = 0; + } + } +} + +/// 2D Haar transform for images (separable). +/// Applies 1D transform to rows, then columns. +pub fn haar_2d_level( + input: &[[i32; 8]; 8], + ll: &mut [[i32; 4]; 4], // Low-low (approximation) + lh: &mut [[i32; 4]; 4], // Low-high (horizontal detail) + hl: &mut [[i32; 4]; 4], // High-low (vertical detail) + hh: &mut [[i32; 4]; 4], // High-high (diagonal detail) +) { + // Temporary storage for row transform + let mut row_avg = [[0i32; 4]; 8]; + let mut row_det = [[0i32; 4]; 8]; + + // Step 1: Transform rows (can be parallelized) + for i in 0..8 { + for j in 0..4 { + let coeff = haar_step(input[i][j * 2], input[i][j * 2 + 1]); + row_avg[i][j] = coeff.average; + row_det[i][j] = coeff.detail; + } + } + + // Step 2: Transform columns (can be parallelized) + for j in 0..4 { + for i in 0..4 { + // Transform average columns -> LL, LH + let coeff_avg = haar_step(row_avg[i * 2][j], row_avg[i * 2 + 1][j]); + ll[i][j] = coeff_avg.average; + lh[i][j] = coeff_avg.detail; + + // Transform detail columns -> HL, HH + let coeff_det = haar_step(row_det[i * 2][j], row_det[i * 2 + 1][j]); + hl[i][j] = coeff_det.average; + hh[i][j] = coeff_det.detail; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_haar_step() { + let coeff = haar_step(10, 6); + assert_eq!(coeff.average, 8); + assert_eq!(coeff.detail, 2); + + let (a, b) = haar_inverse_step(coeff); + assert_eq!(a, 10); + assert_eq!(b, 6); + } + + #[test] + fn test_haar_level() { + let input = [4, 2, 8, 6]; + let mut avg = [0; 2]; + let mut det = [0; 2]; + + haar_level(&input, &mut avg, &mut det); + + assert_eq!(avg, [3, 7]); + assert_eq!(det, [1, 1]); + } + + #[test] + fn test_full_transform_inverse() { + let input = [1, 3, 5, 7, 9, 11, 13, 15]; + let mut transform = HaarTransform::<8>::new(); + + transform.transform(&input); + + let mut output = [0i32; 8]; + transform.inverse(&mut output, 8); + + assert_eq!(output, input); + } + + #[test] + fn test_dc_signal() { + // Constant signal should have zero details + let input = [4, 4, 4, 4]; + let mut transform = HaarTransform::<4>::new(); + + transform.transform(&input); + + assert_eq!(transform.averages[0], 4); + assert_eq!(transform.details[0][0], 0); + assert_eq!(transform.details[0][1], 0); + assert_eq!(transform.details[1][0], 0); + } + + #[test] + fn test_level_energy() { + let details = [3, 4, 0, 0]; + let energy = level_energy(&details, 2); + assert_eq!(energy, 9 + 16); // 3² + 4² + } +} diff --git a/examples/wavelet-transform/src/main.rs b/examples/wavelet-transform/src/main.rs new file mode 100644 index 0000000..58e8ff0 --- /dev/null +++ b/examples/wavelet-transform/src/main.rs @@ -0,0 +1,153 @@ +//! Wavelet Transform Example +//! +//! Demonstrates level-partitioned Haar wavelet transform. + +#![cfg_attr(target_os = "none", no_std)] +#![no_main] + +use wavelet_transform::{ + batch_transform, haar_2d_level, level_energy, threshold_details, HaarTransform, +}; + +cfg_if::cfg_if! { + if #[cfg(target_os = "none")] { + use platform::println; + } else { + use std::println; + } +} + +#[unsafe(no_mangle)] +fn main() -> ! { + println!("=== Wavelet Transform Example ==="); + + // Test 1: Simple 1D transform + println!("\nTest 1: 1D Haar Transform"); + let input = [1, 3, 5, 7, 9, 11, 13, 15]; + let mut transform = HaarTransform::<8>::new(); + + transform.transform(&input); + + println!(" Input: {:?}", input); + println!(" Levels computed: {}", transform.num_levels); + println!(" Final average: {}", transform.averages[0]); + println!(" Level 0 details: [{}, {}, {}, {}]", + transform.details[0][0], transform.details[0][1], + transform.details[0][2], transform.details[0][3]); + println!(" Level 1 details: [{}, {}]", + transform.details[1][0], transform.details[1][1]); + println!(" Level 2 details: [{}]", transform.details[2][0]); + + // Test 2: Round-trip reconstruction + println!("\nTest 2: Reconstruction"); + let mut output = [0i32; 8]; + transform.inverse(&mut output, 8); + + println!(" Reconstructed: {:?}", output); + let matches = output == input; + println!(" Perfect reconstruction: {}", if matches { "PASS" } else { "FAIL" }); + + // Test 3: Signal with edge + println!("\nTest 3: Edge Detection"); + let edge_signal = [0, 0, 0, 0, 10, 10, 10, 10]; + let mut edge_transform = HaarTransform::<8>::new(); + + edge_transform.transform(&edge_signal); + + println!(" Input (step edge): {:?}", edge_signal); + println!(" Level 0 details: [{}, {}, {}, {}]", + edge_transform.details[0][0], edge_transform.details[0][1], + edge_transform.details[0][2], edge_transform.details[0][3]); + println!(" Edge detected at level 1: {}", edge_transform.details[1][0] != 0); + + // Test 4: Energy computation + println!("\nTest 4: Level Energy"); + let energy_0 = level_energy(&edge_transform.details[0], 4); + let energy_1 = level_energy(&edge_transform.details[1], 2); + let energy_2 = level_energy(&edge_transform.details[2], 1); + + println!(" Level 0 energy: {}", energy_0); + println!(" Level 1 energy: {}", energy_1); + println!(" Level 2 energy: {}", energy_2); + + // Test 5: Denoising via thresholding + println!("\nTest 5: Denoising"); + let noisy = [10, 11, 9, 12, 50, 51, 49, 52]; + let mut noisy_transform = HaarTransform::<8>::new(); + noisy_transform.transform(&noisy); + + println!(" Noisy input: {:?}", noisy); + println!(" Before threshold, level 0: [{}, {}, {}, {}]", + noisy_transform.details[0][0], noisy_transform.details[0][1], + noisy_transform.details[0][2], noisy_transform.details[0][3]); + + // Threshold small details (noise) + threshold_details(&mut noisy_transform.details[0], 4, 2); + + println!(" After threshold (t=2): [{}, {}, {}, {}]", + noisy_transform.details[0][0], noisy_transform.details[0][1], + noisy_transform.details[0][2], noisy_transform.details[0][3]); + + let mut denoised = [0i32; 8]; + noisy_transform.inverse(&mut denoised, 8); + println!(" Denoised output: {:?}", denoised); + + // Test 6: Batch transform (parallel-friendly) + println!("\nTest 6: Batch Transform (4 signals)"); + let signals: [[i32; 8]; 4] = [ + [1, 2, 3, 4, 5, 6, 7, 8], + [8, 7, 6, 5, 4, 3, 2, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1], + ]; + + let mut transforms = [ + HaarTransform::<8>::new(), + HaarTransform::<8>::new(), + HaarTransform::<8>::new(), + HaarTransform::<8>::new(), + ]; + + batch_transform(&signals, &mut transforms); + + for (i, t) in transforms.iter().enumerate() { + println!(" Signal {}: avg={}, level0_energy={}", + i, t.averages[0], + level_energy(&t.details[0], 4)); + } + + // Test 7: 2D transform (for images) + println!("\nTest 7: 2D Haar Transform (8x8 block)"); + let image: [[i32; 8]; 8] = [ + [10, 10, 10, 10, 20, 20, 20, 20], + [10, 10, 10, 10, 20, 20, 20, 20], + [10, 10, 10, 10, 20, 20, 20, 20], + [10, 10, 10, 10, 20, 20, 20, 20], + [30, 30, 30, 30, 40, 40, 40, 40], + [30, 30, 30, 30, 40, 40, 40, 40], + [30, 30, 30, 30, 40, 40, 40, 40], + [30, 30, 30, 30, 40, 40, 40, 40], + ]; + + let mut ll = [[0i32; 4]; 4]; + let mut lh = [[0i32; 4]; 4]; + let mut hl = [[0i32; 4]; 4]; + let mut hh = [[0i32; 4]; 4]; + + haar_2d_level(&image, &mut ll, &mut lh, &mut hl, &mut hh); + + println!(" LL (approximation) corner: {}", ll[0][0]); + println!(" LH (horizontal) corner: {}", lh[0][0]); + println!(" HL (vertical) corner: {}", hl[0][0]); + println!(" HH (diagonal) corner: {}", hh[0][0]); + + // Check that edges are detected + let has_vertical_edge = hl.iter().any(|row| row.iter().any(|&x| x != 0)); + let has_horizontal_edge = lh.iter().any(|row| row.iter().any(|&x| x != 0)); + println!(" Vertical edge detected: {}", has_vertical_edge); + println!(" Horizontal edge detected: {}", has_horizontal_edge); + + println!("\n=== Wavelet Transform Example Complete ==="); + + platform::exit(0) +}