diff --git a/.gitignore b/.gitignore index c6e6aa2049..4e1f821464 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,11 @@ profile.json.gz # test fixtures benchmarks/fixtures + + +crates/toolchain/tests/rv32im-test-vectors/tests/* +*.o +*.a +*.s +*.txt +riscv/* \ No newline at end of file diff --git a/crates/vm/src/arch/config.rs b/crates/vm/src/arch/config.rs index 3ffbfb74e0..17f7d228eb 100644 --- a/crates/vm/src/arch/config.rs +++ b/crates/vm/src/arch/config.rs @@ -123,6 +123,11 @@ pub const OPENVM_DEFAULT_INIT_FILE_NAME: &str = "openvm_init.rs"; const DEFAULT_U8_BLOCK_SIZE: usize = 4; const DEFAULT_NATIVE_BLOCK_SIZE: usize = 1; +/// The constant block size used for memory accesses when access adapters are disabled. +/// All memory accesses for address spaces 1-3 must use this block size. +/// This is also the block size used by the Boundary AIR for memory bus interactions. +pub const CONST_BLOCK_SIZE: usize = 4; + /// Trait for generating a init.rs file that contains a call to moduli_init!, /// complex_init!, sw_init! with the supported moduli and curves. /// Should be implemented by all VM config structs. @@ -183,6 +188,11 @@ pub struct MemoryConfig { pub decomp: usize, /// Maximum N AccessAdapter AIR to support. pub max_access_adapter_n: usize, + /// Whether access adapters are enabled. When disabled, all memory accesses must be of the + /// standard block size (e.g., 4 for address spaces 1-3). This removes the need for access + /// adapter AIRs and simplifies the memory system. + #[new(value = "true")] + pub access_adapters_enabled: bool, } impl Default for MemoryConfig { @@ -194,7 +204,15 @@ impl Default for MemoryConfig { addr_spaces[RV32_MEMORY_AS as usize].num_cells = MAX_CELLS; addr_spaces[PUBLIC_VALUES_AS as usize].num_cells = DEFAULT_MAX_NUM_PUBLIC_VALUES; addr_spaces[NATIVE_AS as usize].num_cells = MAX_CELLS; - Self::new(3, addr_spaces, POINTER_MAX_BITS, 29, 17, 32) + Self { + addr_space_height: 3, + addr_spaces, + pointer_max_bits: POINTER_MAX_BITS, + timestamp_max_bits: 29, + decomp: 17, + max_access_adapter_n: 32, + access_adapters_enabled: true, + } } } @@ -245,6 +263,36 @@ impl MemoryConfig { .map(|addr_sp| log2_strict_usize(addr_sp.min_block_size) as u8) .collect() } + + /// Returns true if the Native address space (AS 4) is used. + /// Native AS is considered "used" if it has any allocated cells. + pub fn is_native_as_used(&self) -> bool { + self.addr_spaces + .get(NATIVE_AS as usize) + .is_some_and(|config| config.num_cells > 0) + } + + /// Disables access adapters. When disabled, all memory accesses for address spaces 1-3 + /// must use the constant block size (4). Access adapters will only be used for + /// address space 4 (Native) if it is enabled. + pub fn without_access_adapters(mut self) -> Self { + self.access_adapters_enabled = false; + self + } + + /// Enables access adapters. This is the default behavior. + pub fn with_access_adapters(mut self) -> Self { + self.access_adapters_enabled = true; + self + } + + /// Automatically sets `access_adapters_enabled` based on whether Native AS is used. + /// If Native AS is not used, access adapters are disabled since all other address spaces + /// use a fixed block size of 4. + pub fn with_auto_access_adapters(mut self) -> Self { + self.access_adapters_enabled = self.is_native_as_used(); + self + } } /// System-level configuration for the virtual machine. Contains all configuration parameters that @@ -375,6 +423,7 @@ impl SystemConfig { + num_memory_airs( self.continuation_enabled, self.memory_config.max_access_adapter_n, + self.memory_config.access_adapters_enabled, ) } @@ -384,6 +433,33 @@ impl SystemConfig { false => 1, } } + + /// Disables access adapters. When disabled, all memory accesses for address spaces 1-3 + /// must use the constant block size (4). This simplifies the memory system by removing + /// access adapter AIRs. + pub fn without_access_adapters(mut self) -> Self { + self.memory_config.access_adapters_enabled = false; + self + } + + /// Enables access adapters. This is the default behavior. + pub fn with_access_adapters(mut self) -> Self { + self.memory_config.access_adapters_enabled = true; + self + } + + /// Automatically sets `access_adapters_enabled` based on whether Native AS is used. + /// If Native AS is not used, access adapters are disabled since all other address spaces + /// use a fixed block size of 4. + pub fn with_auto_access_adapters(mut self) -> Self { + self.memory_config = self.memory_config.with_auto_access_adapters(); + self + } + + /// Returns true if access adapters are enabled. + pub fn access_adapters_enabled(&self) -> bool { + self.memory_config.access_adapters_enabled + } } impl Default for SystemConfig { diff --git a/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs b/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs index 3429177d11..d75dc2c46b 100644 --- a/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs +++ b/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs @@ -105,6 +105,7 @@ pub struct MemoryCtx { pub boundary_idx: usize, pub merkle_tree_index: Option, pub adapter_offset: usize, + access_adapters_enabled: bool, continuations_enabled: bool, chunk: u32, chunk_bits: u32, @@ -128,6 +129,7 @@ impl MemoryCtx { boundary_idx: config.memory_boundary_air_id(), merkle_tree_index: config.memory_merkle_air_id(), adapter_offset: config.access_adapter_air_id_offset(), + access_adapters_enabled: config.memory_config.access_adapters_enabled, chunk, chunk_bits, memory_dimensions, @@ -210,6 +212,11 @@ impl MemoryCtx { size_bits: u32, num: u32, ) { + // Skip if access adapters are disabled + if !self.access_adapters_enabled { + return; + } + debug_assert!((address_space as usize) < self.min_block_size_bits.len()); // SAFETY: address_space passed is usually a hardcoded constant or derived from an diff --git a/crates/vm/src/arch/execution_mode/metered_cost.rs b/crates/vm/src/arch/execution_mode/metered_cost.rs index 925bd25af2..c92965ad3f 100644 --- a/crates/vm/src/arch/execution_mode/metered_cost.rs +++ b/crates/vm/src/arch/execution_mode/metered_cost.rs @@ -18,6 +18,7 @@ pub const DEFAULT_MAX_COST: u64 = DEFAULT_MAX_SEGMENTS * DEFAULT_SEGMENT_MAX_CEL pub struct AccessAdapterCtx { min_block_size_bits: Vec, idx_offset: usize, + enabled: bool, } impl AccessAdapterCtx { @@ -25,6 +26,7 @@ impl AccessAdapterCtx { Self { min_block_size_bits: config.memory_config.min_block_size_bits(), idx_offset: config.access_adapter_air_id_offset(), + enabled: config.memory_config.access_adapters_enabled, } } @@ -36,6 +38,11 @@ impl AccessAdapterCtx { size_bits: u32, widths: &[usize], ) { + // Skip if access adapters are disabled + if !self.enabled { + return; + } + debug_assert!((address_space as usize) < self.min_block_size_bits.len()); // SAFETY: address_space passed is usually a hardcoded constant or derived from an diff --git a/crates/vm/src/system/memory/adapter/mod.rs b/crates/vm/src/system/memory/adapter/mod.rs index 8b0797dcf6..a9c89fc2ea 100644 --- a/crates/vm/src/system/memory/adapter/mod.rs +++ b/crates/vm/src/system/memory/adapter/mod.rs @@ -58,21 +58,26 @@ impl AccessAdapterInventory { memory_bus: MemoryBus, memory_config: MemoryConfig, ) -> Self { - let rc = range_checker; - let mb = memory_bus; - let tmb = memory_config.timestamp_max_bits; - let maan = memory_config.max_access_adapter_n; - assert!(matches!(maan, 2 | 4 | 8 | 16 | 32)); - let chips: Vec<_> = [ - Self::create_access_adapter_chip::<2>(rc.clone(), mb, tmb, maan), - Self::create_access_adapter_chip::<4>(rc.clone(), mb, tmb, maan), - Self::create_access_adapter_chip::<8>(rc.clone(), mb, tmb, maan), - Self::create_access_adapter_chip::<16>(rc.clone(), mb, tmb, maan), - Self::create_access_adapter_chip::<32>(rc.clone(), mb, tmb, maan), - ] - .into_iter() - .flatten() - .collect(); + // Only create adapter chips if access adapters are enabled + let chips: Vec<_> = if memory_config.access_adapters_enabled { + let rc = range_checker; + let mb = memory_bus; + let tmb = memory_config.timestamp_max_bits; + let maan = memory_config.max_access_adapter_n; + assert!(matches!(maan, 2 | 4 | 8 | 16 | 32)); + [ + Self::create_access_adapter_chip::<2>(rc.clone(), mb, tmb, maan), + Self::create_access_adapter_chip::<4>(rc.clone(), mb, tmb, maan), + Self::create_access_adapter_chip::<8>(rc.clone(), mb, tmb, maan), + Self::create_access_adapter_chip::<16>(rc.clone(), mb, tmb, maan), + Self::create_access_adapter_chip::<32>(rc.clone(), mb, tmb, maan), + ] + .into_iter() + .flatten() + .collect() + } else { + Vec::new() + }; Self { memory_config, chips, diff --git a/crates/vm/src/system/memory/mod.rs b/crates/vm/src/system/memory/mod.rs index 411e7a5473..8c3f48c7f0 100644 --- a/crates/vm/src/system/memory/mod.rs +++ b/crates/vm/src/system/memory/mod.rs @@ -118,20 +118,24 @@ impl MemoryAirInventory { ); MemoryInterfaceAirs::Volatile { boundary } }; - // Memory access adapters - let lt_air = IsLtSubAir::new(range_bus, mem_config.timestamp_max_bits); - let maan = mem_config.max_access_adapter_n; - assert!(matches!(maan, 2 | 4 | 8 | 16 | 32)); - let access_adapters: Vec> = [ - Arc::new(AccessAdapterAir::<2> { memory_bus, lt_air }) as AirRef, - Arc::new(AccessAdapterAir::<4> { memory_bus, lt_air }) as AirRef, - Arc::new(AccessAdapterAir::<8> { memory_bus, lt_air }) as AirRef, - Arc::new(AccessAdapterAir::<16> { memory_bus, lt_air }) as AirRef, - Arc::new(AccessAdapterAir::<32> { memory_bus, lt_air }) as AirRef, - ] - .into_iter() - .take(log2_strict_usize(maan)) - .collect(); + // Memory access adapters - only create if enabled + let access_adapters: Vec> = if mem_config.access_adapters_enabled { + let lt_air = IsLtSubAir::new(range_bus, mem_config.timestamp_max_bits); + let maan = mem_config.max_access_adapter_n; + assert!(matches!(maan, 2 | 4 | 8 | 16 | 32)); + [ + Arc::new(AccessAdapterAir::<2> { memory_bus, lt_air }) as AirRef, + Arc::new(AccessAdapterAir::<4> { memory_bus, lt_air }) as AirRef, + Arc::new(AccessAdapterAir::<8> { memory_bus, lt_air }) as AirRef, + Arc::new(AccessAdapterAir::<16> { memory_bus, lt_air }) as AirRef, + Arc::new(AccessAdapterAir::<32> { memory_bus, lt_air }) as AirRef, + ] + .into_iter() + .take(log2_strict_usize(maan)) + .collect() + } else { + Vec::new() + }; Self { bridge, @@ -159,7 +163,16 @@ impl MemoryAirInventory { /// This is O(1) and returns the length of /// [`MemoryAirInventory::into_airs`]. -pub fn num_memory_airs(is_persistent: bool, max_access_adapter_n: usize) -> usize { - // boundary + { merkle if is_persistent } + access_adapters - 1 + usize::from(is_persistent) + log2_strict_usize(max_access_adapter_n) +pub fn num_memory_airs( + is_persistent: bool, + max_access_adapter_n: usize, + access_adapters_enabled: bool, +) -> usize { + // boundary + { merkle if is_persistent } + access_adapters (if enabled) + let num_adapters = if access_adapters_enabled { + log2_strict_usize(max_access_adapter_n) + } else { + 0 + }; + 1 + usize::from(is_persistent) + num_adapters } diff --git a/crates/vm/src/system/memory/persistent.rs b/crates/vm/src/system/memory/persistent.rs index eeb22cbfd6..6d7d5dedf1 100644 --- a/crates/vm/src/system/memory/persistent.rs +++ b/crates/vm/src/system/memory/persistent.rs @@ -22,7 +22,7 @@ use tracing::instrument; use super::{merkle::SerialReceiver, online::INITIAL_TIMESTAMP, TimestampedValues}; use crate::{ - arch::{hasher::Hasher, ADDR_SPACE_OFFSET}, + arch::{hasher::Hasher, ADDR_SPACE_OFFSET, CONST_BLOCK_SIZE}, system::memory::{ dimensions::MemoryDimensions, offline_checker::MemoryBus, MemoryAddress, MemoryImage, TimestampedEquipartition, @@ -109,16 +109,27 @@ impl Air for PersistentBoundaryA local.expand_direction * local.expand_direction, ); - self.memory_bus - .send( - MemoryAddress::new( - local.address_space, - local.leaf_label * AB::F::from_canonical_usize(CHUNK), - ), - local.values.to_vec(), - local.timestamp, - ) - .eval(builder, local.expand_direction); + // Send memory bus interactions in CONST_BLOCK_SIZE chunks. + // This sends CHUNK/CONST_BLOCK_SIZE messages, each with CONST_BLOCK_SIZE values. + // For CHUNK=8 and CONST_BLOCK_SIZE=4, this sends 2 messages of 4 values each. + let base_pointer: AB::Expr = local.leaf_label.into() * AB::F::from_canonical_usize(CHUNK); + for block_idx in 0..(CHUNK / CONST_BLOCK_SIZE) { + let block_start = block_idx * CONST_BLOCK_SIZE; + let block_values: Vec = local.values[block_start..block_start + CONST_BLOCK_SIZE] + .iter() + .map(|&v| v.into()) + .collect(); + self.memory_bus + .send( + MemoryAddress::new( + local.address_space, + base_pointer.clone() + AB::F::from_canonical_usize(block_start), + ), + block_values, + local.timestamp, + ) + .eval(builder, local.expand_direction); + } } } diff --git a/extensions/bigint/circuit/cuda/src/bigint.cu b/extensions/bigint/circuit/cuda/src/bigint.cu index a116d251a5..cbb80eb2c7 100644 --- a/extensions/bigint/circuit/cuda/src/bigint.cu +++ b/extensions/bigint/circuit/cuda/src/bigint.cu @@ -14,6 +14,8 @@ using namespace riscv; constexpr size_t INT256_NUM_LIMBS = 32; +// Number of 4-byte blocks for 256-bit operations (must match CONST_BLOCK_SIZE in Rust) +constexpr size_t INT256_READ_BLOCKS = 8; // 32 / 4 = 8 using BaseAlu256CoreRecord = BaseAluCoreRecord<32>; using BaseAlu256Core = BaseAluCore<32>; @@ -41,16 +43,16 @@ template using BranchLessThan256CoreCols = BranchLessThanCoreCols; +// READ_BLOCKS = WRITE_BLOCKS = 8 (8 blocks of 4 bytes each) +using Rv32HeapAdapterExecutor256 = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS, INT256_READ_BLOCKS, INT256_READ_BLOCKS>; template struct BaseAlu256Cols { - Rv32HeapAdapterCols adapter; + Rv32HeapAdapterCols adapter; BaseAlu256CoreCols core; }; struct BaseAlu256Record { - Rv32HeapAdapterRecord<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS> adapter; + Rv32HeapAdapterRecord<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS, INT256_READ_BLOCKS, INT256_READ_BLOCKS> adapter; BaseAlu256CoreRecord core; }; @@ -117,16 +119,16 @@ extern "C" int _alu256_tracegen( } // Heap branch adapter instantiation for 256-bit operations -// NUM_READS = 2, READ_SIZE = INT256_NUM_LIMBS (32 bytes) -using Rv32HeapBranchAdapter256 = Rv32HeapBranchAdapter<2, INT256_NUM_LIMBS>; +// NUM_READS = 2, READ_SIZE = INT256_NUM_LIMBS (32 bytes), READ_BLOCKS = 8 +using Rv32HeapBranchAdapter256 = Rv32HeapBranchAdapter<2, INT256_NUM_LIMBS, INT256_READ_BLOCKS>; template struct BranchEqual256Cols { - Rv32HeapBranchAdapterCols adapter; + Rv32HeapBranchAdapterCols adapter; BranchEqual256CoreCols core; }; struct BranchEqual256Record { - Rv32HeapBranchAdapterRecord<2> adapter; + Rv32HeapBranchAdapterRecord<2, INT256_READ_BLOCKS> adapter; BranchEqual256CoreRecord core; }; @@ -193,12 +195,12 @@ extern "C" int _branch_equal256_tracegen( } template struct LessThan256Cols { - Rv32HeapAdapterCols adapter; + Rv32HeapAdapterCols adapter; LessThan256CoreCols core; }; struct LessThan256Record { - Rv32HeapAdapterRecord<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS> adapter; + Rv32HeapAdapterRecord<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS, INT256_READ_BLOCKS, INT256_READ_BLOCKS> adapter; LessThan256CoreRecord core; }; @@ -218,7 +220,7 @@ __global__ void less_than256_tracegen( if (idx < d_records.len()) { auto const &rec = d_records[idx]; - Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS> adapter( + Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS, INT256_READ_BLOCKS, INT256_READ_BLOCKS> adapter( pointer_max_bits, VariableRangeChecker(d_range_checker_ptr, range_checker_bins), BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits), @@ -265,12 +267,12 @@ extern "C" int _less_than256_tracegen( } template struct BranchLessThan256Cols { - Rv32HeapBranchAdapterCols adapter; + Rv32HeapBranchAdapterCols adapter; BranchLessThan256CoreCols core; }; struct BranchLessThan256Record { - Rv32HeapBranchAdapterRecord<2> adapter; + Rv32HeapBranchAdapterRecord<2, INT256_READ_BLOCKS> adapter; BranchLessThan256CoreRecord core; }; @@ -337,12 +339,12 @@ extern "C" int _branch_less_than256_tracegen( } template struct Shift256Cols { - Rv32HeapAdapterCols adapter; + Rv32HeapAdapterCols adapter; Shift256CoreCols core; }; struct Shift256Record { - Rv32HeapAdapterRecord<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS> adapter; + Rv32HeapAdapterRecord<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS, INT256_READ_BLOCKS, INT256_READ_BLOCKS> adapter; Shift256CoreRecord core; }; @@ -362,7 +364,7 @@ __global__ void shift256_tracegen( if (idx < d_records.len()) { auto const &rec = d_records[idx]; - Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS> adapter( + Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS, INT256_READ_BLOCKS, INT256_READ_BLOCKS> adapter( pointer_max_bits, VariableRangeChecker(d_range_checker_ptr, range_checker_bins), BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits), @@ -412,12 +414,12 @@ extern "C" int _shift256_tracegen( } template struct Multiplication256Cols { - Rv32HeapAdapterCols adapter; + Rv32HeapAdapterCols adapter; Multiplication256CoreCols core; }; struct Multiplication256Record { - Rv32HeapAdapterRecord<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS> adapter; + Rv32HeapAdapterRecord<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS, INT256_READ_BLOCKS, INT256_READ_BLOCKS> adapter; Multiplication256CoreRecord core; }; @@ -439,7 +441,7 @@ __global__ void multiplication256_tracegen( if (idx < d_records.len()) { auto const &rec = d_records[idx]; - Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS> adapter( + Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS, INT256_READ_BLOCKS, INT256_READ_BLOCKS> adapter( pointer_max_bits, VariableRangeChecker(d_range_checker_ptr, range_checker_bins), BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits), diff --git a/extensions/bigint/circuit/src/base_alu.rs b/extensions/bigint/circuit/src/base_alu.rs index 6a8e49a239..04990c9016 100644 --- a/extensions/bigint/circuit/src/base_alu.rs +++ b/extensions/bigint/circuit/src/base_alu.rs @@ -18,7 +18,7 @@ use openvm_rv32im_transpiler::BaseAluOpcode; use openvm_stark_backend::p3_field::PrimeField32; use crate::{ - common::{bytes_to_u64_array, u64_array_to_bytes}, + common::{bytes_to_u64_array, u64_array_to_bytes, vm_read_256, vm_write_256}, Rv32BaseAlu256Executor, INT256_NUM_LIMBS, }; @@ -142,12 +142,12 @@ unsafe fn execute_e12_impl( let rs1_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); let rs2_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); let rd_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); - let rs1 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); - let rs2 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + // Read 32 bytes using 8×4-byte reads to avoid access adapters + let rs1 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); let rd = ::compute(rs1, rs2); - exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + // Write 32 bytes using 8×4-byte writes to avoid access adapters + vm_write_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); let pc = exec_state.pc(); exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP)); } diff --git a/extensions/bigint/circuit/src/branch_eq.rs b/extensions/bigint/circuit/src/branch_eq.rs index 4732f6f9a7..6ad7ddc65f 100644 --- a/extensions/bigint/circuit/src/branch_eq.rs +++ b/extensions/bigint/circuit/src/branch_eq.rs @@ -14,7 +14,10 @@ use openvm_rv32im_circuit::BranchEqualExecutor; use openvm_rv32im_transpiler::BranchEqualOpcode; use openvm_stark_backend::p3_field::PrimeField32; -use crate::{common::bytes_to_u64_array, Rv32BranchEqual256Executor, INT256_NUM_LIMBS}; +use crate::{ + common::{bytes_to_u64_array, vm_read_256}, + Rv32BranchEqual256Executor, INT256_NUM_LIMBS, +}; type AdapterExecutor = Rv32HeapBranchAdapterExecutor<2, INT256_NUM_LIMBS>; @@ -131,10 +134,9 @@ unsafe fn execute_e12_impl(RV32_REGISTER_AS, pre_compute.a as u32); let rs2_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); - let rs1 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); - let rs2 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + // Use chunked reads: 8×4-byte reads instead of 1×32-byte read + let rs1 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); let cmp_result = u256_eq(rs1, rs2); if cmp_result ^ IS_NE { pc = (pc as isize + pre_compute.imm) as u32; diff --git a/extensions/bigint/circuit/src/branch_lt.rs b/extensions/bigint/circuit/src/branch_lt.rs index 8dc294d70d..40103fb738 100644 --- a/extensions/bigint/circuit/src/branch_lt.rs +++ b/extensions/bigint/circuit/src/branch_lt.rs @@ -18,7 +18,7 @@ use openvm_rv32im_transpiler::BranchLessThanOpcode; use openvm_stark_backend::p3_field::PrimeField32; use crate::{ - common::{i256_lt, u256_lt}, + common::{i256_lt, u256_lt, vm_read_256}, Rv32BranchLessThan256Executor, INT256_NUM_LIMBS, }; @@ -139,10 +139,9 @@ unsafe fn execute_e12_impl(RV32_REGISTER_AS, pre_compute.a as u32); let rs2_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); - let rs1 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); - let rs2 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + // Use chunked reads: 8×4-byte reads instead of 1×32-byte read + let rs1 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); let cmp_result = OP::compute(rs1, rs2); if cmp_result { pc = (pc as isize + pre_compute.imm) as u32; diff --git a/extensions/bigint/circuit/src/common.rs b/extensions/bigint/circuit/src/common.rs index 329cf1d479..9a245c5353 100644 --- a/extensions/bigint/circuit/src/common.rs +++ b/extensions/bigint/circuit/src/common.rs @@ -1,5 +1,51 @@ +use openvm_circuit::{ + arch::{ExecutionCtxTrait, VmExecState, CONST_BLOCK_SIZE}, + system::memory::online::GuestMemory, +}; + use crate::{INT256_NUM_LIMBS, RV32_CELL_BITS}; +/// Number of blocks per 256-bit (32-byte) read/write. +pub const NUM_BLOCKS: usize = INT256_NUM_LIMBS / CONST_BLOCK_SIZE; + +/// Read INT256_NUM_LIMBS bytes from memory using NUM_BLOCKS reads of CONST_BLOCK_SIZE bytes each. +/// This ensures all memory bus interactions use the constant 4-byte block size. +#[inline(always)] +pub fn vm_read_256( + exec_state: &mut VmExecState, + addr_space: u32, + ptr: u32, +) -> [u8; INT256_NUM_LIMBS] { + let mut result = [0u8; INT256_NUM_LIMBS]; + for i in 0..NUM_BLOCKS { + let chunk: [u8; CONST_BLOCK_SIZE] = exec_state + .vm_read::(addr_space, ptr + (i * CONST_BLOCK_SIZE) as u32); + result[i * CONST_BLOCK_SIZE..(i + 1) * CONST_BLOCK_SIZE].copy_from_slice(&chunk); + } + result +} + +/// Write INT256_NUM_LIMBS bytes to memory using NUM_BLOCKS writes of CONST_BLOCK_SIZE bytes each. +/// This ensures all memory bus interactions use the constant 4-byte block size. +#[inline(always)] +pub fn vm_write_256( + exec_state: &mut VmExecState, + addr_space: u32, + ptr: u32, + data: &[u8; INT256_NUM_LIMBS], +) { + for i in 0..NUM_BLOCKS { + let chunk: [u8; CONST_BLOCK_SIZE] = data[i * CONST_BLOCK_SIZE..(i + 1) * CONST_BLOCK_SIZE] + .try_into() + .unwrap(); + exec_state.vm_write::( + addr_space, + ptr + (i * CONST_BLOCK_SIZE) as u32, + &chunk, + ); + } +} + #[inline(always)] pub fn bytes_to_u64_array(bytes: [u8; INT256_NUM_LIMBS]) -> [u64; 4] { // SAFETY: [u8; 32] to [u64; 4] transmute is safe - same size and compatible alignment diff --git a/extensions/bigint/circuit/src/cuda/mod.rs b/extensions/bigint/circuit/src/cuda/mod.rs index 6dc6f535a5..71d88de190 100644 --- a/extensions/bigint/circuit/src/cuda/mod.rs +++ b/extensions/bigint/circuit/src/cuda/mod.rs @@ -14,8 +14,8 @@ use openvm_cuda_backend::{ }; use openvm_cuda_common::copy::MemCopyH2D; use openvm_rv32_adapters::{ - Rv32HeapBranchAdapterCols, Rv32HeapBranchAdapterRecord, Rv32VecHeapAdapterCols, - Rv32VecHeapAdapterRecord, + Rv32HeapAdapterCols, Rv32HeapAdapterRecord, Rv32HeapBranchAdapterCols, + Rv32HeapBranchAdapterRecord, }; use openvm_rv32im_circuit::{ adapters::{INT256_NUM_LIMBS, RV32_CELL_BITS}, @@ -25,13 +25,24 @@ use openvm_rv32im_circuit::{ }; use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; +use openvm_circuit::arch::CONST_BLOCK_SIZE; + mod cuda_abi; +/// Number of memory blocks for 256-bit (32-byte) operations. +/// READ_BLOCKS = INT256_NUM_LIMBS / CONST_BLOCK_SIZE = 32 / 4 = 8 +const INT256_READ_BLOCKS: usize = INT256_NUM_LIMBS / CONST_BLOCK_SIZE; + ////////////////////////////////////////////////////////////////////////////////////// /// ALU ////////////////////////////////////////////////////////////////////////////////////// -pub type BaseAlu256AdapterRecord = - Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; +pub type BaseAlu256AdapterRecord = Rv32HeapAdapterRecord< + 2, + INT256_NUM_LIMBS, + INT256_NUM_LIMBS, + INT256_READ_BLOCKS, + INT256_READ_BLOCKS, +>; pub type BaseAlu256CoreRecord = BaseAluCoreRecord; #[derive(new)] @@ -52,7 +63,14 @@ impl Chip for BaseAlu256ChipGpu { debug_assert_eq!(records.len() % RECORD_SIZE, 0); let trace_width = BaseAluCoreCols::::width() - + Rv32VecHeapAdapterCols::::width(); + + Rv32HeapAdapterCols::< + F, + 2, + INT256_NUM_LIMBS, + INT256_NUM_LIMBS, + INT256_READ_BLOCKS, + INT256_READ_BLOCKS, + >::width(); let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); let d_records = records.to_device().unwrap(); @@ -79,7 +97,7 @@ impl Chip for BaseAlu256ChipGpu { ////////////////////////////////////////////////////////////////////////////////////// /// Branch Equal ////////////////////////////////////////////////////////////////////////////////////// -pub type BranchEqual256AdapterRecord = Rv32HeapBranchAdapterRecord<2>; +pub type BranchEqual256AdapterRecord = Rv32HeapBranchAdapterRecord<2, INT256_READ_BLOCKS>; pub type BranchEqual256CoreRecord = BranchEqualCoreRecord; #[derive(new)] @@ -101,7 +119,7 @@ impl Chip for BranchEqual256ChipGpu { debug_assert_eq!(records.len() % RECORD_SIZE, 0); let trace_width = BranchEqualCoreCols::::width() - + Rv32HeapBranchAdapterCols::::width(); + + Rv32HeapBranchAdapterCols::::width(); let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); let d_records = records.to_device().unwrap(); @@ -128,8 +146,13 @@ impl Chip for BranchEqual256ChipGpu { ////////////////////////////////////////////////////////////////////////////////////// /// Less Than ////////////////////////////////////////////////////////////////////////////////////// -pub type LessThan256AdapterRecord = - Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; +pub type LessThan256AdapterRecord = Rv32HeapAdapterRecord< + 2, + INT256_NUM_LIMBS, + INT256_NUM_LIMBS, + INT256_READ_BLOCKS, + INT256_READ_BLOCKS, +>; pub type LessThan256CoreRecord = LessThanCoreRecord; #[derive(new)] @@ -150,7 +173,14 @@ impl Chip for LessThan256ChipGpu { debug_assert_eq!(records.len() % RECORD_SIZE, 0); let trace_width = LessThanCoreCols::::width() - + Rv32VecHeapAdapterCols::::width(); + + Rv32HeapAdapterCols::< + F, + 2, + INT256_NUM_LIMBS, + INT256_NUM_LIMBS, + INT256_READ_BLOCKS, + INT256_READ_BLOCKS, + >::width(); let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); let d_records = records.to_device().unwrap(); @@ -177,7 +207,7 @@ impl Chip for LessThan256ChipGpu { ////////////////////////////////////////////////////////////////////////////////////// /// Branch Less Than ////////////////////////////////////////////////////////////////////////////////////// -pub type BranchLessThan256AdapterRecord = Rv32HeapBranchAdapterRecord<2>; +pub type BranchLessThan256AdapterRecord = Rv32HeapBranchAdapterRecord<2, INT256_READ_BLOCKS>; pub type BranchLessThan256CoreRecord = BranchLessThanCoreRecord; #[derive(new)] @@ -199,7 +229,7 @@ impl Chip for BranchLessThan256ChipGpu { debug_assert_eq!(records.len() % RECORD_SIZE, 0); let trace_width = BranchLessThanCoreCols::::width() - + Rv32HeapBranchAdapterCols::::width(); + + Rv32HeapBranchAdapterCols::::width(); let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); let d_records = records.to_device().unwrap(); @@ -226,8 +256,13 @@ impl Chip for BranchLessThan256ChipGpu { ////////////////////////////////////////////////////////////////////////////////////// /// Shift ////////////////////////////////////////////////////////////////////////////////////// -pub type Shift256AdapterRecord = - Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; +pub type Shift256AdapterRecord = Rv32HeapAdapterRecord< + 2, + INT256_NUM_LIMBS, + INT256_NUM_LIMBS, + INT256_READ_BLOCKS, + INT256_READ_BLOCKS, +>; pub type Shift256CoreRecord = ShiftCoreRecord; #[derive(new)] @@ -248,7 +283,14 @@ impl Chip for Shift256ChipGpu { debug_assert_eq!(records.len() % RECORD_SIZE, 0); let trace_width = ShiftCoreCols::::width() - + Rv32VecHeapAdapterCols::::width(); + + Rv32HeapAdapterCols::< + F, + 2, + INT256_NUM_LIMBS, + INT256_NUM_LIMBS, + INT256_READ_BLOCKS, + INT256_READ_BLOCKS, + >::width(); let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); let d_records = records.to_device().unwrap(); @@ -275,8 +317,13 @@ impl Chip for Shift256ChipGpu { ////////////////////////////////////////////////////////////////////////////////////// /// Multiplication ////////////////////////////////////////////////////////////////////////////////////// -pub type Multiplication256AdapterRecord = - Rv32VecHeapAdapterRecord<2, 1, 1, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; +pub type Multiplication256AdapterRecord = Rv32HeapAdapterRecord< + 2, + INT256_NUM_LIMBS, + INT256_NUM_LIMBS, + INT256_READ_BLOCKS, + INT256_READ_BLOCKS, +>; pub type Multiplication256CoreRecord = MultiplicationCoreRecord; #[derive(new)] @@ -299,7 +346,14 @@ impl Chip for Multiplication256ChipGpu { debug_assert_eq!(records.len() % RECORD_SIZE, 0); let trace_width = MultiplicationCoreCols::::width() - + Rv32VecHeapAdapterCols::::width(); + + Rv32HeapAdapterCols::< + F, + 2, + INT256_NUM_LIMBS, + INT256_NUM_LIMBS, + INT256_READ_BLOCKS, + INT256_READ_BLOCKS, + >::width(); let trace_height = next_power_of_two_or_zero(records.len() / RECORD_SIZE); let d_records = records.to_device().unwrap(); diff --git a/extensions/bigint/circuit/src/extension/mod.rs b/extensions/bigint/circuit/src/extension/mod.rs index 0adb7fc595..466301af8d 100644 --- a/extensions/bigint/circuit/src/extension/mod.rs +++ b/extensions/bigint/circuit/src/extension/mod.rs @@ -32,6 +32,10 @@ use openvm_stark_backend::{ p3_field::PrimeField32, prover::cpu::{CpuBackend, CpuDevice}, }; +use openvm_rv32_adapters::{ + Rv32HeapAdapterAir, Rv32HeapAdapterExecutor, Rv32HeapAdapterFiller, Rv32HeapBranchAdapterAir, + Rv32HeapBranchAdapterExecutor, Rv32HeapBranchAdapterFiller, +}; use serde::{Deserialize, Serialize}; use crate::*; diff --git a/extensions/bigint/circuit/src/less_than.rs b/extensions/bigint/circuit/src/less_than.rs index 68861d8ba0..ce8b8a5488 100644 --- a/extensions/bigint/circuit/src/less_than.rs +++ b/extensions/bigint/circuit/src/less_than.rs @@ -17,7 +17,10 @@ use openvm_rv32im_circuit::LessThanExecutor; use openvm_rv32im_transpiler::LessThanOpcode; use openvm_stark_backend::p3_field::PrimeField32; -use crate::{common, Rv32LessThan256Executor, INT256_NUM_LIMBS}; +use crate::{ + common::{self, vm_read_256, vm_write_256}, + Rv32LessThan256Executor, INT256_NUM_LIMBS, +}; type AdapterExecutor = Rv32HeapAdapterExecutor<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; @@ -134,10 +137,9 @@ unsafe fn execute_e12_impl(RV32_REGISTER_AS, pre_compute.b as u32); let rs2_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); let rd_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); - let rs1 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); - let rs2 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + // Read 32 bytes using 8×4-byte reads to avoid access adapters + let rs1 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); let cmp_result = if IS_U256 { common::u256_lt(rs1, rs2) } else { @@ -145,7 +147,8 @@ unsafe fn execute_e12_impl, BaseAluCoreAir, @@ -63,7 +63,7 @@ pub type Rv32BaseAlu256Chip = VmChipWrapper< >, >; -/// LessThan256 +/// LessThan256 - uses 8×4-byte memory bus interactions pub type Rv32LessThan256Air = VmAirWrapper< Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, LessThanCoreAir, @@ -85,7 +85,7 @@ pub type Rv32LessThan256Chip = VmChipWrapper< >, >; -/// Multiplication256 +/// Multiplication256 - uses 8×4-byte memory bus interactions pub type Rv32Multiplication256Air = VmAirWrapper< Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, MultiplicationCoreAir, @@ -107,7 +107,7 @@ pub type Rv32Multiplication256Chip = VmChipWrapper< >, >; -/// Shift256 +/// Shift256 - uses 8×4-byte memory bus interactions pub type Rv32Shift256Air = VmAirWrapper< Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, ShiftCoreAir, @@ -129,7 +129,7 @@ pub type Rv32Shift256Chip = VmChipWrapper< >, >; -/// BranchEqual256 +/// BranchEqual256 - uses 8×4-byte memory bus interactions pub type Rv32BranchEqual256Air = VmAirWrapper< Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, BranchEqualCoreAir, @@ -143,7 +143,7 @@ pub type Rv32BranchEqual256Chip = VmChipWrapper< BranchEqualFiller, INT256_NUM_LIMBS>, >; -/// BranchLessThan256 +/// BranchLessThan256 - uses 8×4-byte memory bus interactions pub type Rv32BranchLessThan256Air = VmAirWrapper< Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, BranchLessThanCoreAir, diff --git a/extensions/bigint/circuit/src/mult.rs b/extensions/bigint/circuit/src/mult.rs index 2eff4b9096..0417eabea5 100644 --- a/extensions/bigint/circuit/src/mult.rs +++ b/extensions/bigint/circuit/src/mult.rs @@ -15,7 +15,7 @@ use openvm_rv32im_transpiler::MulOpcode; use openvm_stark_backend::p3_field::PrimeField32; use crate::{ - common::{bytes_to_u32_array, u32_array_to_bytes}, + common::{bytes_to_u32_array, u32_array_to_bytes, vm_read_256, vm_write_256}, Rv32Multiplication256Executor, INT256_NUM_LIMBS, }; @@ -125,12 +125,12 @@ unsafe fn execute_e12_impl( let rs1_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); let rs2_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); let rd_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); - let rs1 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); - let rs2 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + // Read 32 bytes using 8×4-byte reads to avoid access adapters + let rs1 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); let rd = u256_mul(rs1, rs2); - exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + // Write 32 bytes using 8×4-byte writes to avoid access adapters + vm_write_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); let pc = exec_state.pc(); exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP)); diff --git a/extensions/bigint/circuit/src/shift.rs b/extensions/bigint/circuit/src/shift.rs index c08afc26e0..aa05bb1920 100644 --- a/extensions/bigint/circuit/src/shift.rs +++ b/extensions/bigint/circuit/src/shift.rs @@ -18,7 +18,7 @@ use openvm_rv32im_transpiler::ShiftOpcode; use openvm_stark_backend::p3_field::PrimeField32; use crate::{ - common::{bytes_to_u64_array, u64_array_to_bytes}, + common::{bytes_to_u64_array, u64_array_to_bytes, vm_read_256, vm_write_256}, Rv32Shift256Executor, INT256_NUM_LIMBS, }; @@ -138,12 +138,12 @@ unsafe fn execute_e12_impl let rs1_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); let rs2_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); let rd_ptr = exec_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); - let rs1 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); - let rs2 = - exec_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + // Read 32 bytes using 8×4-byte reads to avoid access adapters + let rs1 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_read_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); let rd = OP::compute(rs1, rs2); - exec_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + // Write 32 bytes using 8×4-byte writes to avoid access adapters + vm_write_256(exec_state, RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); let pc = exec_state.pc(); exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP)); } diff --git a/extensions/keccak256/circuit/src/columns.rs b/extensions/keccak256/circuit/src/columns.rs index a14ba4dcea..a5a0f19c3a 100644 --- a/extensions/keccak256/circuit/src/columns.rs +++ b/extensions/keccak256/circuit/src/columns.rs @@ -1,6 +1,9 @@ use core::mem::size_of; -use openvm_circuit::system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}; +use openvm_circuit::{ + arch::CONST_BLOCK_SIZE, + system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, +}; use openvm_circuit_primitives::utils::assert_array_eq; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; @@ -96,7 +99,7 @@ pub struct KeccakSpongeCols { pub struct KeccakMemoryCols { pub register_aux: [MemoryReadAuxCols; KECCAK_REGISTER_READS], pub absorb_reads: [MemoryReadAuxCols; KECCAK_ABSORB_READS], - pub digest_writes: [MemoryWriteAuxCols; KECCAK_DIGEST_WRITES], + pub digest_writes: [MemoryWriteAuxCols; KECCAK_DIGEST_WRITES], /// The input bytes are batch read in blocks of private constant KECCAK_WORD_SIZE bytes. /// However if the input length is not a multiple of KECCAK_WORD_SIZE, we read into /// `partial_block` more bytes than we need. On the other hand `block_bytes` expects diff --git a/extensions/keccak256/circuit/src/execution.rs b/extensions/keccak256/circuit/src/execution.rs index 025786c685..38e872804d 100644 --- a/extensions/keccak256/circuit/src/execution.rs +++ b/extensions/keccak256/circuit/src/execution.rs @@ -15,7 +15,7 @@ use openvm_keccak256_transpiler::Rv32KeccakOpcode; use openvm_stark_backend::p3_field::PrimeField32; use p3_keccak_air::NUM_ROUNDS; -use super::{KeccakVmExecutor, KECCAK_WORD_SIZE}; +use super::{KeccakVmExecutor, KECCAK_DIGEST_WRITES, KECCAK_WORD_SIZE}; use crate::utils::{keccak256, num_keccak_f}; #[derive(AlignedBytesBorrow, Clone)] @@ -149,26 +149,34 @@ unsafe fn execute_e12_impl = (0..num_reads) + .flat_map(|i| { + exec_state.vm_read::( + RV32_MEMORY_AS, + src_u32 + (i * KECCAK_WORD_SIZE) as u32, + ) + }) + .collect(); + let (output, height) = if IS_E1 { - // SAFETY: RV32_MEMORY_AS is memory address space of type u8 - let message = exec_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize); - let output = keccak256(message); + let output = keccak256(&message[..len_u32 as usize]); (output, 0) } else { - let num_reads = (len_u32 as usize).div_ceil(KECCAK_WORD_SIZE); - let message: Vec<_> = (0..num_reads) - .flat_map(|i| { - exec_state.vm_read::( - RV32_MEMORY_AS, - src_u32 + (i * KECCAK_WORD_SIZE) as u32, - ) - }) - .collect(); let output = keccak256(&message[..len_u32 as usize]); let height = (num_keccak_f(len_u32 as usize) * NUM_ROUNDS) as u32; (output, height) }; - exec_state.vm_write(RV32_MEMORY_AS, dst_u32, &output); + + for i in 0..KECCAK_DIGEST_WRITES { + exec_state.vm_write::( + RV32_MEMORY_AS, + dst_u32 + (i * KECCAK_WORD_SIZE) as u32, + &output[i * KECCAK_WORD_SIZE..(i + 1) * KECCAK_WORD_SIZE] + .try_into() + .unwrap(), + ); + } let pc = exec_state.pc(); exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP)); diff --git a/extensions/keccak256/circuit/src/lib.rs b/extensions/keccak256/circuit/src/lib.rs index 00e57033e5..3aeba5d079 100644 --- a/extensions/keccak256/circuit/src/lib.rs +++ b/extensions/keccak256/circuit/src/lib.rs @@ -5,6 +5,7 @@ //! Stateful keccak256 hasher. Handles full keccak sponge (padding, absorb, keccak-f) on //! variable length inputs read from VM memory. +use openvm_circuit::arch::CONST_BLOCK_SIZE; use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; pub mod air; @@ -29,7 +30,7 @@ use openvm_circuit::arch::*; /// Register reads to get dst, src, len const KECCAK_REGISTER_READS: usize = 3; /// Number of cells to read/write in a single memory access -const KECCAK_WORD_SIZE: usize = 4; +const KECCAK_WORD_SIZE: usize = CONST_BLOCK_SIZE; /// Memory reads for absorb per row const KECCAK_ABSORB_READS: usize = KECCAK_RATE_BYTES / KECCAK_WORD_SIZE; /// Memory writes for digest per row diff --git a/extensions/native/circuit/src/fri/execution.rs b/extensions/native/circuit/src/fri/execution.rs index 675caa1143..88db650593 100644 --- a/extensions/native/circuit/src/fri/execution.rs +++ b/extensions/native/circuit/src/fri/execution.rs @@ -9,7 +9,9 @@ use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_native_compiler::conversion::AS; use openvm_stark_backend::p3_field::PrimeField32; -use super::{elem_to_ext, FriReducedOpeningExecutor}; +use super::{ + elem_to_ext, FriReducedOpeningExecutor, FRI_READ_SUBBLOCKS, FRI_WRITE_SUBBLOCKS, +}; use crate::field_extension::{FieldExtension, EXT_DEG}; #[derive(AlignedBytesBorrow, Clone)] @@ -225,7 +227,14 @@ unsafe fn execute_e12_impl( exec_state.vm_read(AS::Native as u32, a_ptr_i) }; let b_ptr_i = (b_ptr + F::from_canonical_usize(EXT_DEG * i)).as_canonical_u32(); - let b = exec_state.vm_read(AS::Native as u32, b_ptr_i); + let mut b = [F::ZERO; EXT_DEG]; + for sub in 0..FRI_READ_SUBBLOCKS { + let chunk: [F; CONST_BLOCK_SIZE] = exec_state.vm_read( + AS::Native as u32, + b_ptr_i + (sub * CONST_BLOCK_SIZE) as u32, + ); + b[sub * CONST_BLOCK_SIZE..][..CONST_BLOCK_SIZE].copy_from_slice(&chunk); + } as_and_bs.push((a, b)); } @@ -239,7 +248,15 @@ unsafe fn execute_e12_impl( ); } - exec_state.vm_write(AS::Native as u32, pre_compute.result_ptr, &result); + for sub in 0..FRI_WRITE_SUBBLOCKS { + let chunk: [F; CONST_BLOCK_SIZE] = + core::array::from_fn(|i| result[sub * CONST_BLOCK_SIZE + i]); + exec_state.vm_write( + AS::Native as u32, + pre_compute.result_ptr + (sub * CONST_BLOCK_SIZE) as u32, + &chunk, + ); + } let pc = exec_state.pc(); exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP)); diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 1e1ec65cb8..cb47753a5b 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -6,7 +6,7 @@ use std::{ use itertools::zip_eq; use openvm_circuit::{ - arch::*, + arch::{self, *}, system::{ memory::{ offline_checker::{ @@ -38,6 +38,10 @@ use crate::{ utils::const_max, }; +const_assert_eq!(EXT_DEG % arch::CONST_BLOCK_SIZE, 0); +const FRI_READ_SUBBLOCKS: usize = EXT_DEG / arch::CONST_BLOCK_SIZE; +const FRI_WRITE_SUBBLOCKS: usize = EXT_DEG / arch::CONST_BLOCK_SIZE; + mod execution; #[cfg(feature = "cuda")] @@ -56,10 +60,9 @@ struct WorkloadCols { a_aux: MemoryWriteAuxCols, /// The value of `b` read. b: [T; EXT_DEG], - b_aux: MemoryReadAuxCols, + b_aux: [MemoryReadAuxCols; FRI_READ_SUBBLOCKS], } const WL_WIDTH: usize = WorkloadCols::::width(); -const_assert_eq!(WL_WIDTH, 27); #[repr(C)] #[derive(Debug, AlignedBorrow)] @@ -99,7 +102,7 @@ struct Instruction2Cols { alpha_aux: MemoryReadAuxCols, result_ptr: T, - result_aux: MemoryWriteAuxCols, + result_aux: [MemoryWriteAuxCols; FRI_WRITE_SUBBLOCKS], hint_id_ptr: T, @@ -111,7 +114,6 @@ struct Instruction2Cols { write_a_x_is_first: T, } const INS_2_WIDTH: usize = Instruction2Cols::::width(); -const_assert_eq!(INS_2_WIDTH, 26); const_assert_eq!( offset_of!(WorkloadCols, prefix) + offset_of!(PrefixCols, general), offset_of!(Instruction2Cols, general) @@ -126,7 +128,6 @@ const_assert_eq!( ); pub const OVERALL_WIDTH: usize = const_max(const_max(WL_WIDTH, INS_1_WIDTH), INS_2_WIDTH); -const_assert_eq!(OVERALL_WIDTH, 27); /// Every row starts with these columns. #[repr(C)] @@ -312,14 +313,35 @@ impl FriReducedOpeningAir { ) .eval(builder, local_data.write_a * multiplicity); // read b + let first_b_chunk: [AB::Var; arch::CONST_BLOCK_SIZE] = + core::array::from_fn(|i| local.b[i]); self.memory_bridge .read( MemoryAddress::new(native_as.clone(), next.data.b_ptr), - local.b, + first_b_chunk, start_timestamp + ptr_reads + AB::Expr::ONE, - &local.b_aux, + &local.b_aux[0], ) - .eval(builder, multiplicity); + .eval(builder, multiplicity.clone()); + for sub in 1..FRI_READ_SUBBLOCKS { + let chunk: [AB::Var; arch::CONST_BLOCK_SIZE] = core::array::from_fn(|i| { + local.b[sub * arch::CONST_BLOCK_SIZE + i] + }); + self.memory_bridge + .read( + MemoryAddress::new( + native_as.clone(), + next.data.b_ptr + + AB::Expr::from_canonical_usize(sub * arch::CONST_BLOCK_SIZE), + ), + chunk, + start_timestamp + + ptr_reads + + AB::Expr::from_canonical_usize(sub + 1), + &local.b_aux[sub], + ) + .eval(builder, multiplicity.clone()); + } { let mut when_transition = builder.when_transition(); let mut builder = when_transition.when(local.prefix.general.is_workload_row); @@ -489,11 +511,29 @@ impl FriReducedOpeningAir { self.memory_bridge .write( MemoryAddress::new(native_as.clone(), next.result_ptr), - local_data.result, - write_timestamp, - &next.result_aux, + core::array::from_fn(|i| local_data.result[i]), + write_timestamp.clone(), + &next.result_aux[0], ) .eval(builder, multiplicity.clone()); + for sub in 1..FRI_WRITE_SUBBLOCKS { + let chunk: [AB::Var; arch::CONST_BLOCK_SIZE] = core::array::from_fn(|i| { + local_data.result[sub * arch::CONST_BLOCK_SIZE + i] + }); + self.memory_bridge + .write( + MemoryAddress::new( + native_as.clone(), + next.result_ptr + + AB::Expr::from_canonical_usize(sub * arch::CONST_BLOCK_SIZE), + ), + chunk, + write_timestamp.clone() + + AB::Expr::from_canonical_usize(sub), + &next.result_aux[sub], + ) + .eval(builder, multiplicity.clone()); + } } fn eval_instruction2_row( @@ -600,7 +640,7 @@ pub struct FriReducedOpeningCommonRecord { pub alpha_aux: MemoryReadAuxRecord, pub result_ptr: F, - pub result_aux: MemoryWriteAuxRecord, + pub result_aux: [MemoryWriteAuxRecord; FRI_WRITE_SUBBLOCKS], pub hint_id_ptr: F, @@ -619,7 +659,7 @@ pub struct FriReducedOpeningWorkloadRowRecord { // b can be computed from a, alpha, result, and previous result: // b = result + a - prev_result * alpha pub result: [F; EXT_DEG], - pub b_aux: MemoryReadAuxRecord, + pub b_aux: [MemoryReadAuxRecord; FRI_READ_SUBBLOCKS], } // NOTE: Order for fields is important here to prevent overwriting. @@ -855,11 +895,16 @@ where ) }; let b_ptr_i = record.common.b_ptr + (EXT_DEG * i) as u32; - let b = tracing_read_native::( - state.memory, - b_ptr_i, - &mut workload_row.b_aux.prev_timestamp, - ); + let mut b = [F::ZERO; EXT_DEG]; + for sub in 0..FRI_READ_SUBBLOCKS { + let chunk = tracing_read_native::( + state.memory, + b_ptr_i + (sub * arch::CONST_BLOCK_SIZE) as u32, + &mut workload_row.b_aux[sub].prev_timestamp, + ); + b[sub * arch::CONST_BLOCK_SIZE..][..arch::CONST_BLOCK_SIZE] + .copy_from_slice(&chunk); + } as_and_bs.push((a, b)); } @@ -878,13 +923,17 @@ where } let result_ptr = e.as_canonical_u32(); - tracing_write_native( - state.memory, - result_ptr, - result, - &mut record.common.result_aux.prev_timestamp, - &mut record.common.result_aux.prev_data, - ); + for sub in 0..FRI_WRITE_SUBBLOCKS { + let chunk: [F; arch::CONST_BLOCK_SIZE] = + core::array::from_fn(|i| result[sub * arch::CONST_BLOCK_SIZE + i]); + tracing_write_native( + state.memory, + result_ptr + (sub * arch::CONST_BLOCK_SIZE) as u32, + chunk, + &mut record.common.result_aux[sub].prev_timestamp, + &mut record.common.result_aux[sub].prev_data, + ); + } record.common.result_ptr = e; *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -962,13 +1011,15 @@ impl TraceFiller for FriReducedOpeningFiller { cols.hint_id_ptr = record.common.hint_id_ptr; - cols.result_aux - .set_prev_data(record.common.result_aux.prev_data); - mem_helper.fill( - record.common.result_aux.prev_timestamp, - timestamp + 5 + 2 * length as u32, - cols.result_aux.as_mut(), - ); + for sub in 0..FRI_WRITE_SUBBLOCKS { + cols.result_aux[sub] + .set_prev_data(record.common.result_aux[sub].prev_data); + mem_helper.fill( + record.common.result_aux[sub].prev_timestamp, + timestamp + 5 + 2 * length as u32 + sub as u32, + cols.result_aux[sub].as_mut(), + ); + } cols.result_ptr = record.common.result_ptr; mem_helper.fill( @@ -1058,11 +1109,13 @@ impl TraceFiller for FriReducedOpeningFiller { let timestamp = timestamp + ((length - i) * 2) as u32; // fill in reverse order - mem_helper.fill( - workload_row.b_aux.prev_timestamp, - timestamp + 4, - cols.b_aux.as_mut(), - ); + for sub in 0..FRI_READ_SUBBLOCKS { + mem_helper.fill( + workload_row.b_aux[sub].prev_timestamp, + timestamp + 4 + sub as u32, + cols.b_aux[sub].as_mut(), + ); + } // We temporarily store the result here // the correct value of b is computed during the serial pass below diff --git a/extensions/rv32-adapters/cuda/include/rv32-adapters/heap.cuh b/extensions/rv32-adapters/cuda/include/rv32-adapters/heap.cuh index 98ed20857e..6d8d5ea4fc 100644 --- a/extensions/rv32-adapters/cuda/include/rv32-adapters/heap.cuh +++ b/extensions/rv32-adapters/cuda/include/rv32-adapters/heap.cuh @@ -4,12 +4,149 @@ using namespace riscv; -// Simple heap adapter - just type aliases for vec_heap with BLOCKS_PER_READ=1, BLOCKS_PER_WRITE=1 -template -using Rv32HeapAdapterCols = Rv32VecHeapAdapterCols; +// Heap adapter with explicit READ_BLOCKS and WRITE_BLOCKS parameters. +// The adapter reads READ_SIZE bytes total using READ_BLOCKS blocks of (READ_SIZE/READ_BLOCKS) bytes each. +// Similarly for writes. +// For 256-bit (32-byte) operations with 4-byte block size: READ_SIZE=32, READ_BLOCKS=8 -template -using Rv32HeapAdapterRecord = Rv32VecHeapAdapterRecord; +// Block size used for memory bus interactions (must match CONST_BLOCK_SIZE in Rust) +constexpr size_t HEAP_ADAPTER_BLOCK_SIZE = 4; -template -using Rv32HeapAdapterExecutor = Rv32VecHeapAdapter; \ No newline at end of file +template +struct Rv32HeapAdapterCols { + ExecutionState from_state; + + T rs_ptr[NUM_READS]; + T rd_ptr; + + T rs_val[NUM_READS][RV32_REGISTER_NUM_LIMBS]; + T rd_val[RV32_REGISTER_NUM_LIMBS]; + + MemoryReadAuxCols rs_read_aux[NUM_READS]; + MemoryReadAuxCols rd_read_aux; + + // READ_BLOCKS reads per pointer (each of HEAP_ADAPTER_BLOCK_SIZE bytes) + MemoryReadAuxCols reads_aux[NUM_READS][READ_BLOCKS]; + // WRITE_BLOCKS writes (each of HEAP_ADAPTER_BLOCK_SIZE bytes) + MemoryWriteAuxCols writes_aux[WRITE_BLOCKS]; +}; + +template +struct Rv32HeapAdapterRecord { + uint32_t from_pc; + uint32_t from_timestamp; + + uint32_t rs_ptrs[NUM_READS]; + uint32_t rd_ptr; + + uint32_t rs_vals[NUM_READS]; + uint32_t rd_val; + + MemoryReadAuxRecord rs_read_aux[NUM_READS]; + MemoryReadAuxRecord rd_read_aux; + + MemoryReadAuxRecord reads_aux[NUM_READS][READ_BLOCKS]; + MemoryWriteAuxRecord writes_aux[WRITE_BLOCKS]; +}; + +template +struct Rv32HeapAdapterExecutor { + size_t pointer_max_bits; + BitwiseOperationLookup bitwise_lookup; + MemoryAuxColsFactory mem_helper; + + static constexpr size_t RV32_REGISTER_TOTAL_BITS = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS; + static constexpr size_t MSL_SHIFT = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + + __device__ Rv32HeapAdapterExecutor( + size_t pointer_max_bits, + VariableRangeChecker range_checker, + BitwiseOperationLookup bitwise_lookup, + uint32_t timestamp_max_bits + ) + : pointer_max_bits(pointer_max_bits), bitwise_lookup(bitwise_lookup), + mem_helper(range_checker, timestamp_max_bits) {} + + template + using Cols = Rv32HeapAdapterCols; + + __device__ void fill_trace_row( + RowSlice row, + Rv32HeapAdapterRecord record + ) { + const size_t limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits; + + if (NUM_READS == 2) { + bitwise_lookup.add_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits + ); + bitwise_lookup.add_range( + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits + ); + } else if (NUM_READS == 1) { + bitwise_lookup.add_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits + ); + } else { + assert(false); + } + + uint32_t timestamp = + record.from_timestamp + NUM_READS + 1 + NUM_READS * READ_BLOCKS + WRITE_BLOCKS; + + for (int i = WRITE_BLOCKS - 1; i >= 0; i--) { + timestamp--; + COL_WRITE_ARRAY(row, Cols, writes_aux[i].prev_data, record.writes_aux[i].prev_data); + mem_helper.fill( + row.slice_from(COL_INDEX(Cols, writes_aux[i])), + record.writes_aux[i].prev_timestamp, + timestamp + ); + } + + for (int i = NUM_READS - 1; i >= 0; i--) { + for (int j = READ_BLOCKS - 1; j >= 0; j--) { + timestamp--; + mem_helper.fill( + row.slice_from(COL_INDEX(Cols, reads_aux[i][j])), + record.reads_aux[i][j].prev_timestamp, + timestamp + ); + } + } + + timestamp--; + mem_helper.fill( + row.slice_from(COL_INDEX(Cols, rd_read_aux)), + record.rd_read_aux.prev_timestamp, + timestamp + ); + + for (int i = NUM_READS - 1; i >= 0; i--) { + timestamp--; + mem_helper.fill( + row.slice_from(COL_INDEX(Cols, rs_read_aux[i])), + record.rs_read_aux[i].prev_timestamp, + timestamp + ); + } + + COL_WRITE_ARRAY(row, Cols, rd_val, (uint8_t *)&record.rd_val); + + for (int i = NUM_READS - 1; i >= 0; i--) { + COL_WRITE_ARRAY(row, Cols, rs_val[i], (uint8_t *)&record.rs_vals[i]); + } + + COL_WRITE_VALUE(row, Cols, rd_ptr, record.rd_ptr); + + for (int i = NUM_READS - 1; i >= 0; i--) { + COL_WRITE_VALUE(row, Cols, rs_ptr[i], record.rs_ptrs[i]); + } + + COL_WRITE_VALUE(row, Cols, from_state.timestamp, record.from_timestamp); + COL_WRITE_VALUE(row, Cols, from_state.pc, record.from_pc); + } +}; \ No newline at end of file diff --git a/extensions/rv32-adapters/cuda/include/rv32-adapters/heap_branch.cuh b/extensions/rv32-adapters/cuda/include/rv32-adapters/heap_branch.cuh index 845eac43e0..f7b3a59292 100644 --- a/extensions/rv32-adapters/cuda/include/rv32-adapters/heap_branch.cuh +++ b/extensions/rv32-adapters/cuda/include/rv32-adapters/heap_branch.cuh @@ -7,17 +7,20 @@ using namespace riscv; -template struct Rv32HeapBranchAdapterCols { +// Heap branch adapter with explicit READ_BLOCKS parameter. +// For 256-bit (32-byte) operations with 4-byte block size: READ_SIZE=32, READ_BLOCKS=8 +template struct Rv32HeapBranchAdapterCols { ExecutionState from_state; T rs_ptr[NUM_READS]; T rs_val[NUM_READS][RV32_REGISTER_NUM_LIMBS]; MemoryReadAuxCols rs_read_aux[NUM_READS]; - MemoryReadAuxCols heap_read_aux[NUM_READS]; + // READ_BLOCKS reads per pointer + MemoryReadAuxCols heap_read_aux[NUM_READS][READ_BLOCKS]; }; -template struct Rv32HeapBranchAdapterRecord { +template struct Rv32HeapBranchAdapterRecord { uint32_t from_pc; uint32_t from_timestamp; @@ -25,10 +28,10 @@ template struct Rv32HeapBranchAdapterRecord { uint32_t rs_vals[NUM_READS]; MemoryReadAuxRecord rs_read_aux[NUM_READS]; - MemoryReadAuxRecord heap_read_aux[NUM_READS]; + MemoryReadAuxRecord heap_read_aux[NUM_READS][READ_BLOCKS]; }; -template struct Rv32HeapBranchAdapter { +template struct Rv32HeapBranchAdapter { size_t pointer_max_bits; BitwiseOperationLookup bitwise_lookup; MemoryAuxColsFactory mem_helper; @@ -45,9 +48,9 @@ template struct Rv32HeapBranchAdapter { : pointer_max_bits(pointer_max_bits), bitwise_lookup(bitwise_lookup), mem_helper(range_checker, timestamp_max_bits) {} - template using Cols = Rv32HeapBranchAdapterCols; + template using Cols = Rv32HeapBranchAdapterCols; - __device__ void fill_trace_row(RowSlice row, Rv32HeapBranchAdapterRecord record) { + __device__ void fill_trace_row(RowSlice row, Rv32HeapBranchAdapterRecord record) { const size_t limb_shift_bits = RV32_REGISTER_TOTAL_BITS - pointer_max_bits; bitwise_lookup.add_range( @@ -55,12 +58,15 @@ template struct Rv32HeapBranchAdapter { NUM_READS > 1 ? (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits : 0 ); + // Fill heap read aux in reverse order for (int i = NUM_READS - 1; i >= 0; i--) { - mem_helper.fill( - row.slice_from(COL_INDEX(Cols, heap_read_aux[i])), - record.heap_read_aux[i].prev_timestamp, - record.from_timestamp + (i + NUM_READS) - ); + for (int j = READ_BLOCKS - 1; j >= 0; j--) { + mem_helper.fill( + row.slice_from(COL_INDEX(Cols, heap_read_aux[i][j])), + record.heap_read_aux[i][j].prev_timestamp, + record.from_timestamp + NUM_READS + i * READ_BLOCKS + j + ); + } } for (int i = NUM_READS - 1; i >= 0; i--) { diff --git a/extensions/rv32-adapters/src/heap.rs b/extensions/rv32-adapters/src/heap.rs index 10409d97e9..58a980882f 100644 --- a/extensions/rv32-adapters/src/heap.rs +++ b/extensions/rv32-adapters/src/heap.rs @@ -1,41 +1,90 @@ -use std::borrow::Borrow; +use std::{ + array::from_fn, + borrow::{Borrow, BorrowMut}, + iter::once, +}; +use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, BasicAdapterInterface, - ExecutionBridge, MinimalInstruction, VmAdapterAir, + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, + }, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - system::memory::{offline_checker::MemoryBridge, online::TracingMemory, MemoryAuxColsFactory}, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; +use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, - riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, }; +use openvm_rv32im_circuit::adapters::{abstract_compose, tracing_read, tracing_write}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, - p3_field::{Field, PrimeField32}, + p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use crate::{ - Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols, Rv32VecHeapAdapterExecutor, - Rv32VecHeapAdapterFiller, Rv32VecHeapAdapterRecord, -}; +use openvm_circuit::arch::CONST_BLOCK_SIZE; + +/// Fixed memory block size for all heap adapter memory bus interactions. +/// All reads and writes are sent in CONST_BLOCK_SIZE-byte chunks to avoid access adapters. +pub const HEAP_ADAPTER_BLOCK_SIZE: usize = CONST_BLOCK_SIZE; /// This adapter reads from NUM_READS <= 2 pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers /// (address space 1). /// * Reads are from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). /// * Writes are to the address in `rd`. +/// +/// Memory bus interactions are sent in 4-byte blocks to avoid needing access adapters. +/// READ_SIZE and WRITE_SIZE must be multiples of 4. +/// READ_BLOCKS must equal READ_SIZE / 4, WRITE_BLOCKS must equal WRITE_SIZE / 4. +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct Rv32HeapAdapterCols< + T, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + const READ_BLOCKS: usize, + const WRITE_BLOCKS: usize, +> { + pub from_state: ExecutionState, + + pub rs_ptr: [T; NUM_READS], + pub rd_ptr: T, + + pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS], + pub rd_val: [T; RV32_REGISTER_NUM_LIMBS], + + pub rs_read_aux: [MemoryReadAuxCols; NUM_READS], + pub rd_read_aux: MemoryReadAuxCols, + + /// Aux columns for heap reads: READ_BLOCKS aux cols per read + pub reads_aux: [[MemoryReadAuxCols; READ_BLOCKS]; NUM_READS], + /// Aux columns for heap writes: WRITE_BLOCKS aux cols + pub writes_aux: [MemoryWriteAuxCols; WRITE_BLOCKS], +} #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32HeapAdapterAir< const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, + const READ_BLOCKS: usize = { 8 }, + const WRITE_BLOCKS: usize = { 8 }, > { pub(super) execution_bridge: ExecutionBridge, pub(super) memory_bridge: MemoryBridge, @@ -44,11 +93,19 @@ pub struct Rv32HeapAdapterAir< address_bits: usize, } -impl BaseAir - for Rv32HeapAdapterAir +impl< + F: Field, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + const READ_BLOCKS: usize, + const WRITE_BLOCKS: usize, + > BaseAir + for Rv32HeapAdapterAir { fn width(&self) -> usize { - Rv32VecHeapAdapterCols::::width() + Rv32HeapAdapterCols::::width( + ) } } @@ -57,7 +114,10 @@ impl< const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > VmAdapterAir for Rv32HeapAdapterAir + const READ_BLOCKS: usize, + const WRITE_BLOCKS: usize, + > VmAdapterAir + for Rv32HeapAdapterAir { type Interface = BasicAdapterInterface< AB::Expr, @@ -74,40 +134,195 @@ impl< local: &[AB::Var], ctx: AdapterAirContext, ) { - let vec_heap_air: Rv32VecHeapAdapterAir = - Rv32VecHeapAdapterAir::new( - self.execution_bridge, - self.memory_bridge, - self.bus, - self.address_bits, - ); - vec_heap_air.eval(builder, local, ctx.into()); + debug_assert_eq!(READ_BLOCKS, READ_SIZE / HEAP_ADAPTER_BLOCK_SIZE); + debug_assert_eq!(WRITE_BLOCKS, WRITE_SIZE / HEAP_ADAPTER_BLOCK_SIZE); + + let cols: &Rv32HeapAdapterCols< + _, + NUM_READS, + READ_SIZE, + WRITE_SIZE, + READ_BLOCKS, + WRITE_BLOCKS, + > = local.borrow(); + let timestamp = cols.from_state.timestamp; + let mut timestamp_delta: usize = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + timestamp + AB::F::from_canonical_usize(timestamp_delta - 1) + }; + + // Read register values for rs, rd + for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux).chain(once(( + cols.rd_ptr, + cols.rd_val, + &cols.rd_read_aux, + ))) { + self.memory_bridge + .read( + MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr), + val, + timestamp_pp(), + aux, + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } + + // Range check the highest limbs of heap pointers + let need_range_check: Vec = cols + .rs_val + .iter() + .chain(std::iter::repeat_n(&cols.rd_val, 2)) + .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1]) + .collect(); + + let limb_shift = AB::F::from_canonical_usize( + 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits), + ); + + for pair in need_range_check.chunks_exact(2) { + self.bus + .send_range(pair[0] * limb_shift, pair[1] * limb_shift) + .eval(builder, ctx.instruction.is_valid.clone()); + } + + // Compose the u32 register values into field elements + let rd_val_f: AB::Expr = abstract_compose(cols.rd_val); + let rs_val_f: [AB::Expr; NUM_READS] = cols.rs_val.map(abstract_compose); + + let e = AB::F::from_canonical_u32(RV32_MEMORY_AS); + + // Reads from heap - send READ_BLOCKS reads of 4 bytes each + for (read_idx, (address, reads_aux)) in izip!(rs_val_f.clone(), &cols.reads_aux).enumerate() + { + let read_data = &ctx.reads[read_idx]; + for (block_idx, aux) in reads_aux.iter().enumerate() { + let block_start = block_idx * HEAP_ADAPTER_BLOCK_SIZE; + let block_data: [AB::Expr; HEAP_ADAPTER_BLOCK_SIZE] = + from_fn(|i| read_data[block_start + i].clone()); + self.memory_bridge + .read( + MemoryAddress::new( + e, + address.clone() + AB::Expr::from_canonical_usize(block_start), + ), + block_data, + timestamp_pp(), + aux, + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } + } + + // Writes to heap - send WRITE_BLOCKS writes of 4 bytes each + let write_data = &ctx.writes[0]; + for (block_idx, aux) in cols.writes_aux.iter().enumerate() { + let block_start = block_idx * HEAP_ADAPTER_BLOCK_SIZE; + let block_data: [AB::Expr; HEAP_ADAPTER_BLOCK_SIZE] = + from_fn(|i| write_data[block_start + i].clone()); + self.memory_bridge + .write( + MemoryAddress::new( + e, + rd_val_f.clone() + AB::Expr::from_canonical_usize(block_start), + ), + block_data, + timestamp_pp(), + aux, + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } + + self.execution_bridge + .execute_and_increment_or_set_pc( + ctx.instruction.opcode, + [ + cols.rd_ptr.into(), + cols.rs_ptr + .first() + .map(|&x| x.into()) + .unwrap_or(AB::Expr::ZERO), + cols.rs_ptr + .get(1) + .map(|&x| x.into()) + .unwrap_or(AB::Expr::ZERO), + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + e.into(), + ], + cols.from_state, + AB::F::from_canonical_usize(timestamp_delta), + (DEFAULT_PC_STEP, ctx.to_pc), + ) + .eval(builder, ctx.instruction.is_valid.clone()); } fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { - let cols: &Rv32VecHeapAdapterCols<_, NUM_READS, 1, 1, READ_SIZE, WRITE_SIZE> = - local.borrow(); + let cols: &Rv32HeapAdapterCols< + _, + NUM_READS, + READ_SIZE, + WRITE_SIZE, + READ_BLOCKS, + WRITE_BLOCKS, + > = local.borrow(); cols.from_state.pc } } +/// Record for heap adapter trace generation. +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32HeapAdapterRecord< + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + const READ_BLOCKS: usize, + const WRITE_BLOCKS: usize, +> { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs_ptrs: [u32; NUM_READS], + pub rd_ptr: u32, + + pub rs_vals: [u32; NUM_READS], + pub rd_val: u32, + + pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS], + pub rd_read_aux: MemoryReadAuxRecord, + + pub reads_aux: [[MemoryReadAuxRecord; READ_BLOCKS]; NUM_READS], + pub writes_aux: [MemoryWriteBytesAuxRecord; WRITE_BLOCKS], +} + #[derive(Clone, Copy)] pub struct Rv32HeapAdapterExecutor< const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, ->(Rv32VecHeapAdapterExecutor); + const READ_BLOCKS: usize = { 8 }, + const WRITE_BLOCKS: usize = { 8 }, +> { + pointer_max_bits: usize, +} -impl - Rv32HeapAdapterExecutor +impl< + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + const READ_BLOCKS: usize, + const WRITE_BLOCKS: usize, + > Rv32HeapAdapterExecutor { pub fn new(pointer_max_bits: usize) -> Self { assert!(NUM_READS <= 2); + assert_eq!(READ_BLOCKS, READ_SIZE / HEAP_ADAPTER_BLOCK_SIZE); + assert_eq!(WRITE_BLOCKS, WRITE_SIZE / HEAP_ADAPTER_BLOCK_SIZE); assert!( RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); - Rv32HeapAdapterExecutor(Rv32VecHeapAdapterExecutor::new(pointer_max_bits)) + Self { pointer_max_bits } } } @@ -115,38 +330,63 @@ pub struct Rv32HeapAdapterFiller< const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, ->(Rv32VecHeapAdapterFiller); + const READ_BLOCKS: usize = { 8 }, + const WRITE_BLOCKS: usize = { 8 }, +> { + pointer_max_bits: usize, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} -impl - Rv32HeapAdapterFiller +impl< + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + const READ_BLOCKS: usize, + const WRITE_BLOCKS: usize, + > Rv32HeapAdapterFiller { pub fn new( pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); + assert_eq!(READ_BLOCKS, READ_SIZE / HEAP_ADAPTER_BLOCK_SIZE); + assert_eq!(WRITE_BLOCKS, WRITE_SIZE / HEAP_ADAPTER_BLOCK_SIZE); assert!( RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); - Rv32HeapAdapterFiller(Rv32VecHeapAdapterFiller::new( + Self { pointer_max_bits, bitwise_lookup_chip, - )) + } } } -impl - AdapterTraceExecutor for Rv32HeapAdapterExecutor -where - F: PrimeField32, +impl< + F: PrimeField32, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + const READ_BLOCKS: usize, + const WRITE_BLOCKS: usize, + > AdapterTraceExecutor + for Rv32HeapAdapterExecutor { - const WIDTH: usize = - Rv32VecHeapAdapterCols::::width(); + const WIDTH: usize = Rv32HeapAdapterCols::< + F, + NUM_READS, + READ_SIZE, + WRITE_SIZE, + READ_BLOCKS, + WRITE_BLOCKS, + >::width(); type ReadData = [[u8; READ_SIZE]; NUM_READS]; type WriteData = [[u8; WRITE_SIZE]; 1]; - type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord; + type RecordMut<'a> = + &'a mut Rv32HeapAdapterRecord; + #[inline(always)] fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { record.from_pc = pc; record.from_timestamp = memory.timestamp; @@ -158,8 +398,48 @@ where instruction: &Instruction, record: &mut Self::RecordMut<'_>, ) -> Self::ReadData { - let read_data = AdapterTraceExecutor::::read(&self.0, memory, instruction, record); - read_data.map(|r| r[0]) + let &Instruction { a, b, c, d, e, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + // Read register values + record.rs_vals = from_fn(|i| { + record.rs_ptrs[i] = if i == 0 { b } else { c }.as_canonical_u32(); + u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs_ptrs[i], + &mut record.rs_read_aux[i].prev_timestamp, + )) + }); + + record.rd_ptr = a.as_canonical_u32(); + record.rd_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut record.rd_read_aux.prev_timestamp, + )); + + // Read memory values in 4-byte blocks + from_fn(|read_idx| { + debug_assert!( + (record.rs_vals[read_idx] as usize + READ_SIZE - 1) < (1 << self.pointer_max_bits) + ); + let mut result = [0u8; READ_SIZE]; + for block_idx in 0..READ_BLOCKS { + let block_start = block_idx * HEAP_ADAPTER_BLOCK_SIZE; + let block: [u8; HEAP_ADAPTER_BLOCK_SIZE] = tracing_read( + memory, + RV32_MEMORY_AS, + record.rs_vals[read_idx] + block_start as u32, + &mut record.reads_aux[read_idx][block_idx].prev_timestamp, + ); + result[block_start..block_start + HEAP_ADAPTER_BLOCK_SIZE].copy_from_slice(&block); + } + result + }) } fn write( @@ -169,17 +449,150 @@ where data: Self::WriteData, record: &mut Self::RecordMut<'_>, ) { - AdapterTraceExecutor::::write(&self.0, memory, instruction, data, record); + debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS); + debug_assert!(record.rd_val as usize + WRITE_SIZE - 1 < (1 << self.pointer_max_bits)); + + // Write in 4-byte blocks + for block_idx in 0..WRITE_BLOCKS { + let block_start = block_idx * HEAP_ADAPTER_BLOCK_SIZE; + let block: [u8; HEAP_ADAPTER_BLOCK_SIZE] = data[0] + [block_start..block_start + HEAP_ADAPTER_BLOCK_SIZE] + .try_into() + .unwrap(); + tracing_write( + memory, + RV32_MEMORY_AS, + record.rd_val + block_start as u32, + block, + &mut record.writes_aux[block_idx].prev_timestamp, + &mut record.writes_aux[block_idx].prev_data, + ); + } } } -impl - AdapterTraceFiller for Rv32HeapAdapterFiller +impl< + F: PrimeField32, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + const READ_BLOCKS: usize, + const WRITE_BLOCKS: usize, + > AdapterTraceFiller + for Rv32HeapAdapterFiller { - const WIDTH: usize = - Rv32VecHeapAdapterCols::::width(); + const WIDTH: usize = Rv32HeapAdapterCols::< + F, + NUM_READS, + READ_SIZE, + WRITE_SIZE, + READ_BLOCKS, + WRITE_BLOCKS, + >::width(); + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32HeapAdapterRecord< + NUM_READS, + READ_SIZE, + WRITE_SIZE, + READ_BLOCKS, + WRITE_BLOCKS, + > = unsafe { get_record_from_slice(&mut adapter_row, ()) }; + + let cols: &mut Rv32HeapAdapterCols< + F, + NUM_READS, + READ_SIZE, + WRITE_SIZE, + READ_BLOCKS, + WRITE_BLOCKS, + > = adapter_row.borrow_mut(); + + // Range checks + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + + if NUM_READS > 1 { + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits, + ); + self.bitwise_lookup_chip.request_range( + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + ); + } else { + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + ); + } + + let timestamp_delta = NUM_READS + 1 + NUM_READS * READ_BLOCKS + WRITE_BLOCKS; + let mut timestamp = record.from_timestamp + timestamp_delta as u32; + let mut timestamp_mm = || { + timestamp -= 1; + timestamp + }; + + // Fill in reverse order to avoid overwriting records + for (write_aux, cols_write) in record + .writes_aux + .iter() + .rev() + .zip(cols.writes_aux.iter_mut().rev()) + { + cols_write.set_prev_data(write_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + write_aux.prev_timestamp, + timestamp_mm(), + cols_write.as_mut(), + ); + } + + for (reads, cols_reads) in record.reads_aux.iter().zip(cols.reads_aux.iter_mut()).rev() { + for (read, cols_read) in reads.iter().zip(cols_reads.iter_mut()).rev() { + mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut()); + } + } + + mem_helper.fill( + record.rd_read_aux.prev_timestamp, + timestamp_mm(), + cols.rd_read_aux.as_mut(), + ); + + for (aux, cols_aux) in record + .rs_read_aux + .iter() + .zip(cols.rs_read_aux.iter_mut()) + .rev() + { + mem_helper.fill(aux.prev_timestamp, timestamp_mm(), cols_aux.as_mut()); + } + + cols.rd_val = record.rd_val.to_le_bytes().map(F::from_canonical_u8); + for (cols_val, val) in cols + .rs_val + .iter_mut() + .rev() + .zip(record.rs_vals.iter().rev()) + { + *cols_val = val.to_le_bytes().map(F::from_canonical_u8); + } + + cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + for (cols_ptr, ptr) in cols + .rs_ptr + .iter_mut() + .rev() + .zip(record.rs_ptrs.iter().rev()) + { + *cols_ptr = F::from_canonical_u32(*ptr); + } - fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, adapter_row: &mut [F]) { - AdapterTraceFiller::::fill_trace_row(&self.0, mem_helper, adapter_row); + cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32-adapters/src/heap_branch.rs b/extensions/rv32-adapters/src/heap_branch.rs index e87b4fd973..e29f555e29 100644 --- a/extensions/rv32-adapters/src/heap_branch.rs +++ b/extensions/rv32-adapters/src/heap_branch.rs @@ -32,40 +32,64 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, }; -/// This adapter reads from NUM_READS <= 2 pointers. +use openvm_circuit::arch::CONST_BLOCK_SIZE; + +/// Fixed memory block size for all heap branch adapter memory bus interactions. +/// All reads are sent in CONST_BLOCK_SIZE-byte chunks to avoid access adapters. +pub const HEAP_BRANCH_ADAPTER_BLOCK_SIZE: usize = CONST_BLOCK_SIZE; + +/// This adapter reads from NUM_READS <= 2 pointers (no writes). /// * The data is read from the heap (address space 2), and the pointers are read from registers /// (address space 1). /// * Reads are from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). +/// +/// Memory bus interactions are sent in 4-byte blocks to avoid needing access adapters. +/// READ_SIZE must be a multiple of 4. +/// READ_BLOCKS must equal READ_SIZE / 4. #[repr(C)] #[derive(AlignedBorrow)] -pub struct Rv32HeapBranchAdapterCols { +pub struct Rv32HeapBranchAdapterCols< + T, + const NUM_READS: usize, + const READ_SIZE: usize, + const READ_BLOCKS: usize, +> { pub from_state: ExecutionState, pub rs_ptr: [T; NUM_READS], pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS], pub rs_read_aux: [MemoryReadAuxCols; NUM_READS], - pub heap_read_aux: [MemoryReadAuxCols; NUM_READS], + /// READ_BLOCKS aux columns per read (one per 4-byte block) + pub heap_read_aux: [[MemoryReadAuxCols; READ_BLOCKS]; NUM_READS], } #[derive(Clone, Copy, Debug, derive_new::new)] -pub struct Rv32HeapBranchAdapterAir { +pub struct Rv32HeapBranchAdapterAir< + const NUM_READS: usize, + const READ_SIZE: usize, + const READ_BLOCKS: usize = { 8 }, +> { pub(super) execution_bridge: ExecutionBridge, pub(super) memory_bridge: MemoryBridge, pub bus: BitwiseOperationLookupBus, address_bits: usize, } -impl BaseAir - for Rv32HeapBranchAdapterAir +impl BaseAir + for Rv32HeapBranchAdapterAir { fn width(&self) -> usize { - Rv32HeapBranchAdapterCols::::width() + Rv32HeapBranchAdapterCols::::width() } } -impl VmAdapterAir - for Rv32HeapBranchAdapterAir +impl< + AB: InteractionBuilder, + const NUM_READS: usize, + const READ_SIZE: usize, + const READ_BLOCKS: usize, + > VmAdapterAir for Rv32HeapBranchAdapterAir { type Interface = BasicAdapterInterface, NUM_READS, 0, READ_SIZE, 0>; @@ -76,7 +100,9 @@ impl VmA local: &[AB::Var], ctx: AdapterAirContext, ) { - let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow(); + debug_assert_eq!(READ_BLOCKS, READ_SIZE / HEAP_BRANCH_ADAPTER_BLOCK_SIZE); + + let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE, READ_BLOCKS> = local.borrow(); let timestamp = cols.from_state.timestamp; let mut timestamp_delta: usize = 0; let mut timestamp_pp = || { @@ -87,50 +113,59 @@ impl VmA let d = AB::F::from_canonical_u32(RV32_REGISTER_AS); let e = AB::F::from_canonical_u32(RV32_MEMORY_AS); + // Read register values (pointer addresses) for (ptr, data, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) { self.memory_bridge .read(MemoryAddress::new(d, ptr), data, timestamp_pp(), aux) .eval(builder, ctx.instruction.is_valid.clone()); } - // We constrain the highest limbs of heap pointers to be less than 2^(addr_bits - - // (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))). This ensures that no overflow - // occurs when computing memory pointers. Since the number of cells accessed with each - // address will be small enough, and combined with the memory argument, it ensures - // that all the cells accessed in the memory are less than 2^addr_bits. + // Range check highest limbs of heap pointers let need_range_check: Vec = cols .rs_val .iter() .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1]) .collect(); - // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain - // the correct amount of bits let limb_shift = AB::F::from_canonical_usize( 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits), ); - // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS - // thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that - // limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))) for pair in need_range_check.chunks(2) { self.bus .send_range( pair[0] * limb_shift, - pair.get(1).map(|x| (*x).into()).unwrap_or(AB::Expr::ZERO) * limb_shift, // in case NUM_READS is odd + pair.get(1).map(|x| (*x).into()).unwrap_or(AB::Expr::ZERO) * limb_shift, ) .eval(builder, ctx.instruction.is_valid.clone()); } + // Compute heap pointers let heap_ptr = cols.rs_val.map(|r| { r.iter().rev().fold(AB::Expr::ZERO, |acc, limb| { acc * AB::F::from_canonical_u32(1 << RV32_CELL_BITS) + (*limb) }) }); - for (ptr, data, aux) in izip!(heap_ptr, ctx.reads, &cols.heap_read_aux) { - self.memory_bridge - .read(MemoryAddress::new(e, ptr), data, timestamp_pp(), aux) - .eval(builder, ctx.instruction.is_valid.clone()); + + // Send READ_BLOCKS reads of 4 bytes each per pointer + for (read_idx, (base_ptr, block_auxs)) in izip!(heap_ptr, &cols.heap_read_aux).enumerate() { + let read_data = &ctx.reads[read_idx]; + for (block_idx, aux) in block_auxs.iter().enumerate() { + let block_start = block_idx * HEAP_BRANCH_ADAPTER_BLOCK_SIZE; + let block_data: [AB::Expr; HEAP_BRANCH_ADAPTER_BLOCK_SIZE] = + from_fn(|i| read_data[block_start + i].clone()); + self.memory_bridge + .read( + MemoryAddress::new( + e, + base_ptr.clone() + AB::Expr::from_canonical_usize(block_start), + ), + block_data, + timestamp_pp(), + aux, + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } } self.execution_bridge @@ -157,14 +192,14 @@ impl VmA } fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { - let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow(); + let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE, READ_BLOCKS> = local.borrow(); cols.from_state.pc } } #[repr(C)] #[derive(AlignedBytesBorrow, Debug)] -pub struct Rv32HeapBranchAdapterRecord { +pub struct Rv32HeapBranchAdapterRecord { pub from_pc: u32, pub from_timestamp: u32, @@ -172,25 +207,35 @@ pub struct Rv32HeapBranchAdapterRecord { pub rs_vals: [u32; NUM_READS], pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS], - pub heap_read_aux: [MemoryReadAuxRecord; NUM_READS], + /// READ_BLOCKS aux records per read + pub heap_read_aux: [[MemoryReadAuxRecord; READ_BLOCKS]; NUM_READS], } #[derive(Clone, Copy)] -pub struct Rv32HeapBranchAdapterExecutor { +pub struct Rv32HeapBranchAdapterExecutor< + const NUM_READS: usize, + const READ_SIZE: usize, + const READ_BLOCKS: usize = { 8 }, +> { pub pointer_max_bits: usize, } #[derive(derive_new::new)] -pub struct Rv32HeapBranchAdapterFiller { +pub struct Rv32HeapBranchAdapterFiller< + const NUM_READS: usize, + const READ_SIZE: usize, + const READ_BLOCKS: usize = { 8 }, +> { pub pointer_max_bits: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl - Rv32HeapBranchAdapterExecutor +impl + Rv32HeapBranchAdapterExecutor { pub fn new(pointer_max_bits: usize) -> Self { assert!(NUM_READS <= 2); + assert_eq!(READ_BLOCKS, READ_SIZE / HEAP_BRANCH_ADAPTER_BLOCK_SIZE); assert!( RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" @@ -199,13 +244,13 @@ impl } } -impl AdapterTraceExecutor - for Rv32HeapBranchAdapterExecutor +impl + AdapterTraceExecutor for Rv32HeapBranchAdapterExecutor { - const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); + const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); type ReadData = [[u8; READ_SIZE]; NUM_READS]; type WriteData = (); - type RecordMut<'a> = &'a mut Rv32HeapBranchAdapterRecord; + type RecordMut<'a> = &'a mut Rv32HeapBranchAdapterRecord; fn start(pc: u32, memory: &TracingMemory, adapter_record: &mut Self::RecordMut<'_>) { adapter_record.from_pc = pc; @@ -234,17 +279,24 @@ impl AdapterTra )) }); - // Read memory values - from_fn(|i| { + // Read memory values in 4-byte blocks + from_fn(|read_idx| { debug_assert!( - record.rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits) + record.rs_vals[read_idx] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits) ); - tracing_read( - memory, - RV32_MEMORY_AS, - record.rs_vals[i], - &mut record.heap_read_aux[i].prev_timestamp, - ) + let mut result = [0u8; READ_SIZE]; + for block_idx in 0..READ_BLOCKS { + let block_start = block_idx * HEAP_BRANCH_ADAPTER_BLOCK_SIZE; + let block: [u8; HEAP_BRANCH_ADAPTER_BLOCK_SIZE] = tracing_read( + memory, + RV32_MEMORY_AS, + record.rs_vals[read_idx] + block_start as u32, + &mut record.heap_read_aux[read_idx][block_idx].prev_timestamp, + ); + result[block_start..block_start + HEAP_BRANCH_ADAPTER_BLOCK_SIZE] + .copy_from_slice(&block); + } + result }) } @@ -259,22 +311,18 @@ impl AdapterTra } } -impl AdapterTraceFiller - for Rv32HeapBranchAdapterFiller +impl + AdapterTraceFiller for Rv32HeapBranchAdapterFiller { - const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); + const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { - // SAFETY: - // - caller ensures `adapter_row` contains a valid record representation that was previously - // written by the executor - let record: &Rv32HeapBranchAdapterRecord = + let record: &Rv32HeapBranchAdapterRecord = unsafe { get_record_from_slice(&mut adapter_row, ()) }; - let cols: &mut Rv32HeapBranchAdapterCols = + let cols: &mut Rv32HeapBranchAdapterCols = adapter_row.borrow_mut(); - // Range checks: - // **NOTE**: Must do the range checks before overwriting the records + // Range checks debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); @@ -287,13 +335,19 @@ impl AdapterTra }, ); - // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records - for i in (0..NUM_READS).rev() { - mem_helper.fill( - record.heap_read_aux[i].prev_timestamp, - record.from_timestamp + (i + NUM_READS) as u32, - cols.heap_read_aux[i].as_mut(), - ); + // Fill in reverse order to avoid overwriting records + // Timestamp layout: rs_read[0], rs_read[1], ..., then heap_read[0][0..READ_BLOCKS], heap_read[1][0..READ_BLOCKS], ... + let ts_offset = NUM_READS as u32; // Start after register reads + for read_idx in (0..NUM_READS).rev() { + for block_idx in (0..READ_BLOCKS).rev() { + let block_ts = + record.from_timestamp + ts_offset + (read_idx * READ_BLOCKS + block_idx) as u32; + mem_helper.fill( + record.heap_read_aux[read_idx][block_idx].prev_timestamp, + block_ts, + cols.heap_read_aux[read_idx][block_idx].as_mut(), + ); + } } for i in (0..NUM_READS).rev() { diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs index 2fe1cb26c0..f29421a711 100644 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ b/extensions/sha256/circuit/src/sha256_chip/air.rs @@ -1,7 +1,7 @@ use std::{array, borrow::Borrow, cmp::min}; use openvm_circuit::{ - arch::ExecutionBridge, + arch::{ExecutionBridge, CONST_BLOCK_SIZE}, system::{ memory::{offline_checker::MemoryBridge, MemoryAddress}, SystemPort, @@ -29,7 +29,8 @@ use openvm_stark_backend::{ use super::{ Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH, - SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE, + SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_NUM_READ_ROWS, SHA256_READ_SIZE, + SHA256_READ_SUBBLOCKS, SHA256_WRITE_SUBBLOCKS, }; /// Sha256VmAir does all constraints related to message padding and @@ -461,8 +462,10 @@ impl Sha256VmAir { local_cols.control.read_ptr + read_ptr_delta, ); - // Timestamp should increment by 1 for the first 4 rows and stay the same otherwise - let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE; + // Timestamp should increment by [SHA256_READ_SUBBLOCKS] for the first 4 rows and stay the + // same otherwise + let timestamp_delta = local_cols.inner.flags.is_first_4_rows + * AB::Expr::from_canonical_usize(SHA256_READ_SUBBLOCKS); builder .when_transition() .when(not::(is_last_row.clone())) @@ -483,17 +486,22 @@ impl Sha256VmAir { [i % (SHA256_WORD_U16S * 2)] }); - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - local_cols.control.read_ptr, - ), - message, - local_cols.control.cur_timestamp, - &local_cols.read_aux, - ) - .eval(builder, local_cols.inner.flags.is_first_4_rows); + for sub in 0..SHA256_READ_SUBBLOCKS { + let chunk: [AB::Var; CONST_BLOCK_SIZE] = + array::from_fn(|i| message[sub * CONST_BLOCK_SIZE + i]); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + local_cols.control.read_ptr + + AB::Expr::from_canonical_usize(sub * CONST_BLOCK_SIZE), + ), + chunk, + local_cols.control.cur_timestamp + AB::Expr::from_canonical_usize(sub), + &local_cols.read_aux[sub], + ) + .eval(builder, local_cols.inner.flags.is_first_4_rows); + } } /// Implement the constraints for the last row of a message fn eval_last_row(&self, builder: &mut AB) { @@ -563,11 +571,13 @@ impl Sha256VmAir { ) .eval(builder, is_last_row.clone()); - // the number of reads that happened to read the entire message: we do 4 reads per block + // the number of reads that happened to read the entire message: we do + // (SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS) reads per block let time_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE) - * AB::Expr::from_canonical_usize(4); - // Every time we read the message we increment the read pointer by SHA256_READ_SIZE - let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); + * AB::Expr::from_canonical_usize(SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS); + // Every time we read a row we increment the read pointer by SHA256_READ_SIZE + let read_ptr_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE) + * AB::Expr::from_canonical_usize(SHA256_NUM_READ_ROWS * SHA256_READ_SIZE); let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| { // The limbs are written in big endian order to the memory so need to be reversed @@ -581,14 +591,22 @@ impl Sha256VmAir { // Note: revisit in the future to do 2 block writes of 16 cells instead of 1 block write of // 32 cells This could be beneficial as the output is often an input for // another hash - self.memory_bridge - .write( - MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), dst_ptr_val), - result, - timestamp_pp() + time_delta.clone(), - &local_cols.writes_aux, - ) - .eval(builder, is_last_row.clone()); + for chunk in 0..SHA256_WRITE_SUBBLOCKS { + let data: [AB::Var; CONST_BLOCK_SIZE] = + array::from_fn(|i| result[chunk * CONST_BLOCK_SIZE + i]); + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val.clone() + + AB::Expr::from_canonical_usize(chunk * CONST_BLOCK_SIZE), + ), + data, + timestamp_pp() + time_delta.clone(), + &local_cols.writes_aux[chunk], + ) + .eval(builder, is_last_row.clone()); + } self.execution_bridge .execute_and_increment_pc( diff --git a/extensions/sha256/circuit/src/sha256_chip/columns.rs b/extensions/sha256/circuit/src/sha256_chip/columns.rs index 38c13a0f73..bbff358f78 100644 --- a/extensions/sha256/circuit/src/sha256_chip/columns.rs +++ b/extensions/sha256/circuit/src/sha256_chip/columns.rs @@ -1,14 +1,16 @@ //! WARNING: the order of fields in the structs is important, do not change it use openvm_circuit::{ - arch::ExecutionState, + arch::{ExecutionState, CONST_BLOCK_SIZE}, system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, }; use openvm_circuit_primitives::AlignedBorrow; use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; use openvm_sha256_air::{Sha256DigestCols, Sha256RoundCols}; -use super::{SHA256_REGISTER_READS, SHA256_WRITE_SIZE}; +use super::{ + SHA256_READ_SUBBLOCKS, SHA256_REGISTER_READS, SHA256_WRITE_SIZE, SHA256_WRITE_SUBBLOCKS, +}; /// the first 16 rows of every SHA256 block will be of type Sha256VmRoundCols and the last row will /// be of type Sha256VmDigestCols @@ -17,7 +19,7 @@ use super::{SHA256_REGISTER_READS, SHA256_WRITE_SIZE}; pub struct Sha256VmRoundCols { pub control: Sha256VmControlCols, pub inner: Sha256RoundCols, - pub read_aux: MemoryReadAuxCols, + pub read_aux: [MemoryReadAuxCols; SHA256_READ_SUBBLOCKS], } #[repr(C)] @@ -36,7 +38,7 @@ pub struct Sha256VmDigestCols { pub src_ptr: [T; RV32_REGISTER_NUM_LIMBS], pub len_data: [T; RV32_REGISTER_NUM_LIMBS], pub register_reads_aux: [MemoryReadAuxCols; SHA256_REGISTER_READS], - pub writes_aux: MemoryWriteAuxCols, + pub writes_aux: [MemoryWriteAuxCols; SHA256_WRITE_SUBBLOCKS], } /// These are the columns that are used on both round and digest rows diff --git a/extensions/sha256/circuit/src/sha256_chip/execution.rs b/extensions/sha256/circuit/src/sha256_chip/execution.rs index 8fdc91f9d0..c29808f0be 100644 --- a/extensions/sha256/circuit/src/sha256_chip/execution.rs +++ b/extensions/sha256/circuit/src/sha256_chip/execution.rs @@ -12,7 +12,10 @@ use openvm_sha256_air::{get_sha256_num_blocks, SHA256_ROWS_PER_BLOCK}; use openvm_sha256_transpiler::Rv32Sha256Opcode; use openvm_stark_backend::p3_field::PrimeField32; -use super::{sha256_solve, Sha256VmExecutor, SHA256_NUM_READ_ROWS, SHA256_READ_SIZE}; +use super::{ + sha256_solve, Sha256VmExecutor, SHA256_NUM_READ_ROWS, SHA256_READ_SIZE, SHA256_READ_SUBBLOCKS, + SHA256_WRITE_SIZE, +}; #[derive(AlignedBytesBorrow, Clone)] #[repr(C)] @@ -128,10 +131,15 @@ unsafe fn execute_e12_impl( + RV32_MEMORY_AS, + dst_u32 + (i * CONST_BLOCK_SIZE) as u32, + &output[i * CONST_BLOCK_SIZE..(i + 1) * CONST_BLOCK_SIZE] + .try_into() + .unwrap(), + ); + } let pc = exec_state.pc(); exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP)); diff --git a/extensions/sha256/circuit/src/sha256_chip/mod.rs b/extensions/sha256/circuit/src/sha256_chip/mod.rs index eb027bd6d1..aa85ce77f9 100644 --- a/extensions/sha256/circuit/src/sha256_chip/mod.rs +++ b/extensions/sha256/circuit/src/sha256_chip/mod.rs @@ -33,6 +33,10 @@ const SHA256_REGISTER_READS: usize = 3; const SHA256_READ_SIZE: usize = 16; /// Number of cells to write in a single memory access const SHA256_WRITE_SIZE: usize = 32; +/// Number of subreads per row (split into CONST_BLOCK_SIZE chunks) +pub const SHA256_READ_SUBBLOCKS: usize = SHA256_READ_SIZE / CONST_BLOCK_SIZE; +/// Number of subwrites to emit for the digest (split into CONST_BLOCK_SIZE chunks) +pub const SHA256_WRITE_SUBBLOCKS: usize = SHA256_WRITE_SIZE / CONST_BLOCK_SIZE; /// Number of rv32 cells read in a SHA256 block pub const SHA256_BLOCK_CELLS: usize = SHA256_BLOCK_BITS / RV32_CELL_BITS; /// Number of rows we will do a read on for each SHA256 block diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs index 7fc5c7062c..18335f2995 100644 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ b/extensions/sha256/circuit/src/sha256_chip/trace.rs @@ -36,7 +36,10 @@ use super::{ SHA256VM_DIGEST_WIDTH, }; use crate::{ - sha256_chip::{PaddingFlags, SHA256_READ_SIZE, SHA256_REGISTER_READS, SHA256_WRITE_SIZE}, + sha256_chip::{ + PaddingFlags, CONST_BLOCK_SIZE, SHA256_READ_SIZE, SHA256_READ_SUBBLOCKS, + SHA256_REGISTER_READS, SHA256_WRITE_SIZE, SHA256_WRITE_SUBBLOCKS, + }, sha256_solve, Sha256VmControlCols, Sha256VmFiller, SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_BLOCK_CELLS, SHA256_MAX_MESSAGE_LEN, SHA256_NUM_READ_ROWS, }; @@ -68,7 +71,7 @@ pub struct Sha256VmRecordHeader { pub len: u32, pub register_reads_aux: [MemoryReadAuxRecord; SHA256_REGISTER_READS], - pub write_aux: MemoryWriteBytesAuxRecord, + pub write_aux: [MemoryWriteBytesAuxRecord; SHA256_WRITE_SUBBLOCKS], } pub struct Sha256VmRecordMut<'a> { @@ -81,7 +84,8 @@ pub struct Sha256VmRecordMut<'a> { /// Custom borrowing that splits the buffer into a fixed `Sha256VmRecord` header /// followed by a slice of `u8`'s of length `SHA256_BLOCK_CELLS * num_blocks` where `num_blocks` is /// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length -/// `SHA256_NUM_READ_ROWS * num_blocks`. Uses `align_to_mut()` to make sure the slice is properly +/// `SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS * num_blocks`. Uses `align_to_mut()` to make sure +/// the slice is properly /// aligned to `MemoryReadAuxRecord`. Has debug assertions that check the size and alignment of the /// slices. impl<'a> CustomBorrow<'a, Sha256VmRecordMut<'a>, Sha256VmRecordLayout> for [u8] { @@ -104,15 +108,16 @@ impl<'a> CustomBorrow<'a, Sha256VmRecordMut<'a>, Sha256VmRecordLayout> for [u8] // SAFETY: // - rest is a valid mutable slice from the previous split // - align_to_mut guarantees the middle slice is properly aligned for MemoryReadAuxRecord - // - The subslice operation [..num_blocks * SHA256_NUM_READ_ROWS] validates sufficient - // capacity + // - The subslice operation [..num_blocks * SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS] + // validates sufficient capacity // - Layout calculation ensures space for alignment padding plus required aux records let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; Sha256VmRecordMut { inner: header_buf.borrow_mut(), input, - read_aux: &mut read_aux_buf - [..(layout.metadata.num_blocks as usize) * SHA256_NUM_READ_ROWS], + read_aux: &mut read_aux_buf[..(layout.metadata.num_blocks as usize) + * SHA256_NUM_READ_ROWS + * SHA256_READ_SUBBLOCKS], } } @@ -134,6 +139,7 @@ impl SizedRecord for Sha256VmRecordMut<'_> { total_len = total_len.next_multiple_of(align_of::()); total_len += layout.metadata.num_blocks as usize * SHA256_NUM_READ_ROWS + * SHA256_READ_SUBBLOCKS * size_of::(); total_len } @@ -218,26 +224,36 @@ where // Reads happen on the first 4 rows of each block for row in 0..SHA256_NUM_READ_ROWS { let read_idx = block_idx * SHA256_NUM_READ_ROWS + row; - let row_input: [u8; SHA256_READ_SIZE] = tracing_read( - state.memory, - RV32_MEMORY_AS, - record.inner.src_ptr + (read_idx * SHA256_READ_SIZE) as u32, - &mut record.read_aux[read_idx].prev_timestamp, - ); + let mut row_input = [0u8; SHA256_READ_SIZE]; + for sub in 0..SHA256_READ_SUBBLOCKS { + let chunk: [u8; CONST_BLOCK_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + + (read_idx * SHA256_READ_SIZE + sub * CONST_BLOCK_SIZE) as u32, + &mut record.read_aux[read_idx * SHA256_READ_SUBBLOCKS + sub].prev_timestamp, + ); + row_input[sub * CONST_BLOCK_SIZE..(sub + 1) * CONST_BLOCK_SIZE] + .copy_from_slice(&chunk); + } record.input[read_idx * SHA256_READ_SIZE..(read_idx + 1) * SHA256_READ_SIZE] .copy_from_slice(&row_input); } } let output = sha256_solve(&record.input[..len as usize]); - tracing_write( - state.memory, - RV32_MEMORY_AS, - record.inner.dst_ptr, - output, - &mut record.inner.write_aux.prev_timestamp, - &mut record.inner.write_aux.prev_data, - ); + for i in 0..SHA256_WRITE_SUBBLOCKS { + tracing_write( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (i * CONST_BLOCK_SIZE) as u32, + output[i * CONST_BLOCK_SIZE..(i + 1) * CONST_BLOCK_SIZE] + .try_into() + .unwrap(), + &mut record.inner.write_aux[i].prev_timestamp, + &mut record.inner.write_aux[i].prev_data, + ); + } *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); @@ -356,7 +372,8 @@ impl TraceFiller for Sha256VmFiller { } // Copy the read aux records and input to another place to safely fill in the trace // matrix without overwriting the record - let mut read_aux_records = Vec::with_capacity(SHA256_NUM_READ_ROWS * num_blocks); + let mut read_aux_records = + Vec::with_capacity(SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS * num_blocks); read_aux_records.extend_from_slice(record.read_aux); let vm_record = record.inner.clone(); @@ -381,8 +398,10 @@ impl TraceFiller for Sha256VmFiller { self.fill_block_trace::( block_slice, &vm_record, - &read_aux_records[block_idx * SHA256_NUM_READ_ROWS - ..(block_idx + 1) * SHA256_NUM_READ_ROWS], + &read_aux_records[block_idx + * SHA256_NUM_READ_ROWS + * SHA256_READ_SUBBLOCKS + ..(block_idx + 1) * SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS], &input[block_idx * SHA256_BLOCK_CELLS ..(block_idx + 1) * SHA256_BLOCK_CELLS], &padded_input[block_idx * SHA256_BLOCK_CELLS @@ -426,14 +445,18 @@ impl Sha256VmFiller { ) { debug_assert_eq!(input.len(), SHA256_BLOCK_CELLS); debug_assert_eq!(padded_input.len(), SHA256_BLOCK_CELLS); - debug_assert_eq!(read_aux_records.len(), SHA256_NUM_READ_ROWS); + debug_assert_eq!( + read_aux_records.len(), + SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS + ); let padded_input = array::from_fn(|i| { u32::from_be_bytes(padded_input[i * 4..(i + 1) * 4].try_into().unwrap()) }); - let block_start_timestamp = record.timestamp - + (SHA256_REGISTER_READS + SHA256_NUM_READ_ROWS * local_block_idx) as u32; + let reads_per_block = SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS; + let block_start_timestamp = + record.timestamp + (SHA256_REGISTER_READS + reads_per_block * local_block_idx) as u32; let read_cells = (SHA256_BLOCK_CELLS * local_block_idx) as u32; let block_start_read_ptr = record.src_ptr + read_cells; @@ -487,15 +510,21 @@ impl Sha256VmFiller { }); digest_cols .writes_aux - .set_prev_data(record.write_aux.prev_data.map(F::from_canonical_u8)); - // In the last block we do `SHA256_NUM_READ_ROWS` reads and then write the - // result thus the timestamp of the write is - // `block_start_timestamp + SHA256_NUM_READ_ROWS` - mem_helper.fill( - record.write_aux.prev_timestamp, - block_start_timestamp + SHA256_NUM_READ_ROWS as u32, - digest_cols.writes_aux.as_mut(), - ); + .iter_mut() + .zip(record.write_aux.iter()) + .enumerate() + .for_each(|(chunk, (cols_write, record_write))| { + cols_write.set_prev_data( + record_write.prev_data.map(F::from_canonical_u8), + ); + mem_helper.fill( + record_write.prev_timestamp, + block_start_timestamp + + (SHA256_NUM_READ_ROWS * SHA256_READ_SUBBLOCKS) as u32 + + chunk as u32, + cols_write.as_mut(), + ); + }); // Need to range check the destination and source pointers let msl_rshift: u32 = ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; @@ -512,10 +541,10 @@ impl Sha256VmFiller { digest_cols.register_reads_aux.iter_mut().for_each(|aux| { mem_helper.fill_zero(aux.as_mut()); }); - digest_cols - .writes_aux - .set_prev_data([F::ZERO; SHA256_WRITE_SIZE]); - mem_helper.fill_zero(digest_cols.writes_aux.as_mut()); + digest_cols.writes_aux.iter_mut().for_each(|aux| { + aux.set_prev_data([F::ZERO; CONST_BLOCK_SIZE]); + mem_helper.fill_zero(aux.as_mut()); + }); } digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); digest_cols.inner.flags.is_digest_row = F::from_bool(true); @@ -538,13 +567,19 @@ impl Sha256VmFiller { .for_each(|(cell, data)| { *cell = F::from_canonical_u8(*data); }); - mem_helper.fill( - read_aux_records[row_idx].prev_timestamp, - block_start_timestamp + row_idx as u32, - round_cols.read_aux.as_mut(), - ); + for sub in 0..SHA256_READ_SUBBLOCKS { + mem_helper.fill( + read_aux_records[row_idx * SHA256_READ_SUBBLOCKS + sub] + .prev_timestamp, + block_start_timestamp + + (row_idx * SHA256_READ_SUBBLOCKS + sub) as u32, + round_cols.read_aux[sub].as_mut(), + ); + } } else { - mem_helper.fill_zero(round_cols.read_aux.as_mut()); + round_cols.read_aux.iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); } } // Fill in the control cols, doesn't matter if it is a round or digest row @@ -553,7 +588,8 @@ impl Sha256VmFiller { control_cols.len = F::from_canonical_u32(record.len); // Only the first `SHA256_NUM_READ_ROWS` rows increment the timestamp and read ptr control_cols.cur_timestamp = F::from_canonical_u32( - block_start_timestamp + min(row_idx, SHA256_NUM_READ_ROWS) as u32, + block_start_timestamp + + (SHA256_READ_SUBBLOCKS * min(row_idx, SHA256_NUM_READ_ROWS)) as u32, ); control_cols.read_ptr = F::from_canonical_u32( block_start_read_ptr