diff --git a/.github/workflows/extension-tests.cuda.yml b/.github/workflows/extension-tests.cuda.yml index 2fc7420b37..f236af3457 100644 --- a/.github/workflows/extension-tests.cuda.yml +++ b/.github/workflows/extension-tests.cuda.yml @@ -35,7 +35,7 @@ jobs: matrix: extensions: # group extensions on the same runner based on test time - "rv32im native" - - "keccak256 sha256 bigint algebra ecc pairing" + - "keccak256 sha2 bigint algebra ecc pairing" runs-on: - runs-on=${{ github.run_id }}-extension-tests-cuda-${{ github.run_attempt }}-${{ strategy.job-index }}/runner=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.machine_type || 'test-gpu-nvidia/cpu=8+32' }} @@ -57,8 +57,8 @@ jobs: - "extensions/native/**" keccak256: - "extensions/keccak256/**" - sha256: - - "extensions/sha256/**" + sha2: + - "extensions/sha2/**" bigint: - "extensions/bigint/**" algebra: diff --git a/.github/workflows/extension-tests.yml b/.github/workflows/extension-tests.yml index 4ebec326e0..90a22d6d55 100644 --- a/.github/workflows/extension-tests.yml +++ b/.github/workflows/extension-tests.yml @@ -31,7 +31,7 @@ jobs: - { name: "rv32im", path: "rv32im", aot: false } - { name: "native", path: "native", aot: false } - { name: "keccak256", path: "keccak256", aot: false } - - { name: "sha256", path: "sha256", aot: false } + - { name: "sha2", path: "sha2", aot: false } - { name: "bigint", path: "bigint", aot: false } - { name: "algebra", path: "algebra", aot: false } - { name: "ecc", path: "ecc", aot: false } @@ -39,7 +39,7 @@ jobs: - { name: "rv32im", path: "rv32im", aot: true } - { name: "native", path: "native", aot: true } - { name: "keccak256", path: "keccak256", aot: true } - - { name: "sha256", path: "sha256", aot: true } + - { name: "sha2", path: "sha2", aot: true } - { name: "bigint", path: "bigint", aot: true } - { name: "algebra", path: "algebra", aot: true } - { name: "ecc", path: "ecc", aot: true } diff --git a/.github/workflows/guest-lib-tests.cuda.yml b/.github/workflows/guest-lib-tests.cuda.yml index 57d5594054..d25befe00b 100644 --- a/.github/workflows/guest-lib-tests.cuda.yml +++ b/.github/workflows/guest-lib-tests.cuda.yml @@ -52,7 +52,7 @@ jobs: - ".github/workflows/guest-lib-tests.cuda.yml" - "crates/sdk/guest/fib/**" sha2: - - "extensions/sha256/**" + - "extensions/sha2/**" - "guest-libs/sha2/**" keccak256: - "extensions/keccak256/**" diff --git a/.github/workflows/primitives.yml b/.github/workflows/primitives.yml index 7e669719f3..8dacf66b9d 100644 --- a/.github/workflows/primitives.yml +++ b/.github/workflows/primitives.yml @@ -9,7 +9,7 @@ on: paths: - "crates/circuits/primitives/**" - "crates/circuits/poseidon2-air/**" - - "crates/circuits/sha256-air/**" + - "crates/circuits/sha2-air/**" - "crates/circuits/mod-builder/**" - "Cargo.toml" - ".github/workflows/primitives.yml" @@ -65,13 +65,6 @@ jobs: run: | cargo nextest run ${{ env.NEXTEST_ARGS }} - # No gpu tests in these crates - - name: Run tests for sha256-air - if: ${{ !contains(matrix.platform.runner, 'gpu') }} - working-directory: crates/circuits/sha256-air - run: | - cargo nextest run ${{ env.NEXTEST_ARGS }} - - name: Run tests for mod-builder if: ${{ !contains(matrix.platform.runner, 'gpu') }} working-directory: crates/circuits/mod-builder diff --git a/Cargo.lock b/Cargo.lock index 5509556cdf..0f0271cea3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5179,8 +5179,8 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-ecc-transpiler", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -5469,6 +5469,16 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -5627,6 +5637,21 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -6186,8 +6211,8 @@ dependencies = [ "openvm-rv32im-circuit", "openvm-rv32im-transpiler", "openvm-sdk", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-transpiler", "rand 0.8.5", @@ -6384,6 +6409,8 @@ name = "openvm-circuit-primitives-derive" version = "1.4.2" dependencies = [ "itertools 0.14.0", + "ndarray", + "proc-macro2", "quote", "syn 2.0.106", ] @@ -7097,8 +7124,8 @@ dependencies = [ "openvm-pairing-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", @@ -7122,12 +7149,13 @@ name = "openvm-sha2" version = "1.4.2" dependencies = [ "eyre", + "openvm", "openvm-circuit", "openvm-instructions", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-guest", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-guest", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", @@ -7135,57 +7163,60 @@ dependencies = [ ] [[package]] -name = "openvm-sha256-air" +name = "openvm-sha2-air" version = "1.4.2" dependencies = [ - "openvm-circuit", + "ndarray", + "num_enum", "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", "openvm-stark-backend", - "openvm-stark-sdk", "rand 0.8.5", "sha2 0.10.9", ] [[package]] -name = "openvm-sha256-circuit" +name = "openvm-sha2-circuit" version = "1.4.2" dependencies = [ "cfg-if", "derive-new 0.6.0", "derive_more 1.0.0", "hex", + "itertools 0.14.0", + "ndarray", "openvm-circuit", "openvm-circuit-derive", "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", "openvm-cuda-backend", "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", "openvm-rv32im-circuit", - "openvm-sha256-air", - "openvm-sha256-transpiler", + "openvm-sha2-air", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", "serde", "sha2 0.10.9", - "strum 0.26.3", ] [[package]] -name = "openvm-sha256-guest" +name = "openvm-sha2-guest" version = "1.4.2" dependencies = [ "openvm-platform", ] [[package]] -name = "openvm-sha256-transpiler" +name = "openvm-sha2-transpiler" version = "1.4.2" dependencies = [ "openvm-instructions", "openvm-instructions-derive", - "openvm-sha256-guest", + "openvm-sha2-guest", "openvm-stark-backend", "openvm-transpiler", "rrs-lib", @@ -7366,8 +7397,8 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-ecc-transpiler", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -8538,6 +8569,12 @@ dependencies = [ "bitflags", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.11.0" diff --git a/Cargo.toml b/Cargo.toml index 8ecd91572f..7dbf9c46db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,9 +51,9 @@ members = [ "extensions/keccak256/circuit", "extensions/keccak256/transpiler", "extensions/keccak256/guest", - "extensions/sha256/circuit", - "extensions/sha256/transpiler", - "extensions/sha256/guest", + "extensions/sha2/circuit", + "extensions/sha2/transpiler", + "extensions/sha2/guest", "extensions/ecc/circuit", "extensions/ecc/transpiler", "extensions/ecc/guest", @@ -123,7 +123,7 @@ openvm-cuda-common = { git = "https://github.com/openvm-org/stark-backend.git", openvm-sdk = { path = "crates/sdk", default-features = false } openvm-mod-circuit-builder = { path = "crates/circuits/mod-builder", default-features = false } openvm-poseidon2-air = { path = "crates/circuits/poseidon2-air", default-features = false } -openvm-sha256-air = { path = "crates/circuits/sha256-air", default-features = false } +openvm-sha2-air = { path = "crates/circuits/sha2-air", default-features = false } openvm-circuit-primitives = { path = "crates/circuits/primitives", default-features = false } openvm-circuit-primitives-derive = { path = "crates/circuits/primitives/derive", default-features = false } openvm = { path = "crates/toolchain/openvm", default-features = false } @@ -153,9 +153,9 @@ openvm-native-transpiler = { path = "extensions/native/transpiler", default-feat openvm-keccak256-circuit = { path = "extensions/keccak256/circuit", default-features = false } openvm-keccak256-transpiler = { path = "extensions/keccak256/transpiler", default-features = false } openvm-keccak256-guest = { path = "extensions/keccak256/guest", default-features = false } -openvm-sha256-circuit = { path = "extensions/sha256/circuit", default-features = false } -openvm-sha256-transpiler = { path = "extensions/sha256/transpiler", default-features = false } -openvm-sha256-guest = { path = "extensions/sha256/guest", default-features = false } +openvm-sha2-circuit = { path = "extensions/sha2/circuit", default-features = false } +openvm-sha2-transpiler = { path = "extensions/sha2/transpiler", default-features = false } +openvm-sha2-guest = { path = "extensions/sha2/guest", default-features = false } openvm-bigint-circuit = { path = "extensions/bigint/circuit", default-features = false } openvm-bigint-transpiler = { path = "extensions/bigint/transpiler", default-features = false } openvm-bigint-guest = { path = "extensions/bigint/guest", default-features = false } @@ -231,6 +231,8 @@ libloading = "0.8" tracing-subscriber = { version = "0.3.20", features = ["std", "env-filter"] } tokio = "1" # >=1.0.0 to allow downstream flexibility abi_stable = "0.11.3" +ndarray = { version = "0.16.1", default-features = false } +num_enum = { version = "0.7.4", default-features = false } # default-features = false for no_std for use in guest programs itertools = { version = "0.14.0", default-features = false } diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index b9b4b96776..345b5211d4 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -26,8 +26,8 @@ openvm-keccak256-circuit.workspace = true openvm-keccak256-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-continuations = { workspace = true } openvm-native-recursion = { workspace = true } openvm-sdk = { workspace = true } diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs index 199ad87943..1291d4b2a9 100644 --- a/benchmarks/execute/benches/execute.rs +++ b/benchmarks/execute/benches/execute.rs @@ -47,8 +47,8 @@ use openvm_sdk::{ commit::VmCommittedExe, config::{AggregationConfig, DEFAULT_NUM_CHILDREN_INTERNAL, DEFAULT_NUM_CHILDREN_LEAF}, }; -use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; -use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_sha2_circuit::{Sha2, Sha2CpuProverExt, Sha2Executor}; +use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::{ config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, engine::{StarkEngine, StarkFriEngine}, @@ -144,7 +144,7 @@ pub struct ExecuteConfig { #[extension] pub keccak: Keccak256, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, #[extension] pub modular: ModularExtension, #[extension] @@ -165,7 +165,7 @@ impl Default for ExecuteConfig { io: Rv32Io, bigint: Int256::default(), keccak: Keccak256, - sha256: Sha256, + sha2: Sha2, modular: ModularExtension::new(vec![ bn_config.modulus.clone(), bn_config.scalar.clone(), @@ -225,7 +225,7 @@ where &config.keccak, inventory, )?; - VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; + VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha2, inventory)?; VmProverExtension::::extend_prover( &AlgebraCpuProverExt, &config.modular, @@ -253,7 +253,7 @@ fn create_default_transpiler() -> Transpiler { .with_extension(Rv32MTranspilerExtension) .with_extension(Int256TranspilerExtension) .with_extension(Keccak256TranspilerExtension) - .with_extension(Sha256TranspilerExtension) + .with_extension(Sha2TranspilerExtension) .with_extension(ModularTranspilerExtension) .with_extension(Fp2TranspilerExtension) .with_extension(EccTranspilerExtension) diff --git a/benchmarks/guest/kitchen-sink/openvm.toml b/benchmarks/guest/kitchen-sink/openvm.toml index 2d1b307eef..e6cafcf57f 100644 --- a/benchmarks/guest/kitchen-sink/openvm.toml +++ b/benchmarks/guest/kitchen-sink/openvm.toml @@ -2,7 +2,7 @@ [app_vm_config.rv32m] [app_vm_config.io] [app_vm_config.keccak] -[app_vm_config.sha256] +[app_vm_config.sha2] [app_vm_config.bigint] [app_vm_config.modular] diff --git a/benchmarks/guest/sha256/openvm.toml b/benchmarks/guest/sha256/openvm.toml index 656bf52414..35f92b7195 100644 --- a/benchmarks/guest/sha256/openvm.toml +++ b/benchmarks/guest/sha256/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/benchmarks/guest/sha256/src/main.rs b/benchmarks/guest/sha256/src/main.rs index fc0b3fab78..7f7067fd79 100644 --- a/benchmarks/guest/sha256/src/main.rs +++ b/benchmarks/guest/sha256/src/main.rs @@ -1,7 +1,7 @@ use core::hint::black_box; -use openvm as _; -use openvm_sha2::sha256; +use openvm as _; +use openvm_sha2::Sha256; const INPUT_LENGTH_BYTES: usize = 384 * 1024; @@ -16,5 +16,7 @@ pub fn main() { } // Prevent optimizer from optimizing away the computation - black_box(sha256(&black_box(input))); + let sha256 = Sha256::new(); + sha256.update(black_box(&input)); + black_box(sha256.finalize()); } diff --git a/benchmarks/guest/sha256_iter/openvm.toml b/benchmarks/guest/sha256_iter/openvm.toml index 656bf52414..35f92b7195 100644 --- a/benchmarks/guest/sha256_iter/openvm.toml +++ b/benchmarks/guest/sha256_iter/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/benchmarks/guest/sha256_iter/src/main.rs b/benchmarks/guest/sha256_iter/src/main.rs index aea8b723e9..7c71085de2 100644 --- a/benchmarks/guest/sha256_iter/src/main.rs +++ b/benchmarks/guest/sha256_iter/src/main.rs @@ -5,13 +5,19 @@ use openvm_sha2::sha256; const ITERATIONS: usize = 150_000; +fn sha256(input: &[u8]) -> [u8; 32] { + let mut sha256 = Sha256::new(); + sha256.update(black_box(input)); + sha256.finalize() +} + pub fn main() { // Initialize with hash of an empty vector let mut hash = black_box(sha256(&[])); // Iteratively apply sha256 for _ in 0..ITERATIONS { - hash = sha256(&hash); + hash = black_box(sha256(&hash)); } // Prevent optimizer from optimizing away the computation diff --git a/crates/circuits/primitives/cuda/include/primitives/constants.h b/crates/circuits/primitives/cuda/include/primitives/constants.h index dec26b5f41..16396d41ec 100644 --- a/crates/circuits/primitives/cuda/include/primitives/constants.h +++ b/crates/circuits/primitives/cuda/include/primitives/constants.h @@ -3,91 +3,97 @@ #include namespace riscv { -static const size_t RV32_REGISTER_NUM_LIMBS = 4; -static const size_t RV32_CELL_BITS = 8; -static const size_t RV_J_TYPE_IMM_BITS = 21; +inline constexpr size_t RV32_REGISTER_NUM_LIMBS = 4; +inline constexpr size_t RV32_CELL_BITS = 8; +inline constexpr size_t RV_J_TYPE_IMM_BITS = 21; -static const size_t RV32_IMM_AS = 0; +inline constexpr size_t RV32_IMM_AS = 0; } // namespace riscv namespace program { -static const size_t PC_BITS = 30; -static const size_t DEFAULT_PC_STEP = 4; +inline constexpr size_t PC_BITS = 30; +inline constexpr size_t DEFAULT_PC_STEP = 4; } // namespace program namespace native { -static const size_t AS_IMMEDIATE = 0; -static const size_t AS_NATIVE = 4; -static const size_t EXT_DEG = 4; -static const size_t BETA = 11; +inline constexpr size_t AS_IMMEDIATE = 0; +inline constexpr size_t AS_NATIVE = 4; +inline constexpr size_t EXT_DEG = 4; +inline constexpr size_t BETA = 11; } // namespace native namespace poseidon2 { -static const size_t CHUNK = 8; +inline constexpr size_t CHUNK = 8; } // namespace poseidon2 namespace p3_keccak_air { -static const size_t NUM_ROUNDS = 24; -static const size_t BITS_PER_LIMB = 16; -static const size_t U64_LIMBS = 64 / BITS_PER_LIMB; -static const size_t RATE_BITS = 1088; -static const size_t RATE_LIMBS = RATE_BITS / BITS_PER_LIMB; +inline constexpr size_t NUM_ROUNDS = 24; +inline constexpr size_t BITS_PER_LIMB = 16; +inline constexpr size_t U64_LIMBS = 64 / BITS_PER_LIMB; +inline constexpr size_t RATE_BITS = 1088; +inline constexpr size_t RATE_LIMBS = RATE_BITS / BITS_PER_LIMB; } // namespace p3_keccak_air namespace keccak256 { /// Total number of sponge bytes: number of rate bytes + number of capacity bytes. -static const size_t KECCAK_WIDTH_BYTES = 200; +inline constexpr size_t KECCAK_WIDTH_BYTES = 200; /// Total number of 16-bit limbs in the sponge. -static const size_t KECCAK_WIDTH_U16S = KECCAK_WIDTH_BYTES / 2; +inline constexpr size_t KECCAK_WIDTH_U16S = KECCAK_WIDTH_BYTES / 2; /// Number of rate bytes. -static const size_t KECCAK_RATE_BYTES = 136; +inline constexpr size_t KECCAK_RATE_BYTES = 136; /// Number of 16-bit rate limbs. -static const size_t KECCAK_RATE_U16S = KECCAK_RATE_BYTES / 2; +inline constexpr size_t KECCAK_RATE_U16S = KECCAK_RATE_BYTES / 2; /// Number of absorb rounds, equal to rate in u64s. -static const size_t NUM_ABSORB_ROUNDS = KECCAK_RATE_BYTES / 8; +inline constexpr size_t NUM_ABSORB_ROUNDS = KECCAK_RATE_BYTES / 8; /// Number of capacity bytes. -static const size_t KECCAK_CAPACITY_BYTES = 64; +inline constexpr size_t KECCAK_CAPACITY_BYTES = 64; /// Number of 16-bit capacity limbs. -static const size_t KECCAK_CAPACITY_U16S = KECCAK_CAPACITY_BYTES / 2; +inline constexpr size_t KECCAK_CAPACITY_U16S = KECCAK_CAPACITY_BYTES / 2; /// Number of output digest bytes used during the squeezing phase. -static const size_t KECCAK_DIGEST_BYTES = 32; +inline constexpr size_t KECCAK_DIGEST_BYTES = 32; /// Number of 64-bit digest limbs. -static const size_t KECCAK_DIGEST_U64S = KECCAK_DIGEST_BYTES / 8; +inline constexpr size_t KECCAK_DIGEST_U64S = KECCAK_DIGEST_BYTES / 8; // ==== Constants for register/memory adapter ==== /// Register reads to get dst, src, len -static const size_t KECCAK_REGISTER_READS = 3; +inline constexpr size_t KECCAK_REGISTER_READS = 3; /// Number of cells to read/write in a single memory access -static const size_t KECCAK_WORD_SIZE = 4; +inline constexpr size_t KECCAK_WORD_SIZE = 4; /// Memory reads for absorb per row -static const size_t KECCAK_ABSORB_READS = KECCAK_RATE_BYTES / KECCAK_WORD_SIZE; +inline constexpr size_t KECCAK_ABSORB_READS = KECCAK_RATE_BYTES / KECCAK_WORD_SIZE; /// Memory writes for digest per row -static const size_t KECCAK_DIGEST_WRITES = KECCAK_DIGEST_BYTES / KECCAK_WORD_SIZE; +inline constexpr size_t KECCAK_DIGEST_WRITES = KECCAK_DIGEST_BYTES / KECCAK_WORD_SIZE; /// keccakf parameters -static const size_t KECCAK_ROUND = 24; -static const size_t KECCAK_STATE_SIZE = 25; -static const size_t KECCAK_Q_SIZE = 192; +inline constexpr size_t KECCAK_ROUND = 24; +inline constexpr size_t KECCAK_STATE_SIZE = 25; +inline constexpr size_t KECCAK_Q_SIZE = 192; /// From memory config -static const size_t KECCAK_POINTER_MAX_BITS = 29; +inline constexpr size_t KECCAK_POINTER_MAX_BITS = 29; } // namespace keccak256 namespace mod_builder { -static const size_t MAX_LIMBS = 97; +inline constexpr size_t MAX_LIMBS = 97; } // namespace mod_builder namespace sha256 { -static const size_t SHA256_BLOCK_BITS = 512; -static const size_t SHA256_BLOCK_U8S = 64; -static const size_t SHA256_BLOCK_WORDS = 16; -static const size_t SHA256_WORD_U8S = 4; -static const size_t SHA256_WORD_BITS = 32; -static const size_t SHA256_WORD_U16S = 2; -static const size_t SHA256_HASH_WORDS = 8; -static const size_t SHA256_NUM_READ_ROWS = 4; -static const size_t SHA256_ROWS_PER_BLOCK = 17; -static const size_t SHA256_ROUNDS_PER_ROW = 4; -static const size_t SHA256_ROW_VAR_CNT = 5; -static const size_t SHA256_REGISTER_READS = 3; -static const size_t SHA256_READ_SIZE = 16; -static const size_t SHA256_WRITE_SIZE = 32; -} // namespace sha256 \ No newline at end of file +inline constexpr size_t SHA256_BLOCK_BITS = 512; +inline constexpr size_t SHA256_BLOCK_U8S = 64; +inline constexpr size_t SHA256_BLOCK_WORDS = 16; +inline constexpr size_t SHA256_WORD_U8S = 4; +inline constexpr size_t SHA256_WORD_BITS = 32; +inline constexpr size_t SHA256_WORD_U16S = 2; +inline constexpr size_t SHA256_HASH_WORDS = 8; +inline constexpr size_t SHA256_NUM_READ_ROWS = 4; +inline constexpr size_t SHA256_ROWS_PER_BLOCK = 17; +inline constexpr size_t SHA256_ROUNDS_PER_ROW = 4; +inline constexpr size_t SHA256_ROW_VAR_CNT = 5; +inline constexpr size_t SHA256_REGISTER_READS = 3; +inline constexpr size_t SHA256_READ_SIZE = 16; +inline constexpr size_t SHA256_WRITE_SIZE = 32; +} // namespace sha256 + +namespace hintstore { +// Must match MAX_HINT_BUFFER_WORDS_BITS in openvm_rv32im_guest::lib.rs +inline constexpr size_t MAX_HINT_BUFFER_WORDS_BITS = 18; +inline constexpr size_t MAX_HINT_BUFFER_WORDS = (1 << MAX_HINT_BUFFER_WORDS_BITS) - 1; +} // namespace hintstore diff --git a/crates/circuits/primitives/derive/Cargo.toml b/crates/circuits/primitives/derive/Cargo.toml index 06d4c00aed..23ec5d559d 100644 --- a/crates/circuits/primitives/derive/Cargo.toml +++ b/crates/circuits/primitives/derive/Cargo.toml @@ -12,6 +12,13 @@ license.workspace = true proc-macro = true [dependencies] -syn = { version = "2.0", features = ["parsing"] } +syn = { version = "2.0", features = ["full", "parsing", "extra-traits"] } quote = "1.0" -itertools = { workspace = true } +itertools = { workspace = true, default-features = true } +proc-macro2 = "1.0" + +[dev-dependencies] +ndarray.workspace = true + +[package.metadata.cargo-shear] +ignored = ["ndarray"] diff --git a/crates/circuits/primitives/derive/src/cols_ref/README.md b/crates/circuits/primitives/derive/src/cols_ref/README.md new file mode 100644 index 0000000000..82812f7b90 --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/README.md @@ -0,0 +1,113 @@ +# ColsRef macro + +The `ColsRef` procedural macro is used in constraint generation to create column structs that have dynamic sizes. + +Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384). +See the [SHA-2 VM extension](../../../../../../extensions/sha2/circuit/src/sha2_chip/air.rs) for an example of how to use the `ColsRef` macro to reuse constraint generation code over multiple circuits. + +## Overview + +As an illustrative example, consider the following columns struct: +```rust +struct ExampleCols { + arr: [T; N], + sum: T, +} +``` +Let's say we want to constrain `sum` to be the sum of the elements of `arr`, and `N` can be either 5 or 10. +We can define a trait that stores the config parameters. +```rust +pub trait ExampleConfig { + const N: usize; +} +``` +and then implement it for the two different configs. +```rust +pub struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} +pub struct ExampleConfigImplB; +impl ExampleConfig for ExampleConfigImplB { + const N: usize = 10; +} +``` +Then we can use the `ColsRef` macro like this +```rust +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + sum: T, +} +``` +which will generate a columns struct that uses references to the fields. +```rust +struct ExampleColsRef<'a, T, const N: usize> { + arr: ndarray::ArrayView1<'a, T>, // an n-dimensional view into the input slice (ArrayView2 for 2D arrays, etc.) + sum: &'a T, +} +``` +The `ColsRef` macro will also generate a `from` method that takes a slice of the correct length and returns an instance of the columns struct. +The `from` method is parameterized by a struct that implements the `ExampleConfig` trait, and it uses the associated constants to determine how to split the input slice into the fields of the columns struct. + +So, the constraint generation code can be written as +```rust +impl Air for ExampleAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, _) = (main.row_slice(0), main.row_slice(1)); + let local_cols = ExampleColsRef::::from::(&local[..C::N + 1]); + let sum = local_cols.arr.iter().sum(); + builder.assert_eq(local_cols.sum, sum); + } +} +``` +Notes: +- the `arr` and `sum` fields of `ExampleColsRef` are references to the elements of the `local` slice. +- the name, `N`, of the const generic parameter must match the name of the associated constant `N` in the `ExampleConfig` trait. + +The `ColsRef` macro also generates a `ExampleColsRefMut` struct that stores mutable references to the fields, for use in trace generation. + +The `ColsRef` macro supports more than just variable-length array fields. +The field types can also be: +- any type that derives `AlignedBorrow` via `#[derive(AlignedBorrow)]` +- any type that derives `ColsRef` via `#[derive(ColsRef)]` +- (possibly nested) arrays of `T` or (possibly nested) arrays of a type that derives `AlignedBorrow` + +Note that we currently do not support arrays of types that derive `ColsRef`. + +## Specification + +Annotating a struct named `ExampleCols` with `#[derive(ColsRef)]` and `#[config(ExampleConfig)]` produces two structs, `ExampleColsRef` and `ExampleColsRefMut`. +- we assume `ExampleCols` has exactly one generic type parameter, typically named `T`, and any number of const generic parameters. Each const generic parameter must have a name that matches an associated constant in the `ExampleConfig` trait + +The fields of `ExampleColsRef` have the same names as the fields of `ExampleCols`, but their types are transformed as follows: +- type `T` becomes `&T` +- type `[T; LEN]` becomes `&ArrayView1` (see [ndarray](https://docs.rs/ndarray/latest/ndarray/index.html)) where `LEN` is an associated constant in `ExampleConfig` + - the `ExampleColsRef::from` method will correctly infer the length of the array from the config +- fields with names that end in `Cols` are assumed to be a columns struct that derives `ColsRef` and are transformed into the appropriate `ColsRef` type recursively + - one restriction is that any nested `ColsRef` type must have the same config as the outer `ColsRef` type +- fields that are annotated with `#[aligned_borrow]` are assumed to derive `AlignedBorrow` and are borrowed from the input slice. The new type is a reference to the `AlignedBorrow` type + - if a field whose name ends in `Cols` is annotated with `#[aligned_borrow]`, then the aligned borrow takes precedence, and the field is not transformed into an `ArrayView` +- nested arrays of `U` become `&ArrayViewX` where `X` is the number of dimensions in the nested array type + - `U` can be either the generic type `T` or a type that derives `AlignedBorrow`. In the latter case, the field must be annotated with `#[aligned_borrow]` + - the `ArrayViewX` type provides a `X`-dimensional view into the row slice + +The fields of `ExampleColsRefMut` are almost the same as the fields of `ExampleColsRef`, but they are mutable references. +- the `ArrayViewMutX` type is used instead of `ArrayViewX` for the array fields. +- fields that derive `ColsRef` are transformed into the appropriate `ColsRefMut` type recursively. + +Each of the `ExampleColsRef` and `ExampleColsRefMut` types has the following methods implemented: +```rust +// Takes a slice of the correct length and returns an instance of the columns struct. +pub const fn from(slice: &[T]) -> Self; +// Returns the number of cells in the struct +pub const fn width() -> usize; +``` +Note that the `width` method on both structs returns the same value. + +Additionally, the `ExampleColsRef` struct has a `from_mut` method that takes a `ExampleColsRefMut` and returns a `ExampleColsRef`. +This may be useful in trace generation to pass a `ExampleColsRefMut` to a function that expects a `ExampleColsRef`. + +See the [tests](../../tests/test_cols_ref.rs) for concrete examples of how the `ColsRef` macro handles each of the supported field types. \ No newline at end of file diff --git a/crates/circuits/primitives/derive/src/cols_ref/mod.rs b/crates/circuits/primitives/derive/src/cols_ref/mod.rs new file mode 100644 index 0000000000..2d3fad11ed --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/mod.rs @@ -0,0 +1,703 @@ +/* + * The `ColsRef` procedural macro is used in constraint generation to create column structs that + * have dynamic sizes. + * + * Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the + * same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384). + * See the [SHA-2 VM extension](openvm/extensions/sha2/circuit/src/sha2_chip/air.rs) for an + * example of how to use the `ColsRef` macro to reuse constraint generation code over multiple + * circuits. + * + * This macro can also be used in other situations where we want to derive Borrow for &[u8], + * for some complicated struct T. + */ +mod utils; + +use utils::*; + +extern crate proc_macro; + +use itertools::Itertools; +use quote::{format_ident, quote}; +use syn::{parse_quote, DeriveInput}; + +pub fn cols_ref_impl( + derive_input: DeriveInput, + config: proc_macro2::Ident, +) -> proc_macro2::TokenStream { + let DeriveInput { + ident, + generics, + data, + vis, + .. + } = derive_input; + + let generic_types = generics + .params + .iter() + .filter_map(|p| { + if let syn::GenericParam::Type(type_param) = p { + Some(type_param) + } else { + None + } + }) + .collect::>(); + + if generic_types.len() != 1 { + panic!("Struct must have exactly one generic type parameter"); + } + + let generic_type = generic_types[0]; + + let const_generics = generics.const_params().map(|p| &p.ident).collect_vec(); + + match data { + syn::Data::Struct(data_struct) => { + // Process the fields of the struct, transforming the types for use in ColsRef struct + let const_field_infos: Vec = data_struct + .fields + .iter() + .map(|f| get_const_cols_ref_fields(f, generic_type, &const_generics)) + .collect_vec(); + + // The ColsRef struct is named by appending `Ref` to the struct name + let const_cols_ref_name = syn::Ident::new(&format!("{ident}Ref"), ident.span()); + + // the args to the `from` method will be different for the ColsRef and ColsRefMut + // structs + let from_args = quote! { slice: &'a [#generic_type] }; + + // Package all the necessary information to generate the ColsRef struct + let struct_info = StructInfo { + name: const_cols_ref_name, + vis: vis.clone(), + generic_type: generic_type.clone(), + field_infos: const_field_infos, + fields: data_struct.fields.clone(), + from_args, + derive_clone: true, + }; + + // Generate the ColsRef struct + let const_cols_ref_struct = make_struct(struct_info.clone(), &config); + + // Generate the `from_mut` method for the ColsRef struct + let from_mut_impl = make_from_mut(struct_info, &config); + + // Process the fields of the struct, transforming the types for use in ColsRefMut struct + let mut_field_infos: Vec = data_struct + .fields + .iter() + .map(|f| get_mut_cols_ref_fields(f, generic_type, &const_generics)) + .collect_vec(); + + // The ColsRefMut struct is named by appending `RefMut` to the struct name + let mut_cols_ref_name = syn::Ident::new(&format!("{ident}RefMut"), ident.span()); + + // the args to the `from` method will be different for the ColsRef and ColsRefMut + // structs + let from_args = quote! { slice: &'a mut [#generic_type] }; + + // Package all the necessary information to generate the ColsRefMut struct + let struct_info = StructInfo { + name: mut_cols_ref_name, + vis, + generic_type: generic_type.clone(), + field_infos: mut_field_infos, + fields: data_struct.fields, + from_args, + derive_clone: false, + }; + + // Generate the ColsRefMut struct + let mut_cols_ref_struct = make_struct(struct_info, &config); + + quote! { + #const_cols_ref_struct + #from_mut_impl + #mut_cols_ref_struct + } + } + _ => panic!("ColsRef can only be derived for structs"), + } +} + +#[derive(Debug, Clone)] +struct StructInfo { + name: syn::Ident, + vis: syn::Visibility, + generic_type: syn::TypeParam, + field_infos: Vec, + fields: syn::Fields, + from_args: proc_macro2::TokenStream, + derive_clone: bool, +} + +// Generate the ColsRef and ColsRefMut structs, depending on the value of `struct_info` +// This function is meant to reduce code duplication between the code needed to generate the two +// structs Notable differences between the two structs are: +// - the types of the fields +// - ColsRef derives Clone, but ColsRefMut cannot (since it stores mutable references) +// - the `from` method parameter is a reference to a slice for ColsRef and a mutable reference to +// a slice for ColsRefMut +fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream { + let StructInfo { + name, + vis, + generic_type, + field_infos, + fields, + from_args, + derive_clone, + } = struct_info; + + let field_types = field_infos.iter().map(|f| &f.ty).collect_vec(); + let length_exprs = field_infos.iter().map(|f| &f.length_expr).collect_vec(); + let prepare_subslices = field_infos + .iter() + .map(|f| &f.prepare_subslice) + .collect_vec(); + let initializers = field_infos.iter().map(|f| &f.initializer).collect_vec(); + + let idents = fields.iter().map(|f| &f.ident).collect_vec(); + + let clone_impl = if derive_clone { + quote! { + #[derive(Clone)] + } + } else { + quote! {} + }; + + quote! { + #clone_impl + #[derive(Debug)] + #vis struct #name <'a, #generic_type> { + #( pub #idents: #field_types ),* + } + + impl<'a, #generic_type> #name<'a, #generic_type> { + pub fn from(#from_args) -> Self { + #( #prepare_subslices )* + Self { + #( #idents: #initializers ),* + } + } + + // Returns number of cells in the struct (where each cell has type T). + // This method should only be called if the struct has no primitive types (i.e. for columns structs). + pub const fn width() -> usize { + 0 #( + #length_exprs )* + } + } + } +} + +// Generate the `from_mut` method for the ColsRef struct +fn make_from_mut(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream { + let StructInfo { + name, + vis: _, + generic_type, + field_infos: _, + fields, + from_args: _, + derive_clone: _, + } = struct_info; + + let from_mut_impl = fields + .iter() + .map(|f| { + let ident = f.ident.clone().unwrap(); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + // calling view() on ArrayViewMut returns an ArrayView + quote! { + other.#ident.view() + } + } else if derives_aligned_borrow { + // implicitly converts a mutable reference to an immutable reference, so leave the + // field value unchanged + quote! { + other.#ident + } + } else if is_columns_struct(&f.ty) { + // lifetime 'b is used in from_mut to allow more flexible lifetime of return value + let cols_ref_type = + get_const_cols_ref_type(&f.ty, &generic_type, parse_quote! { 'b }); + // Recursively call `from_mut` on the ColsRef field + quote! { + <#cols_ref_type>::from_mut::(&other.#ident) + } + } else if is_generic_type(&f.ty, &generic_type) { + // implicitly converts a mutable reference to an immutable reference, so leave the + // field value unchanged + quote! { + &other.#ident + } + } else { + panic!("Unsupported field type (in make_from_mut): {:?}", f.ty); + } + }) + .collect_vec(); + + let field_idents = fields + .iter() + .map(|f| f.ident.clone().unwrap()) + .collect_vec(); + + let mut_struct_ident = format_ident!("{}Mut", name.to_string()); + let mut_struct_type: syn::Type = parse_quote! { + #mut_struct_ident<'a, #generic_type> + }; + + parse_quote! { + // lifetime 'b is used in from_mut to allow more flexible lifetime of return value + impl<'b, #generic_type> #name<'b, #generic_type> { + pub fn from_mut<'a, C: #config>(other: &'b #mut_struct_type) -> Self + { + Self { + #( #field_idents: #from_mut_impl ),* + } + } + } + } +} + +// Information about a field that is used to generate the ColsRef and ColsRefMut structs +// See the `make_struct` function to see how this information is used +#[derive(Debug, Clone)] +struct FieldInfo { + // type for struct definition + ty: syn::Type, + // an expr calculating the length of the field + length_expr: proc_macro2::TokenStream, + // prepare a subslice of the slice to be used in the 'from' method + prepare_subslice: proc_macro2::TokenStream, + // an expr used in the Self initializer in the 'from' method + // may refer to the subslice declared in prepare_subslice + initializer: proc_macro2::TokenStream, +} + +// Prepare the fields for the const ColsRef struct +fn get_const_cols_ref_fields( + f: &syn::Field, + generic_type: &syn::TypeParam, + const_generics: &[&syn::Ident], +) -> FieldInfo { + let length_var = format_ident!("{}_length", f.ident.clone().unwrap()); + let slice_var = format_ident!("{}_slice", f.ident.clone().unwrap()); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics); + debug_assert!( + !dims.is_empty(), + "Array field must have at least one dimension" + ); + + let ndarray_ident: syn::Ident = format_ident!("ArrayView{}", dims.len()); + let ndarray_type: syn::Type = parse_quote! { + ndarray::#ndarray_ident<'a, #elem_type> + }; + + // dimensions of the array in terms of number of cells + let dim_exprs = dims + .iter() + .map(|d| match d { + // need to prepend C:: for const generic array dimensions + Dimension::ConstGeneric(expr) => quote! { C::#expr }, + Dimension::Other(expr) => quote! { #expr }, + }) + .collect_vec(); + + if derives_aligned_borrow { + let length_expr = quote! { + <#elem_type>::width() #(* #dim_exprs)* + }; + + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(#length_expr); + let #slice_var: &[#elem_type] = unsafe { &*(#slice_var as *const [T] as *const [#elem_type]) }; + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_columns_struct(&elem_type) { + panic!("Arrays of columns structs are currently not supported"); + } else if is_generic_type(&elem_type, generic_type) { + let length_expr = quote! { + 1 #(* #dim_exprs)* + }; + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(#length_expr); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_primitive_type(&elem_type) { + FieldInfo { + ty: parse_quote! { + &'a #elem_type + }, + // Columns structs won't ever have primitive types, but this macro can be used on + // other structs as well, to make it easy to borrow a struct from &[u8]. + // We just set length = 0 knowing that calling the width() method is undefined if + // the struct has a primitive type. + length_expr: quote! { + 0 + }, + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(std::mem::size_of::<#elem_type>() #(* #dim_exprs)*); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else { + panic!( + "Unsupported field type (in get_const_cols_ref_fields): {:?}", + f.ty + ); + } + } else if derives_aligned_borrow { + // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) + let f_ty = &f.ty; + FieldInfo { + ty: parse_quote! { + &'a #f_ty + }, + length_expr: quote! { + <#f_ty>::width() + }, + prepare_subslice: quote! { + let #length_var = <#f_ty>::width(); + let (#slice_var, slice) = slice.split_at(#length_var); + }, + initializer: quote! { + { + use core::borrow::Borrow; + #slice_var.borrow() + } + }, + } + } else if is_columns_struct(&f.ty) { + let const_cols_ref_type = get_const_cols_ref_type(&f.ty, generic_type, parse_quote! { 'a }); + FieldInfo { + ty: parse_quote! { + #const_cols_ref_type + }, + length_expr: quote! { + <#const_cols_ref_type>::width::() + }, + prepare_subslice: quote! { + let #length_var = <#const_cols_ref_type>::width::(); + let (#slice_var, slice) = slice.split_at(#length_var); + let #slice_var = <#const_cols_ref_type>::from::(#slice_var); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_generic_type(&f.ty, generic_type) { + FieldInfo { + ty: parse_quote! { + &'a #generic_type + }, + length_expr: quote! { + 1 + }, + prepare_subslice: quote! { + let #length_var = 1; + let (#slice_var, slice) = slice.split_at(#length_var); + }, + initializer: quote! { + &#slice_var[0] + }, + } + } else { + panic!( + "Unsupported field type (in get_mut_cols_ref_fields): {:?}", + f.ty + ); + } +} + +// Prepare the fields for the mut ColsRef struct +fn get_mut_cols_ref_fields( + f: &syn::Field, + generic_type: &syn::TypeParam, + const_generics: &[&syn::Ident], +) -> FieldInfo { + let length_var = format_ident!("{}_length", f.ident.clone().unwrap()); + let slice_var = format_ident!("{}_slice", f.ident.clone().unwrap()); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics); + debug_assert!( + !dims.is_empty(), + "Array field must have at least one dimension" + ); + + let ndarray_ident: syn::Ident = format_ident!("ArrayViewMut{}", dims.len()); + let ndarray_type: syn::Type = parse_quote! { + ndarray::#ndarray_ident<'a, #elem_type> + }; + + // dimensions of the array in terms of number of cells + let dim_exprs = dims + .iter() + .map(|d| match d { + // need to prepend C:: for const generic array dimensions + Dimension::ConstGeneric(expr) => quote! { C::#expr }, + Dimension::Other(expr) => quote! { #expr }, + }) + .collect_vec(); + + if derives_aligned_borrow { + let length_expr = quote! { + <#elem_type>::width() #(* #dim_exprs)* + }; + + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut (#length_expr); + let #slice_var: &mut [#elem_type] = unsafe { &mut *(#slice_var as *mut [T] as *mut [#elem_type]) }; + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_columns_struct(&elem_type) { + panic!("Arrays of columns structs are currently not supported"); + } else if is_generic_type(&elem_type, generic_type) { + let length_expr = quote! { + 1 #(* #dim_exprs)* + }; + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut(#length_expr); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_primitive_type(&elem_type) { + FieldInfo { + ty: parse_quote! { + &'a mut #elem_type + }, + // Columns structs won't ever have primitive types, but this macro can be used on + // other structs as well, to make it easy to borrow a struct from &[u8]. + // We just set length = 0 knowing that calling the width() method is undefined if + // the struct has a primitive type. + length_expr: quote! { + 0 + }, + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut(std::mem::size_of::<#elem_type>() #(* #dim_exprs)*); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else { + panic!( + "Unsupported field type (in get_mut_cols_ref_fields): {:?}", + f.ty + ); + } + } else if derives_aligned_borrow { + // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) + let f_ty = &f.ty; + FieldInfo { + ty: parse_quote! { + &'a mut #f_ty + }, + length_expr: quote! { + <#f_ty>::width() + }, + prepare_subslice: quote! { + let #length_var = <#f_ty>::width(); + let (#slice_var, slice) = slice.split_at_mut(#length_var); + }, + initializer: quote! { + { + use core::borrow::BorrowMut; + #slice_var.borrow_mut() + } + }, + } + } else if is_columns_struct(&f.ty) { + let mut_cols_ref_type = get_mut_cols_ref_type(&f.ty, generic_type); + FieldInfo { + ty: parse_quote! { + #mut_cols_ref_type + }, + length_expr: quote! { + <#mut_cols_ref_type>::width::() + }, + prepare_subslice: quote! { + let #length_var = <#mut_cols_ref_type>::width::(); + let (#slice_var, slice) = slice.split_at_mut(#length_var); + let #slice_var = <#mut_cols_ref_type>::from::(#slice_var); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_generic_type(&f.ty, generic_type) { + FieldInfo { + ty: parse_quote! { + &'a mut #generic_type + }, + length_expr: quote! { + 1 + }, + prepare_subslice: quote! { + let #length_var = 1; + let (#slice_var, slice) = slice.split_at_mut(#length_var); + }, + initializer: quote! { + &mut #slice_var[0] + }, + } + } else { + panic!( + "Unsupported field type (in get_mut_cols_ref_fields): {:?}", + f.ty + ); + } +} + +// Helper functions + +fn is_columns_struct(ty: &syn::Type) -> bool { + if let syn::Type::Path(type_path) = ty { + type_path + .path + .segments + .iter() + .next_back() + .map(|s| s.ident.to_string().ends_with("Cols")) + .unwrap_or(false) + } else { + false + } +} + +// If 'ty' is a struct that derives ColsRef, return the ColsRef struct type +// Otherwise, return None +fn get_const_cols_ref_type( + ty: &syn::Type, + generic_type: &syn::TypeParam, + lifetime: syn::Lifetime, +) -> syn::TypePath { + if !is_columns_struct(ty) { + panic!("Expected a columns struct, got {ty:?}"); + } + + if let syn::Type::Path(type_path) = ty { + let s = type_path.path.segments.iter().next_back().unwrap(); + if s.ident.to_string().ends_with("Cols") { + let const_cols_ref_ident = format_ident!("{}Ref", s.ident); + let const_cols_ref_type = parse_quote! { + #const_cols_ref_ident<#lifetime, #generic_type> + }; + const_cols_ref_type + } else { + panic!("is_columns_struct returned true for type {ty:?} but the last segment is not a columns struct"); + } + } else { + panic!("is_columns_struct returned true but the type {ty:?} is not a path",); + } +} + +// If 'ty' is a struct that derives ColsRef, return the ColsRefMut struct type +// Otherwise, return None +fn get_mut_cols_ref_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> syn::TypePath { + if !is_columns_struct(ty) { + panic!("Expected a columns struct, got {ty:?}"); + } + + if let syn::Type::Path(type_path) = ty { + let s = type_path.path.segments.iter().next_back().unwrap(); + if s.ident.to_string().ends_with("Cols") { + let mut_cols_ref_ident = format_ident!("{}RefMut", s.ident); + let mut_cols_ref_type = parse_quote! { + #mut_cols_ref_ident<'a, #generic_type> + }; + mut_cols_ref_type + } else { + panic!("is_columns_struct returned true for type {ty:?} but the last segment is not a columns struct"); + } + } else { + panic!("is_columns_struct returned true but the type {ty:?} is not a path",); + } +} + +fn is_generic_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> bool { + if let syn::Type::Path(type_path) = ty { + if type_path.path.segments.len() == 1 { + type_path + .path + .segments + .iter() + .next_back() + .map(|s| s.ident == generic_type.ident) + .unwrap_or(false) + } else { + false + } + } else { + false + } +} diff --git a/crates/circuits/primitives/derive/src/cols_ref/utils.rs b/crates/circuits/primitives/derive/src/cols_ref/utils.rs new file mode 100644 index 0000000000..56e7dcd918 --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/utils.rs @@ -0,0 +1,102 @@ +use syn::{Expr, ExprPath, Ident, Stmt, Type, TypePath}; + +pub fn is_primitive_type(ty: &Type) -> bool { + match ty { + Type::Path(TypePath { path, .. }) if path.segments.len() == 1 => { + matches!( + path.segments[0].ident.to_string().as_str(), + "u8" | "u16" + | "u32" + | "u64" + | "u128" + | "usize" + | "i8" + | "i16" + | "i32" + | "i64" + | "i128" + | "isize" + | "f32" + | "f64" + | "bool" + | "char" + ) + } + _ => false, + } +} + +// Type of array dimension +pub enum Dimension { + ConstGeneric(Expr), + Other(Expr), +} + +// Describes a nested array +pub struct ArrayInfo { + pub dims: Vec, + pub elem_type: Type, +} + +pub fn get_array_info(ty: &Type, const_generics: &[&Ident]) -> ArrayInfo { + let dims = get_dims(ty, const_generics); + let elem_type = get_elem_type(ty); + ArrayInfo { dims, elem_type } +} + +fn get_elem_type(ty: &Type) -> Type { + match ty { + Type::Array(array) => get_elem_type(array.elem.as_ref()), + Type::Path(_) => ty.clone(), + _ => panic!("Unsupported type: {ty:?}"), + } +} + +// Get a vector of the dimensions of the array +// Each dimension is either a constant generic or a literal integer value +fn get_dims(ty: &Type, const_generics: &[&Ident]) -> Vec { + get_dims_impl(ty, const_generics) + .into_iter() + .rev() + .collect() +} + +fn get_dims_impl(ty: &Type, const_generics: &[&Ident]) -> Vec { + match ty { + Type::Array(array) => { + let mut dims = get_dims_impl(array.elem.as_ref(), const_generics); + match &array.len { + Expr::Block(syn::ExprBlock { block, .. }) => { + if block.stmts.len() != 1 { + panic!( + "Expected exactly one statement in block, got: {:?}", + block.stmts.len() + ); + } + if let Stmt::Expr(Expr::Path(expr_path), ..) = &block.stmts[0] { + if let Some(len_ident) = expr_path.path.get_ident() { + if const_generics.contains(&len_ident) { + dims.push(Dimension::ConstGeneric(expr_path.clone().into())); + } else { + dims.push(Dimension::Other(expr_path.clone().into())); + } + } + } + } + Expr::Path(ExprPath { path, .. }) => { + let len_ident = path.get_ident(); + if len_ident.is_some() && const_generics.contains(&len_ident.unwrap()) { + dims.push(Dimension::ConstGeneric(array.len.clone())); + } else { + dims.push(Dimension::Other(array.len.clone())); + } + } + Expr::Lit(expr_lit) => dims.push(Dimension::Other(expr_lit.clone().into())), + _ => panic!("Unsupported array length type: {:?}", array.len), + } + dims + } + Type::Path(_) => Vec::new(), + _ => panic!("Unsupported field type (in get_dims_impl)"), + } +} diff --git a/crates/circuits/primitives/derive/src/lib.rs b/crates/circuits/primitives/derive/src/lib.rs index 35e5f8fd5b..70db5dc672 100644 --- a/crates/circuits/primitives/derive/src/lib.rs +++ b/crates/circuits/primitives/derive/src/lib.rs @@ -7,6 +7,9 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta}; +mod cols_ref; +use cols_ref::cols_ref_impl; + #[proc_macro_derive(AlignedBorrow)] pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); @@ -426,3 +429,25 @@ pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream { _ => unimplemented!(), } } + +#[proc_macro_derive(ColsRef, attributes(aligned_borrow, config))] +pub fn cols_ref_derive(input: TokenStream) -> TokenStream { + let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput); + + let config = derive_input + .attrs + .iter() + .find(|attr| attr.path().is_ident("config")); + if config.is_none() { + return syn::Error::new(derive_input.ident.span(), "Config attribute is required") + .to_compile_error() + .into(); + } + let config: proc_macro2::Ident = config + .unwrap() + .parse_args() + .expect("Failed to parse config"); + + let res = cols_ref_impl(derive_input, config); + res.into() +} diff --git a/crates/circuits/primitives/derive/tests/example.rs b/crates/circuits/primitives/derive/tests/example.rs new file mode 100644 index 0000000000..58bac9e26c --- /dev/null +++ b/crates/circuits/primitives/derive/tests/example.rs @@ -0,0 +1,87 @@ +use openvm_circuit_primitives_derive::ColsRef; + +pub trait ExampleConfig { + const N: usize; +} +pub struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} +pub struct ExampleConfigImplB; +impl ExampleConfig for ExampleConfigImplB { + const N: usize = 10; +} + +#[allow(dead_code)] +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + sum: T, +} + +#[test] +fn example() { + let input = [1, 2, 3, 4, 5, 15]; + let test: ExampleColsRef = ExampleColsRef::from::(&input); + println!("{}, {}", test.arr, test.sum); +} + +/* + * For reference, this is what the ColsRef macro expands to. + * The `cargo expand` tool is helpful for understanding how the ColsRef macro works. + * See https://github.com/dtolnay/cargo-expand + +#[derive(Debug, Clone)] +struct ExampleColsRef<'a, T> { + pub arr: ndarray::ArrayView1<'a, T>, + pub sum: &'a T, +} + +impl<'a, T> ExampleColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let (arr_slice, slice) = slice.split_at(1 * C::N); + let arr_slice = ndarray::ArrayView1::from_shape((C::N), arr_slice).unwrap(); + let sum_length = 1; + let (sum_slice, slice) = slice.split_at(sum_length); + Self { + arr: arr_slice, + sum: &sum_slice[0], + } + } + pub const fn width() -> usize { + 0 + 1 * C::N + 1 + } +} + +impl<'b, T> ExampleColsRef<'b, T> { + pub fn from_mut<'a, C: ExampleConfig>(other: &'b ExampleColsRefMut<'a, T>) -> Self { + Self { + arr: other.arr.view(), + sum: &other.sum, + } + } +} + +#[derive(Debug)] +struct ExampleColsRefMut<'a, T> { + pub arr: ndarray::ArrayViewMut1<'a, T>, + pub sum: &'a mut T, +} + +impl<'a, T> ExampleColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let (arr_slice, slice) = slice.split_at_mut(1 * C::N); + let arr_slice = ndarray::ArrayViewMut1::from_shape((C::N), arr_slice).unwrap(); + let sum_length = 1; + let (sum_slice, slice) = slice.split_at_mut(sum_length); + Self { + arr: arr_slice, + sum: &mut sum_slice[0], + } + } + pub const fn width() -> usize { + 0 + 1 * C::N + 1 + } +} +*/ diff --git a/crates/circuits/primitives/derive/tests/test_cols_ref.rs b/crates/circuits/primitives/derive/tests/test_cols_ref.rs new file mode 100644 index 0000000000..caaa3bcec9 --- /dev/null +++ b/crates/circuits/primitives/derive/tests/test_cols_ref.rs @@ -0,0 +1,299 @@ +use openvm_circuit_primitives_derive::{AlignedBorrow, ColsRef}; + +pub trait TestConfig { + const N: usize; + const M: usize; +} +pub struct TestConfigImpl; +impl TestConfig for TestConfigImpl { + const N: usize = 5; + const M: usize = 2; +} + +#[allow(dead_code)] // TestCols isn't actually used in the code. silence clippy warning +#[derive(ColsRef)] +#[config(TestConfig)] +struct TestCols { + single_field_element: T, + array_of_t: [T; N], + nested_array_of_t: [[T; N]; N], + cols_struct: TestSubCols, + #[aligned_borrow] + array_of_aligned_borrow: [TestAlignedBorrow; N], + #[aligned_borrow] + nested_array_of_aligned_borrow: [[TestAlignedBorrow; N]; N], +} + +#[allow(dead_code)] // TestSubCols isn't actually used in the code. silence clippy warning +#[derive(ColsRef, Debug)] +#[config(TestConfig)] +struct TestSubCols { + // TestSubCols can have fields of any type that TestCols can have + a: T, + b: [T; M], + #[aligned_borrow] + c: TestAlignedBorrow, +} + +#[derive(AlignedBorrow, Debug)] +struct TestAlignedBorrow { + a: T, + b: [T; 5], +} + +#[test] +fn test_cols_ref() { + assert_eq!( + TestColsRef::::width::(), + TestColsRefMut::::width::() + ); + const WIDTH: usize = TestColsRef::::width::(); + let mut input = vec![0; WIDTH]; + let mut cols: TestColsRefMut = TestColsRefMut::from::(&mut input); + + *cols.single_field_element = 1; + cols.array_of_t[0] = 2; + cols.nested_array_of_t[[0, 0]] = 3; + *cols.cols_struct.a = 4; + cols.cols_struct.b[0] = 5; + cols.cols_struct.c.a = 6; + cols.cols_struct.c.b[0] = 7; + cols.array_of_aligned_borrow[0].a = 8; + cols.array_of_aligned_borrow[0].b[0] = 9; + cols.nested_array_of_aligned_borrow[[0, 0]].a = 10; + cols.nested_array_of_aligned_borrow[[0, 0]].b[0] = 11; + + let cols: TestColsRef = TestColsRef::from::(&input); + println!("{cols:?}"); + assert_eq!(*cols.single_field_element, 1); + assert_eq!(cols.array_of_t[0], 2); + assert_eq!(cols.nested_array_of_t[[0, 0]], 3); + assert_eq!(*cols.cols_struct.a, 4); + assert_eq!(cols.cols_struct.b[0], 5); + assert_eq!(cols.cols_struct.c.a, 6); + assert_eq!(cols.cols_struct.c.b[0], 7); + assert_eq!(cols.array_of_aligned_borrow[0].a, 8); + assert_eq!(cols.array_of_aligned_borrow[0].b[0], 9); + assert_eq!(cols.nested_array_of_aligned_borrow[[0, 0]].a, 10); + assert_eq!(cols.nested_array_of_aligned_borrow[[0, 0]].b[0], 11); +} + +/* + * For reference, this is what the ColsRef macro expands to. + * The `cargo expand` tool is helpful for understanding how the ColsRef macro works. + * See https://github.com/dtolnay/cargo-expand + +#[derive(Debug, Clone)] +struct TestColsRef<'a, T> { + pub single_field_element: &'a T, + pub array_of_t: ndarray::ArrayView1<'a, T>, + pub nested_array_of_t: ndarray::ArrayView2<'a, T>, + pub cols_struct: TestSubColsRef<'a, T>, + pub array_of_aligned_borrow: ndarray::ArrayView1<'a, TestAlignedBorrow>, + pub nested_array_of_aligned_borrow: ndarray::ArrayView2<'a, TestAlignedBorrow>, +} + +impl<'a, T> TestColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let single_field_element_length = 1; + let (single_field_element_slice, slice) = slice + .split_at(single_field_element_length); + let (array_of_t_slice, slice) = slice.split_at(1 * C::N); + let array_of_t_slice = ndarray::ArrayView1::from_shape((C::N), array_of_t_slice) + .unwrap(); + let (nested_array_of_t_slice, slice) = slice.split_at(1 * C::N * C::N); + let nested_array_of_t_slice = ndarray::ArrayView2::from_shape( + (C::N, C::N), + nested_array_of_t_slice, + ) + .unwrap(); + let cols_struct_length = >::width::(); + let (cols_struct_slice, slice) = slice.split_at(cols_struct_length); + let cols_struct_slice = >::from::(cols_struct_slice); + let (array_of_aligned_borrow_slice, slice) = slice + .split_at(>::width() * C::N); + let array_of_aligned_borrow_slice: &[TestAlignedBorrow] = unsafe { + &*(array_of_aligned_borrow_slice as *const [T] + as *const [TestAlignedBorrow]) + }; + let array_of_aligned_borrow_slice = ndarray::ArrayView1::from_shape( + (C::N), + array_of_aligned_borrow_slice, + ) + .unwrap(); + let (nested_array_of_aligned_borrow_slice, slice) = slice + .split_at(>::width() * C::N * C::N); + let nested_array_of_aligned_borrow_slice: &[TestAlignedBorrow] = unsafe { + &*(nested_array_of_aligned_borrow_slice as *const [T] + as *const [TestAlignedBorrow]) + }; + let nested_array_of_aligned_borrow_slice = ndarray::ArrayView2::from_shape( + (C::N, C::N), + nested_array_of_aligned_borrow_slice, + ) + .unwrap(); + Self { + single_field_element: &single_field_element_slice[0], + array_of_t: array_of_t_slice, + nested_array_of_t: nested_array_of_t_slice, + cols_struct: cols_struct_slice, + array_of_aligned_borrow: array_of_aligned_borrow_slice, + nested_array_of_aligned_borrow: nested_array_of_aligned_borrow_slice, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::N + 1 * C::N * C::N + >::width::() + + >::width() * C::N + + >::width() * C::N * C::N + } +} + +impl<'b, T> TestColsRef<'b, T> { + pub fn from_mut<'a, C: TestConfig>(other: &'b TestColsRefMut<'a, T>) -> Self { + Self { + single_field_element: &other.single_field_element, + array_of_t: other.array_of_t.view(), + nested_array_of_t: other.nested_array_of_t.view(), + cols_struct: >::from_mut::(&other.cols_struct), + array_of_aligned_borrow: other.array_of_aligned_borrow.view(), + nested_array_of_aligned_borrow: other.nested_array_of_aligned_borrow.view(), + } + } +} + +#[derive(Debug)] +struct TestColsRefMut<'a, T> { + pub single_field_element: &'a mut T, + pub array_of_t: ndarray::ArrayViewMut1<'a, T>, + pub nested_array_of_t: ndarray::ArrayViewMut2<'a, T>, + pub cols_struct: TestSubColsRefMut<'a, T>, + pub array_of_aligned_borrow: ndarray::ArrayViewMut1<'a, TestAlignedBorrow>, + pub nested_array_of_aligned_borrow: ndarray::ArrayViewMut2<'a, TestAlignedBorrow>, +} + +impl<'a, T> TestColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let single_field_element_length = 1; + let (single_field_element_slice, slice) = slice + .split_at_mut(single_field_element_length); + let (array_of_t_slice, slice) = slice.split_at_mut(1 * C::N); + let array_of_t_slice = ndarray::ArrayViewMut1::from_shape( + (C::N), + array_of_t_slice, + ) + .unwrap(); + let (nested_array_of_t_slice, slice) = slice.split_at_mut(1 * C::N * C::N); + let nested_array_of_t_slice = ndarray::ArrayViewMut2::from_shape( + (C::N, C::N), + nested_array_of_t_slice, + ) + .unwrap(); + let cols_struct_length = >::width::(); + let (cols_struct_slice, slice) = slice.split_at_mut(cols_struct_length); + let cols_struct_slice = >::from::(cols_struct_slice); + let (array_of_aligned_borrow_slice, slice) = slice + .split_at_mut(>::width() * C::N); + let array_of_aligned_borrow_slice: &mut [TestAlignedBorrow] = unsafe { + &mut *(array_of_aligned_borrow_slice as *mut [T] + as *mut [TestAlignedBorrow]) + }; + let array_of_aligned_borrow_slice = ndarray::ArrayViewMut1::from_shape( + (C::N), + array_of_aligned_borrow_slice, + ) + .unwrap(); + let (nested_array_of_aligned_borrow_slice, slice) = slice + .split_at_mut(>::width() * C::N * C::N); + let nested_array_of_aligned_borrow_slice: &mut [TestAlignedBorrow] = unsafe { + &mut *(nested_array_of_aligned_borrow_slice as *mut [T] + as *mut [TestAlignedBorrow]) + }; + let nested_array_of_aligned_borrow_slice = ndarray::ArrayViewMut2::from_shape( + (C::N, C::N), + nested_array_of_aligned_borrow_slice, + ) + .unwrap(); + Self { + single_field_element: &mut single_field_element_slice[0], + array_of_t: array_of_t_slice, + nested_array_of_t: nested_array_of_t_slice, + cols_struct: cols_struct_slice, + array_of_aligned_borrow: array_of_aligned_borrow_slice, + nested_array_of_aligned_borrow: nested_array_of_aligned_borrow_slice, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::N + 1 * C::N * C::N + >::width::() + + >::width() * C::N + + >::width() * C::N * C::N + } +} + +#[derive(Debug, Clone)] +struct TestSubColsRef<'a, T> { + pub a: &'a T, + pub b: ndarray::ArrayView1<'a, T>, + pub c: &'a TestAlignedBorrow, +} + +impl<'a, T> TestSubColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let a_length = 1; + let (a_slice, slice) = slice.split_at(a_length); + let (b_slice, slice) = slice.split_at(1 * C::M); + let b_slice = ndarray::ArrayView1::from_shape((C::M), b_slice).unwrap(); + let c_length = >::width(); + let (c_slice, slice) = slice.split_at(c_length); + Self { + a: &a_slice[0], + b: b_slice, + c: { + use core::borrow::Borrow; + c_slice.borrow() + }, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::M + >::width() + } +} + +impl<'b, T> TestSubColsRef<'b, T> { + pub fn from_mut<'a, C: TestConfig>(other: &'b TestSubColsRefMut<'a, T>) -> Self { + Self { + a: &other.a, + b: other.b.view(), + c: other.c, + } + } +} + +#[derive(Debug)] +struct TestSubColsRefMut<'a, T> { + pub a: &'a mut T, + pub b: ndarray::ArrayViewMut1<'a, T>, + pub c: &'a mut TestAlignedBorrow, +} + +impl<'a, T> TestSubColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let a_length = 1; + let (a_slice, slice) = slice.split_at_mut(a_length); + let (b_slice, slice) = slice.split_at_mut(1 * C::M); + let b_slice = ndarray::ArrayViewMut1::from_shape((C::M), b_slice).unwrap(); + let c_length = >::width(); + let (c_slice, slice) = slice.split_at_mut(c_length); + Self { + a: &mut a_slice[0], + b: b_slice, + c: { + use core::borrow::BorrowMut; + c_slice.borrow_mut() + }, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::M + >::width() + } +} +*/ diff --git a/crates/circuits/sha256-air/Cargo.toml b/crates/circuits/sha2-air/Cargo.toml similarity index 50% rename from crates/circuits/sha256-air/Cargo.toml rename to crates/circuits/sha2-air/Cargo.toml index c376a1ffdd..ff8c1c9006 100644 --- a/crates/circuits/sha256-air/Cargo.toml +++ b/crates/circuits/sha2-air/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "openvm-sha256-air" +name = "openvm-sha2-air" version.workspace = true authors.workspace = true edition.workspace = true @@ -7,13 +7,13 @@ edition.workspace = true [dependencies] openvm-circuit-primitives = { workspace = true } openvm-stark-backend = { workspace = true } -sha2 = { version = "0.10", features = ["compress"] } -rand.workspace = true +openvm-circuit-primitives-derive = { workspace = true } -[dev-dependencies] -openvm-stark-sdk = { workspace = true } -openvm-circuit = { workspace = true, features = ["test-utils"] } +sha2 = { workspace = true, features = ["compress"] } +rand.workspace = true +ndarray.workspace = true +num_enum = { workspace = true } [features] default = ["parallel"] -parallel = ["openvm-stark-backend/parallel"] +parallel = ["openvm-stark-backend/parallel"] \ No newline at end of file diff --git a/crates/circuits/sha2-air/SOUNDNESS.md b/crates/circuits/sha2-air/SOUNDNESS.md new file mode 100644 index 0000000000..d2bb318c1d --- /dev/null +++ b/crates/circuits/sha2-air/SOUNDNESS.md @@ -0,0 +1,209 @@ +# Justification of Soundness + +The soundness of `Sha2BlockHasherSubAir`'s constraints is not obvious. +This document aims to make it clearer. + +## Summary of Constraints + +The main constraints are summarized below. + +### Constraint 1: Digest Row Hash Computation + +In `eval_digest_row()` on lines 148-185, we constrain: + +``` +next.prev_hash + local.work_vars = next.final_hash +``` + +when `next` is a digest row. + +### Constraint 2: Dummy Row Consistency + +In `eval_transitions()` on lines 293-304, we constrain: + +``` +local.work_vars.a == next.work_vars.a +local.work_vars.e == next.work_vars.e +``` + +when `next` is a dummy row. + +This ensures that all dummy rows have the same values in `work_vars.a` and `work_vars.e`, and moreover that these values match the last digest row's `hash` field. (Since the `hash` field on digest rows is the same as `work_vars` on round rows.) + +### Constraint 3: Block Chaining + +In `eval_prev_hash()`, we constrain, via an interaction on digest rows, that: + +``` +curr_block.digest_row.hash == next_block.digest_row.prev_hash +``` + +That is, the next block's digest row's `prev_hash` field is equal to the current block's digest row's `hash` field. + +On the last block, this constraint wraps around, and constrains that: + +``` +last_block.digest_row.hash == first_block.digest_row.prev_hash +``` + +### Constraint 4: Work Variables Update + +In `eval_work_vars()`, we constrain: + +``` +constraint_word_addition(local, next) +``` + +on **all** rows. + +We constrain this on all rows because the constraint degree is already too high to narrow down the rows on which to enforce this constraint. On round rows, this constraint ensures the work vars are updated correctly. + +However, on other rows, even though the constraint doesn't constrain anything meaningful, we still need to ensure that the constraint passes. In order to do this, we fill in certain fields on certain rows with values that satisfy the constraint. + +In particular: +- When `next` is a digest row, we fill in `next.work_vars.carry_a` and `next.work_vars.carry_e` with slack values. +- When `next` is a dummy row, we also fill in `next.work_vars.carry_a` and `next.work_vars.carry_e` with slack values. However, in this case, all these values will be the same on all the dummy rows (due to constraint 2), so we compute them once (on the first dummy row) and copy them into all the other dummy rows. + + +## Soundness + +We will show that the four constraints above imply that the hash of each block is computed correctly. + +We will walk through the justification with an example trace consisting of three blocks and three dummy rows. The argument generalizes readily to traces with more blocks or dummy rows. + +### Example Trace + +Suppose our trace looks like this: + +``` +Block 1: + Round rows (64 rounds): + first round row: work_vars + ... + last round row: work_vars + + Digest row: + hash: [3a] + final_hash: (computed from prev_hash + work_vars) + prev_hash: [3c] + carry_a/e: (slack values) + +Block 2: + Round rows (64 rounds): + first round row: work_vars + ... + last round row: work_vars + Digest row: + hash: [3b] + final_hash: (computed from prev_hash + work_vars) + prev_hash: [3a] + carry_a/e: (slack values) + +Block 3: + Round rows (64 rounds): + first round row: work_vars + ... + last round row: work_vars + + Digest row: + hash: [3c][2a] + final_hash: (computed from prev_hash + work_vars) + prev_hash: [3b] + carry_a/e: (slack values) + +Dummy rows: + Dummy row 1: + work_vars: [2a][2b] + carry_a/e: (slack values) + + Dummy row 2: + work_vars: [2b][2c] + carry_a/e: (slack values) + + Dummy row 3: + work_vars: [2c] + carry_a/e: (slack values) +``` + +**Legend:** +- Square brackets `[X]` indicate fields affected by constraint X (e.g., `[3a]` = affected by constraint 3a) +- Multiple annotations on a field (e.g., `[3c][2a]`) indicate it's affected by multiple constraints + +### Constraint Applications + +**Constraint 1** gives: + +``` +block1.digest_row.prev_hash + block1.last_round_row.work_vars == block1.digest_row.final_hash +block2.digest_row.prev_hash + block2.last_round_row.work_vars == block2.digest_row.final_hash +block3.digest_row.prev_hash + block3.last_round_row.work_vars == block3.digest_row.final_hash +``` + +**Constraint 2** gives: + +``` +[2a] block3.digest_row.hash == dummy_row_1.work_vars +[2b] dummy_row_1.work_vars == dummy_row_2.work_vars +[2c] dummy_row_2.work_vars == dummy_row_3.work_vars +``` + +**Constraint 3** gives: + +``` +[3a] block1.digest_row.hash == block2.digest_row.prev_hash +[3b] block2.digest_row.hash == block3.digest_row.prev_hash +[3c] block3.digest_row.hash == block1.digest_row.prev_hash +``` + +### Constraining Rounds + +First, we claim that all 64 rounds for each block are constrained correctly. That is, we claim that the `work_vars` in each of these rounds are updated correctly. + +**Constraint 4** gives that rounds 5 to 64 inclusive of each block are constrained correctly, since these rounds occur when `local` and `next` are in the same block. + +For the first four rounds of each block, we must examine the case when `next` is the first row of a block. There are two subcases: either `local` is a digest row, or a dummy row. + +#### Case 1: `local` is a digest row + +If `local` is a digest row, for example when `local` is the digest row of block 1 (and so `next` is the first row of block 2), then by constraint 3b, we have: + +``` +local.hash == block3.digest_row.prev_hash +``` + +So, the `constraint_word_addition(local, next)` constrains that `next.work_vars` is updated from `local.work_vars` (i.e. `local.hash`) by consuming the first 4 words of input. Since `local.work_vars` stores the `prev_hash` of block 3, this constraint ensures that the first 4 rounds are constrained correctly. + +A similar argument works for when `local` is the digest row of block2. + +#### Case 2: `local` is a dummy row + +If `local` is a dummy row, then it is the last dummy row, and `next` is the first row of block 1. In this case, constraint 2 gives us that: + +``` +local.work_vars == dummy_row_3.work_vars + == dummy_row_2.work_vars + == dummy_row_1.work_vars + == block3.digest_row.hash +``` + +Then constraint 3c gives: + +``` +block3.digest_row.hash == block1.digest_row.prev_hash +``` + +which overall gives: + +``` +local.work_vars == block1.digest_row.prev_hash +``` + +So, the first 4 rounds of block 1 are constrained correctly. + +**Conclusion:** All rounds are constrained correctly. + +### Final Hash + +Now, we can show that each block's `final_hash` is correct. + +We already argued that, on each block, `last_round_row.work_vars` is correctly computed from the `prev_hash`. Now, **constraint 1** gives that each block's `final_hash` is constructed by adding the `last_round_row.work_vars` to the `prev_state`. This is exactly correct, by the SHA-2 specification. diff --git a/crates/circuits/sha2-air/src/air.rs b/crates/circuits/sha2-air/src/air.rs new file mode 100644 index 0000000000..db5270b8d2 --- /dev/null +++ b/crates/circuits/sha2-air/src/air.rs @@ -0,0 +1,617 @@ +use std::{iter::once, marker::PhantomData}; + +use ndarray::s; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::select, SubAir, +}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, +}; + +use super::{ + big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, + small_sig1_field, +}; +use crate::{ + constraint_word_addition, word_into_u16_limbs, Sha2BlockHasherSubairConfig, Sha2DigestColsRef, + Sha2RoundColsRef, +}; + +/// Expects the message to be padded to a multiple of C::BLOCK_WORDS * C::WORD_BITS bits +#[derive(Clone, Debug)] +pub struct Sha2BlockHasherSubAir { + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + pub row_idx_encoder: Encoder, + /// Internal bus for self-interactions in this AIR. + pub private_bus: PermutationCheckBus, + _phantom: PhantomData, +} + +impl Sha2BlockHasherSubAir { + pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, private_bus_idx: BusIndex) -> Self { + Self { + bitwise_lookup_bus, + row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), /* + 1 for dummy + * (padding) rows */ + private_bus: PermutationCheckBus::new(private_bus_idx), + _phantom: PhantomData, + } + } +} + +impl BaseAir for Sha2BlockHasherSubAir { + fn width(&self) -> usize { + C::SUBAIR_WIDTH + } +} + +impl SubAir + for Sha2BlockHasherSubAir +{ + /// The start column for the sub-air to use + type AirContext<'a> + = usize + where + Self: 'a, + AB: 'a, + ::Var: 'a, + ::Expr: 'a; + + fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>) + where + AB::Var: 'a, + AB::Expr: 'a, + { + self.eval_row(builder, start_col); + self.eval_transitions(builder, start_col); + } +} + +impl Sha2BlockHasherSubAir { + /// Implements the single row constraints (i.e. imposes constraints only on local) + /// Implements some sanity constraints on the row index, flags, and work variables + fn eval_row(&self, builder: &mut AB, start_col: usize) { + let main = builder.main(); + let local = main.row_slice(0); + + // Doesn't matter which column struct we use here as we are only interested in the common + // columns + let local_cols: Sha2DigestColsRef = + Sha2DigestColsRef::from::(&local[start_col..start_col + C::SUBAIR_DIGEST_WIDTH]); + let flags = &local_cols.flags; + builder.assert_bool(*flags.is_round_row); + builder.assert_bool(*flags.is_first_4_rows); + builder.assert_bool(*flags.is_digest_row); + builder.assert_bool(*flags.is_round_row + *flags.is_digest_row); + + self.row_idx_encoder + .eval(builder, local_cols.flags.row_idx.to_slice().unwrap()); + // assert all row indices are in [0, C::ROWS_PER_BLOCK] + builder.assert_one(self.row_idx_encoder.contains_flag_range::( + local_cols.flags.row_idx.to_slice().unwrap(), + 0..=C::ROWS_PER_BLOCK, + )); + // assert that the row indices are [0, 3] for the first 4 rows + builder.assert_eq( + self.row_idx_encoder + .contains_flag_range::(local_cols.flags.row_idx.to_slice().unwrap(), 0..=3), + *flags.is_first_4_rows, + ); + // round row indices are in [0, C::ROUND_ROWS - 1] + builder.assert_eq( + self.row_idx_encoder.contains_flag_range::( + local_cols.flags.row_idx.to_slice().unwrap(), + 0..=C::ROUND_ROWS - 1, + ), + *flags.is_round_row, + ); + // digest row always has row index C::ROUND_ROWS + builder.assert_eq( + self.row_idx_encoder.contains_flag::( + local_cols.flags.row_idx.to_slice().unwrap(), + &[C::ROUND_ROWS], + ), + *flags.is_digest_row, + ); + // If padding row we want the row_idx to be C::ROWS_PER_BLOCK + builder.assert_eq( + self.row_idx_encoder.contains_flag::( + local_cols.flags.row_idx.to_slice().unwrap(), + &[C::ROWS_PER_BLOCK], + ), + flags.is_padding_row(), + ); + + // Constrain a, e, being composed of bits: we make sure a and e are always in the same place + // in the trace matrix Note: this has to be true for every row, even padding rows + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_BITS { + builder.assert_bool(local_cols.hash.a[[i, j]]); + builder.assert_bool(local_cols.hash.e[[i, j]]); + } + } + } + + /// Evaluates the final hash for this block + fn eval_digest_row( + &self, + builder: &mut AB, + local: Sha2RoundColsRef, + next: Sha2DigestColsRef, + ) { + // Assert that the previous hash + work vars == final hash. + // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]` + // where addition is done modulo 2^32 + for i in 0..C::HASH_WORDS { + let mut carry = AB::Expr::ZERO; + for j in 0..C::WORD_U16S { + let work_var_limb = if i < C::ROUNDS_PER_ROW { + compose::( + local + .work_vars + .a + .slice(s![C::ROUNDS_PER_ROW - 1 - i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) + } else { + compose::( + local + .work_vars + .e + .slice(s![C::ROUNDS_PER_ROW + 3 - i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) + }; + let final_hash_limb = compose::( + next.final_hash + .slice(s![i, j * 2..(j + 1) * 2]) + .as_slice() + .unwrap(), + 8, + ); + + carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse()) + * (next.prev_hash[[i, j]] + work_var_limb + carry - final_hash_limb); + builder + .when(*next.flags.is_digest_row) + .assert_bool(carry.clone()); + } + // constrain the final hash limbs two at a time since we can do two checks per + // interaction + for chunk in next.final_hash.row(i).as_slice().unwrap().chunks(2) { + self.bitwise_lookup_bus + .send_range(chunk[0], chunk[1]) + .eval(builder, *next.flags.is_digest_row); + } + } + } + + fn eval_transitions(&self, builder: &mut AB, start_col: usize) { + let main = builder.main(); + let local = main.row_slice(0); + let next = main.row_slice(1); + + // Doesn't matter what column structs we use here + let local_cols: Sha2RoundColsRef = + Sha2RoundColsRef::from::(&local[start_col..start_col + C::SUBAIR_ROUND_WIDTH]); + let next_cols: Sha2RoundColsRef = + Sha2RoundColsRef::from::(&next[start_col..start_col + C::SUBAIR_ROUND_WIDTH]); + + let local_is_padding_row = local_cols.flags.is_padding_row(); + // Note that there will always be a padding row in the trace since the unpadded height is a + // multiple of 17 (SHA-256) or 21 (SHA-512, SHA-384). So the next row is padding iff the + // current block is the last block in the trace. + let next_is_padding_row = next_cols.flags.is_padding_row(); + + // If we are in a round row, the next row cannot be a padding row + builder + .when(*local_cols.flags.is_round_row) + .assert_zero(next_is_padding_row.clone()); + // The first row must be a round row + builder + .when_first_row() + .assert_one(*local_cols.flags.is_round_row); + // If we are in a padding row, the next row must also be a padding row + builder + .when_transition() + .when(local_is_padding_row.clone()) + .assert_one(next_is_padding_row.clone()); + // If we are in a digest row, the next row cannot be a digest row + builder + .when(*local_cols.flags.is_digest_row) + .assert_zero(*next_cols.flags.is_digest_row); + // Constrain how much the row index changes by + // round->round: 1 + // round->digest: 1 + // digest->round: -C::ROUND_ROWS + // digest->padding: 1 + // padding->padding: 0 + // Other transitions are not allowed by the above constraints + let delta = *local_cols.flags.is_round_row * AB::Expr::ONE + + *local_cols.flags.is_digest_row + * *next_cols.flags.is_round_row + * AB::Expr::from_canonical_usize(C::ROUND_ROWS) + * AB::Expr::NEG_ONE + + *local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE; + + let local_row_idx = self.row_idx_encoder.flag_with_val::( + local_cols.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + let next_row_idx = self.row_idx_encoder.flag_with_val::( + next_cols.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + + builder + .when_transition() + .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone()); + builder.when_first_row().assert_zero(local_row_idx); + + // Constrain the global block index starting with 1 so it is not the same as the padding + // rows. + // We set the global block index to 0 for padding rows + + // Global block index is 1 on first row + builder + .when_first_row() + .assert_one(*local_cols.flags.global_block_idx); + + // Global block index is constant on all rows in a block + builder.when(*local_cols.flags.is_round_row).assert_eq( + *local_cols.flags.global_block_idx, + *next_cols.flags.global_block_idx, + ); + // Global block index increases by 1 between blocks + builder + .when_transition() + .when(*local_cols.flags.is_digest_row) + .when(*next_cols.flags.is_round_row) + .assert_eq( + *local_cols.flags.global_block_idx + AB::Expr::ONE, + *next_cols.flags.global_block_idx, + ); + // Global block index is 0 on padding rows + builder + .when(local_is_padding_row.clone()) + .assert_zero(*local_cols.flags.global_block_idx); + + // Constrain that all the padding rows have the same work vars as the last block's digest + // row. We constrain elsewhere that the last block's digest row is equal to the first + // block's prev_hash. Together, this ensures that all the padding rows have the same + // work vars as the first block's prev_hash. As a result, the + // constraint_word_addition constraints in eval_work_vars() on the first row of the first + // block (i.e. when next = first_row) will correctly constrain the first 4 rounds of + // the first block. + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_BITS { + builder.when(next_cols.flags.is_padding_row()).assert_eq( + local_cols.work_vars.a[[i, j]], + next_cols.work_vars.a[[i, j]], + ); + builder.when(next_cols.flags.is_padding_row()).assert_eq( + local_cols.work_vars.e[[i, j]], + next_cols.work_vars.e[[i, j]], + ); + } + } + + self.eval_message_schedule(builder, local_cols.clone(), next_cols.clone()); + self.eval_work_vars(builder, local_cols.clone(), next_cols); + let next: Sha2DigestColsRef = + Sha2DigestColsRef::from::(&next[start_col..start_col + C::SUBAIR_DIGEST_WIDTH]); + self.eval_digest_row(builder, local_cols, next); + let local_cols: Sha2DigestColsRef = + Sha2DigestColsRef::from::(&local[start_col..start_col + C::SUBAIR_DIGEST_WIDTH]); + self.eval_prev_hash(builder, local_cols, next_is_padding_row); + } + + /// Constrains that the next block's `prev_hash` is equal to the current block's `hash` + /// Note: the constraining is done by interactions with the chip itself on every digest row + fn eval_prev_hash( + &self, + builder: &mut AB, + local: Sha2DigestColsRef, + is_last_block_of_trace: AB::Expr, /* note this indicates the last block of the trace, + * not the last block of the message */ + ) { + // Constrain that next block's `prev_hash` is equal to the current block's `hash` + let composed_hash = (0..C::HASH_WORDS) + .map(|i| { + let hash_bits = if i < C::ROUNDS_PER_ROW { + local + .hash + .a + .row(C::ROUNDS_PER_ROW - 1 - i) + .mapv(|x| x.into()) + .to_vec() + } else { + local + .hash + .e + .row(C::ROUNDS_PER_ROW + 3 - i) + .mapv(|x| x.into()) + .to_vec() + }; + (0..C::WORD_U16S) + .map(|j| compose::(&hash_bits[j * 16..(j + 1) * 16], 1)) + .collect::>() + }) + .collect::>(); + // Need to handle the case if this is the very last block of the trace matrix + let next_global_block_idx = select( + is_last_block_of_trace, + AB::Expr::ONE, + *local.flags.global_block_idx + AB::Expr::ONE, + ); + // The following interactions constrain certain values from block to block + self.private_bus.send( + builder, + composed_hash + .into_iter() + .flatten() + .chain(once(next_global_block_idx)), + *local.flags.is_digest_row, + ); + + self.private_bus.receive( + builder, + local + .prev_hash + .flatten() + .mapv(|x| x.into()) + .into_iter() + .chain(once((*local.flags.global_block_idx).into())), + *local.flags.is_digest_row, + ); + } + + /// Constrain the message schedule additions for `next` row + /// Note: For every addition we need to constrain the following for each of [WORD_U16S] limbs + /// sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] + carry_w[t][i-1] - + /// carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] + fn eval_message_schedule<'a, AB: InteractionBuilder>( + &self, + builder: &mut AB, + local: Sha2RoundColsRef<'a, AB::Var>, + next: Sha2RoundColsRef<'a, AB::Var>, + ) { + // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx + let w = ndarray::concatenate( + ndarray::Axis(0), + &[local.message_schedule.w, next.message_schedule.w], + ) + .unwrap(); + + // Constrain `w_3` for `next` row + for i in 0..C::ROUNDS_PER_ROW - 1 { + // here we constrain the w_3 of the i_th word of the next row + // w_3 of next is w[i+4-3] = w[i+1] + let w_3 = w.row(i + 1).mapv(|x| x.into()).to_vec(); + let expected_w_3 = next.schedule_helper.w_3.row(i); + for j in 0..C::WORD_U16S { + let w_3_limb = compose::(&w_3[j * 16..(j + 1) * 16], 1); + builder + .when(*local.flags.is_round_row) + .assert_eq(w_3_limb, expected_w_3[j].into()); + } + } + + // Constrain intermed for `next` row + // We will only constrain intermed_12 for rows [3, C::ROUND_ROWS - 2], and let it + // unconstrained for other rows Other rows should put the needed value in + // intermed_12 to make the below summation constraint hold + let is_row_intermed_12 = self.row_idx_encoder.contains_flag_range::( + next.flags.row_idx.to_slice().unwrap(), + 3..=C::ROUND_ROWS - 2, + ); + // We will only constrain intermed_8 for rows [2, C::ROUND_ROWS - 3], and let it + // unconstrained for other rows + let is_row_intermed_8 = self.row_idx_encoder.contains_flag_range::( + next.flags.row_idx.to_slice().unwrap(), + 2..=C::ROUND_ROWS - 3, + ); + for i in 0..C::ROUNDS_PER_ROW { + // w_idx + let w_idx = w.row(i).mapv(|x| x.into()).to_vec(); + // sig_0(w_{idx+1}) + let sig_w = small_sig0_field::(w.row(i + 1).as_slice().unwrap()); + for j in 0..C::WORD_U16S { + let w_idx_limb = compose::(&w_idx[j * 16..(j + 1) * 16], 1); + let sig_w_limb = compose::(&sig_w[j * 16..(j + 1) * 16], 1); + + // We would like to constrain this only on round rows, but we can't do a conditional + // check because the degree is already 3. So we must fill in `intermed_4` with dummy + // values on the first round row and the digest row (rows 0 and 16 for SHA-256) to + // ensure the constraint holds on these rows. + builder.when_transition().assert_eq( + next.schedule_helper.intermed_4[[i, j]], + w_idx_limb + sig_w_limb, + ); + + builder.when(is_row_intermed_8.clone()).assert_eq( + next.schedule_helper.intermed_8[[i, j]], + local.schedule_helper.intermed_4[[i, j]], + ); + + builder.when(is_row_intermed_12.clone()).assert_eq( + next.schedule_helper.intermed_12[[i, j]], + local.schedule_helper.intermed_8[[i, j]], + ); + } + } + + // Constrain the message schedule additions for `next` row + for i in 0..C::ROUNDS_PER_ROW { + // Note, here by w_{t} we mean the i_th word of the `next` row + // w_{t-7} + let w_7 = if i < 3 { + local.schedule_helper.w_3.row(i).mapv(|x| x.into()).to_vec() + } else { + let w_3 = w.row(i - 3).mapv(|x| x.into()).to_vec(); + (0..C::WORD_U16S) + .map(|j| compose::(&w_3[j * 16..(j + 1) * 16], 1)) + .collect::>() + }; + // sig_0(w_{t-15}) + w_{t-16} + let intermed_16 = local.schedule_helper.intermed_12.row(i).mapv(|x| x.into()); + + let carries = (0..C::WORD_U16S) + .map(|j| { + next.message_schedule.carry_or_buffer[[i, j * 2]] + + AB::Expr::TWO * next.message_schedule.carry_or_buffer[[i, j * 2 + 1]] + }) + .collect::>(); + + // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}` + // We would like to constrain this only on rows 4..C::ROUND_ROWS, but we can't do a + // conditional check because the degree of sum is already 3 So we must fill + // in `intermed_12` with dummy values on rows 0..3 and C::ROUND_ROWS-1 and C::ROUND_ROWS + // to ensure the constraint holds on rows 0..4 and C::ROUND_ROWS. Note that + // the dummy value goes in the previous row to make the current row's constraint hold. + constraint_word_addition::<_, C>( + // Note: here we can't do a conditional check because the degree of sum is already + // 3 + &mut builder.when_transition(), + &[&small_sig1_field::( + w.row(i + 2).as_slice().unwrap(), + )], + &[&w_7, intermed_16.as_slice().unwrap()], + w.row(i + 4).as_slice().unwrap(), + &carries, + ); + + for j in 0..C::WORD_U16S { + // When on rows 4..C::ROUND_ROWS message schedule carries should be 0 or 1 + let is_row_4_or_more = *next.flags.is_round_row - *next.flags.is_first_4_rows; + builder + .when(is_row_4_or_more.clone()) + .assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2]]); + builder + .when(is_row_4_or_more) + .assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2 + 1]]); + } + // Constrain w being composed of bits + for j in 0..C::WORD_BITS { + builder + .when(*next.flags.is_round_row) + .assert_bool(next.message_schedule.w[[i, j]]); + } + } + } + + /// Constrain the work vars on `next` row according to the sha documentation + /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] + fn eval_work_vars<'a, AB: InteractionBuilder>( + &self, + builder: &mut AB, + local: Sha2RoundColsRef<'a, AB::Var>, + next: Sha2RoundColsRef<'a, AB::Var>, + ) { + let a = + ndarray::concatenate(ndarray::Axis(0), &[local.work_vars.a, next.work_vars.a]).unwrap(); + let e = + ndarray::concatenate(ndarray::Axis(0), &[local.work_vars.e, next.work_vars.e]).unwrap(); + + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_U16S { + // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in + // [0, 2^8) is enough to prevent overflow and ensure the soundness + // of the addition we want to check + self.bitwise_lookup_bus + .send_range( + local.work_vars.carry_a[[i, j]], + local.work_vars.carry_e[[i, j]], + ) + .eval(builder, *local.flags.is_round_row); + } + + let w_limbs = (0..C::WORD_U16S) + .map(|j| { + compose::( + next.message_schedule + .w + .slice(s![i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) * *next.flags.is_round_row + }) + .collect::>(); + + let k_limbs = (0..C::WORD_U16S) + .map(|j| { + self.row_idx_encoder.flag_with_val::( + next.flags.row_idx.to_slice().unwrap(), + &(0..C::ROUND_ROWS) + .map(|rw_idx| { + ( + rw_idx, + word_into_u16_limbs::( + C::get_k()[rw_idx * C::ROUNDS_PER_ROW + i], + )[j] as usize, + ) + }) + .collect::>(), + ) + }) + .collect::>(); + + // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)` + // We have to enforce this constraint on all rows since the degree of the constraint is + // already 3. So, we must fill in `carry_a` with dummy values on digest rows + // to ensure the constraint holds. + constraint_word_addition::<_, C>( + builder, + &[ + e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h` + &big_sig1_field::(e.row(i + 3).as_slice().unwrap()), /* sig_1 of + previous `e` */ + &ch_field::( + e.row(i + 3).as_slice().unwrap(), + e.row(i + 2).as_slice().unwrap(), + e.row(i + 1).as_slice().unwrap(), + ), /* Ch of previous `e`, `f`, `g` */ + &big_sig0_field::(a.row(i + 3).as_slice().unwrap()), /* sig_0 of `a` */ + &maj_field::( + a.row(i + 3).as_slice().unwrap(), + a.row(i + 2).as_slice().unwrap(), + a.row(i + 1).as_slice().unwrap(), + ), /* Maj of previous a, b, c */ + ], + &[&w_limbs, &k_limbs], // K and W + a.row(i + 4).as_slice().unwrap(), // new `a` + next.work_vars.carry_a.row(i).as_slice().unwrap(), // carries of addition + ); + + // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W` + // We have to enforce this constraint on all rows since the degree of the constraint is + // already 3. So, we must fill in `carry_e` with dummy values on digest rows + // to ensure the constraint holds. + constraint_word_addition::<_, C>( + builder, + &[ + a.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `d` + e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h` + &big_sig1_field::(e.row(i + 3).as_slice().unwrap()), /* sig_1 of + previous `e` */ + &ch_field::( + e.row(i + 3).as_slice().unwrap(), + e.row(i + 2).as_slice().unwrap(), + e.row(i + 1).as_slice().unwrap(), + ), /* Ch of previous `e`, `f`, `g` */ + ], + &[&w_limbs, &k_limbs], // K and W + e.row(i + 4).as_slice().unwrap(), // new `e` + next.work_vars.carry_e.row(i).as_slice().unwrap(), // carries of addition + ); + } + } +} diff --git a/crates/circuits/sha2-air/src/columns.rs b/crates/circuits/sha2-air/src/columns.rs new file mode 100644 index 0000000000..da9936d95e --- /dev/null +++ b/crates/circuits/sha2-air/src/columns.rs @@ -0,0 +1,170 @@ +//! WARNING: the order of fields in the structs is important, do not change it + +use core::ops::Add; + +use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives_derive::ColsRef; +use openvm_stark_backend::p3_field::FieldAlgebra; + +use crate::Sha2BlockHasherSubairConfig; + +/// In each SHA block: +/// - First C::ROUND_ROWS rows use Sha2RoundCols +/// - Final row uses Sha2DigestCols +/// +/// Note that for soundness, we require that there is always a padding row after the last digest row +/// in the trace. Right now, this is true because the unpadded height is a multiple of 17 (SHA-256) +/// or 21 (SHA-512), and thus not a power of 2. +/// +/// Sha2RoundCols and Sha2DigestCols share the same first 3 fields: +/// - flags +/// - work_vars/hash (same type, different name) +/// - schedule_helper +/// +/// This design allows for: +/// 1. Common constraints to work on either struct type by accessing these shared fields +/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional +/// constraints +/// +/// Note that the `Sha2WorkVarsCols` field is used for different purposes in the two structs. +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2RoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub flags: Sha2FlagsCols, + pub work_vars: Sha2WorkVarsCols, + pub schedule_helper: + Sha2MessageHelperCols, + pub message_schedule: Sha2MessageScheduleCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2DigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub flags: Sha2FlagsCols, + /// Will serve as previous hash values for the next block + pub hash: Sha2WorkVarsCols, + pub schedule_helper: + Sha2MessageHelperCols, + /// The actual final hash values of the given block + /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block + pub final_hash: [[T; WORD_U8S]; HASH_WORDS], + /// The final hash of the previous block + /// Note: will be constrained using interactions with the chip itself + pub prev_hash: [[T; WORD_U16S]; HASH_WORDS], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2MessageScheduleCols< + T, + const WORD_BITS: usize, + const ROUNDS_PER_ROW: usize, + const WORD_U8S: usize, +> { + /// The message schedule words as bits + /// The first 16 words will be the message data + pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW], + /// Will be message schedule carries for rows 4..C::ROUND_ROWS and a buffer for rows 0..4 to be + /// used freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells + /// as individual bits + /// Note: carry_or_buffer is left unconstrained on rounds 0..3 + pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2WorkVarsCols< + T, + const WORD_BITS: usize, + const ROUNDS_PER_ROW: usize, + const WORD_U16S: usize, +> { + /// `a` and `e` after each iteration as 32-bits + pub a: [[T; WORD_BITS]; ROUNDS_PER_ROW], + pub e: [[T; WORD_BITS]; ROUNDS_PER_ROW], + /// The carry's used for addition during each iteration when computing `a` and `e` + pub carry_a: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub carry_e: [[T; WORD_U16S]; ROUNDS_PER_ROW], +} + +/// These are the columns that are used to help with the message schedule additions +/// Note: these need to be correctly assigned for every row even on padding rows +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2MessageHelperCols< + T, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, +> { + /// The following are used to move data forward to constrain the message schedule additions + /// The value of `w` from 3 rounds ago + pub w_3: [[T; WORD_U16S]; ROUNDS_PER_ROW_MINUS_ONE], + /// Here intermediate(i) = w_i + sig_0(w_{i+1}) + /// Intermed_t represents the intermediate t rounds ago + /// This is needed to constrain the message schedule, since we can only constrain on two rows + /// at a time + pub intermed_4: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub intermed_8: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub intermed_12: [[T; WORD_U16S]; ROUNDS_PER_ROW], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2FlagsCols { + pub is_round_row: T, + /// A flag that indicates if the current row is among the first 4 rows of a block (the message + /// rows) + pub is_first_4_rows: T, + pub is_digest_row: T, + /// We will encode the row index [0..C::ROWS_PER_BLOCK] using ROW_VAR_CNT cells + pub row_idx: [T; ROW_VAR_CNT], + /// The global index of the current block, starts at 1 for the first block + /// and increments by 1 for each block. Set to 0 for padding rows. + pub global_block_idx: T, +} + +impl Sha2FlagsColsRef<'_, T> +where + T: Add + Copy, +{ + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_not_padding_row(&self) -> O { + *self.is_round_row + *self.is_digest_row + } + + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_padding_row(&self) -> O + where + O: FieldAlgebra, + { + not(self.is_not_padding_row()) + } +} diff --git a/crates/circuits/sha2-air/src/config.rs b/crates/circuits/sha2-air/src/config.rs new file mode 100644 index 0000000000..74f087a2b7 --- /dev/null +++ b/crates/circuits/sha2-air/src/config.rs @@ -0,0 +1,323 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not, Shl, Shr}; + +use crate::{Sha2DigestColsRef, Sha2RoundColsRef}; + +#[repr(u32)] +#[derive(num_enum::TryFromPrimitive, num_enum::IntoPrimitive, Copy, Clone, Debug)] +pub enum Sha2Variant { + Sha256, + Sha512, + Sha384, +} + +pub trait Sha2BlockHasherSubairConfig: Send + Sync + Clone { + // --- Required --- + + type Word: 'static + + Shr + + Shl + + BitAnd + + Not + + BitXor + + BitOr + + RotateRight + + WrappingAdd + + PartialEq + + From + + TryInto + + Copy + + Send + + Sync; + // Differentiate between the SHA-2 variants + const VARIANT: Sha2Variant; + /// Number of bits in a SHA word + const WORD_BITS: usize; + /// Number of words in a SHA block + const BLOCK_WORDS: usize; + /// Number of rows per block + const ROWS_PER_BLOCK: usize; + /// Number of rounds per row. Must divide Self::ROUNDS_PER_BLOCK + const ROUNDS_PER_ROW: usize; + /// Number of rounds per block. Must be a multiple of Self::ROUNDS_PER_ROW + const ROUNDS_PER_BLOCK: usize; + /// Number of words in a SHA hash + const HASH_WORDS: usize; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize; + + /// We also store the SHA constants K and H + fn get_k() -> &'static [Self::Word]; + fn get_h() -> &'static [Self::Word]; + + // --- Provided --- + + /// Number of 16-bit limbs in a SHA word + const WORD_U16S: usize = Self::WORD_BITS / 16; + /// Number of 8-bit limbs in a SHA word + const WORD_U8S: usize = Self::WORD_BITS / 8; + /// Number of cells in a SHA block + const BLOCK_U8S: usize = Self::BLOCK_WORDS * Self::WORD_U8S; + /// Number of bits in a SHA block + const BLOCK_BITS: usize = Self::BLOCK_WORDS * Self::WORD_BITS; + /// Number of rows used for the sha rounds + const ROUND_ROWS: usize = Self::ROUNDS_PER_BLOCK / Self::ROUNDS_PER_ROW; + /// Number of rows used for the message + const MESSAGE_ROWS: usize = Self::BLOCK_WORDS / Self::ROUNDS_PER_ROW; + /// Number of rounds per row minus one (needed for one of the column structs) + const ROUNDS_PER_ROW_MINUS_ONE: usize = Self::ROUNDS_PER_ROW - 1; + /// Width of the Sha2RoundCols + const SUBAIR_ROUND_WIDTH: usize = Sha2RoundColsRef::::width::(); + /// Width of the Sha2DigestCols + const SUBAIR_DIGEST_WIDTH: usize = Sha2DigestColsRef::::width::(); + /// Width of the Sha2BlockHasherCols + const SUBAIR_WIDTH: usize = if Self::SUBAIR_ROUND_WIDTH > Self::SUBAIR_DIGEST_WIDTH { + Self::SUBAIR_ROUND_WIDTH + } else { + Self::SUBAIR_DIGEST_WIDTH + }; +} + +#[derive(Clone)] +pub struct Sha256Config; + +#[derive(Clone)] +pub struct Sha512Config; + +#[derive(Clone)] +pub struct Sha384Config; + +impl Sha2BlockHasherSubairConfig for Sha256Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha256; + type Word = u32; + /// Number of bits in a SHA256 word + const WORD_BITS: usize = 32; + /// Number of words in a SHA256 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 17; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 64; + /// Number of words in a SHA256 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 5; + + fn get_k() -> &'static [u32] { + &SHA256_K + } + fn get_h() -> &'static [u32] { + &SHA256_H + } +} + +/// SHA256 constant K's +pub const SHA256_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; +/// SHA256 initial hash values +pub const SHA256_H: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +impl Sha2BlockHasherSubairConfig for Sha512Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha512; + type Word = u64; + /// Number of bits in a SHA512 word + const WORD_BITS: usize = 64; + /// Number of words in a SHA512 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 21; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 80; + /// Number of words in a SHA512 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 6; + + fn get_k() -> &'static [u64] { + &SHA512_K + } + fn get_h() -> &'static [u64] { + &SHA512_H + } +} + +/// SHA512 constant K's +pub const SHA512_K: [u64; 80] = [ + 0x428a2f98d728ae22, + 0x7137449123ef65cd, + 0xb5c0fbcfec4d3b2f, + 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, + 0x59f111f1b605d019, + 0x923f82a4af194f9b, + 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, + 0x12835b0145706fbe, + 0x243185be4ee4b28c, + 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, + 0x80deb1fe3b1696b1, + 0x9bdc06a725c71235, + 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, + 0xefbe4786384f25e3, + 0x0fc19dc68b8cd5b5, + 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, + 0x4a7484aa6ea6e483, + 0x5cb0a9dcbd41fbd4, + 0x76f988da831153b5, + 0x983e5152ee66dfab, + 0xa831c66d2db43210, + 0xb00327c898fb213f, + 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, + 0xd5a79147930aa725, + 0x06ca6351e003826f, + 0x142929670a0e6e70, + 0x27b70a8546d22ffc, + 0x2e1b21385c26c926, + 0x4d2c6dfc5ac42aed, + 0x53380d139d95b3df, + 0x650a73548baf63de, + 0x766a0abb3c77b2a8, + 0x81c2c92e47edaee6, + 0x92722c851482353b, + 0xa2bfe8a14cf10364, + 0xa81a664bbc423001, + 0xc24b8b70d0f89791, + 0xc76c51a30654be30, + 0xd192e819d6ef5218, + 0xd69906245565a910, + 0xf40e35855771202a, + 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, + 0x1e376c085141ab53, + 0x2748774cdf8eeb99, + 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, + 0x4ed8aa4ae3418acb, + 0x5b9cca4f7763e373, + 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, + 0x78a5636f43172f60, + 0x84c87814a1f0ab72, + 0x8cc702081a6439ec, + 0x90befffa23631e28, + 0xa4506cebde82bde9, + 0xbef9a3f7b2c67915, + 0xc67178f2e372532b, + 0xca273eceea26619c, + 0xd186b8c721c0c207, + 0xeada7dd6cde0eb1e, + 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, + 0x0a637dc5a2c898a6, + 0x113f9804bef90dae, + 0x1b710b35131c471b, + 0x28db77f523047d84, + 0x32caab7b40c72493, + 0x3c9ebe0a15c9bebc, + 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, + 0x597f299cfc657e2a, + 0x5fcb6fab3ad6faec, + 0x6c44198c4a475817, +]; +/// SHA512 initial hash values +pub const SHA512_H: [u64; 8] = [ + 0x6a09e667f3bcc908, + 0xbb67ae8584caa73b, + 0x3c6ef372fe94f82b, + 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, + 0x9b05688c2b3e6c1f, + 0x1f83d9abfb41bd6b, + 0x5be0cd19137e2179, +]; + +impl Sha2BlockHasherSubairConfig for Sha384Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha384; + type Word = ::Word; + /// Number of bits in a SHA384 word + const WORD_BITS: usize = ::WORD_BITS; + /// Number of words in a SHA384 block + const BLOCK_WORDS: usize = ::BLOCK_WORDS; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = ::ROWS_PER_BLOCK; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = ::ROUNDS_PER_ROW; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = ::ROUNDS_PER_BLOCK; + /// Number of words in a SHA384 hash + const HASH_WORDS: usize = ::HASH_WORDS; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = ::ROW_VAR_CNT; + + fn get_k() -> &'static [u64] { + &SHA384_K + } + fn get_h() -> &'static [u64] { + &SHA384_H + } +} + +/// SHA384 constant K's +pub const SHA384_K: [u64; 80] = SHA512_K; + +/// SHA384 initial hash values +pub const SHA384_H: [u64; 8] = [ + 0xcbbb9d5dc1059ed8, + 0x629a292a367cd507, + 0x9159015a3070dd17, + 0x152fecd8f70e5939, + 0x67332667ffc00b31, + 0x8eb44a8768581511, + 0xdb0c2e0d64f98fa7, + 0x47b5481dbefa4fa4, +]; + +// Needed to avoid compile errors in utils.rs +// not sure why this doesn't inf loop +pub trait RotateRight { + fn rotate_right(self, n: u32) -> Self; +} +impl RotateRight for u32 { + fn rotate_right(self, n: u32) -> Self { + self.rotate_right(n) + } +} +impl RotateRight for u64 { + fn rotate_right(self, n: u32) -> Self { + self.rotate_right(n) + } +} +pub trait WrappingAdd { + fn wrapping_add(self, n: Self) -> Self; +} +impl WrappingAdd for u32 { + fn wrapping_add(self, n: u32) -> Self { + self.wrapping_add(n) + } +} +impl WrappingAdd for u64 { + fn wrapping_add(self, n: u64) -> Self { + self.wrapping_add(n) + } +} diff --git a/crates/circuits/sha2-air/src/lib.rs b/crates/circuits/sha2-air/src/lib.rs new file mode 100644 index 0000000000..8c86ae4f27 --- /dev/null +++ b/crates/circuits/sha2-air/src/lib.rs @@ -0,0 +1,11 @@ +mod air; +mod columns; +mod config; +mod trace; +mod utils; + +pub use air::*; +pub use columns::*; +pub use config::*; +pub use trace::*; +pub use utils::*; diff --git a/crates/circuits/sha2-air/src/trace.rs b/crates/circuits/sha2-air/src/trace.rs new file mode 100644 index 0000000000..8f03417c7c --- /dev/null +++ b/crates/circuits/sha2-air/src/trace.rs @@ -0,0 +1,663 @@ +use std::{marker::PhantomData, ops::Range}; + +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, utils::compose, +}; +use openvm_stark_backend::p3_field::PrimeField32; +use sha2::{compress256, compress512, digest::generic_array::GenericArray}; + +use crate::{ + big_sig0, big_sig0_field, big_sig1, big_sig1_field, ch, ch_field, get_flag_pt_array, + le_limbs_into_word, maj, maj_field, set_arrayview_from_u32_slice, small_sig0, small_sig0_field, + small_sig1, small_sig1_field, word_into_bits, word_into_u16_limbs, word_into_u8_limbs, + Sha2BlockHasherSubairConfig, Sha2DigestColsRefMut, Sha2RoundColsRef, Sha2RoundColsRefMut, + Sha2Variant, WrappingAdd, +}; + +/// A helper struct for the SHA-2 trace generation. +/// Also, separates the inner AIR from the trace generation. +pub struct Sha2BlockHasherFillerHelper { + pub row_idx_encoder: Encoder, + _phantom: PhantomData, +} + +impl Default for Sha2BlockHasherFillerHelper { + fn default() -> Self { + Self::new() + } +} + +/// The trace generation of SHA-2 should be done in two passes. +/// The first pass should do `get_block_trace` for every block and generate the invalid rows through +/// `get_default_row` The second pass should go through all the blocks and call +/// `generate_missing_cells` +impl Sha2BlockHasherFillerHelper { + pub fn new() -> Self { + Self { + row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), + _phantom: PhantomData, + } + } + + /// This function takes the input_message (padding not handled), the previous hash, + /// and returns the new hash after processing the block input + pub fn get_block_hash(prev_hash: &[C::Word], input: Vec) -> Vec { + debug_assert!(prev_hash.len() == C::HASH_WORDS); + debug_assert!(input.len() == C::BLOCK_U8S); + let mut new_hash: [C::Word; 8] = prev_hash.try_into().unwrap(); + match C::VARIANT { + Sha2Variant::Sha256 => { + let input_array = [*GenericArray::::from_slice( + &input, + )]; + let hash_ptr: &mut [u32; 8] = unsafe { std::mem::transmute(&mut new_hash) }; + compress256(hash_ptr, &input_array); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let hash_ptr: &mut [u64; 8] = unsafe { std::mem::transmute(&mut new_hash) }; + let input_array = [*GenericArray::::from_slice( + &input, + )]; + compress512(hash_ptr, &input_array); + } + } + new_hash.to_vec() + } + + /// This function takes a C::BLOCK_BITS-bit chunk of the input message (padding not handled), + /// the previous hash, a flag indicating if it's the last block, the global block index, the + /// local block index, and the buffer values that will be put in rows 0..4. + /// Will populate the given `trace` with the trace of the block, where the width of the trace is + /// `trace_width` and the starting column for the `Sha2Air` is `trace_start_col`. + /// **Note**: this function only generates some of the required trace. Another pass is required, + /// refer to [`Self::generate_missing_cells`] for details. + #[allow(clippy::too_many_arguments)] + pub fn generate_block_trace( + &self, + trace: &mut [F], + trace_width: usize, + trace_start_col: usize, + input: &[C::Word], + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + prev_hash: &[C::Word], + next_block_prev_hash: &[C::Word], + global_block_idx: u32, + ) { + #[cfg(debug_assertions)] + { + assert!(input.len() == C::BLOCK_WORDS); + assert!(prev_hash.len() == C::HASH_WORDS); + assert!(next_block_prev_hash.len() == C::HASH_WORDS); + assert!(trace_start_col + C::SUBAIR_WIDTH == trace_width); + assert!(trace.len() == trace_width * C::ROWS_PER_BLOCK); + } + + let get_range = |start: usize, len: usize| -> Range { start..start + len }; + let mut message_schedule = vec![C::Word::from(0); C::ROUNDS_PER_BLOCK]; + message_schedule[..input.len()].copy_from_slice(input); + let mut work_vars = prev_hash.to_vec(); + for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() { + // do the rounds + if i < C::ROUND_ROWS { + let mut cols: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut row[get_range(trace_start_col, C::SUBAIR_ROUND_WIDTH)], + ); + *cols.flags.is_round_row = F::ONE; + *cols.flags.is_first_4_rows = if i < C::MESSAGE_ROWS { F::ONE } else { F::ZERO }; + *cols.flags.is_digest_row = F::ZERO; + set_arrayview_from_u32_slice( + &mut cols.flags.row_idx, + get_flag_pt_array(&self.row_idx_encoder, i), + ); + *cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); + + // W_idx = M_idx + if i < C::MESSAGE_ROWS { + for j in 0..C::ROUNDS_PER_ROW { + set_arrayview_from_u32_slice( + &mut cols.message_schedule.w.row_mut(j), + word_into_bits::(input[i * C::ROUNDS_PER_ROW + j]), + ); + } + } + // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} + else { + for j in 0..C::ROUNDS_PER_ROW { + let idx = i * C::ROUNDS_PER_ROW + j; + let nums: [C::Word; 4] = [ + small_sig1::(message_schedule[idx - 2]), + message_schedule[idx - 7], + small_sig0::(message_schedule[idx - 15]), + message_schedule[idx - 16], + ]; + let w: C::Word = nums + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + set_arrayview_from_u32_slice( + &mut cols.message_schedule.w.row_mut(j), + word_into_bits::(w), + ); + let nums_limbs = nums + .iter() + .map(|x| word_into_u16_limbs::(*x)) + .collect::>(); + let w_limbs = word_into_u16_limbs::(w); + + // fill in the carrys + for k in 0..C::WORD_U16S { + let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]); + if k > 0 { + sum += (cols.message_schedule.carry_or_buffer[[j, k * 2 - 2]] + + F::TWO + * cols.message_schedule.carry_or_buffer[[j, k * 2 - 1]]) + .as_canonical_u32(); + } + let carry = (sum - w_limbs[k]) >> 16; + cols.message_schedule.carry_or_buffer[[j, k * 2]] = + F::from_canonical_u32(carry & 1); + cols.message_schedule.carry_or_buffer[[j, k * 2 + 1]] = + F::from_canonical_u32(carry >> 1); + } + // update the message schedule + message_schedule[idx] = w; + } + } + // fill in the work variables + for j in 0..C::ROUNDS_PER_ROW { + // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx + let t1 = [ + work_vars[7], + big_sig1::(work_vars[4]), + ch::(work_vars[4], work_vars[5], work_vars[6]), + C::get_k()[i * C::ROUNDS_PER_ROW + j], + le_limbs_into_word::( + cols.message_schedule + .w + .row(j) + .map(|f| f.as_canonical_u32()) + .as_slice() + .unwrap(), + ), + ]; + let t1_sum: C::Word = t1 + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + + // t2 = SIG0(a) + maj(a, b, c) + let t2 = [ + big_sig0::(work_vars[0]), + maj::(work_vars[0], work_vars[1], work_vars[2]), + ]; + + let t2_sum: C::Word = t2 + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + + // e = d + t1 + let e = work_vars[3].wrapping_add(t1_sum); + set_arrayview_from_u32_slice( + &mut cols.work_vars.e.row_mut(j), + word_into_bits::(e), + ); + let e_limbs = word_into_u16_limbs::(e); + // a = t1 + t2 + let a = t1_sum.wrapping_add(t2_sum); + set_arrayview_from_u32_slice( + &mut cols.work_vars.a.row_mut(j), + word_into_bits::(a), + ); + let a_limbs = word_into_u16_limbs::(a); + // fill in the carrys + for k in 0..C::WORD_U16S { + let t1_limb = t1 + .iter() + .fold(0, |acc, &num| acc + word_into_u16_limbs::(num)[k]); + let t2_limb = t2 + .iter() + .fold(0, |acc, &num| acc + word_into_u16_limbs::(num)[k]); + + let mut e_limb = t1_limb + word_into_u16_limbs::(work_vars[3])[k]; + let mut a_limb = t1_limb + t2_limb; + if k > 0 { + a_limb += cols.work_vars.carry_a[[j, k - 1]].as_canonical_u32(); + e_limb += cols.work_vars.carry_e[[j, k - 1]].as_canonical_u32(); + } + let carry_a = (a_limb - a_limbs[k]) >> 16; + let carry_e = (e_limb - e_limbs[k]) >> 16; + cols.work_vars.carry_a[[j, k]] = F::from_canonical_u32(carry_a); + cols.work_vars.carry_e[[j, k]] = F::from_canonical_u32(carry_e); + bitwise_lookup_chip.request_range(carry_a, carry_e); + } + + // update working variables + work_vars[7] = work_vars[6]; + work_vars[6] = work_vars[5]; + work_vars[5] = work_vars[4]; + work_vars[4] = e; + work_vars[3] = work_vars[2]; + work_vars[2] = work_vars[1]; + work_vars[1] = work_vars[0]; + work_vars[0] = a; + } + + // filling w_3 and intermed_4 here and the rest later + if i > 0 { + for j in 0..C::ROUNDS_PER_ROW { + let idx = i * C::ROUNDS_PER_ROW + j; + let w_4 = word_into_u16_limbs::(message_schedule[idx - 4]); + let sig_0_w_3 = + word_into_u16_limbs::(small_sig0::(message_schedule[idx - 3])); + set_arrayview_from_u32_slice( + &mut cols.schedule_helper.intermed_4.row_mut(j), + (0..C::WORD_U16S) + .map(|k| w_4[k] + sig_0_w_3[k]) + .collect::>(), + ); + if j < C::ROUNDS_PER_ROW - 1 { + let w_3 = message_schedule[idx - 3]; + set_arrayview_from_u32_slice( + &mut cols.schedule_helper.w_3.row_mut(j), + word_into_u16_limbs::(w_3), + ); + } + } + } + } + // generate the digest row + else { + let mut cols: Sha2DigestColsRefMut = Sha2DigestColsRefMut::from::( + &mut row[get_range(trace_start_col, C::SUBAIR_DIGEST_WIDTH)], + ); + for j in 0..C::ROUNDS_PER_ROW - 1 { + let w_3 = message_schedule[i * C::ROUNDS_PER_ROW + j - 3]; + set_arrayview_from_u32_slice( + &mut cols.schedule_helper.w_3.row_mut(j), + word_into_u16_limbs::(w_3), + ); + } + *cols.flags.is_round_row = F::ZERO; + *cols.flags.is_first_4_rows = F::ZERO; + *cols.flags.is_digest_row = F::ONE; + set_arrayview_from_u32_slice( + &mut cols.flags.row_idx, + get_flag_pt_array(&self.row_idx_encoder, C::ROUND_ROWS), + ); + *cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); + + let final_hash: Vec = (0..C::HASH_WORDS) + .map(|i| work_vars[i].wrapping_add(prev_hash[i])) + .collect(); + let final_hash_limbs: Vec> = final_hash + .iter() + .map(|word| word_into_u8_limbs::(*word)) + .collect(); + // need to ensure final hash limbs are bytes, in order for + // prev_hash[i] + work_vars[i] == final_hash[i] + // to be constrained correctly + for word in final_hash_limbs.iter() { + for chunk in word.chunks(2) { + bitwise_lookup_chip.request_range(chunk[0], chunk[1]); + } + } + set_arrayview_from_u32_slice( + &mut cols.final_hash, + final_hash + .iter() + .flat_map(|word| word_into_u8_limbs::(*word)), + ); + set_arrayview_from_u32_slice( + &mut cols.prev_hash, + prev_hash + .iter() + .flat_map(|word| word_into_u16_limbs::(*word)), + ); + let next_block_prev_hash_bits = next_block_prev_hash + .iter() + .map(|x| word_into_bits::(*x)) + .collect::>(); + + for i in 0..C::ROUNDS_PER_ROW { + set_arrayview_from_u32_slice( + &mut cols.hash.a.row_mut(i), + next_block_prev_hash_bits[C::ROUNDS_PER_ROW - i - 1].clone(), + ); + set_arrayview_from_u32_slice( + &mut cols.hash.e.row_mut(i), + next_block_prev_hash_bits[C::ROUNDS_PER_ROW - i + 3].clone(), + ); + } + } + } + + for i in 0..C::ROWS_PER_BLOCK - 1 { + let rows = &mut trace[i * trace_width..(i + 2) * trace_width]; + let (local, next) = rows.split_at_mut(trace_width); + let mut local_cols: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut local[get_range(trace_start_col, C::SUBAIR_ROUND_WIDTH)], + ); + let mut next_cols: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut next[get_range(trace_start_col, C::SUBAIR_ROUND_WIDTH)], + ); + if i > 0 { + for j in 0..C::ROUNDS_PER_ROW { + next_cols + .schedule_helper + .intermed_8 + .row_mut(j) + .assign(&local_cols.schedule_helper.intermed_4.row(j)); + if (2..C::ROWS_PER_BLOCK - 3).contains(&i) { + next_cols + .schedule_helper + .intermed_12 + .row_mut(j) + .assign(&local_cols.schedule_helper.intermed_8.row(j)); + } + } + } + if i == C::ROWS_PER_BLOCK - 2 { + // `next` is a digest row. + // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and + // `e` hold. + let const_local_cols = Sha2RoundColsRef::::from_mut::(&local_cols); + Self::generate_carry_ae(const_local_cols.clone(), &mut next_cols); + + // Fill in row 16's `intermed_4` with dummy values so the message schedule + // constraints holds on that row + Self::generate_intermed_4(const_local_cols, &mut next_cols); + } + if i < C::MESSAGE_ROWS - 1 { + // i is in 0..3. + // Fill in `local.intermed_12` with dummy values so the message schedule constraints + // hold on rows 1..4. + Self::generate_intermed_12( + &mut local_cols, + Sha2RoundColsRef::::from_mut::(&next_cols), + ); + } + } + } + + /// This function will fill in the cells that we couldn't do during the first pass. + /// This function should be called only after `generate_block_trace` was called for all blocks + /// And [`Self::generate_default_row`] is called for all invalid rows + /// Will populate the missing values of `trace`, where the width of the trace is `trace_width` + /// Note: `trace` needs to be the rows 1..C::ROWS_PER_BLOCK of a block and the first row of the + /// next block + pub fn generate_missing_cells( + &self, + trace: &mut [F], + trace_width: usize, + trace_start_col: usize, + ) { + let rows = &mut trace[(C::ROUND_ROWS - 2) * trace_width..(C::ROUND_ROWS + 1) * trace_width]; + let (last_round_row, rows) = rows.split_at_mut(trace_width); + let (digest_row, next_block_first_row) = rows.split_at_mut(trace_width); + let mut cols_last_round_row: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut last_round_row[trace_start_col..trace_start_col + C::SUBAIR_ROUND_WIDTH], + ); + let mut cols_digest_row: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut digest_row[trace_start_col..trace_start_col + C::SUBAIR_ROUND_WIDTH], + ); + let mut cols_next_block_first_row: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut next_block_first_row[trace_start_col..trace_start_col + C::SUBAIR_ROUND_WIDTH], + ); + // Fill in the last round row's `intermed_12` with dummy values so the message schedule + // constraints holds on the last round row + Self::generate_intermed_12( + &mut cols_last_round_row, + Sha2RoundColsRef::from_mut::(&cols_digest_row), + ); + // Fill in the digest row's `intermed_12` with dummy values so the message schedule + // constraints holds on the next block's row 0 + Self::generate_intermed_12( + &mut cols_digest_row, + Sha2RoundColsRef::from_mut::(&cols_next_block_first_row), + ); + // Fill in the next block's first row's `intermed_4` with dummy values so the message + // schedule constraints holds on that row + Self::generate_intermed_4( + Sha2RoundColsRef::from_mut::(&cols_digest_row), + &mut cols_next_block_first_row, + ); + } + + /// Fills the `cols` as a padding row + /// Note: we still need to correctly fill in the hash values, carries and intermeds + pub fn generate_default_row( + &self, + cols: &mut Sha2RoundColsRefMut, + first_block_prev_hash: &[C::Word], + carry_a: Option<&[F]>, + carry_e: Option<&[F]>, + ) { + debug_assert!(first_block_prev_hash.len() == C::HASH_WORDS); + debug_assert!(carry_a.is_some() == carry_e.is_some()); + debug_assert!( + carry_a.is_none() || carry_a.unwrap().len() == C::ROUNDS_PER_ROW * C::WORD_U16S + ); + debug_assert!( + carry_e.is_none() || carry_e.unwrap().len() == C::ROUNDS_PER_ROW * C::WORD_U16S + ); + + set_arrayview_from_u32_slice( + &mut cols.flags.row_idx, + get_flag_pt_array(&self.row_idx_encoder, C::ROWS_PER_BLOCK), + ); + + for i in 0..C::ROUNDS_PER_ROW { + // The padding rows need to have the first block's prev_hash here, to satisfy the air + // constraints + set_arrayview_from_u32_slice( + &mut cols.work_vars.a.row_mut(i), + word_into_bits::(first_block_prev_hash[C::ROUNDS_PER_ROW - i - 1]).into_iter(), + ); + set_arrayview_from_u32_slice( + &mut cols.work_vars.e.row_mut(i), + word_into_bits::(first_block_prev_hash[C::ROUNDS_PER_ROW - i + 3]).into_iter(), + ); + + // The invalid carries are not constants anymore, so we need to fill them in here + if let Some(carry_a) = carry_a { + cols.work_vars + .carry_a + .iter_mut() + .zip(carry_a.iter()) + .for_each(|(x, y)| *x = *y); + } + if let Some(carry_e) = carry_e { + cols.work_vars + .carry_e + .iter_mut() + .zip(carry_e.iter()) + .for_each(|(x, y)| *x = *y); + } + } + } + + /// The following functions do the calculations in native field since they will be called on + /// padding rows which can overflow and we need to make sure it matches the AIR constraints + /// Puts the correct carries in the `next_row`, the resulting carries can be out of bounds. + /// Assumes next.w and next.k are zero, which is the case when constraint_word_addition is + /// constrained on digest rows or padding rows. + /// It only looks at local.a, next.a, local.e, next.e. + pub fn generate_carry_ae( + local_cols: Sha2RoundColsRef, + next_cols: &mut Sha2RoundColsRefMut, + ) { + let a = [ + local_cols + .work_vars + .a + .rows() + .into_iter() + .collect::>(), + next_cols.work_vars.a.rows().into_iter().collect::>(), + ] + .concat(); + let e = [ + local_cols + .work_vars + .e + .rows() + .into_iter() + .collect::>(), + next_cols.work_vars.e.rows().into_iter().collect::>(), + ] + .concat(); + for i in 0..C::ROUNDS_PER_ROW { + let cur_a = a[i + 4]; + let sig_a = big_sig0_field::(a[i + 3].as_slice().unwrap()); + let maj_abc = maj_field::( + a[i + 3].as_slice().unwrap(), + a[i + 2].as_slice().unwrap(), + a[i + 1].as_slice().unwrap(), + ); + let d = a[i]; + let cur_e = e[i + 4]; + let sig_e = big_sig1_field::(e[i + 3].as_slice().unwrap()); + let ch_efg = ch_field::( + e[i + 3].as_slice().unwrap(), + e[i + 2].as_slice().unwrap(), + e[i + 1].as_slice().unwrap(), + ); + let h = e[i]; + + let t1 = [h.to_vec(), sig_e, ch_efg.to_vec()]; + let t2 = [sig_a, maj_abc]; + for j in 0..C::WORD_U16S { + let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| { + acc + compose::(&x[j * 16..(j + 1) * 16], 1) + }); + let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| { + acc + compose::(&x[j * 16..(j + 1) * 16], 1) + }); + let d_limb = compose::(&d.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let cur_a_limb = compose::(&cur_a.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let cur_e_limb = compose::(&cur_e.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let sum = d_limb + + t1_limb_sum + + if j == 0 { + F::ZERO + } else { + next_cols.work_vars.carry_e[[i, j - 1]] + } + - cur_e_limb; + let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse()); + + let sum = t1_limb_sum + + t2_limb_sum + + if j == 0 { + F::ZERO + } else { + next_cols.work_vars.carry_a[[i, j - 1]] + } + - cur_a_limb; + let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse()); + next_cols.work_vars.carry_e[[i, j]] = carry_e; + next_cols.work_vars.carry_a[[i, j]] = carry_a; + } + } + } + + /// Puts the correct intermed_4 in the `next_row` + pub fn generate_intermed_4( + local_cols: Sha2RoundColsRef, + next_cols: &mut Sha2RoundColsRefMut, + ) { + let w = [ + local_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + next_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + ] + .concat(); + let w_limbs: Vec> = w + .iter() + .map(|x| { + (0..C::WORD_U16S) + .map(|i| compose::(&x.as_slice().unwrap()[i * 16..(i + 1) * 16], 1)) + .collect::>() + }) + .collect(); + for i in 0..C::ROUNDS_PER_ROW { + let sig_w = small_sig0_field::(w[i + 1].as_slice().unwrap()); + let sig_w_limbs: Vec = (0..C::WORD_U16S) + .map(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)) + .collect(); + for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { + next_cols.schedule_helper.intermed_4[[i, j]] = w_limbs[i][j] + *sig_w_limb; + } + } + } + + /// Puts the needed intermed_12 in the `local_row` + pub fn generate_intermed_12( + local_cols: &mut Sha2RoundColsRefMut, + next_cols: Sha2RoundColsRef, + ) { + let w = [ + local_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + next_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + ] + .concat(); + let w_limbs: Vec> = w + .iter() + .map(|x| { + (0..C::WORD_U16S) + .map(|i| compose::(&x.as_slice().unwrap()[i * 16..(i + 1) * 16], 1)) + .collect::>() + }) + .collect(); + for i in 0..C::ROUNDS_PER_ROW { + // sig_1(w_{t-2}) + let sig_w_2: Vec = (0..C::WORD_U16S) + .map(|j| { + compose::( + &small_sig1_field::(w[i + 2].as_slice().unwrap()) + [j * 16..(j + 1) * 16], + 1, + ) + }) + .collect(); + // w_{t-7} + let w_7 = if i < 3 { + local_cols.schedule_helper.w_3.row(i).to_slice().unwrap() + } else { + w_limbs[i - 3].as_slice() + }; + // w_t + let w_cur = w_limbs[i + 4].as_slice(); + for j in 0..C::WORD_U16S { + let carry = next_cols.message_schedule.carry_or_buffer[[i, j * 2]] + + F::TWO * next_cols.message_schedule.carry_or_buffer[[i, j * 2 + 1]]; + let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j] + + if j > 0 { + next_cols.message_schedule.carry_or_buffer[[i, j * 2 - 2]] + + F::from_canonical_u32(2) + * next_cols.message_schedule.carry_or_buffer[[i, j * 2 - 1]] + } else { + F::ZERO + }; + local_cols.schedule_helper.intermed_12[[i, j]] = -sum; + } + } + } +} diff --git a/crates/circuits/sha2-air/src/utils.rs b/crates/circuits/sha2-air/src/utils.rs new file mode 100644 index 0000000000..ac8b33fb73 --- /dev/null +++ b/crates/circuits/sha2-air/src/utils.rs @@ -0,0 +1,313 @@ +use ndarray::ArrayViewMut; +pub use openvm_circuit_primitives::utils::compose; +use openvm_circuit_primitives::{ + encoder::Encoder, + utils::{not, select}, +}; +use openvm_stark_backend::{ + p3_air::AirBuilder, + p3_field::{FieldAlgebra, PrimeField32}, +}; +use rand::{rngs::StdRng, Rng}; + +use crate::{RotateRight, Sha2BlockHasherSubairConfig}; + +/// Convert a word into a list of 8-bit limbs in little endian +pub fn word_into_u8_limbs(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_U8S) +} + +/// Convert a word into a list of 16-bit limbs in little endian +pub fn word_into_u16_limbs(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_U16S) +} + +/// Convert a word into a list of 1-bit limbs in little endian +pub fn word_into_bits(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_BITS) +} + +/// Convert a word into a list of limbs in little endian +pub fn word_into_limbs(num: C::Word, num_limbs: usize) -> Vec { + let limb_bits = std::mem::size_of::() * 8 / num_limbs; + (0..num_limbs) + .map(|i| { + let shifted = num >> (limb_bits * i); + let mask: C::Word = ((1u32 << limb_bits) - 1).into(); + let masked = shifted & mask; + masked.try_into().unwrap() + }) + .collect() +} + +/// Convert a u32 into a list of 1-bit limbs in little endian +pub fn u32_into_bits(num: u32) -> Vec { + let limb_bits = 32 / C::WORD_BITS; + (0..C::WORD_BITS) + .map(|i| (num >> (limb_bits * i)) & ((1 << limb_bits) - 1)) + .collect() +} + +/// Convert a list of limbs in little endian into a Word +pub fn le_limbs_into_word(limbs: &[u32]) -> C::Word { + let mut limbs = limbs.to_vec(); + limbs.reverse(); + be_limbs_into_word::(&limbs) +} + +/// Convert a list of limbs in big endian into a Word +pub fn be_limbs_into_word(limbs: &[u32]) -> C::Word { + let limb_bits = C::WORD_BITS / limbs.len(); + limbs.iter().fold(C::Word::from(0), |acc, &limb| { + (acc << limb_bits) | limb.into() + }) +} + +/// Convert a list of limbs in little endian into a u32 +pub fn limbs_into_u32(limbs: &[u32]) -> u32 { + let limb_bits = 32 / limbs.len(); + limbs + .iter() + .rev() + .fold(0, |acc, &limb| (acc << limb_bits) | limb) +} + +/// Rotates `bits` right by `n` bits, assumes `bits` is in little-endian +#[inline] +pub(crate) fn rotr(bits: &[impl Into + Clone], n: usize) -> Vec { + (0..bits.len()) + .map(|i| bits[(i + n) % bits.len()].clone().into()) + .collect() +} + +/// Shifts `bits` right by `n` bits, assumes `bits` is in little-endian +#[inline] +pub(crate) fn shr(bits: &[impl Into + Clone], n: usize) -> Vec { + (0..bits.len()) + .map(|i| { + if i + n < bits.len() { + bits[i + n].clone().into() + } else { + F::ZERO + } + }) + .collect() +} + +/// Computes x ^ y ^ z, where x, y, z are assumed to be boolean +#[inline] +pub(crate) fn xor_bit( + x: impl Into, + y: impl Into, + z: impl Into, +) -> F { + let (x, y, z) = (x.into(), y.into(), z.into()); + (x.clone() * y.clone() * z.clone()) + + (x.clone() * not::(y.clone()) * not::(z.clone())) + + (not::(x.clone()) * y.clone() * not::(z.clone())) + + (not::(x) * not::(y) * z) +} + +/// Computes x ^ y ^ z, where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn xor( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| xor_bit(x[i].clone(), y[i].clone(), z[i].clone())) + .collect() +} + +/// Choose function from the SHA spec +#[inline] +pub fn ch(x: C::Word, y: C::Word, z: C::Word) -> C::Word { + (x & y) ^ ((!x) & z) +} + +/// Computes Ch(x,y,z), where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn ch_field( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| select(x[i].clone(), y[i].clone(), z[i].clone())) + .collect() +} + +/// Majority function from the SHA spec +pub fn maj(x: C::Word, y: C::Word, z: C::Word) -> C::Word { + (x & y) ^ (x & z) ^ (y & z) +} + +/// Computes Maj(x,y,z), where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn maj_field( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| { + let (x, y, z) = ( + x[i].clone().into(), + y[i].clone().into(), + z[i].clone().into(), + ); + x.clone() * y.clone() + x.clone() * z.clone() + y.clone() * z.clone() + - F::TWO * x * y * z + }) + .collect() +} + +/// Big sigma_0 function from the SHA spec +pub fn big_sig0(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) + } else { + x.rotate_right(28) ^ x.rotate_right(34) ^ x.rotate_right(39) + } +} + +/// Computes BigSigma0(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn big_sig0_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 2), &rotr::(x, 13), &rotr::(x, 22)) + } else { + xor(&rotr::(x, 28), &rotr::(x, 34), &rotr::(x, 39)) + } +} + +/// Big sigma_1 function from the SHA spec +pub fn big_sig1(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) + } else { + x.rotate_right(14) ^ x.rotate_right(18) ^ x.rotate_right(41) + } +} + +/// Computes BigSigma1(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn big_sig1_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 6), &rotr::(x, 11), &rotr::(x, 25)) + } else { + xor(&rotr::(x, 14), &rotr::(x, 18), &rotr::(x, 41)) + } +} + +/// Small sigma_0 function from the SHA spec +pub fn small_sig0(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) + } else { + x.rotate_right(1) ^ x.rotate_right(8) ^ (x >> 7) + } +} + +/// Computes SmallSigma0(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn small_sig0_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 7), &rotr::(x, 18), &shr::(x, 3)) + } else { + xor(&rotr::(x, 1), &rotr::(x, 8), &shr::(x, 7)) + } +} + +/// Small sigma_1 function from the SHA spec +pub fn small_sig1(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) + } else { + x.rotate_right(19) ^ x.rotate_right(61) ^ (x >> 6) + } +} + +/// Computes SmallSigma1(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn small_sig1_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) + } else { + xor(&rotr::(x, 19), &rotr::(x, 61), &shr::(x, 6)) + } +} + +/// Generate a random message of a given length +pub fn get_random_message(rng: &mut StdRng, len: usize) -> Vec { + let mut random_message: Vec = vec![0u8; len]; + rng.fill(&mut random_message[..]); + random_message +} + +/// Wrapper of `get_flag_pt` to get the flag pointer as an array +pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> Vec { + encoder.get_flag_pt(flag_idx) +} + +/// Constrain the addition of [C::WORD_BITS] bit words in 16-bit limbs +/// It takes in the terms some in bits some in 16-bit limbs, +/// the expected sum in bits and the carries +pub fn constraint_word_addition( + builder: &mut AB, + terms_bits: &[&[impl Into + Clone]], + terms_limb: &[&[impl Into + Clone]], + expected_sum: &[impl Into + Clone], + carries: &[impl Into + Clone], +) { + debug_assert!(terms_bits.iter().all(|x| x.len() == C::WORD_BITS)); + debug_assert!(terms_limb.iter().all(|x| x.len() == C::WORD_U16S)); + assert_eq!(expected_sum.len(), C::WORD_BITS); + assert_eq!(carries.len(), C::WORD_U16S); + + for i in 0..C::WORD_U16S { + let mut limb_sum = if i == 0 { + AB::Expr::ZERO + } else { + carries[i - 1].clone().into() + }; + for term in terms_bits { + limb_sum += compose::(&term[i * 16..(i + 1) * 16], 1); + } + for term in terms_limb { + limb_sum += term[i].clone().into(); + } + let expected_sum_limb = compose::(&expected_sum[i * 16..(i + 1) * 16], 1) + + carries[i].clone().into() * AB::Expr::from_canonical_u32(1 << 16); + builder.assert_eq(limb_sum, expected_sum_limb); + } +} + +pub fn set_arrayview_from_u32_slice( + arrayview: &mut ArrayViewMut, + data: impl IntoIterator, +) { + arrayview + .iter_mut() + .zip(data.into_iter().map(|x| F::from_canonical_u32(x))) + .for_each(|(x, y)| *x = y); +} + +pub fn set_arrayview_from_u8_slice( + arrayview: &mut ArrayViewMut, + data: impl IntoIterator, +) { + arrayview + .iter_mut() + .zip(data.into_iter().map(|x| F::from_canonical_u8(x))) + .for_each(|(x, y)| *x = y); +} diff --git a/crates/circuits/sha256-air/cuda/include/sha256-air/columns.cuh b/crates/circuits/sha256-air/cuda/include/sha256-air/columns.cuh deleted file mode 100644 index 64c6a276eb..0000000000 --- a/crates/circuits/sha256-air/cuda/include/sha256-air/columns.cuh +++ /dev/null @@ -1,131 +0,0 @@ -#pragma once - -#include "primitives/constants.h" -#include "primitives/execution.h" -#include "system/memory/offline_checker.cuh" - -using namespace riscv; -using namespace sha256; - -template struct Sha256FlagsCols { - T is_round_row; - T is_first_4_rows; - T is_digest_row; - T is_last_block; - T row_idx[SHA256_ROW_VAR_CNT]; - T global_block_idx; - T local_block_idx; -}; - -template struct Sha256MessageHelperCols { - T w_3[SHA256_ROUNDS_PER_ROW - 1][SHA256_WORD_U16S]; - T intermed_4[SHA256_ROUNDS_PER_ROW][SHA256_WORD_U16S]; - T intermed_8[SHA256_ROUNDS_PER_ROW][SHA256_WORD_U16S]; - T intermed_12[SHA256_ROUNDS_PER_ROW][SHA256_WORD_U16S]; -}; - -template struct Sha256MessageScheduleCols { - T w[SHA256_ROUNDS_PER_ROW][SHA256_WORD_BITS]; - T carry_or_buffer[SHA256_ROUNDS_PER_ROW][SHA256_WORD_U8S]; -}; - -template struct Sha256WorkVarsCols { - T a[SHA256_ROUNDS_PER_ROW][SHA256_WORD_BITS]; - T e[SHA256_ROUNDS_PER_ROW][SHA256_WORD_BITS]; - T carry_a[SHA256_ROUNDS_PER_ROW][SHA256_WORD_U16S]; - T carry_e[SHA256_ROUNDS_PER_ROW][SHA256_WORD_U16S]; -}; - -template struct Sha256RoundCols { - Sha256FlagsCols flags; - Sha256WorkVarsCols work_vars; - Sha256MessageHelperCols schedule_helper; - Sha256MessageScheduleCols message_schedule; -}; - -template struct Sha256DigestCols { - Sha256FlagsCols flags; - Sha256WorkVarsCols hash; - Sha256MessageHelperCols schedule_helper; - T final_hash[SHA256_HASH_WORDS][SHA256_WORD_U8S]; - T prev_hash[SHA256_HASH_WORDS][SHA256_WORD_U16S]; -}; - -template struct Sha256VmControlCols { - T len; - T cur_timestamp; - T read_ptr; - T pad_flags[6]; - T padding_occurred; -}; - -template struct Sha256VmRoundCols { - Sha256VmControlCols control; - Sha256RoundCols inner; - MemoryReadAuxCols read_aux; -}; - -template struct Sha256VmDigestCols { - Sha256VmControlCols control; - Sha256DigestCols inner; - ExecutionState from_state; - T rd_ptr; - T rs1_ptr; - T rs2_ptr; - T dst_ptr[RV32_REGISTER_NUM_LIMBS]; - T src_ptr[RV32_REGISTER_NUM_LIMBS]; - T len_data[RV32_REGISTER_NUM_LIMBS]; - MemoryReadAuxCols register_reads_aux[SHA256_REGISTER_READS]; - MemoryWriteAuxCols writes_aux; -}; - -struct Sha256InnerBlockRecord { - uint32_t global_block_idx; - uint32_t local_block_idx; - uint32_t is_last_block; - uint32_t prev_hash[SHA256_HASH_WORDS]; - uint32_t input_words[SHA256_BLOCK_WORDS]; -}; - -__global__ void sha256_fill_invalid_rows( - Fp *d_trace, - size_t trace_height, - size_t trace_width, - uint32_t rows_used -); - -// ===== MACROS AND CONSTANTS ===== -static constexpr size_t SHA256VM_CONTROL_WIDTH = sizeof(Sha256VmControlCols); -static constexpr size_t SHA256_ROUND_WIDTH = sizeof(Sha256RoundCols); -static constexpr size_t SHA256_DIGEST_WIDTH = sizeof(Sha256DigestCols); -static constexpr size_t SHA256VM_ROUND_WIDTH = sizeof(Sha256VmRoundCols); -static constexpr size_t SHA256VM_DIGEST_WIDTH = sizeof(Sha256VmDigestCols); - -static constexpr size_t SHA256_WIDTH = - (SHA256_ROUND_WIDTH > SHA256_DIGEST_WIDTH) ? SHA256_ROUND_WIDTH : SHA256_DIGEST_WIDTH; -static constexpr size_t SHA256VM_WIDTH = - (SHA256VM_ROUND_WIDTH > SHA256VM_DIGEST_WIDTH) ? SHA256VM_ROUND_WIDTH : SHA256VM_DIGEST_WIDTH; -static constexpr size_t SHA256_INNER_COLUMN_OFFSET = sizeof(Sha256VmControlCols); - -#define SHA256_WRITE_ROUND(row, FIELD, VALUE) COL_WRITE_VALUE(row, Sha256VmRoundCols, FIELD, VALUE) -#define SHA256_WRITE_DIGEST(row, FIELD, VALUE) \ - COL_WRITE_VALUE(row, Sha256VmDigestCols, FIELD, VALUE) -#define SHA256_WRITE_ARRAY_ROUND(row, FIELD, VALUES) \ - COL_WRITE_ARRAY(row, Sha256VmRoundCols, FIELD, VALUES) -#define SHA256_WRITE_ARRAY_DIGEST(row, FIELD, VALUES) \ - COL_WRITE_ARRAY(row, Sha256VmDigestCols, FIELD, VALUES) -#define SHA256_FILL_ZERO_ROUND(row, FIELD) COL_FILL_ZERO(row, Sha256VmRoundCols, FIELD) -#define SHA256_FILL_ZERO_DIGEST(row, FIELD) COL_FILL_ZERO(row, Sha256VmDigestCols, FIELD) -#define SHA256_SLICE_ROUND(row, FIELD) row.slice_from(COL_INDEX(Sha256VmRoundCols, FIELD)) -#define SHA256_SLICE_DIGEST(row, FIELD) row.slice_from(COL_INDEX(Sha256VmDigestCols, FIELD)) - -#define SHA256INNER_WRITE_ROUND(row, FIELD, VALUE) \ - COL_WRITE_VALUE(row, Sha256RoundCols, FIELD, VALUE) -#define SHA256INNER_WRITE_DIGEST(row, FIELD, VALUE) \ - COL_WRITE_VALUE(row, Sha256DigestCols, FIELD, VALUE) -#define SHA256INNER_WRITE_ARRAY_ROUND(row, FIELD, VALUES) \ - COL_WRITE_ARRAY(row, Sha256RoundCols, FIELD, VALUES) -#define SHA256INNER_WRITE_ARRAY_DIGEST(row, FIELD, VALUES) \ - COL_WRITE_ARRAY(row, Sha256DigestCols, FIELD, VALUES) -#define SHA256INNER_FILL_ZERO_ROUND(row, FIELD) COL_FILL_ZERO(row, Sha256RoundCols, FIELD) -#define SHA256INNER_FILL_ZERO_DIGEST(row, FIELD) COL_FILL_ZERO(row, Sha256DigestCols, FIELD) \ No newline at end of file diff --git a/crates/circuits/sha256-air/cuda/include/sha256-air/tracegen.cuh b/crates/circuits/sha256-air/cuda/include/sha256-air/tracegen.cuh deleted file mode 100644 index ead1830a8f..0000000000 --- a/crates/circuits/sha256-air/cuda/include/sha256-air/tracegen.cuh +++ /dev/null @@ -1,666 +0,0 @@ -#pragma once - -#include "columns.cuh" -#include "primitives/constants.h" -#include "primitives/encoder.cuh" -#include "primitives/histogram.cuh" -#include "primitives/trace_access.h" -#include "utils.cuh" - -using namespace riscv; -using namespace sha256; - -__device__ void generate_carry_ae(RowSlice local_row, RowSlice next_row) { - Fp a_bits[SHA256_ROUNDS_PER_ROW * 2][SHA256_WORD_BITS]; - Fp e_bits[SHA256_ROUNDS_PER_ROW * 2][SHA256_WORD_BITS]; - for (int i = 0; i < SHA256_ROUNDS_PER_ROW; i++) { - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - a_bits[i][bit] = local_row[COL_INDEX(Sha256RoundCols, work_vars.a[i][bit])]; - e_bits[i][bit] = local_row[COL_INDEX(Sha256RoundCols, work_vars.e[i][bit])]; - } - } - - for (int i = 0; i < SHA256_ROUNDS_PER_ROW; i++) { - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - a_bits[i + SHA256_ROUNDS_PER_ROW][bit] = - next_row[COL_INDEX(Sha256RoundCols, work_vars.a[i][bit])]; - e_bits[i + SHA256_ROUNDS_PER_ROW][bit] = - next_row[COL_INDEX(Sha256RoundCols, work_vars.e[i][bit])]; - } - } - - for (int i = 0; i < SHA256_ROUNDS_PER_ROW; i++) { - Fp sig_a_bits[SHA256_WORD_BITS]; - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - sig_a_bits[bit] = - (a_bits[i + 3][(bit + 2) & 31] + a_bits[i + 3][(bit + 13) & 31] - - Fp(2) * a_bits[i + 3][(bit + 2) & 31] * a_bits[i + 3][(bit + 13) & 31]) + - a_bits[i + 3][(bit + 22) & 31] - - Fp(2) * - (a_bits[i + 3][(bit + 2) & 31] + a_bits[i + 3][(bit + 13) & 31] - - Fp(2) * a_bits[i + 3][(bit + 2) & 31] * a_bits[i + 3][(bit + 13) & 31]) * - a_bits[i + 3][(bit + 22) & 31]; - } - - Fp sig_e_bits[SHA256_WORD_BITS]; - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - sig_e_bits[bit] = - (e_bits[i + 3][(bit + 6) & 31] + e_bits[i + 3][(bit + 11) & 31] - - Fp(2) * e_bits[i + 3][(bit + 6) & 31] * e_bits[i + 3][(bit + 11) & 31]) + - e_bits[i + 3][(bit + 25) & 31] - - Fp(2) * - (e_bits[i + 3][(bit + 6) & 31] + e_bits[i + 3][(bit + 11) & 31] - - Fp(2) * e_bits[i + 3][(bit + 6) & 31] * e_bits[i + 3][(bit + 11) & 31]) * - e_bits[i + 3][(bit + 25) & 31]; - } - - Fp maj_abc_bits[SHA256_WORD_BITS]; - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - maj_abc_bits[bit] = - a_bits[i + 3][bit] * a_bits[i + 2][bit] + a_bits[i + 3][bit] * a_bits[i + 1][bit] + - a_bits[i + 2][bit] * a_bits[i + 1][bit] - - Fp(2) * a_bits[i + 3][bit] * a_bits[i + 2][bit] * a_bits[i + 1][bit]; - } - - Fp ch_efg_bits[SHA256_WORD_BITS]; - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - ch_efg_bits[bit] = e_bits[i + 3][bit] * e_bits[i + 2][bit] + e_bits[i + 1][bit] - - e_bits[i + 3][bit] * e_bits[i + 1][bit]; - } - for (int j = 0; j < SHA256_WORD_U16S; j++) { - Fp t1_limb_sum = Fp::zero(); -#pragma unroll 1 - for (int bit = 0; bit < 16; bit++) { - t1_limb_sum += (e_bits[i][j * 16 + bit] + sig_e_bits[j * 16 + bit] + - ch_efg_bits[j * 16 + bit]) * - Fp(1 << bit); - } - Fp t2_limb_sum = Fp::zero(); -#pragma unroll 1 - for (int bit = 0; bit < 16; bit++) { - t2_limb_sum += - (sig_a_bits[j * 16 + bit] + maj_abc_bits[j * 16 + bit]) * Fp(1 << bit); - } - - Fp d_limb = Fp::zero(); - Fp cur_a_limb = Fp::zero(); - Fp cur_e_limb = Fp::zero(); -#pragma unroll 1 - for (int bit = 0; bit < 16; bit++) { - d_limb += (a_bits[i][j * 16 + bit] * Fp(1 << bit)); - cur_a_limb += (a_bits[i + 4][j * 16 + bit] * Fp(1 << bit)); - cur_e_limb += (e_bits[i + 4][j * 16 + bit] * Fp(1 << bit)); - } - Fp prev_carry_e = - (j == 0) ? Fp::zero() - : next_row[COL_INDEX(Sha256RoundCols, work_vars.carry_e[i][j - 1])]; - Fp carry_e_numerator = d_limb + t1_limb_sum + prev_carry_e - cur_e_limb; - Fp carry_e = carry_e_numerator.mul_2exp_neg_n(16); - - Fp prev_carry_a = - (j == 0) ? Fp::zero() - : next_row[COL_INDEX(Sha256RoundCols, work_vars.carry_a[i][j - 1])]; - - Fp carry_a_numerator = t1_limb_sum + t2_limb_sum + prev_carry_a - cur_a_limb; - Fp carry_a = carry_a_numerator.mul_2exp_neg_n(16); - SHA256INNER_WRITE_ROUND(next_row, work_vars.carry_e[i][j], carry_e); - SHA256INNER_WRITE_ROUND(next_row, work_vars.carry_a[i][j], carry_a); - } - } -} - -__device__ void generate_intermed_4(RowSlice local_row, RowSlice next_row) { - Fp w_bits[SHA256_ROUNDS_PER_ROW * 2][SHA256_WORD_BITS]; - - for (int i = 0; i < SHA256_ROUNDS_PER_ROW; i++) { - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - w_bits[i][bit] = local_row[COL_INDEX(Sha256RoundCols, message_schedule.w[i][bit])]; - } - } - - for (int i = 0; i < SHA256_ROUNDS_PER_ROW; i++) { - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - w_bits[i + SHA256_ROUNDS_PER_ROW][bit] = - next_row[COL_INDEX(Sha256RoundCols, message_schedule.w[i][bit])]; - } - } - - Fp w_limbs[SHA256_ROUNDS_PER_ROW * 2][SHA256_WORD_U16S]; - for (int i = 0; i < SHA256_ROUNDS_PER_ROW * 2; i++) { - for (int j = 0; j < SHA256_WORD_U16S; j++) { - w_limbs[i][j] = Fp::zero(); - for (int bit = 0; bit < 16; bit++) { - w_limbs[i][j] = w_limbs[i][j] + w_bits[i][j * 16 + bit] * Fp(1 << bit); - } - } - } - - for (int i = 0; i < SHA256_ROUNDS_PER_ROW; i++) { - Fp sig_w_bits[SHA256_WORD_BITS]; - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - sig_w_bits[bit] = - (w_bits[i + 1][(bit + 7) & 31] + w_bits[i + 1][(bit + 18) & 31] - - Fp(2) * w_bits[i + 1][(bit + 7) & 31] * w_bits[i + 1][(bit + 18) & 31]) + - ((bit + 3 < 32) ? w_bits[i + 1][bit + 3] : Fp::zero()) - - Fp(2) * - (w_bits[i + 1][(bit + 7) & 31] + w_bits[i + 1][(bit + 18) & 31] - - Fp(2) * w_bits[i + 1][(bit + 7) & 31] * w_bits[i + 1][(bit + 18) & 31]) * - ((bit + 3 < 32) ? w_bits[i + 1][bit + 3] : Fp::zero()); - } - - Fp sig_w_limbs[SHA256_WORD_U16S]; - for (int j = 0; j < SHA256_WORD_U16S; j++) { - sig_w_limbs[j] = Fp::zero(); - for (int bit = 0; bit < 16; bit++) { - sig_w_limbs[j] = sig_w_limbs[j] + sig_w_bits[j * 16 + bit] * Fp(1 << bit); - } - } - - for (int j = 0; j < SHA256_WORD_U16S; j++) { - SHA256INNER_WRITE_ROUND( - next_row, schedule_helper.intermed_4[i][j], w_limbs[i][j] + sig_w_limbs[j] - ); - } - } -} - -__device__ void generate_intermed_12(RowSlice local_row, RowSlice next_row) { - for (int i = 0; i < SHA256_ROUNDS_PER_ROW; i++) { - Fp sig_w_2_limbs[SHA256_WORD_U16S]; - Fp w_i_plus_2_bits[SHA256_WORD_BITS]; - - if (i + 2 < SHA256_ROUNDS_PER_ROW) { - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - w_i_plus_2_bits[bit] = - local_row[COL_INDEX(Sha256RoundCols, message_schedule.w[i + 2][bit])]; - } - } else { - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - w_i_plus_2_bits[bit] = next_row[COL_INDEX( - Sha256RoundCols, message_schedule.w[i + 2 - SHA256_ROUNDS_PER_ROW][bit] - )]; - } - } - - Fp sig_bits[SHA256_WORD_BITS]; - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - sig_bits[bit] = - (w_i_plus_2_bits[(bit + 17) & 31] + w_i_plus_2_bits[(bit + 19) & 31] - - Fp(2) * w_i_plus_2_bits[(bit + 17) & 31] * w_i_plus_2_bits[(bit + 19) & 31]) + - ((bit + 10 < 32) ? w_i_plus_2_bits[bit + 10] : Fp::zero()) - - Fp(2) * - (w_i_plus_2_bits[(bit + 17) & 31] + w_i_plus_2_bits[(bit + 19) & 31] - - Fp(2) * w_i_plus_2_bits[(bit + 17) & 31] * w_i_plus_2_bits[(bit + 19) & 31]) * - ((bit + 10 < 32) ? w_i_plus_2_bits[bit + 10] : Fp::zero()); - } - - for (int j = 0; j < SHA256_WORD_U16S; j++) { - sig_w_2_limbs[j] = Fp::zero(); - for (int bit = 0; bit < 16; bit++) { - sig_w_2_limbs[j] += sig_bits[j * 16 + bit] * Fp(1 << bit); - } - } - - Fp w_7_limbs[SHA256_WORD_U16S]; - if (i < 3) { - w_7_limbs[0] = local_row[COL_INDEX(Sha256RoundCols, schedule_helper.w_3[i][0])]; - w_7_limbs[1] = local_row[COL_INDEX(Sha256RoundCols, schedule_helper.w_3[i][1])]; - } else { - Fp w_i_minus_3_bits[SHA256_WORD_BITS]; - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - w_i_minus_3_bits[bit] = - local_row[COL_INDEX(Sha256RoundCols, message_schedule.w[i - 3][bit])]; - } - for (int j = 0; j < SHA256_WORD_U16S; j++) { - w_7_limbs[j] = Fp::zero(); - for (int bit = 0; bit < 16; bit++) { - w_7_limbs[j] += w_i_minus_3_bits[j * 16 + bit] * Fp(1 << bit); - } - } - } - - Fp w_cur_limbs[SHA256_WORD_U16S]; - Fp w_cur_bits[SHA256_WORD_BITS]; - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - w_cur_bits[bit] = next_row[COL_INDEX(Sha256RoundCols, message_schedule.w[i][bit])]; - } - for (int j = 0; j < SHA256_WORD_U16S; j++) { - w_cur_limbs[j] = Fp::zero(); - for (int bit = 0; bit < 16; bit++) { - w_cur_limbs[j] += w_cur_bits[j * 16 + bit] * Fp(1 << bit); - } - } - - for (int j = 0; j < SHA256_WORD_U16S; j++) { - Fp carry = - next_row[COL_INDEX(Sha256RoundCols, message_schedule.carry_or_buffer[i][j * 2])] + - Fp(2) * next_row[COL_INDEX( - Sha256RoundCols, message_schedule.carry_or_buffer[i][j * 2 + 1] - )]; - - Fp prev_carry = Fp::zero(); - if (j > 0) { - prev_carry = - next_row[COL_INDEX( - Sha256RoundCols, message_schedule.carry_or_buffer[i][j * 2 - 2] - )] + - Fp(2) * next_row[COL_INDEX( - Sha256RoundCols, message_schedule.carry_or_buffer[i][j * 2 - 1] - )]; - } - - Fp sum = - sig_w_2_limbs[j] + w_7_limbs[j] - carry * Fp(1 << 16) - w_cur_limbs[j] + prev_carry; - SHA256INNER_WRITE_ROUND(local_row, schedule_helper.intermed_12[i][j], -sum); - } - } -} - -__device__ void get_block_hash( - uint32_t hash[SHA256_HASH_WORDS], - const uint8_t input[SHA256_BLOCK_U8S] -) { - uint32_t work_vars[SHA256_HASH_WORDS]; - memcpy(work_vars, hash, SHA256_HASH_WORDS * sizeof(uint32_t)); - - uint32_t w[64]; - for (int i = 0; i < 16; i++) { - w[i] = u32_from_bytes_be(input + i * 4); - } - for (int i = 16; i < 64; i++) { - w[i] = small_sig1(w[i - 2]) + w[i - 7] + small_sig0(w[i - 15]) + w[i - 16]; - } - - for (int i = 0; i < 64; i++) { - uint32_t t1 = work_vars[7] + big_sig1(work_vars[4]) + - ch(work_vars[4], work_vars[5], work_vars[6]) + SHA256_K[i] + w[i]; - uint32_t t2 = big_sig0(work_vars[0]) + maj(work_vars[0], work_vars[1], work_vars[2]); - - uint32_t a = work_vars[0]; - uint32_t b = work_vars[1]; - uint32_t c = work_vars[2]; - uint32_t d = work_vars[3]; - uint32_t e = work_vars[4]; - uint32_t f = work_vars[5]; - uint32_t g = work_vars[6]; - uint32_t h = work_vars[7]; - - h = g; - g = f; - f = e; - e = d + t1; - d = c; - c = b; - b = a; - a = t1 + t2; - - work_vars[0] = a; - work_vars[1] = b; - work_vars[2] = c; - work_vars[3] = d; - work_vars[4] = e; - work_vars[5] = f; - work_vars[6] = g; - work_vars[7] = h; - } - - for (int i = 0; i < SHA256_HASH_WORDS; i++) { - hash[i] += work_vars[i]; - } -} - -__device__ void generate_block_trace( - Fp *trace, - size_t trace_height, - const uint32_t input[SHA256_BLOCK_WORDS], - uint32_t *bitwise_lookup_ptr, - uint32_t bitwise_num_bits, - const uint32_t prev_hash[SHA256_HASH_WORDS], - bool is_last_block, - uint32_t global_block_idx, - uint32_t local_block_idx -) { - BitwiseOperationLookup bitwise_lookup(bitwise_lookup_ptr, bitwise_num_bits); - Encoder row_idx_encoder(18, 2, false); - - uint32_t message_schedule[64]; - uint32_t work_vars[SHA256_HASH_WORDS]; - - memcpy(message_schedule, input, SHA256_BLOCK_WORDS * sizeof(uint32_t)); - memcpy(work_vars, prev_hash, SHA256_HASH_WORDS * sizeof(uint32_t)); - - for (int i = 0; i < SHA256_ROWS_PER_BLOCK; i++) { - RowSlice row_slice(trace + i, trace_height); - - if (i < 16) { - SHA256INNER_WRITE_ROUND(row_slice, flags.is_round_row, Fp::one()); - SHA256INNER_WRITE_ROUND( - row_slice, flags.is_first_4_rows, (i < 4) ? Fp::one() : Fp::zero() - ); - SHA256INNER_WRITE_ROUND(row_slice, flags.is_digest_row, Fp::zero()); - SHA256INNER_WRITE_ROUND( - row_slice, flags.is_last_block, is_last_block ? Fp::one() : Fp::zero() - ); - - RowSlice row_idx_flags = - row_slice.slice_from(COL_INDEX(Sha256RoundCols, flags.row_idx)); - row_idx_encoder.write_flag_pt(row_idx_flags, i); - - SHA256INNER_WRITE_ROUND(row_slice, flags.global_block_idx, global_block_idx); - SHA256INNER_WRITE_ROUND(row_slice, flags.local_block_idx, local_block_idx); - - if (i < 4) { - for (int j = 0; j < SHA256_ROUNDS_PER_ROW; j++) { - COL_WRITE_BITS( - row_slice, - Sha256RoundCols, - message_schedule.w[j], - input[i * SHA256_ROUNDS_PER_ROW + j] - ); - } - } else { - for (int j = 0; j < SHA256_ROUNDS_PER_ROW; j++) { - int idx = i * SHA256_ROUNDS_PER_ROW + j; - - uint32_t w = small_sig1(message_schedule[idx - 2]) + message_schedule[idx - 7] + - small_sig0(message_schedule[idx - 15]) + - message_schedule[idx - 16]; - - COL_WRITE_BITS(row_slice, Sha256RoundCols, message_schedule.w[j], w); - - for (int k = 0; k < SHA256_WORD_U16S; k++) { - uint32_t sum = u32_to_u16_limb(small_sig1(message_schedule[idx - 2]), k) + - u32_to_u16_limb(message_schedule[idx - 7], k) + - u32_to_u16_limb(small_sig0(message_schedule[idx - 15]), k) + - u32_to_u16_limb(message_schedule[idx - 16], k); - - if (k > 0) { - sum += row_slice[COL_INDEX( - Sha256RoundCols, - message_schedule.carry_or_buffer[j][k * 2 - 2] - )] - .asUInt32() + - 2 * row_slice[COL_INDEX( - Sha256RoundCols, - message_schedule.carry_or_buffer[j][k * 2 - 1] - )] - .asUInt32(); - } - - uint32_t carry = (sum - u32_to_u16_limb(w, k)) >> 16; - SHA256INNER_WRITE_ROUND( - row_slice, message_schedule.carry_or_buffer[j][k * 2], Fp(carry & 1) - ); - SHA256INNER_WRITE_ROUND( - row_slice, - message_schedule.carry_or_buffer[j][k * 2 + 1], - Fp(carry >> 1) - ); - } - message_schedule[idx] = w; - } - } - - for (int j = 0; j < SHA256_ROUNDS_PER_ROW; j++) { - int idx = i * SHA256_ROUNDS_PER_ROW + j; - - uint32_t t1 = work_vars[7] + big_sig1(work_vars[4]) + - ch(work_vars[4], work_vars[5], work_vars[6]) + SHA256_K[idx] + - message_schedule[idx]; - uint32_t t2 = - big_sig0(work_vars[0]) + maj(work_vars[0], work_vars[1], work_vars[2]); - uint32_t e = work_vars[3] + t1; - uint32_t a = t1 + t2; - - COL_WRITE_BITS(row_slice, Sha256RoundCols, work_vars.a[j], a); - COL_WRITE_BITS(row_slice, Sha256RoundCols, work_vars.e[j], e); - - uint32_t carry_a_values[SHA256_WORD_U16S] = {0}; - uint32_t carry_e_values[SHA256_WORD_U16S] = {0}; - - for (int k = 0; k < SHA256_WORD_U16S; k++) { - uint32_t t1_limb = - u32_to_u16_limb(work_vars[7], k) + - u32_to_u16_limb(big_sig1(work_vars[4]), k) + - u32_to_u16_limb(ch(work_vars[4], work_vars[5], work_vars[6]), k) + - u32_to_u16_limb(SHA256_K[idx], k) + - u32_to_u16_limb(message_schedule[idx], k); - - uint32_t t2_limb = - u32_to_u16_limb(big_sig0(work_vars[0]), k) + - u32_to_u16_limb(maj(work_vars[0], work_vars[1], work_vars[2]), k); - - uint32_t e_limb = t1_limb + u32_to_u16_limb(work_vars[3], k); - uint32_t a_limb = t1_limb + t2_limb; - - if (k > 0) { - a_limb += carry_a_values[k - 1]; - e_limb += carry_e_values[k - 1]; - } - - carry_a_values[k] = (a_limb - u32_to_u16_limb(a, k)) >> 16; - carry_e_values[k] = (e_limb - u32_to_u16_limb(e, k)) >> 16; - - SHA256INNER_WRITE_ROUND(row_slice, work_vars.carry_a[j][k], carry_a_values[k]); - SHA256INNER_WRITE_ROUND(row_slice, work_vars.carry_e[j][k], carry_e_values[k]); - - bitwise_lookup.add_range(carry_a_values[k], carry_e_values[k]); - } - - work_vars[7] = work_vars[6]; - work_vars[6] = work_vars[5]; - work_vars[5] = work_vars[4]; - work_vars[4] = e; - work_vars[3] = work_vars[2]; - work_vars[2] = work_vars[1]; - work_vars[1] = work_vars[0]; - work_vars[0] = a; - } - - if (i == 0) { - for (int j = 0; j < SHA256_ROUNDS_PER_ROW - 1; j++) { - for (int k = 0; k < SHA256_WORD_U16S; k++) { - SHA256INNER_WRITE_ROUND(row_slice, schedule_helper.w_3[j][k], Fp::zero()); - } - } - - for (int j = 0; j < SHA256_ROUNDS_PER_ROW; j++) { - for (int k = 0; k < SHA256_WORD_U16S; k++) { - SHA256INNER_WRITE_ROUND( - row_slice, schedule_helper.intermed_4[j][k], Fp::zero() - ); - SHA256INNER_WRITE_ROUND( - row_slice, schedule_helper.intermed_8[j][k], Fp::zero() - ); - SHA256INNER_WRITE_ROUND( - row_slice, schedule_helper.intermed_12[j][k], Fp::zero() - ); - } - } - } else if (i > 0) { - for (int j = 0; j < SHA256_ROUNDS_PER_ROW; j++) { - uint32_t idx = i * SHA256_ROUNDS_PER_ROW + j; - - uint32_t w_4 = message_schedule[idx - 4]; - uint32_t sig_0_w_3 = small_sig0(message_schedule[idx - 3]); - - SHA256INNER_WRITE_ROUND( - row_slice, - schedule_helper.intermed_4[j][0], - Fp(u32_to_u16_limb(w_4, 0) + u32_to_u16_limb(sig_0_w_3, 0)) - ); - SHA256INNER_WRITE_ROUND( - row_slice, - schedule_helper.intermed_4[j][1], - Fp(u32_to_u16_limb(w_4, 1) + u32_to_u16_limb(sig_0_w_3, 1)) - ); - - if (j < SHA256_ROUNDS_PER_ROW - 1) { - uint32_t w_3 = message_schedule[idx - 3]; - SHA256INNER_WRITE_ROUND( - row_slice, schedule_helper.w_3[j][0], u32_to_u16_limb(w_3, 0) - ); - SHA256INNER_WRITE_ROUND( - row_slice, schedule_helper.w_3[j][1], u32_to_u16_limb(w_3, 1) - ); - } - } - } - } else { - for (int j = 0; j < SHA256_ROUNDS_PER_ROW - 1; j++) { - uint32_t w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3]; - SHA256INNER_WRITE_DIGEST( - row_slice, schedule_helper.w_3[j][0], u32_to_u16_limb(w_3, 0) - ); - SHA256INNER_WRITE_DIGEST( - row_slice, schedule_helper.w_3[j][1], u32_to_u16_limb(w_3, 1) - ); - } - - SHA256INNER_WRITE_DIGEST(row_slice, flags.is_round_row, Fp::zero()); - SHA256INNER_WRITE_DIGEST(row_slice, flags.is_first_4_rows, Fp::zero()); - SHA256INNER_WRITE_DIGEST(row_slice, flags.is_digest_row, Fp::one()); - SHA256INNER_WRITE_DIGEST( - row_slice, flags.is_last_block, is_last_block ? Fp::one() : Fp::zero() - ); - - RowSlice row_idx_flags = - row_slice.slice_from(COL_INDEX(Sha256DigestCols, flags.row_idx)); - row_idx_encoder.write_flag_pt(row_idx_flags, 16); - - SHA256INNER_WRITE_DIGEST(row_slice, flags.global_block_idx, global_block_idx); - SHA256INNER_WRITE_DIGEST(row_slice, flags.local_block_idx, local_block_idx); - - uint32_t final_hash[SHA256_HASH_WORDS]; - for (int j = 0; j < SHA256_HASH_WORDS; j++) { - final_hash[j] = work_vars[j] + prev_hash[j]; - } - - for (int j = 0; j < SHA256_HASH_WORDS; j++) { - uint8_t *hash_bytes = (uint8_t *)&final_hash[j]; - SHA256INNER_WRITE_ARRAY_DIGEST(row_slice, final_hash[j], hash_bytes); - -#pragma unroll - for (int chunk = 0; chunk < SHA256_WORD_U8S; chunk += 2) { - bitwise_lookup.add_range( - (uint32_t)hash_bytes[chunk], (uint32_t)hash_bytes[chunk + 1] - ); - } - } - - for (int j = 0; j < SHA256_HASH_WORDS; j++) { - SHA256INNER_WRITE_DIGEST( - row_slice, prev_hash[j][0], u32_to_u16_limb(prev_hash[j], 0) - ); - SHA256INNER_WRITE_DIGEST( - row_slice, prev_hash[j][1], u32_to_u16_limb(prev_hash[j], 1) - ); - } - - uint32_t hash[SHA256_HASH_WORDS]; - if (is_last_block) { - for (int j = 0; j < SHA256_HASH_WORDS; j++) { - hash[j] = SHA256_H[j]; - } - } else { - for (int j = 0; j < SHA256_HASH_WORDS; j++) { - hash[j] = final_hash[j]; - } - } - - for (int j = 0; j < SHA256_ROUNDS_PER_ROW; j++) { - COL_WRITE_BITS( - row_slice, Sha256DigestCols, hash.a[j], hash[SHA256_ROUNDS_PER_ROW - j - 1] - ); - COL_WRITE_BITS( - row_slice, Sha256DigestCols, hash.e[j], hash[SHA256_ROUNDS_PER_ROW - j + 3] - ); - } - } - } - for (int i = 0; i < SHA256_ROWS_PER_BLOCK - 1; i++) { - RowSlice local_row(trace + i, trace_height); - RowSlice next_row(trace + i + 1, trace_height); - - for (int j = 0; j < SHA256_ROUNDS_PER_ROW; j++) { - for (int k = 0; k < SHA256_WORD_U16S; k++) { - Fp intermed_4_val = - local_row[COL_INDEX(Sha256RoundCols, schedule_helper.intermed_4[j][k])]; - SHA256INNER_WRITE_ROUND(next_row, schedule_helper.intermed_8[j][k], intermed_4_val); - } - } - - if (i >= 2 && i <= 13) { - for (int j = 0; j < SHA256_ROUNDS_PER_ROW; j++) { - for (int k = 0; k < SHA256_WORD_U16S; k++) { - Fp intermed_8_val = - local_row[COL_INDEX(Sha256RoundCols, schedule_helper.intermed_8[j][k])]; - SHA256INNER_WRITE_ROUND( - next_row, schedule_helper.intermed_12[j][k], intermed_8_val - ); - } - } - } - - if (i == SHA256_ROWS_PER_BLOCK - 2) { - - generate_carry_ae(local_row, next_row); - generate_intermed_4(local_row, next_row); - } - - if (i <= 2) { - generate_intermed_12(local_row, next_row); - } - } -} - -__device__ void generate_default_row(RowSlice row_slice) { - Encoder row_idx_encoder(18, 2, false); - RowSlice row_idx_flags = row_slice.slice_from(COL_INDEX(Sha256RoundCols, flags.row_idx)); - row_idx_encoder.write_flag_pt(row_idx_flags, 17); - - for (int i = 0; i < SHA256_ROUNDS_PER_ROW; i++) { - uint32_t a_word = SHA256_H[3 - i]; - uint32_t e_word = SHA256_H[7 - i]; - - for (int bit = 0; bit < SHA256_WORD_BITS; bit++) { - SHA256INNER_WRITE_ROUND(row_slice, work_vars.a[i][bit], Fp((a_word >> bit) & 1)); - SHA256INNER_WRITE_ROUND(row_slice, work_vars.e[i][bit], Fp((e_word >> bit) & 1)); - } - -#pragma unroll - for (int j = 0; j < SHA256_WORD_U16S; j++) { - SHA256INNER_WRITE_ROUND( - row_slice, work_vars.carry_a[i][j], SHA256_INVALID_CARRY_A[i][j] - ); - SHA256INNER_WRITE_ROUND( - row_slice, work_vars.carry_e[i][j], SHA256_INVALID_CARRY_E[i][j] - ); - } - } -} - -__device__ void generate_missing_cells(Fp *trace_chunk, size_t trace_height) { - RowSlice row15(trace_chunk + 15, trace_height); - RowSlice row16(trace_chunk + 16, trace_height); - RowSlice row17(trace_chunk + 17, trace_height); - - generate_intermed_12(row15, row16); - generate_intermed_12(row16, row17); - generate_intermed_4(row16, row17); -} - -__global__ void sha256_second_pass_dependencies( - Fp *inner_trace_start, - size_t trace_height, - size_t total_sha256_blocks -) { - uint32_t sha256_block_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (sha256_block_idx >= total_sha256_blocks) { - return; - } - - Fp *block_start = inner_trace_start + (sha256_block_idx * SHA256_ROWS_PER_BLOCK); - generate_missing_cells(block_start, trace_height); -} diff --git a/crates/circuits/sha256-air/cuda/include/sha256-air/utils.cuh b/crates/circuits/sha256-air/cuda/include/sha256-air/utils.cuh deleted file mode 100644 index d93f75a1d0..0000000000 --- a/crates/circuits/sha256-air/cuda/include/sha256-air/utils.cuh +++ /dev/null @@ -1,177 +0,0 @@ -#pragma once - -#include "primitives/constants.h" -#include "primitives/utils.cuh" -#include "system/memory/offline_checker.cuh" - -using namespace riscv; -using namespace sha256; - -__device__ __host__ inline uint32_t get_sha256_num_blocks(uint32_t len) { - uint32_t bit_len = len * 8; - uint32_t padded_bit_len = bit_len + 1 + 64; - return (padded_bit_len + 511) >> 9; -} - -struct Sha256VmRecordHeader { - uint32_t from_pc; - uint32_t timestamp; - uint32_t rd_ptr; - uint32_t rs1_ptr; - uint32_t rs2_ptr; - uint32_t dst_ptr; - uint32_t src_ptr; - uint32_t len; - MemoryReadAuxRecord register_reads_aux[SHA256_REGISTER_READS]; - MemoryWriteBytesAuxRecord write_aux; -}; - -struct Sha256VmRecordMut { - Sha256VmRecordHeader *header; - uint8_t *input; - MemoryReadAuxRecord *read_aux; - - __device__ __host__ __forceinline__ static uint32_t next_multiple_of( - uint32_t value, - uint32_t alignment - ) { - return ((value + alignment - 1) / alignment) * alignment; - } - - __device__ __host__ __forceinline__ Sha256VmRecordMut(uint8_t *record_buf) { - // Use memcpy for safe unaligned access instead of reinterpret_cast - header = reinterpret_cast(record_buf); - - uint32_t offset = sizeof(Sha256VmRecordHeader); - - input = record_buf + offset; - uint32_t num_blocks = get_sha256_num_blocks(header->len); - uint32_t input_size = num_blocks * SHA256_BLOCK_U8S; - - offset += input_size; - offset = next_multiple_of(offset, alignof(MemoryReadAuxRecord)); - - read_aux = reinterpret_cast(record_buf + offset); - } -}; - -__device__ static constexpr uint32_t SHA256_K[64] = { - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, - 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, - 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, - 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, - 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, -}; - -__device__ static constexpr uint32_t SHA256_H[8] = { - 0x6a09e667, - 0xbb67ae85, - 0x3c6ef372, - 0xa54ff53a, - 0x510e527f, - 0x9b05688c, - 0x1f83d9ab, - 0x5be0cd19, -}; - -__device__ static constexpr uint32_t SHA256_INVALID_CARRY_A[SHA256_ROUNDS_PER_ROW] - [SHA256_WORD_U16S] = { - {1230919683, 1162494304}, - {266373122, 1282901987}, - {1519718403, 1008990871}, - {923381762, 330807052}, -}; - -__device__ static constexpr uint32_t SHA256_INVALID_CARRY_E[SHA256_ROUNDS_PER_ROW] - [SHA256_WORD_U16S] = { - {204933122, 1994683449}, - {443873282, 1544639095}, - {719953922, 1888246508}, - {194580482, 1075725211}, -}; - -__device__ inline uint32_t ch(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ ((~x) & z); } - -__device__ inline uint32_t maj(uint32_t x, uint32_t y, uint32_t z) { - return (x & y) ^ (x & z) ^ (y & z); -} - -__device__ inline uint32_t big_sig0(uint32_t x) { return rotr(x, 2) ^ rotr(x, 13) ^ rotr(x, 22); } - -__device__ inline uint32_t big_sig1(uint32_t x) { return rotr(x, 6) ^ rotr(x, 11) ^ rotr(x, 25); } - -__device__ inline uint32_t small_sig0(uint32_t x) { return rotr(x, 7) ^ rotr(x, 18) ^ (x >> 3); } - -__device__ inline uint32_t small_sig1(uint32_t x) { return rotr(x, 17) ^ rotr(x, 19) ^ (x >> 10); } - -enum class PaddingFlags : uint32_t { - NotConsidered = 0, - NotPadding = 1, - FirstPadding0 = 2, - FirstPadding1 = 3, - FirstPadding2 = 4, - FirstPadding3 = 5, - FirstPadding4 = 6, - FirstPadding5 = 7, - FirstPadding6 = 8, - FirstPadding7 = 9, - FirstPadding8 = 10, - FirstPadding9 = 11, - FirstPadding10 = 12, - FirstPadding11 = 13, - FirstPadding12 = 14, - FirstPadding13 = 15, - FirstPadding14 = 16, - FirstPadding15 = 17, - FirstPadding0_LastRow = 18, - FirstPadding1_LastRow = 19, - FirstPadding2_LastRow = 20, - FirstPadding3_LastRow = 21, - FirstPadding4_LastRow = 22, - FirstPadding5_LastRow = 23, - FirstPadding6_LastRow = 24, - FirstPadding7_LastRow = 25, - EntirePaddingLastRow = 26, - EntirePadding = 27, - COUNT = 28, -}; - -__device__ inline MemoryReadAuxRecord *get_read_aux_record( - const Sha256VmRecordMut *record, - uint32_t block_idx, - uint32_t read_row_idx -) { - return &record->read_aux[block_idx * SHA256_NUM_READ_ROWS + read_row_idx]; -} - -__device__ inline uint16_t u32_to_u16_limb(uint32_t value, uint32_t limb_idx) { - return ((uint16_t *)&value)[limb_idx]; -} - -__device__ inline void sha256_pad_input( - const uint8_t *input, - uint32_t len, - uint8_t *padded_output, - uint32_t num_blocks -) { -#pragma unroll - for (uint32_t i = 0; i < len; i++) { - padded_output[i] = input[i]; - } - - padded_output[len] = 0x80; - - uint32_t total_len = num_blocks * SHA256_BLOCK_U8S; - for (uint32_t i = len + 1; i < total_len - 8; i++) { - padded_output[i] = 0; - } - - uint64_t bit_len = static_cast(len) * 8; -#pragma unroll - for (int i = 0; i < 8; i++) { - padded_output[total_len - 8 + i] = static_cast(bit_len >> (8 * (7 - i))); - } -} \ No newline at end of file diff --git a/crates/circuits/sha256-air/src/air.rs b/crates/circuits/sha256-air/src/air.rs deleted file mode 100644 index b27af6ffa9..0000000000 --- a/crates/circuits/sha256-air/src/air.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::{array, borrow::Borrow, cmp::max, iter::once}; - -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, - encoder::Encoder, - utils::{not, select}, - SubAir, -}; -use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, - p3_air::{AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra}, - p3_matrix::Matrix, -}; - -use super::{ - big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, - small_sig1_field, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH, SHA256_H, - SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, SHA256_WORD_BITS, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; -use crate::{constraint_word_addition, u32_into_u16s}; - -/// Expects the message to be padded to a multiple of 512 bits -#[derive(Clone, Debug)] -pub struct Sha256Air { - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - pub row_idx_encoder: Encoder, - /// Internal bus for self-interactions in this AIR. - bus: PermutationCheckBus, -} - -impl Sha256Air { - pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self { - Self { - bitwise_lookup_bus, - row_idx_encoder: Encoder::new(18, 2, false), - bus: PermutationCheckBus::new(self_bus_idx), - } - } -} - -impl BaseAir for Sha256Air { - fn width(&self) -> usize { - max( - Sha256RoundCols::::width(), - Sha256DigestCols::::width(), - ) - } -} - -impl SubAir for Sha256Air { - /// The start column for the sub-air to use - type AirContext<'a> - = usize - where - Self: 'a, - AB: 'a, - ::Var: 'a, - ::Expr: 'a; - - fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>) - where - ::Var: 'a, - ::Expr: 'a, - { - self.eval_row(builder, start_col); - self.eval_transitions(builder, start_col); - } -} - -impl Sha256Air { - /// Implements the single row constraints (i.e. imposes constraints only on local) - /// Implements some sanity constraints on the row index, flags, and work variables - fn eval_row(&self, builder: &mut AB, start_col: usize) { - let main = builder.main(); - let local = main.row_slice(0); - - // Doesn't matter which column struct we use here as we are only interested in the common - // columns - let local_cols: &Sha256DigestCols = - local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - let flags = &local_cols.flags; - builder.assert_bool(flags.is_round_row); - builder.assert_bool(flags.is_first_4_rows); - builder.assert_bool(flags.is_digest_row); - builder.assert_bool(flags.is_round_row + flags.is_digest_row); - builder.assert_bool(flags.is_last_block); - - self.row_idx_encoder - .eval(builder, &local_cols.flags.row_idx); - builder.assert_one( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=17), - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=3), - flags.is_first_4_rows, - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=15), - flags.is_round_row, - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag::(&local_cols.flags.row_idx, &[16]), - flags.is_digest_row, - ); - // If padding row we want the row_idx to be 17 - builder.assert_eq( - self.row_idx_encoder - .contains_flag::(&local_cols.flags.row_idx, &[17]), - flags.is_padding_row(), - ); - - // Constrain a, e, being composed of bits: we make sure a and e are always in the same place - // in the trace matrix Note: this has to be true for every row, even padding rows - for i in 0..SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_WORD_BITS { - builder.assert_bool(local_cols.hash.a[i][j]); - builder.assert_bool(local_cols.hash.e[i][j]); - } - } - } - - /// Implements constraints for a digest row that ensure proper state transitions between blocks - /// This validates that: - /// The work variables are correctly initialized for the next message block - /// For the last message block, the initial state matches SHA256_H constants - fn eval_digest_row( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256DigestCols, - ) { - // Check that if this is the last row of a message or an inpadding row, the hash should be - // the [SHA256_H] - for i in 0..SHA256_ROUNDS_PER_ROW { - let a = next.hash.a[i].map(|x| x.into()); - let e = next.hash.e[i].map(|x| x.into()); - for j in 0..SHA256_WORD_U16S { - let a_limb = compose::(&a[j * 16..(j + 1) * 16], 1); - let e_limb = compose::(&e[j * 16..(j + 1) * 16], 1); - - // If it is a padding row or the last row of a message, the `hash` should be the - // [SHA256_H] - builder - .when( - next.flags.is_padding_row() - + next.flags.is_last_block * next.flags.is_digest_row, - ) - .assert_eq( - a_limb, - AB::Expr::from_canonical_u32( - u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j], - ), - ); - - builder - .when( - next.flags.is_padding_row() - + next.flags.is_last_block * next.flags.is_digest_row, - ) - .assert_eq( - e_limb, - AB::Expr::from_canonical_u32( - u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j], - ), - ); - } - } - - // Check if last row of a non-last block, the `hash` should be equal to the final hash of - // the current block - for i in 0..SHA256_ROUNDS_PER_ROW { - let prev_a = next.hash.a[i].map(|x| x.into()); - let prev_e = next.hash.e[i].map(|x| x.into()); - let cur_a = next.final_hash[SHA256_ROUNDS_PER_ROW - i - 1].map(|x| x.into()); - - let cur_e = next.final_hash[SHA256_ROUNDS_PER_ROW - i + 3].map(|x| x.into()); - for j in 0..SHA256_WORD_U8S { - let prev_a_limb = compose::(&prev_a[j * 8..(j + 1) * 8], 1); - let prev_e_limb = compose::(&prev_e[j * 8..(j + 1) * 8], 1); - - builder - .when(not(next.flags.is_last_block) * next.flags.is_digest_row) - .assert_eq(prev_a_limb, cur_a[j].clone()); - - builder - .when(not(next.flags.is_last_block) * next.flags.is_digest_row) - .assert_eq(prev_e_limb, cur_e[j].clone()); - } - } - - // Assert that the previous hash + work vars == final hash. - // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]` - // where addition is done modulo 2^32 - for i in 0..SHA256_HASH_WORDS { - let mut carry = AB::Expr::ZERO; - for j in 0..SHA256_WORD_U16S { - let work_var_limb = if i < SHA256_ROUNDS_PER_ROW { - compose::( - &local.work_vars.a[SHA256_ROUNDS_PER_ROW - 1 - i][j * 16..(j + 1) * 16], - 1, - ) - } else { - compose::( - &local.work_vars.e[SHA256_ROUNDS_PER_ROW + 3 - i][j * 16..(j + 1) * 16], - 1, - ) - }; - let final_hash_limb = - compose::(&next.final_hash[i][j * 2..(j + 1) * 2], 8); - - carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse()) - * (next.prev_hash[i][j] + work_var_limb + carry - final_hash_limb); - builder - .when(next.flags.is_digest_row) - .assert_bool(carry.clone()); - } - // constrain the final hash limbs two at a time since we can do two checks per - // interaction - for chunk in next.final_hash[i].chunks(2) { - self.bitwise_lookup_bus - .send_range(chunk[0], chunk[1]) - .eval(builder, next.flags.is_digest_row); - } - } - } - - fn eval_transitions(&self, builder: &mut AB, start_col: usize) { - let main = builder.main(); - let local = main.row_slice(0); - let next = main.row_slice(1); - - // Doesn't matter what column structs we use here - let local_cols: &Sha256RoundCols = - local[start_col..start_col + SHA256_ROUND_WIDTH].borrow(); - let next_cols: &Sha256RoundCols = - next[start_col..start_col + SHA256_ROUND_WIDTH].borrow(); - - let local_is_padding_row = local_cols.flags.is_padding_row(); - // Note that there will always be a padding row in the trace since the unpadded height is a - // multiple of 17. So the next row is padding iff the current block is the last - // block in the trace. - let next_is_padding_row = next_cols.flags.is_padding_row(); - - // We check that the very last block has `is_last_block` set to true, which guarantees that - // there is at least one complete message. If other digest rows have `is_last_block` set to - // true, then the trace will be interpreted as containing multiple messages. - builder - .when(next_is_padding_row.clone()) - .when(local_cols.flags.is_digest_row) - .assert_one(local_cols.flags.is_last_block); - // If we are in a round row, the next row cannot be a padding row - builder - .when(local_cols.flags.is_round_row) - .assert_zero(next_is_padding_row.clone()); - // The first row must be a round row - builder - .when_first_row() - .assert_one(local_cols.flags.is_round_row); - // If we are in a padding row, the next row must also be a padding row - builder - .when_transition() - .when(local_is_padding_row.clone()) - .assert_one(next_is_padding_row.clone()); - // If we are in a digest row, the next row cannot be a digest row - builder - .when(local_cols.flags.is_digest_row) - .assert_zero(next_cols.flags.is_digest_row); - // Constrain how much the row index changes by - // round->round: 1 - // round->digest: 1 - // digest->round: -16 - // digest->padding: 1 - // padding->padding: 0 - // Other transitions are not allowed by the above constraints - let delta = local_cols.flags.is_round_row * AB::Expr::ONE - + local_cols.flags.is_digest_row - * next_cols.flags.is_round_row - * AB::Expr::from_canonical_u32(16) - * AB::Expr::NEG_ONE - + local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE; - - let local_row_idx = self.row_idx_encoder.flag_with_val::( - &local_cols.flags.row_idx, - &(0..18).map(|i| (i, i)).collect::>(), - ); - let next_row_idx = self.row_idx_encoder.flag_with_val::( - &next_cols.flags.row_idx, - &(0..18).map(|i| (i, i)).collect::>(), - ); - - builder - .when_transition() - .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone()); - builder.when_first_row().assert_zero(local_row_idx); - - // Constrain the global block index - // We set the global block index to 0 for padding rows - // Starting with 1 so it is not the same as the padding rows - - // Global block index is 1 on first row - builder - .when_first_row() - .assert_one(local_cols.flags.global_block_idx); - - // Global block index is constant on all rows in a block - builder.when(local_cols.flags.is_round_row).assert_eq( - local_cols.flags.global_block_idx, - next_cols.flags.global_block_idx, - ); - // Global block index increases by 1 between blocks - builder - .when_transition() - .when(local_cols.flags.is_digest_row) - .when(next_cols.flags.is_round_row) - .assert_eq( - local_cols.flags.global_block_idx + AB::Expr::ONE, - next_cols.flags.global_block_idx, - ); - // Global block index is 0 on padding rows - builder - .when(local_is_padding_row.clone()) - .assert_zero(local_cols.flags.global_block_idx); - - // Constrain the local block index - // We set the local block index to 0 for padding rows - - // Local block index is constant on all rows in a block - // and its value on padding rows is equal to its value on the first block - builder.when(not(local_cols.flags.is_digest_row)).assert_eq( - local_cols.flags.local_block_idx, - next_cols.flags.local_block_idx, - ); - // Local block index increases by 1 between blocks in the same message - builder - .when(local_cols.flags.is_digest_row) - .when(not(local_cols.flags.is_last_block)) - .assert_eq( - local_cols.flags.local_block_idx + AB::Expr::ONE, - next_cols.flags.local_block_idx, - ); - // Local block index is 0 on padding rows - // Combined with the above, this means that the local block index is 0 in the first block - builder - .when(local_cols.flags.is_digest_row) - .when(local_cols.flags.is_last_block) - .assert_zero(next_cols.flags.local_block_idx); - - self.eval_message_schedule::(builder, local_cols, next_cols); - self.eval_work_vars::(builder, local_cols, next_cols); - let next_cols: &Sha256DigestCols = - next[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - self.eval_digest_row(builder, local_cols, next_cols); - let local_cols: &Sha256DigestCols = - local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - self.eval_prev_hash::(builder, local_cols, next_is_padding_row); - } - - /// Constrains that the next block's `prev_hash` is equal to the current block's `hash` - /// Note: the constraining is done by interactions with the chip itself on every digest row - fn eval_prev_hash( - &self, - builder: &mut AB, - local: &Sha256DigestCols, - is_last_block_of_trace: AB::Expr, /* note this indicates the last block of the trace, - * not the last block of the message */ - ) { - // Constrain that next block's `prev_hash` is equal to the current block's `hash` - let composed_hash: [[::Expr; SHA256_WORD_U16S]; SHA256_HASH_WORDS] = - array::from_fn(|i| { - let hash_bits = if i < SHA256_ROUNDS_PER_ROW { - local.hash.a[SHA256_ROUNDS_PER_ROW - 1 - i].map(|x| x.into()) - } else { - local.hash.e[SHA256_ROUNDS_PER_ROW + 3 - i].map(|x| x.into()) - }; - array::from_fn(|j| compose::(&hash_bits[j * 16..(j + 1) * 16], 1)) - }); - // Need to handle the case if this is the very last block of the trace matrix - let next_global_block_idx = select( - is_last_block_of_trace, - AB::Expr::ONE, - local.flags.global_block_idx + AB::Expr::ONE, - ); - // The following interactions constrain certain values from block to block - self.bus.send( - builder, - composed_hash - .into_iter() - .flatten() - .chain(once(next_global_block_idx)), - local.flags.is_digest_row, - ); - - self.bus.receive( - builder, - local - .prev_hash - .into_iter() - .flatten() - .map(|x| x.into()) - .chain(once(local.flags.global_block_idx.into())), - local.flags.is_digest_row, - ); - } - - /// Constrain the message schedule additions for `next` row - /// Note: For every addition we need to constrain the following for each of [SHA256_WORD_U16S] - /// limbs sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] + - /// carry_w[t][i-1] - carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] - fn eval_message_schedule( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256RoundCols, - ) { - // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx - let w = [local.message_schedule.w, next.message_schedule.w].concat(); - - // Constrain `w_3` for `next` row - for i in 0..SHA256_ROUNDS_PER_ROW - 1 { - // here we constrain the w_3 of the i_th word of the next row - // w_3 of next is w[i+4-3] = w[i+1] - let w_3 = w[i + 1].map(|x| x.into()); - let expected_w_3 = next.schedule_helper.w_3[i]; - for j in 0..SHA256_WORD_U16S { - let w_3_limb = compose::(&w_3[j * 16..(j + 1) * 16], 1); - builder - .when(local.flags.is_round_row) - .assert_eq(w_3_limb, expected_w_3[j].into()); - } - } - - // Constrain intermed for `next` row - // We will only constrain intermed_12 for rows [3, 14], and let it be unconstrained for - // other rows Other rows should put the needed value in intermed_12 to make the - // below summation constraint hold - let is_row_3_14 = self - .row_idx_encoder - .contains_flag_range::(&next.flags.row_idx, 3..=14); - // We will only constrain intermed_8 for rows [2, 13], and let it unconstrained for other - // rows - let is_row_2_13 = self - .row_idx_encoder - .contains_flag_range::(&next.flags.row_idx, 2..=13); - for i in 0..SHA256_ROUNDS_PER_ROW { - // w_idx - let w_idx = w[i].map(|x| x.into()); - // sig_0(w_{idx+1}) - let sig_w = small_sig0_field::(&w[i + 1]); - for j in 0..SHA256_WORD_U16S { - let w_idx_limb = compose::(&w_idx[j * 16..(j + 1) * 16], 1); - let sig_w_limb = compose::(&sig_w[j * 16..(j + 1) * 16], 1); - - // We would like to constrain this only on rows 0..16, but we can't do a conditional - // check because the degree is already 3. So we must fill in - // `intermed_4` with dummy values on rows 0 and 16 to ensure the constraint holds on - // these rows. - builder.when_transition().assert_eq( - next.schedule_helper.intermed_4[i][j], - w_idx_limb + sig_w_limb, - ); - - builder.when(is_row_2_13.clone()).assert_eq( - next.schedule_helper.intermed_8[i][j], - local.schedule_helper.intermed_4[i][j], - ); - - builder.when(is_row_3_14.clone()).assert_eq( - next.schedule_helper.intermed_12[i][j], - local.schedule_helper.intermed_8[i][j], - ); - } - } - - // Constrain the message schedule additions for `next` row - for i in 0..SHA256_ROUNDS_PER_ROW { - // Note, here by w_{t} we mean the i_th word of the `next` row - // w_{t-7} - let w_7 = if i < 3 { - local.schedule_helper.w_3[i].map(|x| x.into()) - } else { - let w_3 = w[i - 3].map(|x| x.into()); - array::from_fn(|j| compose::(&w_3[j * 16..(j + 1) * 16], 1)) - }; - // sig_0(w_{t-15}) + w_{t-16} - let intermed_16 = local.schedule_helper.intermed_12[i].map(|x| x.into()); - - let carries = array::from_fn(|j| { - next.message_schedule.carry_or_buffer[i][j * 2] - + AB::Expr::TWO * next.message_schedule.carry_or_buffer[i][j * 2 + 1] - }); - - // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}` - // We would like to constrain this only on rows 4..16, but we can't do a conditional - // check because the degree of sum is already 3 So we must fill in - // `intermed_12` with dummy values on rows 0..3 and 15 and 16 to ensure the constraint - // holds on rows 0..4 and 16. Note that the dummy value goes in the previous - // row to make the current row's constraint hold. - constraint_word_addition( - // Note: here we can't do a conditional check because the degree of sum is already - // 3 - &mut builder.when_transition(), - &[&small_sig1_field::(&w[i + 2])], - &[&w_7, &intermed_16], - &w[i + 4], - &carries, - ); - - for j in 0..SHA256_WORD_U16S { - // When on rows 4..16 message schedule carries should be 0 or 1 - let is_row_4_15 = next.flags.is_round_row - next.flags.is_first_4_rows; - builder - .when(is_row_4_15.clone()) - .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2]); - builder - .when(is_row_4_15) - .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2 + 1]); - } - // Constrain w being composed of bits - for j in 0..SHA256_WORD_BITS { - builder - .when(next.flags.is_round_row) - .assert_bool(next.message_schedule.w[i][j]); - } - } - } - - /// Constrain the work vars on `next` row according to the sha256 documentation - /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] - fn eval_work_vars( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256RoundCols, - ) { - let a = [local.work_vars.a, next.work_vars.a].concat(); - let e = [local.work_vars.e, next.work_vars.e].concat(); - for i in 0..SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_WORD_U16S { - // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in - // [0, 2^8) is enough to prevent overflow and ensure the soundness - // of the addition we want to check - self.bitwise_lookup_bus - .send_range(local.work_vars.carry_a[i][j], local.work_vars.carry_e[i][j]) - .eval(builder, local.flags.is_round_row); - } - - let w_limbs = array::from_fn(|j| { - compose::(&next.message_schedule.w[i][j * 16..(j + 1) * 16], 1) - * next.flags.is_round_row - }); - let k_limbs = array::from_fn(|j| { - self.row_idx_encoder.flag_with_val::( - &next.flags.row_idx, - &(0..16) - .map(|rw_idx| { - ( - rw_idx, - u32_into_u16s(SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i])[j] - as usize, - ) - }) - .collect::>(), - ) - }); - - // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)` - // We have to enforce this constraint on all rows since the degree of the constraint is - // already 3. So, we must fill in `carry_a` with dummy values on digest rows - // to ensure the constraint holds. - constraint_word_addition( - builder, - &[ - &e[i].map(|x| x.into()), // previous `h` - &big_sig1_field::(&e[i + 3]), // sig_1 of previous `e` - &ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous - * `e`, `f`, `g` */ - &big_sig0_field::(&a[i + 3]), // sig_0 of previous `a` - &maj_field::(&a[i + 3], &a[i + 2], &a[i + 1]), /* Maj of previous - * a, b, c */ - ], - &[&w_limbs, &k_limbs], // K and W - &a[i + 4], // new `a` - &next.work_vars.carry_a[i], // carries of addition - ); - - // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W` - // We have to enforce this constraint on all rows since the degree of the constraint is - // already 3. So, we must fill in `carry_e` with dummy values on digest rows - // to ensure the constraint holds. - constraint_word_addition( - builder, - &[ - &a[i].map(|x| x.into()), // previous `d` - &e[i].map(|x| x.into()), // previous `h` - &big_sig1_field::(&e[i + 3]), /* sig_1 of previous - * `e` */ - &ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous - * `e`, `f`, `g` */ - ], - &[&w_limbs, &k_limbs], // K and W - &e[i + 4], // new `e` - &next.work_vars.carry_e[i], // carries of addition - ); - } - } -} diff --git a/crates/circuits/sha256-air/src/columns.rs b/crates/circuits/sha256-air/src/columns.rs deleted file mode 100644 index 1c735394c3..0000000000 --- a/crates/circuits/sha256-air/src/columns.rs +++ /dev/null @@ -1,140 +0,0 @@ -//! WARNING: the order of fields in the structs is important, do not change it - -use openvm_circuit_primitives::{utils::not, AlignedBorrow}; -use openvm_stark_backend::p3_field::FieldAlgebra; - -use super::{ - SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROW_VAR_CNT, SHA256_WORD_BITS, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -/// In each SHA256 block: -/// - First 16 rows use Sha256RoundCols -/// - Final row uses Sha256DigestCols -/// -/// Note that for soundness, we require that there is always a padding row after the last digest row -/// in the trace. Right now, this is true because the unpadded height is a multiple of 17, and thus -/// not a power of 2. -/// -/// Sha256RoundCols and Sha256DigestCols share the same first 3 fields: -/// - flags -/// - work_vars/hash (same type, different name) -/// - schedule_helper -/// -/// This design allows for: -/// 1. Common constraints to work on either struct type by accessing these shared fields -/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional -/// constraints -/// -/// Note that the `Sha256WorkVarsCols` field it is used for different purposes in the two structs. -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256RoundCols { - pub flags: Sha256FlagsCols, - /// Stores the current state of the working variables - pub work_vars: Sha256WorkVarsCols, - pub schedule_helper: Sha256MessageHelperCols, - pub message_schedule: Sha256MessageScheduleCols, -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256DigestCols { - pub flags: Sha256FlagsCols, - /// Will serve as previous hash values for the next block. - /// - on non-last blocks, this is the final hash of the current block - /// - on last blocks, this is the initial state constants, SHA256_H. - /// The work variables constraints are applied on all rows, so `carry_a` and `carry_e` - /// must be filled in with dummy values to ensure these constraints hold. - pub hash: Sha256WorkVarsCols, - pub schedule_helper: Sha256MessageHelperCols, - /// The actual final hash values of the given block - /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block - pub final_hash: [[T; SHA256_WORD_U8S]; SHA256_HASH_WORDS], - /// The final hash of the previous block - /// Note: will be constrained using interactions with the chip itself - pub prev_hash: [[T; SHA256_WORD_U16S]; SHA256_HASH_WORDS], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256MessageScheduleCols { - /// The message schedule words as 32-bit integers - /// The first 16 words will be the message data - pub w: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - /// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used - /// freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells as - /// individual bits - pub carry_or_buffer: [[T; SHA256_WORD_U8S]; SHA256_ROUNDS_PER_ROW], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256WorkVarsCols { - /// `a` and `e` after each iteration as 32-bits - pub a: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - pub e: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - /// The carry's used for addition during each iteration when computing `a` and `e` - pub carry_a: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub carry_e: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], -} - -/// These are the columns that are used to help with the message schedule additions -/// Note: these need to be correctly assigned for every row even on padding rows -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256MessageHelperCols { - /// The following are used to move data forward to constrain the message schedule additions - /// The value of `w` (message schedule word) from 3 rounds ago - /// In general, `w_i` means `w` from `i` rounds ago - pub w_3: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW - 1], - /// Here intermediate(i) = w_i + sig_0(w_{i+1}) - /// Intermed_t represents the intermediate t rounds ago - /// This is needed to constrain the message schedule, since we can only constrain on two rows - /// at a time - pub intermed_4: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub intermed_8: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub intermed_12: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256FlagsCols { - /// A flag that indicates if the current row is among the first 16 rows of a block. - pub is_round_row: T, - /// A flag that indicates if the current row is among the first 4 rows of a block. - pub is_first_4_rows: T, - /// A flag that indicates if the current row is the last (17th) row of a block. - pub is_digest_row: T, - // A flag that indicates if the current row is the last block of the message. - // This flag is only used in digest rows. - pub is_last_block: T, - /// We will encode the row index [0..17) using 5 cells - pub row_idx: [T; SHA256_ROW_VAR_CNT], - /// The index of the current block in the trace starting at 1. - /// Set to 0 on padding rows. - pub global_block_idx: T, - /// The index of the current block in the current message starting at 0. - /// Resets after every message. - /// Set to 0 on padding rows. - pub local_block_idx: T, -} - -impl> Sha256FlagsCols { - // This refers to the padding rows that are added to the air to make the trace length a power of - // 2. Not to be confused with the padding added to messages as part of the SHA hash - // function. - pub fn is_not_padding_row(&self) -> O { - self.is_round_row + self.is_digest_row - } - - // This refers to the padding rows that are added to the air to make the trace length a power of - // 2. Not to be confused with the padding added to messages as part of the SHA hash - // function. - pub fn is_padding_row(&self) -> O - where - O: FieldAlgebra, - { - not(self.is_not_padding_row()) - } -} diff --git a/crates/circuits/sha256-air/src/lib.rs b/crates/circuits/sha256-air/src/lib.rs deleted file mode 100644 index 48bdaee5f9..0000000000 --- a/crates/circuits/sha256-air/src/lib.rs +++ /dev/null @@ -1,15 +0,0 @@ -//! Implementation of the SHA256 compression function without padding -//! This this AIR doesn't constrain any of the message padding - -mod air; -mod columns; -mod trace; -mod utils; - -pub use air::*; -pub use columns::*; -pub use trace::*; -pub use utils::*; - -#[cfg(test)] -mod tests; diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha256-air/src/tests.rs deleted file mode 100644 index 7ad0229185..0000000000 --- a/crates/circuits/sha256-air/src/tests.rs +++ /dev/null @@ -1,163 +0,0 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; - -use openvm_circuit::arch::{ - instructions::riscv::RV32_CELL_BITS, - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::{ - BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, - SharedBitwiseOperationLookupChip, - }, - SubAir, -}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - interaction::{BusIndex, InteractionBuilder}, - p3_air::{Air, BaseAir}, - p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::{cpu::CpuBackend, types::AirProvingContext}, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - utils::disable_debug_builder, - verifier::VerificationError, - AirRef, Chip, -}; -use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; - -use crate::{ - Sha256Air, Sha256DigestCols, Sha256FillerHelper, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_WIDTH, SHA256_WORD_U8S, -}; - -// A wrapper AIR purely for testing purposes -#[derive(Clone, Debug)] -pub struct Sha256TestAir { - pub sub_air: Sha256Air, -} - -impl BaseAirWithPublicValues for Sha256TestAir {} -impl PartitionedBaseAir for Sha256TestAir {} -impl BaseAir for Sha256TestAir { - fn width(&self) -> usize { - >::width(&self.sub_air) - } -} - -impl Air for Sha256TestAir { - fn eval(&self, builder: &mut AB) { - self.sub_air.eval(builder, 0); - } -} - -const SELF_BUS_IDX: BusIndex = 28; -type F = BabyBear; -type RecordType = Vec<([u8; SHA256_BLOCK_U8S], bool)>; - -// A wrapper Chip purely for testing purposes -pub struct Sha256TestChip { - pub step: Sha256FillerHelper, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, -} - -impl Chip> for Sha256TestChip -where - Val: PrimeField32, -{ - fn generate_proving_ctx(&self, records: RecordType) -> AirProvingContext> { - let trace = crate::generate_trace::>( - &self.step, - self.bitwise_lookup_chip.as_ref(), - SHA256_WIDTH, - records, - ); - AirProvingContext::simple_no_pis(Arc::new(trace)) - } -} - -#[allow(clippy::type_complexity)] -fn create_air_with_air_ctx() -> ( - (AirRef, AirProvingContext>), - ( - BitwiseOperationLookupAir, - SharedBitwiseOperationLookupChip, - ), -) -where - Val: PrimeField32, -{ - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - let len = rng.gen_range(1..100); - let random_records: Vec<_> = (0..len) - .map(|i| { - ( - array::from_fn(|_| rng.gen::()), - rng.gen::() || i == len - 1, - ) - }) - .collect(); - - let air = Sha256TestAir { - sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), - }; - let chip = Sha256TestChip { - step: Sha256FillerHelper::new(), - bitwise_lookup_chip: bitwise_chip.clone(), - }; - let air_ctx = chip.generate_proving_ctx(random_records); - - ((Arc::new(air), air_ctx), (bitwise_chip.air, bitwise_chip)) -} - -#[test] -fn rand_sha256_test() { - let tester = VmChipTestBuilder::default(); - let (air_ctx, bitwise) = create_air_with_air_ctx(); - let tester = tester - .build() - .load_air_proving_ctx(air_ctx) - .load_periphery(bitwise) - .finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn negative_sha256_test_bad_final_hash() { - let tester = VmChipTestBuilder::default(); - let ((air, mut air_ctx), bitwise) = create_air_with_air_ctx(); - - // Set the final_hash to all zeros - let modify_trace = |trace: &mut RowMajorMatrix| { - trace.row_chunks_exact_mut(1).for_each(|row| { - let mut row_slice = row.row_slice(0).to_vec(); - let cols: &mut Sha256DigestCols = row_slice[..SHA256_DIGEST_WIDTH].borrow_mut(); - if cols.flags.is_last_block.is_one() && cols.flags.is_digest_row.is_one() { - for i in 0..SHA256_HASH_WORDS { - for j in 0..SHA256_WORD_U8S { - cols.final_hash[i][j] = F::ZERO; - } - } - row.values.copy_from_slice(&row_slice); - } - }); - }; - - // Modify the air_ctx - let trace = Option::take(&mut air_ctx.common_main).unwrap(); - let mut trace = Arc::into_inner(trace).unwrap(); - modify_trace(&mut trace); - air_ctx.common_main = Some(Arc::new(trace)); - - disable_debug_builder(); - let tester = tester - .build() - .load_air_proving_ctx((air, air_ctx)) - .load_periphery(bitwise) - .finalize(); - tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); -} diff --git a/crates/circuits/sha256-air/src/trace.rs b/crates/circuits/sha256-air/src/trace.rs deleted file mode 100644 index 8cbaebbc55..0000000000 --- a/crates/circuits/sha256-air/src/trace.rs +++ /dev/null @@ -1,558 +0,0 @@ -use std::{array, borrow::BorrowMut, ops::Range}; - -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupChip, encoder::Encoder, - utils::next_power_of_two_or_zero, -}; -use openvm_stark_backend::{ - p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, -}; -use sha2::{compress256, digest::generic_array::GenericArray}; - -use super::{ - big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, get_flag_pt_array, - maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, -}; -use crate::{ - big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1, - u32_into_bits_field, u32_into_u16s, SHA256_BLOCK_U8S, SHA256_H, SHA256_INVALID_CARRY_A, - SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -/// A helper struct for the SHA256 trace generation. -/// Also, separates the inner AIR from the trace generation. -pub struct Sha256FillerHelper { - pub row_idx_encoder: Encoder, -} - -impl Default for Sha256FillerHelper { - fn default() -> Self { - Self::new() - } -} - -/// The trace generation of SHA256 should be done in two passes. -/// The first pass should do `get_block_trace` for every block and generate the invalid rows through -/// `get_default_row` The second pass should go through all the blocks and call -/// `generate_missing_cells` -impl Sha256FillerHelper { - pub fn new() -> Self { - Self { - row_idx_encoder: Encoder::new(18, 2, false), - } - } - /// This function takes the input_message (padding not handled), the previous hash, - /// and returns the new hash after processing the block input - pub fn get_block_hash( - prev_hash: &[u32; SHA256_HASH_WORDS], - input: [u8; SHA256_BLOCK_U8S], - ) -> [u32; SHA256_HASH_WORDS] { - let mut new_hash = *prev_hash; - let input_array = [GenericArray::from(input)]; - compress256(&mut new_hash, &input_array); - new_hash - } - - /// This function takes a 512-bit chunk of the input message (padding not handled), the previous - /// hash, a flag indicating if it's the last block, the global block index, the local block - /// index, and the buffer values that will be put in rows 0..4. - /// Will populate the given `trace` with the trace of the block, where the width of the trace is - /// `trace_width` and the starting column for the `Sha256Air` is `trace_start_col`. - /// **Note**: this function only generates some of the required trace. Another pass is required, - /// refer to [`Self::generate_missing_cells`] for details. - #[allow(clippy::too_many_arguments)] - pub fn generate_block_trace( - &self, - trace: &mut [F], - trace_width: usize, - trace_start_col: usize, - input: &[u32; SHA256_BLOCK_WORDS], - bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, - prev_hash: &[u32; SHA256_HASH_WORDS], - is_last_block: bool, - global_block_idx: u32, - local_block_idx: u32, - ) { - #[cfg(debug_assertions)] - { - assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK); - assert!(trace_start_col + super::SHA256_WIDTH <= trace_width); - if local_block_idx == 0 { - assert!(*prev_hash == SHA256_H); - } - } - let get_range = |start: usize, len: usize| -> Range { start..start + len }; - let mut message_schedule = [0u32; 64]; - message_schedule[..input.len()].copy_from_slice(input); - let mut work_vars = *prev_hash; - for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() { - // doing the 64 rounds in 16 rows - if i < 16 { - let cols: &mut Sha256RoundCols = - row[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - cols.flags.is_round_row = F::ONE; - cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO }; - cols.flags.is_digest_row = F::ZERO; - cols.flags.is_last_block = F::from_bool(is_last_block); - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, i).map(F::from_canonical_u32); - cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); - cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); - - // W_idx = M_idx - if i < 4 { - for j in 0..SHA256_ROUNDS_PER_ROW { - cols.message_schedule.w[j] = - u32_into_bits_field::(input[i * SHA256_ROUNDS_PER_ROW + j]); - } - } - // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} - else { - for j in 0..SHA256_ROUNDS_PER_ROW { - let idx = i * SHA256_ROUNDS_PER_ROW + j; - let nums: [u32; 4] = [ - small_sig1(message_schedule[idx - 2]), - message_schedule[idx - 7], - small_sig0(message_schedule[idx - 15]), - message_schedule[idx - 16], - ]; - let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - cols.message_schedule.w[j] = u32_into_bits_field::(w); - - let nums_limbs = nums.map(u32_into_u16s); - let w_limbs = u32_into_u16s(w); - - // fill in the carrys - for k in 0..SHA256_WORD_U16S { - let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]); - if k > 0 { - sum += (cols.message_schedule.carry_or_buffer[j][k * 2 - 2] - + F::TWO * cols.message_schedule.carry_or_buffer[j][k * 2 - 1]) - .as_canonical_u32(); - } - let carry = (sum - w_limbs[k]) >> 16; - cols.message_schedule.carry_or_buffer[j][k * 2] = - F::from_canonical_u32(carry & 1); - cols.message_schedule.carry_or_buffer[j][k * 2 + 1] = - F::from_canonical_u32(carry >> 1); - } - // update the message schedule - message_schedule[idx] = w; - } - } - // fill in the work variables - for j in 0..SHA256_ROUNDS_PER_ROW { - // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx - let t1 = [ - work_vars[7], - big_sig1(work_vars[4]), - ch(work_vars[4], work_vars[5], work_vars[6]), - SHA256_K[i * SHA256_ROUNDS_PER_ROW + j], - limbs_into_u32(cols.message_schedule.w[j].map(|f| f.as_canonical_u32())), - ]; - let t1_sum: u32 = t1.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - - // t2 = SIG0(a) + maj(a, b, c) - let t2 = [ - big_sig0(work_vars[0]), - maj(work_vars[0], work_vars[1], work_vars[2]), - ]; - - let t2_sum: u32 = t2.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - - // e = d + t1 - let e = work_vars[3].wrapping_add(t1_sum); - cols.work_vars.e[j] = u32_into_bits_field::(e); - let e_limbs = u32_into_u16s(e); - // a = t1 + t2 - let a = t1_sum.wrapping_add(t2_sum); - cols.work_vars.a[j] = u32_into_bits_field::(a); - let a_limbs = u32_into_u16s(a); - // fill in the carrys - for k in 0..SHA256_WORD_U16S { - let t1_limb = t1.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); - let t2_limb = t2.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); - - let mut e_limb = t1_limb + u32_into_u16s(work_vars[3])[k]; - let mut a_limb = t1_limb + t2_limb; - if k > 0 { - a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32(); - e_limb += cols.work_vars.carry_e[j][k - 1].as_canonical_u32(); - } - let carry_a = (a_limb - a_limbs[k]) >> 16; - let carry_e = (e_limb - e_limbs[k]) >> 16; - cols.work_vars.carry_a[j][k] = F::from_canonical_u32(carry_a); - cols.work_vars.carry_e[j][k] = F::from_canonical_u32(carry_e); - bitwise_lookup_chip.request_range(carry_a, carry_e); - } - - // update working variables - work_vars[7] = work_vars[6]; - work_vars[6] = work_vars[5]; - work_vars[5] = work_vars[4]; - work_vars[4] = e; - work_vars[3] = work_vars[2]; - work_vars[2] = work_vars[1]; - work_vars[1] = work_vars[0]; - work_vars[0] = a; - } - - // filling w_3 and intermed_4 here and the rest later - if i > 0 { - for j in 0..SHA256_ROUNDS_PER_ROW { - let idx = i * SHA256_ROUNDS_PER_ROW + j; - let w_4 = u32_into_u16s(message_schedule[idx - 4]); - let sig_0_w_3 = u32_into_u16s(small_sig0(message_schedule[idx - 3])); - cols.schedule_helper.intermed_4[j] = - array::from_fn(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])); - if j < SHA256_ROUNDS_PER_ROW - 1 { - let w_3 = message_schedule[idx - 3]; - cols.schedule_helper.w_3[j] = - u32_into_u16s(w_3).map(F::from_canonical_u32); - } - } - } - } - // generate the digest row - else { - let cols: &mut Sha256DigestCols = - row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut(); - for j in 0..SHA256_ROUNDS_PER_ROW - 1 { - let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3]; - cols.schedule_helper.w_3[j] = u32_into_u16s(w_3).map(F::from_canonical_u32); - } - cols.flags.is_round_row = F::ZERO; - cols.flags.is_first_4_rows = F::ZERO; - cols.flags.is_digest_row = F::ONE; - cols.flags.is_last_block = F::from_bool(is_last_block); - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, 16).map(F::from_canonical_u32); - cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); - - cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); - let final_hash: [u32; SHA256_HASH_WORDS] = - array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i])); - let final_hash_limbs: [[u8; SHA256_WORD_U8S]; SHA256_HASH_WORDS] = - array::from_fn(|i| final_hash[i].to_le_bytes()); - // need to ensure final hash limbs are bytes, in order for - // prev_hash[i] + work_vars[i] == final_hash[i] - // to be constrained correctly - for word in final_hash_limbs.iter() { - for chunk in word.chunks(2) { - bitwise_lookup_chip.request_range(chunk[0] as u32, chunk[1] as u32); - } - } - cols.final_hash = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u8(final_hash_limbs[i][j])) - }); - cols.prev_hash = prev_hash.map(|f| u32_into_u16s(f).map(F::from_canonical_u32)); - let hash = if is_last_block { - SHA256_H.map(u32_into_bits_field::) - } else { - cols.final_hash - .map(|f| u32::from_le_bytes(f.map(|x| x.as_canonical_u32() as u8))) - .map(u32_into_bits_field::) - }; - - for i in 0..SHA256_ROUNDS_PER_ROW { - cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; - cols.hash.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3]; - } - } - } - - for i in 0..SHA256_ROWS_PER_BLOCK - 1 { - let rows = &mut trace[i * trace_width..(i + 2) * trace_width]; - let (local, next) = rows.split_at_mut(trace_width); - let local_cols: &mut Sha256RoundCols = - local[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - let next_cols: &mut Sha256RoundCols = - next[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - if i > 0 { - for j in 0..SHA256_ROUNDS_PER_ROW { - next_cols.schedule_helper.intermed_8[j] = - local_cols.schedule_helper.intermed_4[j]; - if (2..SHA256_ROWS_PER_BLOCK - 3).contains(&i) { - next_cols.schedule_helper.intermed_12[j] = - local_cols.schedule_helper.intermed_8[j]; - } - } - } - if i == SHA256_ROWS_PER_BLOCK - 2 { - // `next` is a digest row. - // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and - // `e` hold. - Self::generate_carry_ae(local_cols, next_cols); - // Fill in row 16's `intermed_4` with dummy values so the message schedule - // constraints holds on that row - Self::generate_intermed_4(local_cols, next_cols); - } - if i <= 2 { - // i is in 0..3. - // Fill in `local.intermed_12` with dummy values so the message schedule constraints - // hold on rows 1..4. - Self::generate_intermed_12(local_cols, next_cols); - } - } - } - - /// This function will fill in the cells that we couldn't do during the first pass. - /// This function should be called only after `generate_block_trace` was called for all blocks - /// And [`Self::generate_default_row`] is called for all invalid rows - /// Will populate the missing values of `trace`, where the width of the trace is `trace_width` - /// and the starting column for the `Sha256Air` is `trace_start_col`. - /// Note: `trace` needs to be the rows 1..17 of a block and the first row of the next block - pub fn generate_missing_cells( - &self, - trace: &mut [F], - trace_width: usize, - trace_start_col: usize, - ) { - // Here row_17 = next blocks row 0 - let rows_15_17 = &mut trace[14 * trace_width..17 * trace_width]; - let (row_15, row_16_17) = rows_15_17.split_at_mut(trace_width); - let (row_16, row_17) = row_16_17.split_at_mut(trace_width); - let cols_15: &mut Sha256RoundCols = - row_15[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - let cols_16: &mut Sha256RoundCols = - row_16[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - let cols_17: &mut Sha256RoundCols = - row_17[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - // Fill in row 15's `intermed_12` with dummy values so the message schedule constraints - // holds on row 16 - Self::generate_intermed_12(cols_15, cols_16); - // Fill in row 16's `intermed_12` with dummy values so the message schedule constraints - // holds on the next block's row 0 - Self::generate_intermed_12(cols_16, cols_17); - // Fill in row 0's `intermed_4` with dummy values so the message schedule constraints holds - // on that row - Self::generate_intermed_4(cols_16, cols_17); - } - - /// Fills the `cols` as a padding row - /// Note: we still need to correctly fill in the hash values, carries and intermeds - pub fn generate_default_row( - self: &Sha256FillerHelper, - cols: &mut Sha256RoundCols, - ) { - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, 17).map(F::from_canonical_u32); - - let hash = SHA256_H.map(u32_into_bits_field::); - - for i in 0..SHA256_ROUNDS_PER_ROW { - cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; - cols.work_vars.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3]; - } - - cols.work_vars.carry_a = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_A[i][j])) - }); - cols.work_vars.carry_e = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_E[i][j])) - }); - } - - /// The following functions do the calculations in native field since they will be called on - /// padding rows which can overflow and we need to make sure it matches the AIR constraints - /// Puts the correct carrys in the `next_row`, the resulting carrys can be out of bound - fn generate_carry_ae( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, - ) { - let a = [local_cols.work_vars.a, next_cols.work_vars.a].concat(); - let e = [local_cols.work_vars.e, next_cols.work_vars.e].concat(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let cur_a = a[i + 4]; - let sig_a = big_sig0_field::(&a[i + 3]); - let maj_abc = maj_field::(&a[i + 3], &a[i + 2], &a[i + 1]); - let d = a[i]; - let cur_e = e[i + 4]; - let sig_e = big_sig1_field::(&e[i + 3]); - let ch_efg = ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]); - let h = e[i]; - - let t1 = [h, sig_e, ch_efg]; - let t2 = [sig_a, maj_abc]; - for j in 0..SHA256_WORD_U16S { - let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| { - acc + compose::(&x[j * 16..(j + 1) * 16], 1) - }); - let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| { - acc + compose::(&x[j * 16..(j + 1) * 16], 1) - }); - let d_limb = compose::(&d[j * 16..(j + 1) * 16], 1); - let cur_a_limb = compose::(&cur_a[j * 16..(j + 1) * 16], 1); - let cur_e_limb = compose::(&cur_e[j * 16..(j + 1) * 16], 1); - let sum = d_limb - + t1_limb_sum - + if j == 0 { - F::ZERO - } else { - next_cols.work_vars.carry_e[i][j - 1] - } - - cur_e_limb; - let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse()); - - let sum = t1_limb_sum - + t2_limb_sum - + if j == 0 { - F::ZERO - } else { - next_cols.work_vars.carry_a[i][j - 1] - } - - cur_a_limb; - let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse()); - next_cols.work_vars.carry_e[i][j] = carry_e; - next_cols.work_vars.carry_a[i][j] = carry_a; - } - } - } - - /// Puts the correct intermed_4 in the `next_row` - fn generate_intermed_4( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, - ) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let sig_w = small_sig0_field::(&w[i + 1]); - let sig_w_limbs: [F; SHA256_WORD_U16S] = - array::from_fn(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)); - for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { - next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb; - } - } - } - - /// Puts the needed intermed_12 in the `local_row` - fn generate_intermed_12( - local_cols: &mut Sha256RoundCols, - next_cols: &Sha256RoundCols, - ) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - // sig_1(w_{t-2}) - let sig_w_2: [F; SHA256_WORD_U16S] = array::from_fn(|j| { - compose::(&small_sig1_field::(&w[i + 2])[j * 16..(j + 1) * 16], 1) - }); - // w_{t-7} - let w_7 = if i < 3 { - local_cols.schedule_helper.w_3[i] - } else { - w_limbs[i - 3] - }; - // w_t - let w_cur = w_limbs[i + 4]; - for j in 0..SHA256_WORD_U16S { - let carry = next_cols.message_schedule.carry_or_buffer[i][j * 2] - + F::TWO * next_cols.message_schedule.carry_or_buffer[i][j * 2 + 1]; - let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j] - + if j > 0 { - next_cols.message_schedule.carry_or_buffer[i][j * 2 - 2] - + F::from_canonical_u32(2) - * next_cols.message_schedule.carry_or_buffer[i][j * 2 - 1] - } else { - F::ZERO - }; - local_cols.schedule_helper.intermed_12[i][j] = -sum; - } - } - } -} - -/// Generates a trace for a standalone SHA256 computation (currently only used for testing) -/// `records` consists of pairs of `(input_block, is_last_block)`. -pub fn generate_trace( - step: &Sha256FillerHelper, - bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, - width: usize, - records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, -) -> RowMajorMatrix { - let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK; - let height = next_power_of_two_or_zero(non_padded_height); - let mut values = F::zero_vec(height * width); - - struct BlockContext { - prev_hash: [u32; 8], - local_block_idx: u32, - global_block_idx: u32, - input: [u8; SHA256_BLOCK_U8S], - is_last_block: bool, - } - let mut block_ctx: Vec = Vec::with_capacity(records.len()); - let mut prev_hash = SHA256_H; - let mut local_block_idx = 0; - let mut global_block_idx = 1; - for (input, is_last_block) in records { - block_ctx.push(BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - }); - global_block_idx += 1; - if is_last_block { - local_block_idx = 0; - prev_hash = SHA256_H; - } else { - local_block_idx += 1; - prev_hash = Sha256FillerHelper::get_block_hash(&prev_hash, input); - } - } - // first pass - values - .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK) - .zip(block_ctx) - .for_each(|(block, ctx)| { - let BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - } = ctx; - let input_words = array::from_fn(|i| { - limbs_into_u32::(array::from_fn(|j| { - input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32 - })) - }); - step.generate_block_trace( - block, - width, - 0, - &input_words, - bitwise_lookup_chip, - &prev_hash, - is_last_block, - global_block_idx, - local_block_idx, - ); - }); - // second pass: padding rows - values[width * non_padded_height..] - .par_chunks_mut(width) - .for_each(|row| { - let cols: &mut Sha256RoundCols = row.borrow_mut(); - step.generate_default_row(cols); - }); - // second pass: non-padding rows - values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - step.generate_missing_cells(chunk, width, 0); - }); - RowMajorMatrix::new(values, width) -} diff --git a/crates/circuits/sha256-air/src/utils.rs b/crates/circuits/sha256-air/src/utils.rs deleted file mode 100644 index ba598f2604..0000000000 --- a/crates/circuits/sha256-air/src/utils.rs +++ /dev/null @@ -1,271 +0,0 @@ -use std::array; - -pub use openvm_circuit_primitives::utils::compose; -use openvm_circuit_primitives::{ - encoder::Encoder, - utils::{not, select}, -}; -use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra}; - -use super::{Sha256DigestCols, Sha256RoundCols}; - -// ==== Do not change these constants! ==== -/// Number of bits in a SHA256 word -pub const SHA256_WORD_BITS: usize = 32; -/// Number of 16-bit limbs in a SHA256 word -pub const SHA256_WORD_U16S: usize = SHA256_WORD_BITS / 16; -/// Number of 8-bit limbs in a SHA256 word -pub const SHA256_WORD_U8S: usize = SHA256_WORD_BITS / 8; -/// Number of words in a SHA256 block -pub const SHA256_BLOCK_WORDS: usize = 16; -/// Number of cells in a SHA256 block -pub const SHA256_BLOCK_U8S: usize = SHA256_BLOCK_WORDS * SHA256_WORD_U8S; -/// Number of bits in a SHA256 block -pub const SHA256_BLOCK_BITS: usize = SHA256_BLOCK_WORDS * SHA256_WORD_BITS; -/// Number of rows per block -pub const SHA256_ROWS_PER_BLOCK: usize = 17; -/// Number of rounds per row -pub const SHA256_ROUNDS_PER_ROW: usize = 4; -/// Number of words in a SHA256 hash -pub const SHA256_HASH_WORDS: usize = 8; -/// Number of vars needed to encode the row index with [Encoder] -pub const SHA256_ROW_VAR_CNT: usize = 5; -/// Width of the Sha256RoundCols -pub const SHA256_ROUND_WIDTH: usize = Sha256RoundCols::::width(); -/// Width of the Sha256DigestCols -pub const SHA256_DIGEST_WIDTH: usize = Sha256DigestCols::::width(); -/// Size of the buffer of the first 4 rows of a block (each row's size) -pub const SHA256_BUFFER_SIZE: usize = SHA256_ROUNDS_PER_ROW * SHA256_WORD_U16S * 2; -/// Width of the Sha256Cols -pub const SHA256_WIDTH: usize = if SHA256_ROUND_WIDTH > SHA256_DIGEST_WIDTH { - SHA256_ROUND_WIDTH -} else { - SHA256_DIGEST_WIDTH -}; -/// We can notice that `carry_a`'s and `carry_e`'s are always the same on invalid rows -/// To optimize the trace generation of invalid rows, we have those values precomputed here -pub(crate) const SHA256_INVALID_CARRY_A: [[u32; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW] = [ - [1230919683, 1162494304], - [266373122, 1282901987], - [1519718403, 1008990871], - [923381762, 330807052], -]; -pub(crate) const SHA256_INVALID_CARRY_E: [[u32; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW] = [ - [204933122, 1994683449], - [443873282, 1544639095], - [719953922, 1888246508], - [194580482, 1075725211], -]; -/// SHA256 constant K's -pub const SHA256_K: [u32; 64] = [ - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, - 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, - 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, - 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, - 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, -]; - -/// SHA256 initial hash values -pub const SHA256_H: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, -]; - -/// Returns the number of blocks required to hash a message of length `len` -pub fn get_sha256_num_blocks(len: u32) -> u32 { - // need to pad with one 1 bit, 64 bits for the message length and then pad until the length - // is divisible by [SHA256_BLOCK_BITS] - ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS) as u32 -} - -/// Convert a u32 into a list of bits in little endian then convert each bit into a field element -pub fn u32_into_bits_field(num: u32) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| F::from_bool((num >> i) & 1 == 1)) -} - -/// Convert a u32 into a an array of 2 16-bit limbs in little endian -pub fn u32_into_u16s(num: u32) -> [u32; 2] { - [num & 0xffff, num >> 16] -} - -/// Convert a list of limbs in little endian into a u32 -pub fn limbs_into_u32(limbs: [u32; NUM_LIMBS]) -> u32 { - let limb_bits = 32 / NUM_LIMBS; - limbs - .iter() - .rev() - .fold(0, |acc, &limb| (acc << limb_bits) | limb) -} - -/// Rotates `bits` right by `n` bits, assumes `bits` is in little-endian -#[inline] -pub(crate) fn rotr( - bits: &[impl Into + Clone; SHA256_WORD_BITS], - n: usize, -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| bits[(i + n) % SHA256_WORD_BITS].clone().into()) -} - -/// Shifts `bits` right by `n` bits, assumes `bits` is in little-endian -#[inline] -pub(crate) fn shr( - bits: &[impl Into + Clone; SHA256_WORD_BITS], - n: usize, -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| { - if i + n < SHA256_WORD_BITS { - bits[i + n].clone().into() - } else { - F::ZERO - } - }) -} - -/// Computes x ^ y ^ z, where x, y, z are assumed to be boolean -#[inline] -pub(crate) fn xor_bit( - x: impl Into, - y: impl Into, - z: impl Into, -) -> F { - let (x, y, z) = (x.into(), y.into(), z.into()); - (x.clone() * y.clone() * z.clone()) - + (x.clone() * not::(y.clone()) * not::(z.clone())) - + (not::(x.clone()) * y.clone() * not::(z.clone())) - + (not::(x) * not::(y) * z) -} - -/// Computes x ^ y ^ z, where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn xor( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| xor_bit(x[i].clone(), y[i].clone(), z[i].clone())) -} - -/// Choose function from SHA256 -#[inline] -pub fn ch(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ ((!x) & z) -} - -/// Computes Ch(x,y,z), where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn ch_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| select(x[i].clone(), y[i].clone(), z[i].clone())) -} - -/// Majority function from SHA256 -pub fn maj(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (x & z) ^ (y & z) -} - -/// Computes Maj(x,y,z), where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn maj_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| { - let (x, y, z) = ( - x[i].clone().into(), - y[i].clone().into(), - z[i].clone().into(), - ); - x.clone() * y.clone() + x.clone() * z.clone() + y.clone() * z.clone() - F::TWO * x * y * z - }) -} - -/// Big sigma_0 function from SHA256 -pub fn big_sig0(x: u32) -> u32 { - x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) -} - -/// Computes BigSigma0(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn big_sig0_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 2), &rotr::(x, 13), &rotr::(x, 22)) -} - -/// Big sigma_1 function from SHA256 -pub fn big_sig1(x: u32) -> u32 { - x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) -} - -/// Computes BigSigma1(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn big_sig1_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 6), &rotr::(x, 11), &rotr::(x, 25)) -} - -/// Small sigma_0 function from SHA256 -pub fn small_sig0(x: u32) -> u32 { - x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) -} - -/// Computes SmallSigma0(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn small_sig0_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 7), &rotr::(x, 18), &shr::(x, 3)) -} - -/// Small sigma_1 function from SHA256 -pub fn small_sig1(x: u32) -> u32 { - x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) -} - -/// Computes SmallSigma1(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn small_sig1_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) -} - -/// Wrapper of `get_flag_pt` to get the flag pointer as an array -pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> [u32; N] { - encoder.get_flag_pt(flag_idx).try_into().unwrap() -} - -/// Constrain the addition of [SHA256_WORD_BITS] bit words in 16-bit limbs -/// It takes in the terms some in bits some in 16-bit limbs, -/// the expected sum in bits and the carries -pub fn constraint_word_addition( - builder: &mut AB, - terms_bits: &[&[impl Into + Clone; SHA256_WORD_BITS]], - terms_limb: &[&[impl Into + Clone; SHA256_WORD_U16S]], - expected_sum: &[impl Into + Clone; SHA256_WORD_BITS], - carries: &[impl Into + Clone; SHA256_WORD_U16S], -) { - for i in 0..SHA256_WORD_U16S { - let mut limb_sum = if i == 0 { - AB::Expr::ZERO - } else { - carries[i - 1].clone().into() - }; - for term in terms_bits { - limb_sum += compose::(&term[i * 16..(i + 1) * 16], 1); - } - for term in terms_limb { - limb_sum += term[i].clone().into(); - } - let expected_sum_limb = compose::(&expected_sum[i * 16..(i + 1) * 16], 1) - + carries[i].clone().into() * AB::Expr::from_canonical_u32(1 << 16); - builder.assert_eq(limb_sum, expected_sum_limb); - } -} diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index bb9e0e4f89..a498372a36 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -18,8 +18,8 @@ openvm-ecc-circuit = { workspace = true } openvm-ecc-transpiler = { workspace = true } openvm-keccak256-circuit = { workspace = true } openvm-keccak256-transpiler = { workspace = true } -openvm-sha256-circuit = { workspace = true } -openvm-sha256-transpiler = { workspace = true } +openvm-sha2-circuit = { workspace = true } +openvm-sha2-transpiler = { workspace = true } openvm-pairing-circuit = { workspace = true } openvm-pairing-transpiler = { workspace = true } openvm-native-circuit = { workspace = true } @@ -86,7 +86,7 @@ tco = [ "openvm-circuit/tco", "openvm-rv32im-circuit/tco", "openvm-native-circuit/tco", - "openvm-sha256-circuit/tco", + "openvm-sha2-circuit/tco", "openvm-keccak256-circuit/tco", "openvm-bigint-circuit/tco", "openvm-algebra-circuit/tco", @@ -97,7 +97,7 @@ aot = [ "openvm-circuit/aot", "openvm-rv32im-circuit/aot", "openvm-native-circuit/aot", - "openvm-sha256-circuit/aot", + "openvm-sha2-circuit/aot", "openvm-keccak256-circuit/aot", "openvm-bigint-circuit/aot", "openvm-algebra-circuit/aot", @@ -129,7 +129,7 @@ cuda = [ "openvm-bigint-circuit/cuda", "openvm-ecc-circuit/cuda", "openvm-keccak256-circuit/cuda", - "openvm-sha256-circuit/cuda", + "openvm-sha2-circuit/cuda", "openvm-pairing-circuit/cuda", "openvm-native-circuit/cuda", "openvm-rv32im-circuit/cuda", diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index 9699b1ed34..d3c4ec7297 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -33,8 +33,8 @@ use openvm_rv32im_circuit::{ use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; -use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_sha2_circuit::{Sha2, Sha2CpuProverExt, Sha2Executor}; +use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::StarkEngine, @@ -55,7 +55,7 @@ cfg_if::cfg_if! { use openvm_keccak256_circuit::Keccak256GpuProverExt; use openvm_native_circuit::NativeGpuProverExt; use openvm_rv32im_circuit::Rv32ImGpuProverExt; - use openvm_sha256_circuit::Sha256GpuProverExt; + use openvm_sha2_circuit::Sha2GpuProverExt; pub use SdkVmGpuBuilder as SdkVmBuilder; } else { pub use SdkVmCpuBuilder as SdkVmBuilder; @@ -81,7 +81,7 @@ pub struct SdkVmConfig { pub rv32i: Option, pub io: Option, pub keccak: Option, - pub sha256: Option, + pub sha2: Option, pub native: Option, pub castf: Option, @@ -118,7 +118,7 @@ impl SdkVmConfig { .rv32m(Default::default()) .io(Default::default()) .keccak(Default::default()) - .sha256(Default::default()) + .sha2(Default::default()) .bigint(Default::default()) .modular(ModularExtension::new(vec![ bn_config.modulus.clone(), @@ -199,8 +199,8 @@ impl TranspilerConfig for SdkVmConfig { if self.keccak.is_some() { transpiler = transpiler.with_extension(Keccak256TranspilerExtension); } - if self.sha256.is_some() { - transpiler = transpiler.with_extension(Sha256TranspilerExtension); + if self.sha2.is_some() { + transpiler = transpiler.with_extension(Sha2TranspilerExtension); } if self.native.is_some() { transpiler = transpiler.with_extension(LongFormTranspilerExtension); @@ -269,7 +269,7 @@ impl SdkVmConfig { let rv32i = config.rv32i.map(|_| Rv32I); let io = config.io.map(|_| Rv32Io); let keccak = config.keccak.map(|_| Keccak256); - let sha256 = config.sha256.map(|_| Sha256); + let sha2 = config.sha2.map(|_| Sha2); let native = config.native.map(|_| Native); let castf = config.castf.map(|_| CastFExtension); let rv32m = config.rv32m; @@ -284,7 +284,7 @@ impl SdkVmConfig { rv32i, io, keccak, - sha256, + sha2, native, castf, rv32m, @@ -315,8 +315,8 @@ pub struct SdkVmConfigInner { pub io: Option, #[extension(executor = "Keccak256Executor")] pub keccak: Option, - #[extension(executor = "Sha256Executor")] - pub sha256: Option, + #[extension(executor = "Sha2Executor")] + pub sha2: Option, #[extension(executor = "NativeExecutor")] pub native: Option, #[extension(executor = "CastFExtensionExecutor")] @@ -392,8 +392,8 @@ where if let Some(keccak) = &config.keccak { VmProverExtension::::extend_prover(&Keccak256CpuProverExt, keccak, inventory)?; } - if let Some(sha256) = &config.sha256 { - VmProverExtension::::extend_prover(&Sha2CpuProverExt, sha256, inventory)?; + if let Some(sha2) = &config.sha2 { + VmProverExtension::::extend_prover(&Sha2CpuProverExt, sha2, inventory)?; } if let Some(native) = &config.native { VmProverExtension::::extend_prover(&NativeCpuProverExt, native, inventory)?; @@ -456,8 +456,8 @@ impl VmBuilder for SdkVmGpuBuilder { if let Some(keccak) = &config.keccak { VmProverExtension::::extend_prover(&Keccak256GpuProverExt, keccak, inventory)?; } - if let Some(sha256) = &config.sha256 { - VmProverExtension::::extend_prover(&Sha256GpuProverExt, sha256, inventory)?; + if let Some(sha2) = &config.sha2 { + VmProverExtension::::extend_prover(&Sha2GpuProverExt, sha2, inventory)?; } if let Some(native) = &config.native { VmProverExtension::::extend_prover(&NativeGpuProverExt, native, inventory)?; @@ -566,8 +566,8 @@ impl From for UnitStruct { } } -impl From for UnitStruct { - fn from(_: Sha256) -> Self { +impl From for UnitStruct { + fn from(_: Sha2) -> Self { UnitStruct {} } } @@ -592,7 +592,7 @@ struct SdkVmConfigWithDefaultDeser { pub rv32i: Option, pub io: Option, pub keccak: Option, - pub sha256: Option, + pub sha2: Option, pub native: Option, pub castf: Option, @@ -611,7 +611,7 @@ impl From for SdkVmConfig { rv32i: config.rv32i, io: config.io, keccak: config.keccak, - sha256: config.sha256, + sha2: config.sha2, native: config.native, castf: config.castf, rv32m: config.rv32m, diff --git a/crates/sdk/src/config/openvm_standard.toml b/crates/sdk/src/config/openvm_standard.toml index f1f9267191..eac2c0a81b 100644 --- a/crates/sdk/src/config/openvm_standard.toml +++ b/crates/sdk/src/config/openvm_standard.toml @@ -3,7 +3,7 @@ [app_vm_config.io] [app_vm_config.keccak] -[app_vm_config.sha256] +[app_vm_config.sha2] [app_vm_config.bigint] [app_vm_config.modular] diff --git a/crates/toolchain/openvm/src/io/mod.rs b/crates/toolchain/openvm/src/io/mod.rs index eb00a9d3cd..05f073073e 100644 --- a/crates/toolchain/openvm/src/io/mod.rs +++ b/crates/toolchain/openvm/src/io/mod.rs @@ -6,7 +6,7 @@ use core::alloc::Layout; use core::fmt::Write; #[cfg(target_os = "zkvm")] -use openvm_rv32im_guest::{hint_buffer_u32, hint_input, hint_store_u32}; +use openvm_rv32im_guest::{hint_buffer_chunked, hint_input, hint_store_u32}; use serde::de::DeserializeOwned; #[cfg(not(target_os = "zkvm"))] @@ -83,7 +83,7 @@ pub(crate) fn read_vec_by_len(len: usize) -> Vec { // The heap-embedded-alloc uses linked list allocator, which has a minimum alignment of // `sizeof(usize) * 2 = 8` on 32-bit architectures: https://github.com/rust-osdev/linked-list-allocator/blob/b5caf3271259ddda60927752fa26527e0ccd2d56/src/hole.rs#L429 let mut bytes = Vec::with_capacity(capacity); - hint_buffer_u32!(bytes.as_mut_ptr(), num_words); + hint_buffer_chunked(bytes.as_mut_ptr(), num_words as usize); // SAFETY: We populate a `Vec` by hintstore-ing `num_words` 4 byte words. We set the // length to `len` and don't care about the extra `capacity - len` bytes stored. unsafe { diff --git a/crates/toolchain/openvm/src/io/read.rs b/crates/toolchain/openvm/src/io/read.rs index 39b2166e39..f2eff6cfa5 100644 --- a/crates/toolchain/openvm/src/io/read.rs +++ b/crates/toolchain/openvm/src/io/read.rs @@ -2,7 +2,7 @@ use core::mem::MaybeUninit; use openvm_platform::WORD_SIZE; #[cfg(target_os = "zkvm")] -use openvm_rv32im_guest::hint_buffer_u32; +use openvm_rv32im_guest::hint_buffer_chunked; use super::hint_store_word; use crate::serde::WordRead; @@ -31,7 +31,7 @@ impl WordRead for Reader { let num_words = words.len(); if let Some(new_remaining) = self.bytes_remaining.checked_sub(num_words * WORD_SIZE) { #[cfg(target_os = "zkvm")] - hint_buffer_u32!(words.as_mut_ptr(), words.len()); + hint_buffer_chunked(words.as_mut_ptr() as *mut u8, words.len()); #[cfg(not(target_os = "zkvm"))] { for w in words.iter_mut() { @@ -51,7 +51,7 @@ impl WordRead for Reader { } let mut num_padded_bytes = bytes.len(); #[cfg(target_os = "zkvm")] - hint_buffer_u32!(bytes as *mut [u8] as *mut u32, num_padded_bytes / WORD_SIZE); + hint_buffer_chunked(bytes.as_mut_ptr(), num_padded_bytes / WORD_SIZE); #[cfg(not(target_os = "zkvm"))] { let mut words = bytes.chunks_exact_mut(WORD_SIZE); diff --git a/crates/toolchain/openvm/src/pal_abi.rs b/crates/toolchain/openvm/src/pal_abi.rs index 0ab3d3f386..3797998bb8 100644 --- a/crates/toolchain/openvm/src/pal_abi.rs +++ b/crates/toolchain/openvm/src/pal_abi.rs @@ -5,7 +5,7 @@ /// system operations in the same way: there is no operating system and even the standard /// library should be directly handled with intrinsics. use openvm_platform::{fileno::*, memory::sys_alloc_aligned, rust_rt::terminate, WORD_SIZE}; -use openvm_rv32im_guest::{hint_buffer_u32, hint_random, raw_print_str_from_bytes}; +use openvm_rv32im_guest::{hint_buffer_chunked, hint_random, raw_print_str_from_bytes}; const DIGEST_WORDS: usize = 8; @@ -73,7 +73,7 @@ pub unsafe extern "C" fn sys_sha_buffer( #[no_mangle] pub unsafe extern "C" fn sys_rand(recv_buf: *mut u32, words: usize) { hint_random(words); - hint_buffer_u32!(recv_buf, words); + hint_buffer_chunked(recv_buf as *mut u8, words); } /// # Safety diff --git a/crates/vm/src/arch/execution.rs b/crates/vm/src/arch/execution.rs index 234dfbd5b9..0b7a13bfe9 100644 --- a/crates/vm/src/arch/execution.rs +++ b/crates/vm/src/arch/execution.rs @@ -38,6 +38,12 @@ pub enum ExecutionError { DisabledOperation { pc: u32, opcode: VmOpcode }, #[error("at pc = {pc}")] HintOutOfBounds { pc: u32 }, + #[error("at pc {pc}, hint buffer num_words {num_words} exceeds MAX_HINT_BUFFER_WORDS {max_hint_buffer_words}")] + HintBufferTooLarge { + pc: u32, + num_words: u32, + max_hint_buffer_words: u32, + }, #[error("at pc {pc}, tried to publish into index {public_value_index} when num_public_values = {num_public_values}")] PublicValueIndexOutOfBounds { pc: u32, diff --git a/crates/vm/src/arch/testing/cpu.rs b/crates/vm/src/arch/testing/cpu.rs index 105962bc11..bea9317927 100644 --- a/crates/vm/src/arch/testing/cpu.rs +++ b/crates/vm/src/arch/testing/cpu.rs @@ -533,6 +533,25 @@ where self } + pub fn load_periphery_and_prank_trace( + mut self, + (air, chip): (A, C), + modify_trace: P, + ) -> Self + where + A: AnyRap + 'static, + C: Chip<(), CpuBackend>, + P: Fn(&mut RowMajorMatrix>), + { + let mut ctx = chip.generate_proving_ctx(()); + let trace: Arc>> = Option::take(&mut ctx.common_main).unwrap(); + let mut trace = Arc::into_inner(trace).unwrap(); + modify_trace(&mut trace); + ctx.common_main = Some(Arc::new(trace)); + self.air_ctxs.push((Arc::new(air), ctx)); + self + } + /// Given a function to produce an engine from the max trace height, /// runs a simple test on that engine pub fn test E>( diff --git a/docs/vocs/docs/pages/book/acceleration-using-extensions/overview.mdx b/docs/vocs/docs/pages/book/acceleration-using-extensions/overview.mdx index edfac8d542..17edf81d98 100644 --- a/docs/vocs/docs/pages/book/acceleration-using-extensions/overview.mdx +++ b/docs/vocs/docs/pages/book/acceleration-using-extensions/overview.mdx @@ -3,7 +3,7 @@ OpenVM ships with a set of pre-built extensions maintained by the OpenVM team. Below, we highlight six of these extensions designed to accelerate common arithmetic and cryptographic operations that are notoriously expensive to execute. Some of these extensions have corresponding guest libraries which provide convenient, high-level interfaces for your guest program to interact with the extension. - [`openvm-keccak-guest`](/book/acceleration-using-extensions/keccak) - Keccak256 hash function. See the [Keccak256 guest library](/book/guest-libraries/keccak256) for usage details. -- [`openvm-sha256-guest`](/book/acceleration-using-extensions/sha-256) - SHA-256 hash function. See the [SHA2 guest library](/book/guest-libraries/sha2) for usage details. +- [`openvm-sha2-guest`](/book/acceleration-using-extensions/sha-2) - SHA-2 family of hash functions. See the [SHA-2 guest library](/book/guest-libraries/sha2) for usage details. - [`openvm-bigint-guest`](/book/acceleration-using-extensions/big-integer) - Big integer arithmetic for 256-bit signed and unsigned integers. See the [Ruint guest library](/book/guest-libraries/ruint) for using accelerated 256-bit integer ops in rust. - [`openvm-algebra-guest`](/book/acceleration-using-extensions/algebra) - Modular arithmetic and complex field extensions. - [`openvm-ecc-guest`](/book/acceleration-using-extensions/elliptic-curve-cryptography) - Elliptic curve cryptography. See the [K256](/book/guest-libraries/k256) and [P256](/book/guest-libraries/p256) guest libraries for using this extension over the respective curves. @@ -43,9 +43,7 @@ range_tuple_checker_sizes = [256, 8192] [app_vm_config.io] [app_vm_config.keccak] - -[app_vm_config.sha256] - +[app_vm_config.sha2] [app_vm_config.native] [app_vm_config.bigint] diff --git a/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-2.mdx b/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-2.mdx new file mode 100644 index 0000000000..30daa44f6a --- /dev/null +++ b/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-2.mdx @@ -0,0 +1,45 @@ +# SHA-2 + +The SHA-2 extension guest provides functions that are meant to be linked to other external libraries. The external libraries can use these functions as a hook for SHA-2 intrinsics. This is enabled only when the target is `zkvm`. We support the SHA-256 and SHA-512 hash functions. + +We provide the following functions to compute the SHA-2 compression function. +- `zkvm_shaXXX_impl(state: *const u8, input: *const u8, output: *mut u8)` where `XXX` is `256` or `512`. These functions have `C` ABI. They take in a pointer to the current hasher state (`state`), a pointer to the next block of the message (`input`), and a pointer where the new hasher state will be written (`output`). `state` is expected to be a pointer to 8 little-endian words, even though its type is `*const u8`. For Sha256 that means it is a pointer to `[u32; 8]`, and for Sha512 it's `[u64; 8]`. + +In the external library, you can do something like the following: + +```rust +extern "C" { + fn zkvm_sha256_impl(input: *const u8, len: usize, output: *mut u8); +} + +fn sha256(input: &[u8]) -> [u8; 32] { + #[cfg(target_os = "zkvm")] + { + let mut state = [0u32; 8]; + let padded_input = add_padding(input); + padded_input + .chunks_exact(32) + .for_each(|input_block| { + unsafe { + zkvm_sha256_impl(state.as_ptr() as *const u8, input_block.as_ptr(), output.as_mut_ptr() as *mut u8); + } + }) + state + .map(|word| word.to_be_bytes()) + .collect::>() + .try_into() + .unwrap() + } + #[cfg(not(target_os = "zkvm"))] { + // Regular SHA-256 implementation + } +} +``` + +### Config parameters + +For the guest program to build successfully add the following to your `.toml` file: + +```toml +[app_vm_config.sha2] +``` diff --git a/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-256.mdx b/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-256.mdx deleted file mode 100644 index a4a7f46261..0000000000 --- a/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-256.mdx +++ /dev/null @@ -1,35 +0,0 @@ -# SHA-256 - -The SHA-256 extension guest provides a function that is meant to be linked to other external libraries. The external libraries can use this function as a hook for the SHA-256 intrinsic. This is enabled only when the target is `zkvm`. - -- `zkvm_sha256_impl(input: *const u8, len: usize, output: *mut u8)`: This function has `C` ABI. It takes in a pointer to the input, the length of the input, and a pointer to the output buffer. - -In the external library, you can do the following: - -```rust -extern "C" { - fn zkvm_sha256_impl(input: *const u8, len: usize, output: *mut u8); -} - -fn sha256(input: &[u8]) -> [u8; 32] { - #[cfg(target_os = "zkvm")] - { - let mut output = [0u8; 32]; - unsafe { - zkvm_sha256_impl(input.as_ptr(), input.len(), output.as_mut_ptr() as *mut u8); - } - output - } - #[cfg(not(target_os = "zkvm"))] { - // Regular SHA-256 implementation - } -} -``` - -### Config parameters - -For the guest program to build successfully add the following to your `.toml` file: - -```toml -[app_vm_config.sha256] -``` diff --git a/docs/vocs/docs/pages/book/advanced-usage/sdk.mdx b/docs/vocs/docs/pages/book/advanced-usage/sdk.mdx index 8a69d4231b..4a9fbc7842 100644 --- a/docs/vocs/docs/pages/book/advanced-usage/sdk.mdx +++ b/docs/vocs/docs/pages/book/advanced-usage/sdk.mdx @@ -38,7 +38,7 @@ Note that to use `Sdk::riscv32()` or `Sdk::standard()` the `app_vm_config` field ``` -Observe that this standard `openvm.toml` also enables normal Rust and `openvm::io` functions (via the `rv32i`, `rv32m`, and `io` extensions). `keccak` and `sha256` enable intrinsic instructions for the [Keccak](/book/acceleration-using-extensions/keccak) and [SHA-256](/book/acceleration-using-extensions/sha-256) hashes respectively, and `bigint` supports [Big Integer](/book/acceleration-using-extensions/big-integer) operations. +Observe that this standard `openvm.toml` also enables normal Rust and `openvm::io` functions (via the `rv32i`, `rv32m`, and `io` extensions). `keccak` and `sha2` enable intrinsic instructions for the [Keccak](/book/acceleration-using-extensions/keccak) and [SHA-2](/book/acceleration-using-extensions/sha-2) hashes respectively, and `bigint` supports [Big Integer](/book/acceleration-using-extensions/big-integer) operations. [Modular](/book/acceleration-using-extensions/algebra) operations for the BN254, Secp256k1 (i.e. K256), Secp256r1 (i.e. P256), and BLS12-381 curves' scalar and coordinate field moduli are also supported, as well as [Complex Field Extension](/book/acceleration-using-extensions/algebra#complex-field-extension) operations over the BN254 and BLS12-381 coordinate fields. [Elliptic Curve Cryptography](/book/acceleration-using-extensions/elliptic-curve-cryptography) operations are also supported for the BN254, Secp256k1, Secp256r1, and BLS12-381 curves, and [Elliptic Curve Pairing](/book/acceleration-using-extensions/elliptic-curve-pairing) checks are supported for the BN254 and BLS12-381 curves. diff --git a/docs/vocs/docs/pages/book/getting-started/introduction.mdx b/docs/vocs/docs/pages/book/getting-started/introduction.mdx index 6dc47c4d71..8c01bdab13 100644 --- a/docs/vocs/docs/pages/book/getting-started/introduction.mdx +++ b/docs/vocs/docs/pages/book/getting-started/introduction.mdx @@ -12,7 +12,7 @@ OpenVM is an open-source zero-knowledge virtual machine (zkVM) framework focused - RISC-V support via RV32IM - A native field arithmetic extension for proof recursion and aggregation - - The Keccak-256 and SHA2-256 hash functions + - The Keccak-256, SHA-256 and SHA-512 hash functions - Int256 arithmetic - Modular arithmetic over arbitrary fields - Elliptic curve operations, including multi-scalar multiplication and ECDSA signature verification, including for the secp256k1 and secp256r1 curves diff --git a/docs/vocs/docs/pages/book/guest-libraries/sha2.mdx b/docs/vocs/docs/pages/book/guest-libraries/sha2.mdx index 0a626fdd82..7710c35b02 100644 --- a/docs/vocs/docs/pages/book/guest-libraries/sha2.mdx +++ b/docs/vocs/docs/pages/book/guest-libraries/sha2.mdx @@ -3,30 +3,33 @@ The OpenVM SHA-2 guest library provides access to a set of accelerated SHA-2 family hash functions. Currently, it supports the following: - SHA-256 +- SHA-512 +- SHA-384 -## SHA-256 - -Refer [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for more details on SHA-256. - -For SHA-256, the SHA2 guest library provides two functions for use in your guest code: - -- `sha256(input: &[u8]) -> [u8; 32]`: Computes the SHA-256 hash of the input data and returns it as an array of 32 bytes. -- `set_sha256(input: &[u8], output: &mut [u8; 32])`: Sets the output to the SHA-256 hash of the input data into the provided output buffer. - -See the full example [here](https://github.com/openvm-org/openvm/blob/main/examples/sha256/src/main.rs). +Refer to [the FIPS publication](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for more details on the SHA-2 family of hash functions. +The SHA-2 guest library provides the `Sha256`, `Sha512`, and `Sha384` structs for use in your guest code. +These structs mimic the identically-named structs found in the `sha2` crate, in that they implement the [`sha2::Digest` trait](https://docs.rs/sha2/latest/sha2/trait.Digest.html). +Importantly, the `sha2::Digest` trait has the following methods, which can be used to incrementally hash a stream of bytes. +```Rust +fn update(&mut self, data: impl AsRef<[u8]>); +fn finalize(self) -> GenericArray; +``` ### Example +The following example can be run as guest code or host code (i.e. run in the zkvm or natively). +To run with guest code, use `cargo openvm run`, and for host code use `cargo run`. +The implementations for `Sha256`, `Sha512`, and `Sha384` fall back to using `sha2` when running as host code. ```rust -// [!include ~/snippets/examples/sha256/src/main.rs:imports] -// [!include ~/snippets/examples/sha256/src/main.rs:main] +[!include ~/snippets/examples/sha2/src/main.rs:imports] +[!include ~/snippets/examples/sha2/src/main.rs:main] ``` -To be able to import the `sha256` function, add the following to your `Cargo.toml` file: +To be able to import the `shaXXX` functions and run the example, add the following to your `Cargo.toml` file: ```toml -openvm-sha2 = { git = "https://github.com/openvm-org/openvm.git", tag = "v1.4.2" } -hex = { version = "0.4.3" } +openvm = { git = "https://github.com/openvm-org/openvm.git" } +openvm-sha2 = { git = "https://github.com/openvm-org/openvm.git" } ``` ### Config parameters @@ -34,4 +37,4 @@ hex = { version = "0.4.3" } For the guest program to build successfully add the following to your `.toml` file: ```toml -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/docs/vocs/docs/pages/specs/architecture/circuit-architecture.mdx b/docs/vocs/docs/pages/specs/architecture/circuit-architecture.mdx index 6417d8a234..e0f8e9ac89 100644 --- a/docs/vocs/docs/pages/specs/architecture/circuit-architecture.mdx +++ b/docs/vocs/docs/pages/specs/architecture/circuit-architecture.mdx @@ -107,7 +107,7 @@ The chips that fall into these categories are: | FriReducedOpeningChip | – | – | Case 1. | | NativePoseidon2Chip | – | – | Case 1. | | Rv32HintStoreChip | – | – | Case 1. | -| Sha256VmChip | – | – | Case 1. | +| Sha2VmChip | – | – | Case 1. | The PhantomChip satisfies the condition because `1 < 3`. diff --git a/docs/vocs/docs/pages/specs/openvm/isa.mdx b/docs/vocs/docs/pages/specs/openvm/isa.mdx index 14b71fa05c..d6358592d0 100644 --- a/docs/vocs/docs/pages/specs/openvm/isa.mdx +++ b/docs/vocs/docs/pages/specs/openvm/isa.mdx @@ -6,7 +6,7 @@ This specification describes the overall architecture and default VM extensions - [RV32IM](#rv32im-extension): An extension supporting the 32-bit RISC-V ISA with multiplication. - [Native](#native-extension): An extension supporting native field arithmetic for proof recursion and aggregation. - [Keccak-256](#keccak-extension): An extension implementing the Keccak-256 hash function compatibly with RISC-V memory. -- [SHA2-256](#sha2-256-extension): An extension implementing the SHA2-256 hash function compatibly with RISC-V memory. +- [SHA2](#sha-2-extension): An extension implementing the SHA-256 and SHA-512 hash functions compatibly with RISC-V memory. - [BigInt](#bigint-extension): An extension supporting 256-bit signed and unsigned integer arithmetic, including multiplication. This extension respects the RISC-V memory format. - [Algebra](#algebra-extension): An extension supporting modular arithmetic over arbitrary fields and their complex @@ -35,6 +35,7 @@ OpenVM depends on the following parameters, some of which are fixed and some of | `addr_space_height` | The base 2 log of the number of writable address spaces supported. | Configurable, must satisfy `addr_space_height <= F::bits() - 2` | | `pointer_max_bits` | The maximum number of bits in a pointer. | Configurable, must satisfy `pointer_max_bits <= F::bits() - 2` | | `num_public_values` | The number of user public values. | Configurable. If continuation is enabled, it must equal `8` times a power of two(which is nonzero). | +| `MAX_HINT_BUFFER_WORDS_BITS` | The maximum number of bits for hint buffer word count. This determines `MAX_HINT_BUFFER_WORDS = 2^MAX_HINT_BUFFER_WORDS_BITS - 1` = 262,143 words (≈1MB), the maximum words per `HINT_BUFFER_RV32` instruction. | Fixed to 18. | We explain these parameters in subsequent sections. @@ -428,9 +429,11 @@ with user input-output. | Name | Operands | Description | | ---------------- | --------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | HINT_STOREW_RV32 | `_,b,_,1,2` | `[r32{0}(b):4]_2 = next 4 bytes from hint stream`. Only valid if next 4 values in hint stream are bytes. | -| HINT_BUFFER_RV32 | `a,b,_,1,2` | `[r32{0}(b):4 * l]_2 = next 4 * l bytes from hint stream` where `l = r32{0}(a)`. Only valid if next `4 * l` values in hint stream are bytes. Very important: `l` should not be 0. The pointer address `r32{0}(b)` does not need to be a multiple of `4`. | +| HINT_BUFFER_RV32 | `a,b,_,1,2` | `[r32{0}(b):4 * l]_2 = next 4 * l bytes from hint stream` where `l = r32{0}(a)`. Only valid if next `4 * l` values in hint stream are bytes. `l` must be non-zero and <= `MAX_HINT_BUFFER_WORDS` (262,143 words ≈ 1MB). The pointer address `r32{0}(b)` does not need to be a multiple of `4`. | | REVEAL_RV32 | `a,b,c,1,3,_,g` | Pseudo-instruction for `STOREW_RV32 a,b,c,1,3,_,g` writing to the user IO address space `3`. Only valid when continuations are enabled. | +> **Note:** The `MAX_HINT_BUFFER_WORDS` bound on `HINT_BUFFER_RV32` is enforced by both the executor and AIR constraints. The SDK's `hint_buffer_chunked` function automatically splits larger reads into multiple `HINT_BUFFER_RV32` instructions. + #### Phantom Sub-Instructions The RV32IM extension defines the following phantom sub-instructions. @@ -547,14 +550,15 @@ all memory cells are constrained to be bytes. | -------------- | ----------- | ----------------------------------------------------------------------------------------------------------------- | | KECCAK256_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = keccak256([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Performs memory accesses with block size `4`. | -### SHA2-256 Extension +### SHA-2 Extension -The SHA2-256 extension supports the SHA2-256 hash function. The extension operates on address spaces `1` and `2`, +The SHA-2 extension supports the SHA-256 and SHA-512 hash functions. The extension operates on address spaces `1` and `2`, meaning all memory cells are constrained to be bytes. | Name | Operands | Description | | ----------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| SHA256_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = sha256([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `16` and writes with block size `32`. | +| SHA256_UPDATE_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = compress256([r32{0}(b):32]_2, [r32{0}(c):64]_2)`. where `compress256(state, input)` is the SHA-256 block compression function which returns the updated state. This instruction performs memory reads and writes with block size `4`. | +| SHA512_UPDATE_RV32 | `a,b,c,1,2` | `[r32{0}(a):64]_2 = compress512([r32{0}(b):64]_2, [r32{0}(c):128]_2)`. where `compress512(state, input)` is the SHA-512 block compression function which returns the updated state. This instruction performs memory reads and writes with block size `4`. | ### BigInt Extension diff --git a/docs/vocs/docs/pages/specs/reference/instruction-reference.mdx b/docs/vocs/docs/pages/specs/reference/instruction-reference.mdx index fa44faebb0..02276a7e43 100644 --- a/docs/vocs/docs/pages/specs/reference/instruction-reference.mdx +++ b/docs/vocs/docs/pages/specs/reference/instruction-reference.mdx @@ -130,13 +130,14 @@ In the tables below, we provide the mapping between the `LocalOpcode` and `Phant | ------------- | ---------- | ------------- | | Keccak | `Rv32KeccakOpcode::KECCAK256` | KECCAK256_RV32 | -## SHA2-256 Extension +## SHA-2 Extension #### Instructions -| VM Extension | `LocalOpcode` | ISA Instruction | -| ------------- | ---------- | ------------- | -| SHA2-256 | `Rv32Sha256Opcode::SHA256` | SHA256_RV32 | +| VM Extension | `LocalOpcode` | ISA Instruction | +| ------------- | ------------------------ | ------------------------ | +| SHA-2 | `Rv32Sha2Opcode::SHA256` | SHA256_UPDATE_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA512` | SHA512_UPDATE_RV32 | ## BigInt Extension diff --git a/docs/vocs/docs/pages/specs/reference/riscv-custom-code.mdx b/docs/vocs/docs/pages/specs/reference/riscv-custom-code.mdx index ebbb17c5f4..6c5b44d347 100644 --- a/docs/vocs/docs/pages/specs/reference/riscv-custom-code.mdx +++ b/docs/vocs/docs/pages/specs/reference/riscv-custom-code.mdx @@ -5,7 +5,7 @@ The default VM extensions that support transpilation are: - [RV32IM](#rv32im-extension): An extension supporting the 32-bit RISC-V ISA with multiplication. - [Keccak-256](#keccak-extension): An extension implementing the Keccak-256 hash function compatibly with RISC-V memory. -- [SHA2-256](#sha2-256-extension): An extension implementing the SHA2-256 hash function compatibly with RISC-V memory. +- [SHA2](#sha-2-extension): An extension implementing the SHA-256 and SHA-512 hash functions compatibly with RISC-V memory. - [BigInt](#bigint-extension): An extension supporting 256-bit signed and unsigned integer arithmetic, including multiplication. This extension respects the RISC-V memory format. - [Algebra](#algebra-extension): An extension supporting modular arithmetic over arbitrary fields and their complex field extensions. This extension respects the RISC-V memory format. - [Elliptic curve](#elliptic-curve-extension): An extension for elliptic curve operations over Weierstrass curves, including addition and doubling. This can be used to implement multi-scalar multiplication and ECDSA scalar multiplication. This extension respects the RISC-V memory format. @@ -85,11 +85,12 @@ implementation is here. But we use `funct3 = 111` because the native extension h | ----------- | --- | ----------- | ------ | ------ | ------------------------------------------- | | keccak256 | R | 0001011 | 100 | 0x0 | `[rd:32]_2 = keccak256([rs1..rs1 + rs2]_2)` | -## SHA2-256 Extension +## SHA-2 Extension -| RISC-V Inst | FMT | opcode[6:0] | funct3 | funct7 | RISC-V description and notes | -| ----------- | --- | ----------- | ------ | ------ | ---------------------------------------- | -| sha256 | R | 0001011 | 100 | 0x1 | `[rd:32]_2 = sha256([rs1..rs1 + rs2]_2)` | +| RISC-V Inst | FMT | opcode[6:0] | funct3 | funct7 | RISC-V description and notes | +| ------------- | --- | ----------- | ------ | ------ | -------------------------------------------------- | +| sha256_update | R | 0001011 | 100 | 0x1 | `[rd:32]_2 = compress256([rs1:32]_2, [rs2:64]_2)` | +| sha512_update | R | 0001011 | 100 | 0x2 | `[rd:64]_2 = compress512([rs1:64]_2, [rs2:128]_2)` | ## BigInt Extension diff --git a/docs/vocs/docs/pages/specs/reference/transpiler.mdx b/docs/vocs/docs/pages/specs/reference/transpiler.mdx index ebb31770b5..6fc1c6b479 100644 --- a/docs/vocs/docs/pages/specs/reference/transpiler.mdx +++ b/docs/vocs/docs/pages/specs/reference/transpiler.mdx @@ -151,11 +151,12 @@ Each VM extension's behavior is specified below. | ----------- | -------------------------------------------------- | | keccak256 | KECCAK256_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | -### SHA2-256 Extension +### SHA-2 Extension -| RISC-V Inst | OpenVM Instruction | -| ----------- | ----------------------------------------------- | -| sha256 | SHA256_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| RISC-V Inst | OpenVM Instruction | +| ------------- | ------------------------------------------------------ | +| sha256_update | SHA256_UPDATE_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| sha512_update | SHA512_UPDATE_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | ### BigInt Extension diff --git a/examples/sha256/Cargo.toml b/examples/sha2/Cargo.toml similarity index 63% rename from examples/sha256/Cargo.toml rename to examples/sha2/Cargo.toml index 0b5a44bc3e..b093a850ae 100644 --- a/examples/sha256/Cargo.toml +++ b/examples/sha2/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "sha256-example" +name = "sha2-example" version = "0.0.0" edition = "2021" @@ -7,16 +7,14 @@ edition = "2021" members = [] [dependencies] -openvm = { git = "https://github.com/openvm-org/openvm.git", features = [ - "std", -] } +openvm = { git = "https://github.com/openvm-org/openvm.git" } openvm-sha2 = { git = "https://github.com/openvm-org/openvm.git" } -hex = { version = "0.4.3" } [features] default = [] +std = ["openvm/std"] # remove this if copying example outside of monorepo [patch."https://github.com/openvm-org/openvm.git"] openvm = { path = "../../crates/toolchain/openvm" } -openvm-sha2 = { path = "../../guest-libs/sha2" } +openvm-sha2 = { path = "../../guest-libs/sha2" } \ No newline at end of file diff --git a/examples/sha256/openvm.toml b/examples/sha2/openvm.toml similarity index 73% rename from examples/sha256/openvm.toml rename to examples/sha2/openvm.toml index 656bf52414..fb0cbe8cd4 100644 --- a/examples/sha256/openvm.toml +++ b/examples/sha2/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] \ No newline at end of file diff --git a/examples/sha2/src/main.rs b/examples/sha2/src/main.rs new file mode 100644 index 0000000000..c11221d33d --- /dev/null +++ b/examples/sha2/src/main.rs @@ -0,0 +1,41 @@ +// [!region imports] +#![cfg_attr(all(target_os = "zkvm", not(feature = "std")), no_main)] +#![cfg_attr(all(target_os = "zkvm", not(feature = "std")), no_std)] + +extern crate alloc; + +use alloc::{format, string::String}; + +use openvm_sha2::{Digest, Sha256, Sha384, Sha512}; + +openvm::entry!(main); +// [!endregion imports] + +// [!region main] +fn println(s: String) { + #[cfg(target_os = "zkvm")] + openvm::io::println(s); + #[cfg(not(target_os = "zkvm"))] + println!("{}", s); +} +pub fn main() { + let mut sha256 = Sha256::new(); + sha256.update(b"Hello, world!"); + sha256.update(b"some other input"); + let output = sha256.finalize(); + println(format!("SHA-256: {:?}", output)); + + let mut sha512 = Sha512::new(); + sha512.update(b"Hello, world!"); + sha512.update(b"some other input"); + let output = sha512.finalize(); + println(format!("SHA-512: {:?}", output)); + + let mut sha384 = Sha384::new(); + sha384.update(b"Hello, world!"); + sha384.update(b"some other input"); + let output = sha384.finalize(); + println(format!("SHA-384: {:?}", output)); +} + +// [!endregion main] diff --git a/examples/sha256/src/main.rs b/examples/sha256/src/main.rs deleted file mode 100644 index 670ac5c011..0000000000 --- a/examples/sha256/src/main.rs +++ /dev/null @@ -1,24 +0,0 @@ -// [!region imports] -use core::hint::black_box; - -use hex::FromHex; -use openvm as _; -use openvm_sha2::sha256; -// [!endregion imports] - -// [!region main] -pub fn main() { - let test_vectors = [( - "", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - )]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = sha256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} -// [!endregion main] diff --git a/extensions/algebra/moduli-macros/src/lib.rs b/extensions/algebra/moduli-macros/src/lib.rs index 4ea8af0211..0266b7468e 100644 --- a/extensions/algebra/moduli-macros/src/lib.rs +++ b/extensions/algebra/moduli-macros/src/lib.rs @@ -875,15 +875,15 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { - use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; // needed for hint_store_u32! and hint_buffer_u32! + use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; // needed for hint_store_u32! and hint_buffer_chunked let is_square = core::mem::MaybeUninit::::uninit(); - let sqrt = core::mem::MaybeUninit::<#struct_name>::uninit(); + let mut sqrt = core::mem::MaybeUninit::<#struct_name>::uninit(); unsafe { #hint_sqrt_extern_func(self as *const #struct_name as usize); let is_square_ptr = is_square.as_ptr() as *const u32; openvm_rv32im_guest::hint_store_u32!(is_square_ptr); - openvm_rv32im_guest::hint_buffer_u32!(sqrt.as_ptr() as *const u8, <#struct_name as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4); + openvm_rv32im_guest::hint_buffer_chunked(sqrt.as_mut_ptr() as *mut u8, <#struct_name as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4 as usize); let is_square = is_square.assume_init(); if is_square == 0 || is_square == 1 { Some((is_square == 1, sqrt.assume_init())) @@ -902,14 +902,14 @@ pub fn moduli_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { - use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; // needed for hint_buffer_u32! + use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; // needed for hint_buffer_chunked let mut non_qr_uninit = core::mem::MaybeUninit::::uninit(); let mut non_qr; unsafe { #hint_non_qr_extern_func(); - let ptr = non_qr_uninit.as_ptr() as *const u8; - openvm_rv32im_guest::hint_buffer_u32!(ptr, ::NUM_LIMBS / 4); + let ptr = non_qr_uninit.as_mut_ptr() as *mut u8; + openvm_rv32im_guest::hint_buffer_chunked(ptr, ::NUM_LIMBS / 4 as usize); non_qr = non_qr_uninit.assume_init(); } // ensure non_qr < modulus diff --git a/extensions/rv32im/circuit/cuda/src/hintstore.cu b/extensions/rv32im/circuit/cuda/src/hintstore.cu index ce09a22477..b4e3a1b607 100644 --- a/extensions/rv32im/circuit/cuda/src/hintstore.cu +++ b/extensions/rv32im/circuit/cuda/src/hintstore.cu @@ -6,6 +6,8 @@ using namespace riscv; using namespace program; +using hintstore::MAX_HINT_BUFFER_WORDS; +using hintstore::MAX_HINT_BUFFER_WORDS_BITS; template struct Rv32HintStoreCols { // common @@ -87,11 +89,25 @@ struct Rv32HintStore { COL_WRITE_ARRAY(row, Rv32HintStoreCols, mem_ptr_limbs, mem_ptr_limbs); if (local_idx == 0) { + // The overflow check for mem_ptr + num_words * 4 is not needed because + // 4 * MAX_HINT_BUFFER_WORDS < 2^pointer_max_bits guarantees no overflow + assert(MAX_HINT_BUFFER_WORDS_BITS + 2 < pointer_max_bits); + + // Range check for mem_ptr (using pointer_max_bits) uint32_t msl_rshift = (RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS; uint32_t msl_lshift = RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits; + + // Range check for num_words (using MAX_HINT_BUFFER_WORDS_BITS) + // These constraints only work for MAX_HINT_BUFFER_WORDS_BITS in [16, 23] + assert(MAX_HINT_BUFFER_WORDS_BITS >= 16 && MAX_HINT_BUFFER_WORDS_BITS <= 23); + + assert(record.num_words <= MAX_HINT_BUFFER_WORDS); + uint32_t rem_words_limb2_lshift = (RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS - MAX_HINT_BUFFER_WORDS_BITS; + + // Combined range check for mem_ptr and num_words bitwise_lookup.add_range( (record.mem_ptr >> msl_rshift) << msl_lshift, - (record.num_words >> msl_rshift) << msl_lshift + ((record.num_words >> 16) & 0xFF) << rem_words_limb2_lshift ); mem_helper.fill( row.slice_from(COL_INDEX(Rv32HintStoreCols, mem_ptr_aux_cols)), diff --git a/extensions/rv32im/circuit/src/hintstore/execution.rs b/extensions/rv32im/circuit/src/hintstore/execution.rs index 47e68e3084..631cdb79a1 100644 --- a/extensions/rv32im/circuit/src/hintstore/execution.rs +++ b/extensions/rv32im/circuit/src/hintstore/execution.rs @@ -14,6 +14,7 @@ use openvm_instructions::{ use openvm_rv32im_transpiler::{ Rv32HintStoreOpcode, Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW}, + MAX_HINT_BUFFER_WORDS, }; use openvm_stark_backend::p3_field::PrimeField32; @@ -172,6 +173,15 @@ unsafe fn execute_e12_impl MAX_HINT_BUFFER_WORDS as u32 { + return Err(ExecutionError::HintBufferTooLarge { + pc, + num_words, + max_hint_buffer_words: MAX_HINT_BUFFER_WORDS as u32, + }); + } + if exec_state.streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { let err = ExecutionError::HintOutOfBounds { pc }; return Err(err); diff --git a/extensions/rv32im/circuit/src/hintstore/mod.rs b/extensions/rv32im/circuit/src/hintstore/mod.rs index 35955bb979..b9cac88249 100644 --- a/extensions/rv32im/circuit/src/hintstore/mod.rs +++ b/extensions/rv32im/circuit/src/hintstore/mod.rs @@ -25,6 +25,7 @@ use openvm_instructions::{ use openvm_rv32im_transpiler::{ Rv32HintStoreOpcode, Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW}, + MAX_HINT_BUFFER_WORDS, MAX_HINT_BUFFER_WORDS_BITS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -202,19 +203,29 @@ impl Air for Rv32HintStoreAir { ) .eval(builder, is_start.clone()); - // Preventing mem_ptr and rem_words overflow - // Constraining mem_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1] < 2^(pointer_max_bits - - // (RV32_REGISTER_NUM_LIMBS - 1)*RV32_CELL_BITS) which implies mem_ptr <= - // 2^pointer_max_bits Similarly for rem_words <= 2^pointer_max_bits + // Preventing rem_words overflow: rem_words < 2^MAX_HINT_BUFFER_WORDS_BITS + // These constraints only work for MAX_HINT_BUFFER_WORDS_BITS in [16, 23] + debug_assert!( + (16..=23).contains(&MAX_HINT_BUFFER_WORDS_BITS), + "MAX_HINT_BUFFER_WORDS_BITS must be in [16, 23] for these constraints to work" + ); + // For MAX_HINT_BUFFER_WORDS_BITS = 18, this requires: + // - limbs[3] = 0 (since 2^18 < 2^24) + // - limbs[2] < 4 (since 2^18 = 4 * 2^16) + builder.assert_zero(local_cols.rem_words_limbs[RV32_REGISTER_NUM_LIMBS - 1]); + + // Preventing mem_ptr overflow: mem_ptr < 2^pointer_max_bits + // (rem_words overflow is handled below with the stricter MAX_HINT_BUFFER_WORDS_BITS bound) self.bitwise_operation_lookup_bus .send_range( local_cols.mem_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1] * AB::F::from_canonical_usize( 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits), ), - local_cols.rem_words_limbs[RV32_REGISTER_NUM_LIMBS - 1] + local_cols.rem_words_limbs[RV32_REGISTER_NUM_LIMBS - 2] * AB::F::from_canonical_usize( - 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits), + 1 << ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS + - MAX_HINT_BUFFER_WORDS_BITS), ), ) .eval(builder, is_start.clone()); @@ -409,6 +420,15 @@ where read_rv32_register(state.memory.data(), a) }; + // Bounds check: num_words must not exceed MAX_HINT_BUFFER_WORDS + if num_words > MAX_HINT_BUFFER_WORDS as u32 { + return Err(ExecutionError::HintBufferTooLarge { + pc: *state.pc, + num_words, + max_hint_buffer_words: MAX_HINT_BUFFER_WORDS as u32, + }); + } + let record = state.ctx.alloc(MultiRowLayout::new(Rv32HintStoreMetadata { num_words: num_words as usize, })); @@ -508,6 +528,10 @@ impl TraceFiller for Rv32HintStoreFiller { let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits) as u32; + // Scale factors for rem_words range check (using MAX_HINT_BUFFER_WORDS_BITS) + let rem_words_limb2_lshift: u32 = + ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS - MAX_HINT_BUFFER_WORDS_BITS) as u32; + chunks .par_iter_mut() .zip(sizes.par_iter()) @@ -526,9 +550,17 @@ impl TraceFiller for Rv32HintStoreFiller { }), ) }; + // Range check for mem_ptr (using pointer_max_bits) + // (num_words overflow check is handled below with the stricter + // MAX_HINT_BUFFER_WORDS_BITS bound) + // Range check for num_words (using MAX_HINT_BUFFER_WORDS_BITS) + debug_assert!( + num_words <= MAX_HINT_BUFFER_WORDS as u32, + "num_words must be <= MAX_HINT_BUFFER_WORDS" + ); self.bitwise_lookup_chip.request_range( (record.inner.mem_ptr >> msl_rshift) << msl_lshift, - (num_words >> msl_rshift) << msl_lshift, + ((num_words >> 16) & 0xFF) << rem_words_limb2_lshift, ); let mut timestamp = record.inner.timestamp + num_words * 3; diff --git a/extensions/rv32im/circuit/src/hintstore/tests.rs b/extensions/rv32im/circuit/src/hintstore/tests.rs index e79066aae6..61019a1104 100644 --- a/extensions/rv32im/circuit/src/hintstore/tests.rs +++ b/extensions/rv32im/circuit/src/hintstore/tests.rs @@ -19,7 +19,10 @@ use openvm_instructions::{ riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, LocalOpcode, }; -use openvm_rv32im_transpiler::Rv32HintStoreOpcode::{self, *}; +use openvm_rv32im_transpiler::{ + Rv32HintStoreOpcode::{self, *}, + MAX_HINT_BUFFER_WORDS, +}; use openvm_stark_backend::{ p3_field::FieldAlgebra, p3_matrix::{ @@ -194,6 +197,94 @@ fn rand_hintstore_test() { // part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[test] +#[should_panic(expected = "HintBufferTooLarge")] +fn test_hint_buffer_exceeds_max_words() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + + let (mut harness, _bitwise) = create_harness::>(&mut tester); + + let num_words = (MAX_HINT_BUFFER_WORDS + 1) as u32; + + let a = gen_pointer(&mut rng, RV32_REGISTER_NUM_LIMBS); + tester.write( + RV32_REGISTER_AS as usize, + a, + num_words.to_le_bytes().map(F::from_canonical_u8), + ); + + let mem_ptr = gen_pointer(&mut rng, 4) as u32; + let b = gen_pointer(&mut rng, RV32_REGISTER_NUM_LIMBS); + tester.write(1, b, mem_ptr.to_le_bytes().map(F::from_canonical_u8)); + + for _ in 0..num_words { + let data = rng.next_u32().to_le_bytes().map(F::from_canonical_u8); + tester.streams_mut().hint_stream.extend(data); + } + + tester.execute( + &mut harness.executor, + &mut harness.arena, + &Instruction::from_usize( + HINT_BUFFER.global_opcode(), + [a, b, 0, RV32_REGISTER_AS as usize, RV32_MEMORY_AS as usize], + ), + ); +} + +#[test] +fn test_hint_buffer_rem_words_range_check() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + + let (mut harness, bitwise) = create_harness(&mut tester); + + // Build a small, valid buffer instruction with 1 word so trace has 1 row. + let num_words: u32 = 1; + let a = gen_pointer(&mut rng, RV32_REGISTER_NUM_LIMBS); + tester.write( + RV32_REGISTER_AS as usize, + a, + num_words.to_le_bytes().map(F::from_canonical_u8), + ); + + let mem_ptr = gen_pointer(&mut rng, 4) as u32; + let b = gen_pointer(&mut rng, RV32_REGISTER_NUM_LIMBS); + tester.write(1, b, mem_ptr.to_le_bytes().map(F::from_canonical_u8)); + + for _ in 0..num_words { + let data = rng.next_u32().to_le_bytes().map(F::from_canonical_u8); + tester.streams_mut().hint_stream.extend(data); + } + + tester.execute( + &mut harness.executor, + &mut harness.arena, + &Instruction::from_usize( + HINT_BUFFER.global_opcode(), + [a, b, 0, RV32_REGISTER_AS as usize, RV32_MEMORY_AS as usize], + ), + ); + + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); + let cols: &mut Rv32HintStoreCols = trace_row.as_mut_slice().borrow_mut(); + // Force `rem_words` to overflow MAX_HINT_BUFFER_WORDS_BITS on the start row. + cols.rem_words_limbs = [F::ZERO, F::ZERO, F::ZERO, F::from_canonical_u8(1)]; + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(harness, modify_trace) + .load_periphery(bitwise) + .finalize(); + + tester.simple_test_with_expected_error(get_verification_error(false)); +} + #[allow(clippy::too_many_arguments)] fn run_negative_hintstore_test( opcode: Rv32HintStoreOpcode, diff --git a/extensions/rv32im/guest/src/io.rs b/extensions/rv32im/guest/src/io.rs index 664b9b1117..535959d4cc 100644 --- a/extensions/rv32im/guest/src/io.rs +++ b/extensions/rv32im/guest/src/io.rs @@ -1,5 +1,5 @@ #![allow(unused_imports)] -use crate::{PhantomImm, PHANTOM_FUNCT3, SYSTEM_OPCODE}; +use crate::{PhantomImm, MAX_HINT_BUFFER_WORDS, PHANTOM_FUNCT3, SYSTEM_OPCODE}; /// Store the next 4 bytes from the hint stream to [[rd]_1]_2. #[macro_export] @@ -21,8 +21,8 @@ macro_rules! hint_buffer_u32 { ($x:expr, $len:expr) => { if $len != 0 { openvm_custom_insn::custom_insn_i!( - opcode = openvm_rv32im_guest::SYSTEM_OPCODE, - funct3 = openvm_rv32im_guest::HINT_FUNCT3, + opcode = $crate::SYSTEM_OPCODE, + funct3 = $crate::HINT_FUNCT3, rd = In $x, rs1 = In $len, imm = Const 1, @@ -31,6 +31,18 @@ macro_rules! hint_buffer_u32 { }; } +/// Read hint buffer with automatic chunking for large reads. +/// Splits reads larger than MAX_HINT_BUFFER_WORDS into multiple instructions. +#[inline(always)] +pub fn hint_buffer_chunked(mut ptr: *mut u8, mut num_words: usize) { + while num_words > 0 { + let chunk = core::cmp::min(num_words, MAX_HINT_BUFFER_WORDS); + hint_buffer_u32!(ptr, chunk); + ptr = ptr.wrapping_add(chunk * 4); + num_words -= chunk; + } +} + /// Reset the hint stream with the next hint. #[inline(always)] pub fn hint_input() { diff --git a/extensions/rv32im/guest/src/lib.rs b/extensions/rv32im/guest/src/lib.rs index 99f1a6f97f..cea29068e2 100644 --- a/extensions/rv32im/guest/src/lib.rs +++ b/extensions/rv32im/guest/src/lib.rs @@ -25,6 +25,16 @@ pub const REVEAL_FUNCT3: u8 = 0b010; pub const PHANTOM_FUNCT3: u8 = 0b011; pub const CSRRW_FUNCT3: u8 = 0b001; +/// Maximum number of bits for hint buffer size. +/// IMPORTANT: Must be synced with MAX_HINT_BUFFER_WORDS_BITS constant for cuda +/// `crates/circuits/primitives/cuda/include/primitives/constants.h` +// For the constraints, they are configured for a range of MAX_HINT_BUFFER_WORDS_BITS between +// [16,23] +pub const MAX_HINT_BUFFER_WORDS_BITS: usize = 18; +/// Maximum number of words that can be read in a single HINT_BUFFER instruction. +/// AIR constraint requires rem_words < 2^MAX_HINT_BUFFER_WORDS_BITS, so max is one less +pub const MAX_HINT_BUFFER_WORDS: usize = (1 << MAX_HINT_BUFFER_WORDS_BITS) - 1; // 262,143 words ≈ 1MB + /// imm options for system phantom instructions #[derive(Debug, Copy, Clone, PartialEq, Eq, FromRepr)] #[repr(u16)] diff --git a/extensions/rv32im/tests/programs/examples/hint_large_buffer.rs b/extensions/rv32im/tests/programs/examples/hint_large_buffer.rs new file mode 100644 index 0000000000..64472b0f25 --- /dev/null +++ b/extensions/rv32im/tests/programs/examples/hint_large_buffer.rs @@ -0,0 +1,25 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +use openvm::io::read_vec; +use openvm_rv32im_guest::MAX_HINT_BUFFER_WORDS; + +openvm::entry!(main); + +pub fn main() { + let vec = read_vec(); + + // Create a hint buffer larger than MAX_HINT_BUFFER_WORDS, to test chunking + let expected_words = MAX_HINT_BUFFER_WORDS + 100; + let expected_len = expected_words * 4; + + if vec.len() != expected_len { + openvm::process::panic(); + } + + for (i, item) in vec.iter().enumerate() { + if *item != (i as u8) { + openvm::process::panic(); + } + } +} diff --git a/extensions/rv32im/tests/src/lib.rs b/extensions/rv32im/tests/src/lib.rs index ff141398f5..c4302ae808 100644 --- a/extensions/rv32im/tests/src/lib.rs +++ b/extensions/rv32im/tests/src/lib.rs @@ -13,7 +13,7 @@ mod tests { }; use openvm_instructions::{exe::VmExe, instruction::Instruction, LocalOpcode, SystemOpcode}; use openvm_rv32im_circuit::{Rv32IBuilder, Rv32IConfig, Rv32ImBuilder, Rv32ImConfig}; - use openvm_rv32im_guest::hint_load_by_key_encode; + use openvm_rv32im_guest::{hint_load_by_key_encode, MAX_HINT_BUFFER_WORDS}; use openvm_rv32im_transpiler::{ DivRemOpcode, MulHOpcode, MulOpcode, Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, @@ -169,6 +169,37 @@ mod tests { Ok(()) } + /// NOTE: This test is slow because it processes > 1MB of data. It is marked #[ignore] + /// and can be run with: cargo test -p openvm-rv32im-integration-tests test_hint_buffer_chunking + /// -- --ignored + #[test] + #[ignore = "slow test: processes >1MB of data"] + fn test_hint_buffer_chunking() -> Result<()> { + let config = test_rv32im_config(); + let elf = build_example_program_at_path(get_programs_dir!(), "hint_large_buffer", &config)?; + let exe = VmExe::from_elf( + elf, + Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Rv32IoTranspilerExtension), + )?; + + // Create input buffer larger than MAX_HINT_BUFFER_WORDS + // This will require chunking to succeed + let expected_words = MAX_HINT_BUFFER_WORDS + 100; + let expected_len = expected_words * 4; + + // Create data with a pattern that can be verified + let data: Vec = (0..expected_len) + .map(|i| F::from_canonical_u8((i % 256) as u8)) + .collect(); + + let input = vec![data]; + air_test_with_min_segments(Rv32ImBuilder, config, exe, input, 1); + Ok(()) + } + #[test] fn test_read() -> Result<()> { let config = test_rv32im_config(); diff --git a/extensions/rv32im/transpiler/src/lib.rs b/extensions/rv32im/transpiler/src/lib.rs index 03a354517e..218202c369 100644 --- a/extensions/rv32im/transpiler/src/lib.rs +++ b/extensions/rv32im/transpiler/src/lib.rs @@ -9,6 +9,7 @@ use openvm_rv32im_guest::{ NATIVE_STOREW_FUNCT3, NATIVE_STOREW_FUNCT7, PHANTOM_FUNCT3, REVEAL_FUNCT3, RV32M_FUNCT7, RV32_ALU_OPCODE, SYSTEM_OPCODE, TERMINATE_FUNCT3, }; +pub use openvm_rv32im_guest::{MAX_HINT_BUFFER_WORDS, MAX_HINT_BUFFER_WORDS_BITS}; use openvm_stark_backend::p3_field::PrimeField32; use openvm_transpiler::{ util::{nop, unimp}, diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha2/circuit/Cargo.toml similarity index 80% rename from extensions/sha256/circuit/Cargo.toml rename to extensions/sha2/circuit/Cargo.toml index 7b3a8d8edd..60d2951687 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha2/circuit/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "openvm-sha256-circuit" +name = "openvm-sha2-circuit" version.workspace = true authors.workspace = true edition.workspace = true -description = "OpenVM circuit extension for sha256" +description = "OpenVM circuit extension for SHA-2" [dependencies] openvm-stark-backend = { workspace = true } @@ -11,20 +11,23 @@ openvm-stark-sdk = { workspace = true } openvm-cuda-backend = { workspace = true, optional = true } openvm-cuda-common = { workspace = true, optional = true } openvm-circuit-primitives = { workspace = true } +openvm-circuit-primitives-derive = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-circuit = { workspace = true } openvm-instructions = { workspace = true } -openvm-sha256-transpiler = { workspace = true } +openvm-sha2-transpiler = { workspace = true } openvm-rv32im-circuit = { workspace = true } -openvm-sha256-air = { workspace = true } + +openvm-sha2-air = { workspace = true } derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true serde.workspace = true -sha2 = { version = "0.10", default-features = false } -strum = { workspace = true } +sha2 = { version = "0.10", features = ["compress"] } +ndarray = { workspace = true, default-features = false } cfg-if.workspace = true +itertools = { workspace = true } [dev-dependencies] hex = { workspace = true } @@ -62,3 +65,6 @@ touchemall = [ "openvm-cuda-common/touchemall", "openvm-rv32im-circuit/touchemall", ] + +[package.metadata.cargo-shear] +ignored = ["ndarray"] \ No newline at end of file diff --git a/extensions/sha2/circuit/README.md b/extensions/sha2/circuit/README.md new file mode 100644 index 0000000000..d98d9d6a4c --- /dev/null +++ b/extensions/sha2/circuit/README.md @@ -0,0 +1,93 @@ +# SHA-2 VM Extension + +This crate contains circuits for the SHA-2 family of hash functions. +We support constraining the block compression functions for SHA-256 and SHA-512. +It is also possible to use this crate to constrain the SHA-384 algorithm, as described in the next section. + +## SHA-2 Algorithms Summary + +The SHA-256, SHA-512, and SHA-384 algorithms are similar in structure. +We will first describe the SHA-256 algorithm, and then describe the differences between the three algorithms. + +See the [FIPS standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for reference. In particular, sections 6.2, 6.4, and 6.5. + +In short the SHA-256 algorithm works as follows. +1. Pad the message to 512 bits and split it into 512-bit 'blocks'. +2. Initialize a hash state consisting of eight 32-bit words to a specific constant value. +3. For each block, + 1. split the message into 16 32-bit words and produce 48 more words based on them. The 16 message words together with the 48 additional words are called the 'message schedule'. + 2. apply a scrambling function 64 times to the hash state to update it based on the message schedule. We call each update a 'round'. + 3. add the previous block's final hash state to the current hash state (modulo $2^{32}$). +4. The output is the final hash state + +The differences with the SHA-512 algorithm are that: +- SHA-512 uses 64-bit words, 1024-bit blocks, performs 80 rounds, and produces a 512-bit output. +- all the arithmetic is done modulo $2^{64}$. +- the initial hash state is different. + +The SHA-384 algorithm is a truncation of the SHA-512 output to 384 bits, and the only difference is that the initial hash state is different. +In particular, SHA-384 shares its compression function with SHA-512, so users may write guest code that supports SHA-384 by using this crate's SHA-512 compression function. + +## Design Overview + +We support the `SHA256_UPDATE` and `SHA512_UPDATE` intrinsic instructions, each of which takes three operands, `dst`, `state`, and `input`. +- `input` is a pointer to exactly one block of message bytes (the size of a block varies for each SHA-2 variant.) +- `state` is a pointer to the current hasher state (8 words in big endian). Word size varies by SHA-2 variant. +- `dst` is a pointer where the updated state should be stored. `dst` may be equal to `state`. + +The `SHA256_UPDATE` and `SHA512_UPDATE` instructions write the value of the updated hasher state after consuming the input to `dst`. +Note that these instructions do not pad the input. + +### Chips + +The SHA-2 extension consists of 2 chips: `Sha2MainChip` and `Sha2BlockHasherChip`. + +The main chip constraints reading `state` and `input` from memory and writing the new state into `dst`. +The main chip sends the operands to the the block hasher chip via interactions. +The trace of the main chip consists of one row per instruction. + +The block hasher chip constraints the SHA-2 compression algorithm for one block at a time. +More specifically, the trace of the block hasher chip consists of groups of 17 consecutive rows (21 for SHA-512) which together constrain the 64 (resp. 80) rounds of the SHA-2 compression algorithm for one block of input +Each block receives the previous hasher state and the input bytes from the main chip and sends the updated state to the main chip via interactions. +Note that the block hasher chip consists of a SubAir, which constrains all the SHA-2 logic, while the block hasher chip itself only constrains its interactions with the main chip. + + +### Air Design + +We reuse the same AIR code to produce circuits for SHA-256 and SHA-512. +To achieve this, we parameterize the AIR by constants (such as the word size, number of rounds, and block size) that are specific to each algorithm. + +The block hasher AIR consists of $R+1$ rows for each instruction, and no more rows +(for SHA-256, $R = 16$ and for SHA-512 and SHA-384, $R = 20$). +The first $R$ rows of each block are called 'round rows', and each of them constrains four rounds of the hash algorithm. +Each row constrains updates to the working variables on each round, and also constrains the message schedule words based on previous rounds. +The final row of each block is called a 'digest row' and it produces a final hash for the block, computed as the sum of the working variables and the previous block's final hash. + +### Storing working variables + +One optimization is that we only keep track of the `a` and `e` working variables. +It turns out that if we have their values over four consecutive rounds, we can reconstruct all eight variables at the end of the four rounds. +This is because there is overlap between the values of the working variables in adjacent rounds. +If the state is visualized as an array, `s_0 = [a, b, c, d, e, f, g, h]`, then the new state, `s_1`, after one round is produced by a right-shift and an addition. +More formally, +``` +s_1 = (s_0 >> 1) + [T_1 + T_2, 0, 0, 0, T_1, 0, 0, 0] + = [0, a, b, c, d, e, f, g] + [T_1 + T_2, 0, 0, 0, T_1, 0, 0, 0] + = [T_1 + T_2, a, b, c, d + T_1, e, f, g] +``` +where `T_1` and `T_2` are certain functions of the working variables and message data (see the FIPS spec). +So if `a_i` and `e_i` denote the values of `a` and `e` after the `i`th round, for `0 <= i < 4`, then the state `s_3` after the fourth round can be written as `s_3 = [a_3, a_2, a_1, a_0, e_3, e_2, e_1, e_0]`. + +### Message schedule constraints + +The algorithm for computing the message schedule involves message schedule words from 16 rounds ago. +Since we can only constrain two rows at a time, we cannot access data from more than four rounds ago for the first round in each row. +So, we maintain intermediate values that we call `intermed_4`, `intermed_8` and `intermed_12`, where `intermed_i = w_i + sig_0(w_{i+1})` where `w_i` is the value of `w` from `i` rounds ago and `sig_0` denotes the `sigma_0` function from the FIPS spec. +Since we can reliably constrain values from four rounds ago, we can build up `intermed_16` from these values, which is needed for computing the message schedule. + + +### Dummy values + +Some constraints have degree three, and so we cannot restrict them to particular rows due to the limitation of the maximum constraint degree. +We must enforce them on all rows, and in order to ensure they hold on the remaining rows we must fill in some cells with appropriate dummy values. +We use this trick in several places in this chip. diff --git a/extensions/sha256/circuit/build.rs b/extensions/sha2/circuit/build.rs similarity index 59% rename from extensions/sha256/circuit/build.rs rename to extensions/sha2/circuit/build.rs index bcb991c8b9..3826770526 100644 --- a/extensions/sha256/circuit/build.rs +++ b/extensions/sha2/circuit/build.rs @@ -7,18 +7,13 @@ fn main() { if !cuda_available() { return; // Skip CUDA compilation } - let builder: CudaBuilder = CudaBuilder::new() .include_from_dep("DEP_CUDA_COMMON_INCLUDE") .include("../../../crates/circuits/primitives/cuda/include") - .include("../../../crates/circuits/sha256-air/cuda/include") .include("../../../crates/vm/cuda/include") - .watch("cuda") - .watch("../../../crates/circuits/primitives/cuda") - .watch("../../../crates/circuits/sha256-air/cuda") - .watch("../../../crates/vm/cuda") - .library_name("tracegen_gpu_sha256") - .file("cuda/src/sha256.cu"); + .include("cuda/include") + .library_name("tracegen_gpu_sha2") + .files_from_glob("cuda/src/*.cu"); builder.emit_link_directives(); builder.build(); diff --git a/extensions/sha2/circuit/cuda/include/block_hasher/columns.cuh b/extensions/sha2/circuit/cuda/include/block_hasher/columns.cuh new file mode 100644 index 0000000000..622ffe6cff --- /dev/null +++ b/extensions/sha2/circuit/cuda/include/block_hasher/columns.cuh @@ -0,0 +1,110 @@ +#pragma once + +#include +#include + +// Column structs matching the new block-hasher AIR (request_id + inner round/digest columns). +template struct Sha2FlagsCols { + T is_round_row; + T is_first_4_rows; + T is_digest_row; + T row_idx[V::ROW_VAR_CNT]; + T global_block_idx; +}; + +template struct Sha2MessageHelperCols { + T w_3[V::ROUNDS_PER_ROW_MINUS_ONE][V::WORD_U16S]; + T intermed_4[V::ROUNDS_PER_ROW][V::WORD_U16S]; + T intermed_8[V::ROUNDS_PER_ROW][V::WORD_U16S]; + T intermed_12[V::ROUNDS_PER_ROW][V::WORD_U16S]; +}; + +template struct Sha2MessageScheduleCols { + T w[V::ROUNDS_PER_ROW][V::WORD_BITS]; + T carry_or_buffer[V::ROUNDS_PER_ROW][V::WORD_U8S]; +}; + +template struct Sha2WorkVarsCols { + T a[V::ROUNDS_PER_ROW][V::WORD_BITS]; + T e[V::ROUNDS_PER_ROW][V::WORD_BITS]; + T carry_a[V::ROUNDS_PER_ROW][V::WORD_U16S]; + T carry_e[V::ROUNDS_PER_ROW][V::WORD_U16S]; +}; + +template struct Sha2RoundCols { + Sha2FlagsCols flags; + Sha2WorkVarsCols work_vars; + Sha2MessageHelperCols schedule_helper; + Sha2MessageScheduleCols message_schedule; +}; + +template struct Sha2DigestCols { + Sha2FlagsCols flags; + Sha2WorkVarsCols hash; + Sha2MessageHelperCols schedule_helper; + T final_hash[V::HASH_WORDS][V::WORD_U8S]; + T prev_hash[V::HASH_WORDS][V::WORD_U16S]; +}; + +template struct Sha2BlockHasherRoundCols { + T request_id; + Sha2RoundCols inner; +}; + +template struct Sha2BlockHasherDigestCols { + T request_id; + Sha2DigestCols inner; +}; + +template struct Sha2Layout { + static constexpr size_t ROUND_WIDTH = sizeof(Sha2BlockHasherRoundCols); + static constexpr size_t DIGEST_WIDTH = sizeof(Sha2BlockHasherDigestCols); + static constexpr size_t WIDTH = (ROUND_WIDTH > DIGEST_WIDTH) ? ROUND_WIDTH : DIGEST_WIDTH; + static constexpr size_t INNER_OFFSET = sizeof(uint8_t); // request_id + static constexpr size_t INNER_COLUMN_OFFSET = sizeof(uint8_t); +}; + +#define SHA2_COL_INDEX(V, STRUCT, FIELD) \ + (reinterpret_cast(&(reinterpret_cast *>(0)->FIELD))) +#define SHA2_COL_ARRAY_LEN(V, STRUCT, FIELD) \ + (sizeof((reinterpret_cast *>(0)->FIELD))) +#define SHA2_WRITE_VALUE(V, ROW, STRUCT, FIELD, VALUE) \ + (ROW).write(SHA2_COL_INDEX(V, STRUCT, FIELD), VALUE) +#define SHA2_WRITE_ARRAY(V, ROW, STRUCT, FIELD, VALUES) \ + (ROW).write_array( \ + SHA2_COL_INDEX(V, STRUCT, FIELD), SHA2_COL_ARRAY_LEN(V, STRUCT, FIELD), VALUES \ + ) +#define SHA2_WRITE_BITS(V, ROW, STRUCT, FIELD, VALUE) \ + (ROW).write_bits(SHA2_COL_INDEX(V, STRUCT, FIELD), VALUE) +#define SHA2_FILL_ZERO(V, ROW, STRUCT, FIELD) \ + (ROW).fill_zero(SHA2_COL_INDEX(V, STRUCT, FIELD), SHA2_COL_ARRAY_LEN(V, STRUCT, FIELD)) +#define SHA2_SLICE_FROM(V, ROW, STRUCT, FIELD) (ROW).slice_from(SHA2_COL_INDEX(V, STRUCT, FIELD)) + +#define SHA2_WRITE_ROUND(V, row, FIELD, VALUE) \ + SHA2_WRITE_VALUE(V, row, Sha2BlockHasherRoundCols, FIELD, VALUE) +#define SHA2_WRITE_DIGEST(V, row, FIELD, VALUE) \ + SHA2_WRITE_VALUE(V, row, Sha2BlockHasherDigestCols, FIELD, VALUE) +#define SHA2_WRITE_ARRAY_ROUND(V, row, FIELD, VALUES) \ + SHA2_WRITE_ARRAY(V, row, Sha2BlockHasherRoundCols, FIELD, VALUES) +#define SHA2_WRITE_ARRAY_DIGEST(V, row, FIELD, VALUES) \ + SHA2_WRITE_ARRAY(V, row, Sha2BlockHasherDigestCols, FIELD, VALUES) +#define SHA2_FILL_ZERO_ROUND(V, row, FIELD) SHA2_FILL_ZERO(V, row, Sha2BlockHasherRoundCols, FIELD) +#define SHA2_FILL_ZERO_DIGEST(V, row, FIELD) \ + SHA2_FILL_ZERO(V, row, Sha2BlockHasherDigestCols, FIELD) +#define SHA2_SLICE_ROUND(V, row, FIELD) SHA2_SLICE_FROM(V, row, Sha2BlockHasherRoundCols, FIELD) +#define SHA2_SLICE_DIGEST(V, row, FIELD) SHA2_SLICE_FROM(V, row, Sha2BlockHasherDigestCols, FIELD) + +#define SHA2INNER_WRITE_ROUND(V, row, FIELD, VALUE) \ + SHA2_WRITE_VALUE(V, row, Sha2RoundCols, FIELD, VALUE) +#define SHA2INNER_WRITE_DIGEST(V, row, FIELD, VALUE) \ + SHA2_WRITE_VALUE(V, row, Sha2DigestCols, FIELD, VALUE) +#define SHA2INNER_WRITE_ARRAY_ROUND(V, row, FIELD, VALUES) \ + SHA2_WRITE_ARRAY(V, row, Sha2RoundCols, FIELD, VALUES) +#define SHA2INNER_WRITE_ARRAY_DIGEST(V, row, FIELD, VALUES) \ + SHA2_WRITE_ARRAY(V, row, Sha2DigestCols, FIELD, VALUES) +#define SHA2INNER_FILL_ZERO_ROUND(V, row, FIELD) SHA2_FILL_ZERO(V, row, Sha2RoundCols, FIELD) +#define SHA2INNER_FILL_ZERO_DIGEST(V, row, FIELD) SHA2_FILL_ZERO(V, row, Sha2DigestCols, FIELD) +#define SHA2INNER_WRITE_BITS_ROUND(V, row, FIELD, VALUE) \ + SHA2_WRITE_BITS(V, row, Sha2RoundCols, FIELD, VALUE) +#define SHA2INNER_WRITE_BITS_DIGEST(V, row, FIELD, VALUE) \ + SHA2_WRITE_BITS(V, row, Sha2DigestCols, FIELD, VALUE) diff --git a/extensions/sha2/circuit/cuda/include/block_hasher/record.cuh b/extensions/sha2/circuit/cuda/include/block_hasher/record.cuh new file mode 100644 index 0000000000..4ee1bc4d0f --- /dev/null +++ b/extensions/sha2/circuit/cuda/include/block_hasher/record.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include "system/memory/offline_checker.cuh" +#include "variant.cuh" +#include + +namespace sha2 { + +// GPU view of the per-block record produced by the executor (matches Sha2RecordMut in Rust). +template struct Sha2BlockRecordHeader { + uint32_t variant; + uint32_t from_pc; + uint32_t timestamp; + uint32_t dst_reg_ptr; + uint32_t state_reg_ptr; + uint32_t input_reg_ptr; + uint32_t dst_ptr; + uint32_t state_ptr; + uint32_t input_ptr; + MemoryReadAuxRecord register_reads_aux[sha2::SHA2_REGISTER_READS]; +}; + +template struct Sha2BlockRecordMut { + Sha2BlockRecordHeader *header; + uint8_t *message_bytes; + uint8_t *prev_state; + uint8_t *new_state; + MemoryReadAuxRecord *input_reads_aux; + MemoryReadAuxRecord *state_reads_aux; + MemoryWriteBytesAuxRecord *write_aux; + + __device__ __host__ __forceinline__ static uint32_t next_multiple_of( + uint32_t value, + uint32_t alignment + ) { + return ((value + alignment - 1) / alignment) * alignment; + } + + __device__ __host__ __forceinline__ Sha2BlockRecordMut(uint8_t *record_buf) { + header = reinterpret_cast *>(record_buf); + uint32_t offset = sizeof(Sha2BlockRecordHeader); + + message_bytes = record_buf + offset; + offset += V::BLOCK_U8S; + + prev_state = record_buf + offset; + offset += V::STATE_BYTES; + + new_state = record_buf + offset; + offset += V::STATE_BYTES; + + offset = next_multiple_of(offset, alignof(MemoryReadAuxRecord)); + input_reads_aux = reinterpret_cast(record_buf + offset); + offset += V::BLOCK_READS * sizeof(MemoryReadAuxRecord); + + offset = next_multiple_of(offset, alignof(MemoryReadAuxRecord)); + state_reads_aux = reinterpret_cast(record_buf + offset); + offset += V::STATE_READS * sizeof(MemoryReadAuxRecord); + + offset = + next_multiple_of(offset, alignof(MemoryWriteBytesAuxRecord)); + write_aux = reinterpret_cast *>( + record_buf + offset + ); + } +}; + +} // namespace sha2 diff --git a/extensions/sha2/circuit/cuda/include/block_hasher/tracegen.cuh b/extensions/sha2/circuit/cuda/include/block_hasher/tracegen.cuh new file mode 100644 index 0000000000..2da86fd2a0 --- /dev/null +++ b/extensions/sha2/circuit/cuda/include/block_hasher/tracegen.cuh @@ -0,0 +1,71 @@ +#pragma once + +#include "columns.cuh" +#include "fp.h" +#include "primitives/trace_access.h" +#include +#include + +// NOTE: This is a stubbed tracegen implementation to get the CUDA pipeline compiling again. +// It fills rows with zeros and basic flags/request_id; the full round/digest population +// matching the Rust AIR still needs to be implemented. +namespace sha2 { + +template +__device__ inline void write_round_stub( + RowSlice row, + uint32_t request_id, + uint32_t global_block_idx, + uint32_t local_row_idx +) { + row.fill_zero(0, Sha2Layout::ROUND_WIDTH); + RowSlice inner = row.slice_from(Sha2Layout::INNER_OFFSET); + // Mark round rows within the block + if (local_row_idx < V::ROUND_ROWS) { + SHA2INNER_WRITE_ROUND(V, inner, flags.is_round_row, Fp::one()); + SHA2INNER_WRITE_ROUND( + V, + inner, + flags.is_first_4_rows, + (local_row_idx < static_cast(V::MESSAGE_ROWS)) ? Fp::one() : Fp::zero() + ); + SHA2INNER_WRITE_ROUND(V, inner, flags.is_digest_row, Fp::zero()); + SHA2INNER_WRITE_ROUND(V, inner, flags.global_block_idx, global_block_idx); + } else { + // digest rows + SHA2INNER_WRITE_DIGEST(V, inner, flags.is_round_row, Fp::zero()); + SHA2INNER_WRITE_DIGEST(V, inner, flags.is_first_4_rows, Fp::zero()); + SHA2INNER_WRITE_DIGEST(V, inner, flags.is_digest_row, Fp::one()); + SHA2INNER_WRITE_DIGEST(V, inner, flags.global_block_idx, global_block_idx); + } + // Write request_id in the wrapper column + SHA2_WRITE_ROUND(V, row, request_id, Fp(request_id)); +} + +template +__global__ void sha2_block_tracegen_stub( + Fp *trace, + size_t trace_height, + uint8_t *records, + size_t num_records, + size_t *record_offsets +) { + uint32_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (row_idx >= trace_height) { + return; + } + + RowSlice row(trace + row_idx, trace_height); + row.fill_zero(0, Sha2Layout::WIDTH); + + uint32_t record_idx = row_idx / V::ROWS_PER_BLOCK; + uint32_t local_row = row_idx % V::ROWS_PER_BLOCK; + if (record_idx >= num_records) { + return; + } + + // Basic request_id and flags; actual round data is left zeroed for now. + write_round_stub(row, record_idx, record_idx + 1, local_row); +} + +} // namespace sha2 diff --git a/extensions/sha2/circuit/cuda/include/block_hasher/variant.cuh b/extensions/sha2/circuit/cuda/include/block_hasher/variant.cuh new file mode 100644 index 0000000000..ec36a9dd1c --- /dev/null +++ b/extensions/sha2/circuit/cuda/include/block_hasher/variant.cuh @@ -0,0 +1,297 @@ +#pragma once + +#include +#include + +namespace sha2 { + +// Common VM constants across SHA-2 variants. +inline constexpr size_t SHA2_REGISTER_READS = 3; +inline constexpr size_t SHA2_READ_SIZE = 4; +inline constexpr size_t SHA2_WRITE_SIZE = 4; +inline constexpr size_t SHA2_MAIN_READ_SIZE = 4; + +template < + typename WordT, + size_t WORD_BITS_, + size_t BLOCK_WORDS_, + size_t ROUNDS_PER_ROW_, + size_t ROUNDS_PER_BLOCK_, + size_t HASH_WORDS_, + size_t ROW_VAR_CNT_, + size_t MESSAGE_LENGTH_BITS_> +struct Sha2VariantBase { + using Word = WordT; + + static constexpr size_t WORD_BITS = WORD_BITS_; + static constexpr size_t BLOCK_WORDS = BLOCK_WORDS_; + static constexpr size_t ROUNDS_PER_ROW = ROUNDS_PER_ROW_; + static constexpr size_t ROUNDS_PER_BLOCK = ROUNDS_PER_BLOCK_; + static constexpr size_t HASH_WORDS = HASH_WORDS_; + static constexpr size_t ROW_VAR_CNT = ROW_VAR_CNT_; + static constexpr size_t MESSAGE_LENGTH_BITS = MESSAGE_LENGTH_BITS_; + + static constexpr size_t WORD_U16S = WORD_BITS / 16; + static constexpr size_t WORD_U8S = WORD_BITS / 8; + static constexpr size_t WORD_BYTES = WORD_U8S; + static constexpr size_t BLOCK_U8S = BLOCK_WORDS * WORD_U8S; + static constexpr size_t BLOCK_BYTES = BLOCK_U8S; + static constexpr size_t BLOCK_BITS = BLOCK_WORDS * WORD_BITS; + static constexpr size_t ROUND_ROWS = ROUNDS_PER_BLOCK / ROUNDS_PER_ROW; + static constexpr size_t MESSAGE_ROWS = BLOCK_WORDS / ROUNDS_PER_ROW; + static constexpr size_t ROUNDS_PER_ROW_MINUS_ONE = ROUNDS_PER_ROW - 1; + + static constexpr size_t NUM_READ_ROWS = BLOCK_U8S / SHA2_READ_SIZE; + static constexpr size_t STATE_BYTES = HASH_WORDS * WORD_U8S; + static constexpr size_t BLOCK_READS = BLOCK_U8S / SHA2_MAIN_READ_SIZE; + static constexpr size_t STATE_READS = STATE_BYTES / SHA2_MAIN_READ_SIZE; + static constexpr size_t STATE_WRITES = STATE_BYTES / SHA2_WRITE_SIZE; + static constexpr size_t TIMESTAMP_DELTA = + BLOCK_READS + STATE_READS + STATE_WRITES + SHA2_REGISTER_READS; +}; + +// SHA-256 constants +static constexpr uint32_t SHA256_K_HOST[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +}; + +static constexpr uint32_t SHA256_H_HOST[8] = { + 0x6a09e667, + 0xbb67ae85, + 0x3c6ef372, + 0xa54ff53a, + 0x510e527f, + 0x9b05688c, + 0x1f83d9ab, + 0x5be0cd19, +}; + +// SHA-512 constants +static constexpr uint64_t SHA512_K_HOST[80] = { + 0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f, 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, 0x59f111f1b605d019, 0x923f82a4af194f9b, 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, 0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235, 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, 0xefbe4786384f25e3, 0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, 0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5, + 0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f, 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, 0xd5a79147930aa725, 0x06ca6351e003826f, 0x142929670a0e6e70, + 0x27b70a8546d22ffc, 0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df, + 0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6, 0x92722c851482353b, + 0xa2bfe8a14cf10364, 0xa81a664bbc423001, 0xc24b8b70d0f89791, 0xc76c51a30654be30, + 0xd192e819d6ef5218, 0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99, 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb, 0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, 0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec, + 0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915, 0xc67178f2e372532b, + 0xca273eceea26619c, 0xd186b8c721c0c207, 0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, 0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b, + 0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc, 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a, 0x5fcb6fab3ad6faec, 0x6c44198c4a475817, +}; + +static constexpr uint64_t SHA512_H_HOST[8] = { + 0x6a09e667f3bcc908, + 0xbb67ae8584caa73b, + 0x3c6ef372fe94f82b, + 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, + 0x9b05688c2b3e6c1f, + 0x1f83d9abfb41bd6b, + 0x5be0cd19137e2179, +}; + +// Device copies of the constants +// These are only defined when SHA2_DEFINE_DEVICE_CONSTANTS is set (in sha2_hasher.cu) +#ifdef SHA2_DEFINE_DEVICE_CONSTANTS +__device__ __constant__ uint32_t SHA256_K_DEV[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +}; +__device__ __constant__ uint32_t SHA256_H_DEV[8] = { + 0x6a09e667, + 0xbb67ae85, + 0x3c6ef372, + 0xa54ff53a, + 0x510e527f, + 0x9b05688c, + 0x1f83d9ab, + 0x5be0cd19, +}; +__device__ __constant__ uint64_t SHA512_K_DEV[80] = { + 0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f, 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, 0x59f111f1b605d019, 0x923f82a4af194f9b, 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, 0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235, 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, 0xefbe4786384f25e3, 0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, 0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5, + 0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f, 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, 0xd5a79147930aa725, 0x06ca6351e003826f, 0x142929670a0e6e70, + 0x27b70a8546d22ffc, 0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df, + 0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6, 0x92722c851482353b, + 0xa2bfe8a14cf10364, 0xa81a664bbc423001, 0xc24b8b70d0f89791, 0xc76c51a30654be30, + 0xd192e819d6ef5218, 0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99, 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb, 0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, 0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec, + 0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915, 0xc67178f2e372532b, + 0xca273eceea26619c, 0xd186b8c721c0c207, 0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, 0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b, + 0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc, 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a, 0x5fcb6fab3ad6faec, 0x6c44198c4a475817, +}; +__device__ __constant__ uint64_t SHA512_H_DEV[8] = { + 0x6a09e667f3bcc908, + 0xbb67ae8584caa73b, + 0x3c6ef372fe94f82b, + 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, + 0x9b05688c2b3e6c1f, + 0x1f83d9abfb41bd6b, + 0x5be0cd19137e2179, +}; +#else +extern __device__ __constant__ uint32_t SHA256_K_DEV[64]; +extern __device__ __constant__ uint32_t SHA256_H_DEV[8]; +extern __device__ __constant__ uint64_t SHA512_K_DEV[80]; +extern __device__ __constant__ uint64_t SHA512_H_DEV[8]; +#endif + +struct Sha256Variant : Sha2VariantBase { + static constexpr size_t ROWS_PER_BLOCK = 17; + static constexpr int SIGMA_A0 = 2; + static constexpr int SIGMA_A1 = 13; + static constexpr int SIGMA_A2 = 22; + static constexpr int SIGMA_E0 = 6; + static constexpr int SIGMA_E1 = 11; + static constexpr int SIGMA_E2 = 25; + static constexpr int SIGMA0_ROT1 = 7; + static constexpr int SIGMA0_ROT2 = 18; + static constexpr int SIGMA0_SHR = 3; + static constexpr int SIGMA1_ROT1 = 17; + static constexpr int SIGMA1_ROT2 = 19; + static constexpr int SIGMA1_SHR = 10; + + __device__ __host__ static inline Word K(size_t i) { +#ifdef __CUDA_ARCH__ + return SHA256_K_DEV[i]; +#else + return SHA256_K_HOST[i]; +#endif + } + __device__ __host__ static inline Word H(size_t i) { +#ifdef __CUDA_ARCH__ + return SHA256_H_DEV[i]; +#else + return SHA256_H_HOST[i]; +#endif + } +}; + +struct Sha512Variant : Sha2VariantBase { + static constexpr size_t ROWS_PER_BLOCK = 21; + static constexpr int SIGMA_A0 = 28; + static constexpr int SIGMA_A1 = 34; + static constexpr int SIGMA_A2 = 39; + static constexpr int SIGMA_E0 = 14; + static constexpr int SIGMA_E1 = 18; + static constexpr int SIGMA_E2 = 41; + static constexpr int SIGMA0_ROT1 = 1; + static constexpr int SIGMA0_ROT2 = 8; + static constexpr int SIGMA0_SHR = 7; + static constexpr int SIGMA1_ROT1 = 19; + static constexpr int SIGMA1_ROT2 = 61; + static constexpr int SIGMA1_SHR = 6; + + __device__ __host__ static inline Word K(size_t i) { +#ifdef __CUDA_ARCH__ + return SHA512_K_DEV[i]; +#else + return SHA512_K_HOST[i]; +#endif + } + __device__ __host__ static inline Word H(size_t i) { +#ifdef __CUDA_ARCH__ + return SHA512_H_DEV[i]; +#else + return SHA512_H_HOST[i]; +#endif + } +}; + +template +__device__ __host__ __forceinline__ WordT rotr_generic(WordT value, int n); + +template <> inline __device__ __host__ uint32_t rotr_generic(uint32_t value, int n) { + return (value >> n) | (value << (32 - n)); +} + +template <> inline __device__ __host__ uint64_t rotr_generic(uint64_t value, int n) { + return (value >> n) | (value << (64 - n)); +} + +template +__device__ __host__ __forceinline__ typename V::Word big_sig0(typename V::Word x) { + return rotr_generic(x, V::SIGMA_A0) ^ + rotr_generic(x, V::SIGMA_A1) ^ + rotr_generic(x, V::SIGMA_A2); +} + +template +__device__ __host__ __forceinline__ typename V::Word big_sig1(typename V::Word x) { + return rotr_generic(x, V::SIGMA_E0) ^ + rotr_generic(x, V::SIGMA_E1) ^ + rotr_generic(x, V::SIGMA_E2); +} + +template +__device__ __host__ __forceinline__ typename V::Word small_sig0(typename V::Word x) { + return rotr_generic(x, V::SIGMA0_ROT1) ^ + rotr_generic(x, V::SIGMA0_ROT2) ^ (x >> V::SIGMA0_SHR); +} + +template +__device__ __host__ __forceinline__ typename V::Word small_sig1(typename V::Word x) { + return rotr_generic(x, V::SIGMA1_ROT1) ^ + rotr_generic(x, V::SIGMA1_ROT2) ^ (x >> V::SIGMA1_SHR); +} + +template +__device__ __host__ __forceinline__ typename V::Word ch( + typename V::Word x, + typename V::Word y, + typename V::Word z +) { + return (x & y) ^ ((~x) & z); +} + +template +__device__ __host__ __forceinline__ typename V::Word maj( + typename V::Word x, + typename V::Word y, + typename V::Word z +) { + return (x & y) ^ (x & z) ^ (y & z); +} + +template __device__ __host__ __forceinline__ uint32_t get_num_blocks(uint32_t len) { + constexpr uint32_t length_bits = static_cast(V::MESSAGE_LENGTH_BITS); + uint64_t bit_len = static_cast(len) * 8; + uint64_t padded_bit_len = bit_len + 1 + length_bits; + return static_cast((padded_bit_len + (V::BLOCK_BITS - 1)) / V::BLOCK_BITS); +} + +} // namespace sha2 diff --git a/extensions/sha2/circuit/cuda/include/main/columns.cuh b/extensions/sha2/circuit/cuda/include/main/columns.cuh new file mode 100644 index 0000000000..88e054d4e3 --- /dev/null +++ b/extensions/sha2/circuit/cuda/include/main/columns.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include "block_hasher/variant.cuh" +#include "primitives/constants.h" +#include "primitives/execution.h" +#include "system/memory/offline_checker.cuh" +#include +#include + +using namespace riscv; + +namespace sha2 { + +template struct Sha2MainBlockCols { + T request_id; + T message_bytes[V::BLOCK_U8S]; + T prev_state[V::STATE_BYTES]; + T new_state[V::STATE_BYTES]; +}; + +template struct Sha2MainInstructionCols { + T is_enabled; + ExecutionState from_state; + T dst_reg_ptr; + T state_reg_ptr; + T input_reg_ptr; + T dst_ptr_limbs[RV32_REGISTER_NUM_LIMBS]; + T state_ptr_limbs[RV32_REGISTER_NUM_LIMBS]; + T input_ptr_limbs[RV32_REGISTER_NUM_LIMBS]; +}; + +template struct Sha2MainMemoryCols { + MemoryReadAuxCols register_aux[sha2::SHA2_REGISTER_READS]; + MemoryReadAuxCols input_reads[V::BLOCK_READS]; + MemoryReadAuxCols state_reads[V::STATE_READS]; + MemoryWriteAuxCols write_aux[V::STATE_WRITES]; +}; + +template struct Sha2MainCols { + Sha2MainBlockCols block; + Sha2MainInstructionCols instruction; + Sha2MainMemoryCols mem; +}; + +template struct Sha2MainLayout { + static constexpr size_t WIDTH = sizeof(Sha2MainCols); +}; + +#define SHA2_MAIN_COL_INDEX_V(V, STRUCT, FIELD) \ + (reinterpret_cast(&(reinterpret_cast *>(0)->FIELD))) +#define SHA2_MAIN_COL_INDEX_PLAIN(STRUCT, FIELD) \ + (reinterpret_cast(&(reinterpret_cast *>(0)->FIELD))) + +#define SHA2_MAIN_COL_ARRAY_LEN_V(V, STRUCT, FIELD) \ + (sizeof((reinterpret_cast *>(0)->FIELD))) +#define SHA2_MAIN_COL_ARRAY_LEN_PLAIN(STRUCT, FIELD) \ + (sizeof((reinterpret_cast *>(0)->FIELD))) + +#define SHA2_MAIN_WRITE_VALUE_V(V, ROW, STRUCT, FIELD, VALUE) \ + (ROW).write(SHA2_MAIN_COL_INDEX_V(V, STRUCT, FIELD), VALUE) +#define SHA2_MAIN_WRITE_VALUE_PLAIN(ROW, STRUCT, FIELD, VALUE) \ + (ROW).write(SHA2_MAIN_COL_INDEX_PLAIN(STRUCT, FIELD), VALUE) + +#define SHA2_MAIN_WRITE_ARRAY_V(V, ROW, STRUCT, FIELD, VALUES) \ + (ROW).write_array( \ + SHA2_MAIN_COL_INDEX_V(V, STRUCT, FIELD), \ + SHA2_MAIN_COL_ARRAY_LEN_V(V, STRUCT, FIELD), \ + VALUES \ + ) +#define SHA2_MAIN_WRITE_ARRAY_PLAIN(ROW, STRUCT, FIELD, VALUES) \ + (ROW).write_array( \ + SHA2_MAIN_COL_INDEX_PLAIN(STRUCT, FIELD), \ + SHA2_MAIN_COL_ARRAY_LEN_PLAIN(STRUCT, FIELD), \ + VALUES \ + ) + +#define SHA2_MAIN_FILL_ZERO_V(V, ROW, STRUCT, FIELD) \ + (ROW).fill_zero(SHA2_MAIN_COL_INDEX_V(V, STRUCT, FIELD), SHA2_MAIN_COL_ARRAY_LEN_V(V, STRUCT, FIELD)) +#define SHA2_MAIN_SLICE_FROM_V(V, ROW, STRUCT, FIELD) \ + (ROW).slice_from(SHA2_MAIN_COL_INDEX_V(V, STRUCT, FIELD)) + +// Compute offset of nested struct field: offsetof(Sha2MainCols, block) + offsetof(Sha2MainBlockCols, FIELD) +#define SHA2_MAIN_COL_INDEX_BLOCK_V(V, FIELD) \ + (SHA2_MAIN_COL_INDEX_V(V, Sha2MainCols, block) + \ + SHA2_MAIN_COL_INDEX_V(V, Sha2MainBlockCols, FIELD)) +// Compute offset of nested struct field: offsetof(Sha2MainCols, instruction) + offsetof(Sha2MainInstructionCols, FIELD) +#define SHA2_MAIN_COL_INDEX_INSTR(V, FIELD) \ + (SHA2_MAIN_COL_INDEX_V(V, Sha2MainCols, instruction) + \ + SHA2_MAIN_COL_INDEX_PLAIN(Sha2MainInstructionCols, FIELD)) +// Compute offset of nested struct field: offsetof(Sha2MainCols, mem) + offsetof(Sha2MainMemoryCols, FIELD) +#define SHA2_MAIN_COL_INDEX_MEM_V(V, FIELD) \ + (SHA2_MAIN_COL_INDEX_V(V, Sha2MainCols, mem) + \ + SHA2_MAIN_COL_INDEX_V(V, Sha2MainMemoryCols, FIELD)) + +#define SHA2_MAIN_WRITE_BLOCK(V, ROW, FIELD, VALUE) \ + (ROW).write(SHA2_MAIN_COL_INDEX_BLOCK_V(V, FIELD), VALUE) +#define SHA2_MAIN_WRITE_ARRAY_BLOCK(V, ROW, FIELD, VALUES) \ + (ROW).write_array( \ + SHA2_MAIN_COL_INDEX_BLOCK_V(V, FIELD), \ + SHA2_MAIN_COL_ARRAY_LEN_V(V, Sha2MainBlockCols, FIELD), \ + VALUES \ + ) + +#define SHA2_MAIN_WRITE_INSTR(V, ROW, FIELD, VALUE) \ + (ROW).write(SHA2_MAIN_COL_INDEX_INSTR(V, FIELD), VALUE) +#define SHA2_MAIN_WRITE_ARRAY_INSTR(V, ROW, FIELD, VALUES) \ + (ROW).write_array( \ + SHA2_MAIN_COL_INDEX_INSTR(V, FIELD), \ + SHA2_MAIN_COL_ARRAY_LEN_PLAIN(Sha2MainInstructionCols, FIELD), \ + VALUES \ + ) + +#define SHA2_MAIN_SLICE_MEM(V, ROW, FIELD) \ + (ROW).slice_from(SHA2_MAIN_COL_INDEX_MEM_V(V, FIELD)) + +} // namespace sha2 diff --git a/extensions/sha2/circuit/cuda/include/main/record.cuh b/extensions/sha2/circuit/cuda/include/main/record.cuh new file mode 100644 index 0000000000..570946d8b5 --- /dev/null +++ b/extensions/sha2/circuit/cuda/include/main/record.cuh @@ -0,0 +1,70 @@ +#pragma once + +#include "block_hasher/variant.cuh" +#include "system/memory/offline_checker.cuh" +#include + +namespace sha2 { + +template struct Sha2RecordHeader { + uint32_t variant; + uint32_t from_pc; + uint32_t timestamp; + uint32_t dst_reg_ptr; + uint32_t state_reg_ptr; + uint32_t input_reg_ptr; + uint32_t dst_ptr; + uint32_t state_ptr; + uint32_t input_ptr; + MemoryReadAuxRecord register_reads_aux[sha2::SHA2_REGISTER_READS]; +}; + +template struct Sha2RecordMut { + Sha2RecordHeader *header; + uint8_t *message_bytes; + uint8_t *prev_state; + uint8_t *new_state; + MemoryReadAuxRecord *input_reads_aux; + MemoryReadAuxRecord *state_reads_aux; + MemoryWriteBytesAuxRecord *write_aux; + + __device__ __host__ __forceinline__ static uint32_t next_multiple_of( + uint32_t value, + uint32_t alignment + ) { + return ((value + alignment - 1) / alignment) * alignment; + } + + __device__ __host__ __forceinline__ Sha2RecordMut(uint8_t *record_buf) { + header = reinterpret_cast *>(record_buf); + uint32_t offset = sizeof(Sha2RecordHeader); + + message_bytes = record_buf + offset; + offset += V::BLOCK_U8S; + + prev_state = record_buf + offset; + offset += V::STATE_BYTES; + + new_state = record_buf + offset; + offset += V::STATE_BYTES; + + offset = next_multiple_of(offset, alignof(MemoryReadAuxRecord)); + + input_reads_aux = reinterpret_cast(record_buf + offset); + offset += V::BLOCK_READS * sizeof(MemoryReadAuxRecord); + + offset = next_multiple_of(offset, alignof(MemoryReadAuxRecord)); + + state_reads_aux = reinterpret_cast(record_buf + offset); + offset += V::STATE_READS * sizeof(MemoryReadAuxRecord); + + offset = next_multiple_of( + offset, alignof(MemoryWriteBytesAuxRecord) + ); + write_aux = reinterpret_cast *>( + record_buf + offset + ); + } +}; + +} // namespace sha2 diff --git a/extensions/sha2/circuit/cuda/include/variant.cuh b/extensions/sha2/circuit/cuda/include/variant.cuh new file mode 100644 index 0000000000..7f46878412 --- /dev/null +++ b/extensions/sha2/circuit/cuda/include/variant.cuh @@ -0,0 +1,6 @@ +#pragma once + +// Expose the shared SHA-2 variant definitions at the top-level include dir so both +// main/ and block_hasher/ headers can include "variant.cuh". +#include "block_hasher/variant.cuh" + diff --git a/extensions/sha2/circuit/cuda/src/sha2_hasher.cu b/extensions/sha2/circuit/cuda/src/sha2_hasher.cu new file mode 100644 index 0000000000..fb37585362 --- /dev/null +++ b/extensions/sha2/circuit/cuda/src/sha2_hasher.cu @@ -0,0 +1,1066 @@ +#define SHA2_DEFINE_DEVICE_CONSTANTS + +#include "block_hasher/columns.cuh" +#include "block_hasher/record.cuh" +#include "block_hasher/variant.cuh" +#include "fp.h" +#include "launcher.cuh" +#include "primitives/encoder.cuh" +#include "primitives/histogram.cuh" +#include "primitives/trace_access.h" +#include +#include +#include + +using namespace riscv; +using namespace sha2; + +// === Utility helpers for SHA-2 block hasher === +template +__device__ __forceinline__ typename V::Word word_from_bytes_be(const uint8_t *bytes) { + typename V::Word acc = 0; +#pragma unroll + for (int i = 0; i < static_cast(V::WORD_U8S); i++) { + acc = (acc << 8) | static_cast(bytes[i]); + } + return acc; +} + +template +__device__ __forceinline__ typename V::Word word_from_bytes_le(const uint8_t *bytes) { + typename V::Word acc = 0; +#pragma unroll + for (int i = static_cast(V::WORD_U8S) - 1; i >= 0; i--) { + acc = (acc << 8) | static_cast(bytes[i]); + } + return acc; +} + +template +__device__ __forceinline__ uint32_t word_to_u16_limb(typename V::Word w, int limb) { + return static_cast((w >> (16 * limb)) & static_cast(0xFFFF)); +} + +template +__device__ __forceinline__ uint32_t word_to_u8_limb(typename V::Word w, int limb) { + return static_cast((w >> (8 * limb)) & static_cast(0xFF)); +} + +// === Shared helpers mirroring the CPU tracegen structure === +template struct Sha2TraceHelper { + Encoder row_idx_encoder; + + __device__ Sha2TraceHelper() + : row_idx_encoder(static_cast(V::ROWS_PER_BLOCK + 1), 2, false) {} + + __device__ __forceinline__ size_t base_a(uint32_t row_idx) const { + return SHA2_COL_INDEX(V, Sha2RoundCols, work_vars.a[row_idx]); + } + + __device__ __forceinline__ size_t base_e(uint32_t row_idx) const { + return SHA2_COL_INDEX(V, Sha2RoundCols, work_vars.e[row_idx]); + } + + __device__ __forceinline__ size_t base_carry_a(uint32_t row_idx) const { + return SHA2_COL_INDEX(V, Sha2RoundCols, work_vars.carry_a[row_idx]); + } + + __device__ __forceinline__ size_t base_carry_e(uint32_t row_idx) const { + return SHA2_COL_INDEX(V, Sha2RoundCols, work_vars.carry_e[row_idx]); + } + + __device__ __forceinline__ void read_w(RowSlice inner, uint32_t j, Fp *w_limbs) const { + size_t base = SHA2_COL_INDEX(V, Sha2RoundCols, message_schedule.w[j]); + for (int limb = 0; limb < V::WORD_U16S; limb++) { + w_limbs[limb] = Fp::zero(); + for (int bit = 0; bit < 16; bit++) { + w_limbs[limb] += inner[base + bit] * Fp(1 << bit); + } + base += 16; + } + } + + __device__ __forceinline__ Fp read_carry_fp(RowSlice inner, uint32_t i, uint32_t limb) const { + size_t base = SHA2_COL_INDEX(V, Sha2RoundCols, message_schedule.carry_or_buffer[i]); + Fp low = inner[base + limb * 2]; + Fp high = inner[base + limb * 2 + 1]; + return low + high + high; // low + 2 * high + } + + __device__ __forceinline__ void read_word_bits( + RowSlice inner, + size_t base, + Fp *dst_bits + ) const { +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst_bits[bit] = inner[base + bit]; + } + } + + __device__ __forceinline__ void read_w_bits(RowSlice inner, uint32_t j, Fp *dst_bits) const { + read_word_bits(inner, SHA2_COL_INDEX(V, Sha2RoundCols, message_schedule.w[j]), dst_bits); + } + + __device__ __forceinline__ Fp xor_fp(Fp a, Fp b) const { return a + b - Fp(2) * a * b; } + + __device__ __forceinline__ Fp xor_fp(Fp a, Fp b, Fp c) const { return xor_fp(xor_fp(a, b), c); } + + __device__ __forceinline__ Fp ch_fp(Fp x, Fp y, Fp z) const { return x * y + z - x * z; } + + __device__ __forceinline__ Fp maj_fp(Fp x, Fp y, Fp z) const { + return x * y + x * z + y * z - Fp(2) * x * y * z; + } + + __device__ __forceinline__ void rotr_bits(const Fp *src, uint32_t rot, Fp *dst) const { +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = src[(bit + rot) % V::WORD_BITS]; + } + } + + __device__ __forceinline__ void shr_bits(const Fp *src, uint32_t shift, Fp *dst) const { +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = (bit + shift < V::WORD_BITS) ? src[bit + shift] : Fp::zero(); + } + } + + __device__ __forceinline__ void big_sig0_bits(const Fp *src, Fp *dst) const { + if (V::WORD_BITS == 32) { + Fp r2[V::WORD_BITS], r13[V::WORD_BITS], r22[V::WORD_BITS]; + rotr_bits(src, 2, r2); + rotr_bits(src, 13, r13); + rotr_bits(src, 22, r22); +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = xor_fp(r2[bit], r13[bit], r22[bit]); + } + } else { + Fp r28[V::WORD_BITS], r34[V::WORD_BITS], r39[V::WORD_BITS]; + rotr_bits(src, 28, r28); + rotr_bits(src, 34, r34); + rotr_bits(src, 39, r39); +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = xor_fp(r28[bit], r34[bit], r39[bit]); + } + } + } + + __device__ __forceinline__ void big_sig1_bits(const Fp *src, Fp *dst) const { + if (V::WORD_BITS == 32) { + Fp r6[V::WORD_BITS], r11[V::WORD_BITS], r25[V::WORD_BITS]; + rotr_bits(src, 6, r6); + rotr_bits(src, 11, r11); + rotr_bits(src, 25, r25); +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = xor_fp(r6[bit], r11[bit], r25[bit]); + } + } else { + Fp r14[V::WORD_BITS], r18[V::WORD_BITS], r41[V::WORD_BITS]; + rotr_bits(src, 14, r14); + rotr_bits(src, 18, r18); + rotr_bits(src, 41, r41); +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = xor_fp(r14[bit], r18[bit], r41[bit]); + } + } + } + + __device__ __forceinline__ void small_sig0_bits(const Fp *src, Fp *dst) const { + if (V::WORD_BITS == 32) { + Fp r7[V::WORD_BITS], r18[V::WORD_BITS], s3[V::WORD_BITS]; + rotr_bits(src, 7, r7); + rotr_bits(src, 18, r18); + shr_bits(src, 3, s3); +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = xor_fp(r7[bit], r18[bit], s3[bit]); + } + } else { + Fp r1[V::WORD_BITS], r8[V::WORD_BITS], s7[V::WORD_BITS]; + rotr_bits(src, 1, r1); + rotr_bits(src, 8, r8); + shr_bits(src, 7, s7); +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = xor_fp(r1[bit], r8[bit], s7[bit]); + } + } + } + + __device__ __forceinline__ void small_sig1_bits(const Fp *src, Fp *dst) const { + if (V::WORD_BITS == 32) { + Fp r17[V::WORD_BITS], r19[V::WORD_BITS], s10[V::WORD_BITS]; + rotr_bits(src, 17, r17); + rotr_bits(src, 19, r19); + shr_bits(src, 10, s10); +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = xor_fp(r17[bit], r19[bit], s10[bit]); + } + } else { + Fp r19[V::WORD_BITS], r61[V::WORD_BITS], s6[V::WORD_BITS]; + rotr_bits(src, 19, r19); + rotr_bits(src, 61, r61); + shr_bits(src, 6, s6); +#pragma unroll + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + dst[bit] = xor_fp(r19[bit], r61[bit], s6[bit]); + } + } + } + + __device__ __forceinline__ Fp compose_u16_limb(const Fp *bits, uint32_t limb) const { + Fp acc = Fp::zero(); +#pragma unroll + for (uint32_t bit = 0; bit < 16; bit++) { + acc += bits[limb * 16 + bit] * Fp(1u << bit); + } + return acc; + } + + __device__ void write_flags_round( + RowSlice inner_row, + uint32_t row_idx, + uint32_t global_block_idx + ) const { + SHA2INNER_WRITE_ROUND(V, inner_row, flags.is_round_row, Fp::one()); + SHA2INNER_WRITE_ROUND( + V, + inner_row, + flags.is_first_4_rows, + (row_idx < static_cast(V::MESSAGE_ROWS)) ? Fp::one() : Fp::zero() + ); + SHA2INNER_WRITE_ROUND(V, inner_row, flags.is_digest_row, Fp::zero()); + RowSlice row_idx_flags = + inner_row.slice_from(SHA2_COL_INDEX(V, Sha2RoundCols, flags.row_idx)); + row_idx_encoder.write_flag_pt(row_idx_flags, row_idx); + SHA2INNER_WRITE_ROUND(V, inner_row, flags.global_block_idx, Fp(global_block_idx)); + } + + __device__ void write_flags_digest( + RowSlice inner_row, + uint32_t row_idx, + uint32_t global_block_idx + ) const { + SHA2INNER_WRITE_DIGEST(V, inner_row, flags.is_round_row, Fp::zero()); + SHA2INNER_WRITE_DIGEST(V, inner_row, flags.is_first_4_rows, Fp::zero()); + SHA2INNER_WRITE_DIGEST(V, inner_row, flags.is_digest_row, Fp::one()); + RowSlice row_idx_flags = + inner_row.slice_from(SHA2_COL_INDEX(V, Sha2DigestCols, flags.row_idx)); + row_idx_encoder.write_flag_pt(row_idx_flags, row_idx); + SHA2INNER_WRITE_DIGEST(V, inner_row, flags.global_block_idx, Fp(global_block_idx)); + } + + __device__ void generate_carry_ae(RowSlice local_inner, RowSlice next_inner) const { + Fp a_bits[2 * V::ROUNDS_PER_ROW][V::WORD_BITS]; + Fp e_bits[2 * V::ROUNDS_PER_ROW][V::WORD_BITS]; +#pragma unroll + for (uint32_t i = 0; i < V::ROUNDS_PER_ROW; i++) { + read_word_bits(local_inner, base_a(i), a_bits[i]); + read_word_bits(next_inner, base_a(i), a_bits[i + V::ROUNDS_PER_ROW]); + read_word_bits(local_inner, base_e(i), e_bits[i]); + read_word_bits(next_inner, base_e(i), e_bits[i + V::ROUNDS_PER_ROW]); + } + + const Fp pow16_inv = inv(Fp(1u << 16)); + + for (uint32_t i = 0; i < V::ROUNDS_PER_ROW; i++) { + Fp sig_a[V::WORD_BITS]; + Fp sig_e[V::WORD_BITS]; + Fp maj_abc[V::WORD_BITS]; + Fp ch_efg[V::WORD_BITS]; + + big_sig0_bits(a_bits[i + 3], sig_a); + big_sig1_bits(e_bits[i + 3], sig_e); + for (uint32_t bit = 0; bit < V::WORD_BITS; bit++) { + maj_abc[bit] = maj_fp(a_bits[i + 3][bit], a_bits[i + 2][bit], a_bits[i + 1][bit]); + ch_efg[bit] = ch_fp(e_bits[i + 3][bit], e_bits[i + 2][bit], e_bits[i + 1][bit]); + } + + Fp prev_carry_a = Fp::zero(); + Fp prev_carry_e = Fp::zero(); + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + Fp t1_sum = compose_u16_limb(e_bits[i], limb) + compose_u16_limb(sig_e, limb) + + compose_u16_limb(ch_efg, limb); + Fp t2_sum = compose_u16_limb(sig_a, limb) + compose_u16_limb(maj_abc, limb); + Fp d_limb = compose_u16_limb(a_bits[i], limb); + Fp cur_a_limb = compose_u16_limb(a_bits[i + 4], limb); + Fp cur_e_limb = compose_u16_limb(e_bits[i + 4], limb); + + Fp e_sum = d_limb + t1_sum + + (limb == 0 ? Fp::zero() : next_inner[base_carry_e(i) + limb - 1]); + Fp a_sum = t1_sum + t2_sum + + (limb == 0 ? Fp::zero() : next_inner[base_carry_a(i) + limb - 1]); + Fp carry_e = (e_sum - cur_e_limb) * pow16_inv; + Fp carry_a = (a_sum - cur_a_limb) * pow16_inv; + + SHA2INNER_WRITE_ROUND(V, next_inner, work_vars.carry_e[i][limb], Fp(carry_e)); + SHA2INNER_WRITE_ROUND(V, next_inner, work_vars.carry_a[i][limb], Fp(carry_a)); + + prev_carry_e = carry_e; + prev_carry_a = carry_a; + } + } + } + + __device__ void generate_intermed_4(RowSlice local_inner, RowSlice next_inner) const { + Fp w_bits[2 * V::ROUNDS_PER_ROW][V::WORD_BITS]; + Fp w_limbs[2 * V::ROUNDS_PER_ROW][V::WORD_U16S]; +#pragma unroll + for (uint32_t j = 0; j < V::ROUNDS_PER_ROW; j++) { + read_w_bits(local_inner, j, w_bits[j]); + read_w_bits(next_inner, j, w_bits[j + V::ROUNDS_PER_ROW]); + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + w_limbs[j][limb] = compose_u16_limb(w_bits[j], limb); + w_limbs[j + V::ROUNDS_PER_ROW][limb] = + compose_u16_limb(w_bits[j + V::ROUNDS_PER_ROW], limb); + } + } + + for (uint32_t i = 0; i < V::ROUNDS_PER_ROW; i++) { + Fp sig_bits[V::WORD_BITS]; + Fp sig_limbs[V::WORD_U16S]; + + small_sig0_bits(w_bits[i + 1], sig_bits); + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + sig_limbs[limb] = compose_u16_limb(sig_bits, limb); + } +#pragma unroll + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + Fp val = w_limbs[i][limb] + sig_limbs[limb]; + SHA2INNER_WRITE_ROUND(V, next_inner, schedule_helper.intermed_4[i][limb], val); + } + } + } + + __device__ void generate_intermed_12(RowSlice local_inner, RowSlice next_inner) const { + Fp w_bits[2 * V::ROUNDS_PER_ROW][V::WORD_BITS]; + Fp w_limbs[2 * V::ROUNDS_PER_ROW][V::WORD_U16S]; +#pragma unroll + for (uint32_t j = 0; j < V::ROUNDS_PER_ROW; j++) { + read_w_bits(local_inner, j, w_bits[j]); + read_w_bits(next_inner, j, w_bits[j + V::ROUNDS_PER_ROW]); + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + w_limbs[j][limb] = compose_u16_limb(w_bits[j], limb); + w_limbs[j + V::ROUNDS_PER_ROW][limb] = + compose_u16_limb(w_bits[j + V::ROUNDS_PER_ROW], limb); + } + } + + for (uint32_t i = 0; i < V::ROUNDS_PER_ROW; i++) { + Fp sig_bits[V::WORD_BITS]; + Fp sig_limbs[V::WORD_U16S]; + + small_sig1_bits(w_bits[i + 2], sig_bits); + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + sig_limbs[limb] = compose_u16_limb(sig_bits, limb); + } + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + Fp carry = read_carry_fp(next_inner, i, limb); + Fp prev_carry = (limb > 0) ? read_carry_fp(next_inner, i, limb - 1) : Fp::zero(); + Fp w7_limb = (i < 3) + ? local_inner[SHA2_COL_INDEX( + V, Sha2RoundCols, schedule_helper.w_3[i][limb] + )] + : w_limbs[i - 3][limb]; + Fp w_cur = w_limbs[i + 4][limb]; + Fp sum = sig_limbs[limb] + w7_limb - carry * Fp(1u << 16) - w_cur + prev_carry; + Fp intermed = -sum; + SHA2INNER_WRITE_ROUND( + V, local_inner, schedule_helper.intermed_12[i][limb], intermed + ); + } + } + } + + __device__ void generate_default_row( + RowSlice inner_row, + const typename V::Word *first_block_prev_hash, + Fp *carry_a, + Fp *carry_e, + size_t trace_height + ) const { + RowSlice row_idx_flags = + inner_row.slice_from(SHA2_COL_INDEX(V, Sha2RoundCols, flags.row_idx)); + row_idx_encoder.write_flag_pt(row_idx_flags, V::ROWS_PER_BLOCK); + + for (uint32_t i = 0; i < V::ROUNDS_PER_ROW; i++) { + uint32_t a_idx = V::ROUNDS_PER_ROW - i - 1; + uint32_t e_idx = V::ROUNDS_PER_ROW - i + 3; + SHA2INNER_WRITE_BITS_ROUND(V, inner_row, work_vars.a[i], first_block_prev_hash[a_idx]); + SHA2INNER_WRITE_BITS_ROUND(V, inner_row, work_vars.e[i], first_block_prev_hash[e_idx]); + } + + if (carry_a && carry_e) { + for (uint32_t i = 0; i < V::ROUNDS_PER_ROW; i++) { + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + SHA2INNER_WRITE_ROUND( + V, + inner_row, + work_vars.carry_a[i][limb], + carry_a[(i * V::WORD_U16S + limb) * trace_height] + ); + SHA2INNER_WRITE_ROUND( + V, + inner_row, + work_vars.carry_e[i][limb], + carry_e[(i * V::WORD_U16S + limb) * trace_height] + ); + } + } + } + } + + __device__ void generate_missing_cells( + Fp *trace, + size_t trace_height, + uint32_t block_idx + ) const { + trace += 1; // skip the first row of the trace + uint32_t block_row_base = block_idx * V::ROWS_PER_BLOCK; + uint32_t last_round_row = block_row_base + (V::ROUND_ROWS - 2); + uint32_t digest_row = block_row_base + (V::ROUND_ROWS - 1); + uint32_t next_block_row_base = block_row_base + V::ROUND_ROWS; + + if (last_round_row >= trace_height || digest_row >= trace_height || + next_block_row_base >= trace_height) { + return; + } + + RowSlice last_round_row_slice(trace + last_round_row, trace_height); + RowSlice last_round_inner = + last_round_row_slice.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + RowSlice digest_row_slice(trace + digest_row, trace_height); + RowSlice digest_inner = digest_row_slice.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + RowSlice next_row_slice(trace + next_block_row_base, trace_height); + RowSlice next_inner = next_row_slice.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + + generate_intermed_12(last_round_inner, digest_inner); + generate_intermed_12(digest_inner, next_inner); + generate_intermed_4(digest_inner, next_inner); + } +}; + +// ===== BLOCK HASHER KERNELS ===== +template +__global__ void sha2_hash_computation( + uint8_t *records, + size_t num_records, + size_t *record_offsets, + typename V::Word *prev_hashes, + uint32_t total_num_blocks +) { + uint32_t block_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (block_idx >= total_num_blocks) { + return; + } + + if (block_idx >= num_records) { + return; + } + + uint32_t offset = record_offsets[block_idx]; + Sha2BlockRecordMut record(records + offset); +#pragma unroll + for (int i = 0; i < static_cast(V::HASH_WORDS); i++) { + const uint8_t *ptr = record.prev_state + i * V::WORD_U8S; + prev_hashes[block_idx * V::HASH_WORDS + i] = word_from_bytes_le(ptr); + } +} + +template +__global__ void sha2_first_pass_tracegen( + Fp *trace, + size_t trace_height, + uint8_t *records, + size_t num_records, + size_t *record_offsets, + uint32_t total_num_blocks, + typename V::Word *prev_hashes, + uint32_t /*ptr_max_bits*/, + uint32_t * /*range_checker_ptr*/, + uint32_t /*range_checker_num_bins*/, + uint32_t *bitwise_lookup_ptr, + uint32_t bitwise_num_bits, + uint32_t /*timestamp_max_bits*/ +) { + uint32_t global_block_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (global_block_idx >= total_num_blocks) { + return; + } + + uint32_t record_idx = global_block_idx; + if (record_idx >= num_records) { + return; + } + + uint32_t trace_start_row = global_block_idx * V::ROWS_PER_BLOCK; + if (trace_start_row + V::ROWS_PER_BLOCK > trace_height) { + return; + } + + Sha2TraceHelper helper; + Sha2BlockRecordMut record(records + record_offsets[record_idx]); + const typename V::Word *prev_hash = prev_hashes + global_block_idx * V::HASH_WORDS; + const typename V::Word *next_block_prev_hash = + prev_hashes + ((global_block_idx + 1) % total_num_blocks) * V::HASH_WORDS; + + BitwiseOperationLookup bitwise_lookup(bitwise_lookup_ptr, bitwise_num_bits); + + typename V::Word w_schedule[V::ROUNDS_PER_BLOCK] = {}; +#pragma unroll + for (int i = 0; i < static_cast(V::BLOCK_WORDS); i++) { + w_schedule[i] = word_from_bytes_be(record.message_bytes + i * V::WORD_U8S); + } + + typename V::Word a = prev_hash[0]; + typename V::Word b = prev_hash[1]; + typename V::Word c = prev_hash[2]; + typename V::Word d = prev_hash[3]; + typename V::Word e = prev_hash[4]; + typename V::Word f = prev_hash[5]; + typename V::Word g = prev_hash[6]; + typename V::Word h = prev_hash[7]; + + for (uint32_t row_in_block = 0; row_in_block < V::ROWS_PER_BLOCK; row_in_block++) { + uint32_t absolute_row = trace_start_row + row_in_block; + if (absolute_row >= trace_height) { + return; + } + + RowSlice row(trace + absolute_row, trace_height); + row.fill_zero(0, Sha2Layout::WIDTH); + + if (row_in_block < V::ROUND_ROWS) { + SHA2_WRITE_ROUND(V, row, request_id, Fp(record_idx)); + RowSlice inner_row = row.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + helper.write_flags_round(inner_row, row_in_block, global_block_idx + 1); + + for (uint32_t j = 0; j < V::ROUNDS_PER_ROW; j++) { + uint32_t t = row_in_block * V::ROUNDS_PER_ROW + j; + typename V::Word w_val; + if (t < V::BLOCK_WORDS) { + w_val = w_schedule[t]; + } else { + typename V::Word nums[4] = { + sha2::small_sig1(w_schedule[t - 2]), + w_schedule[t - 7], + sha2::small_sig0(w_schedule[t - 15]), + w_schedule[t - 16], + }; + w_val = nums[0] + nums[1] + nums[2] + nums[3]; + w_schedule[t] = w_val; + +#pragma unroll + for (int limb = 0; limb < static_cast(V::WORD_U16S); limb++) { + uint32_t sum = 0; +#pragma unroll + for (auto num : nums) { + sum += word_to_u16_limb(num, limb); + } + if (limb > 0) { + size_t carry_base = SHA2_COL_INDEX( + V, Sha2RoundCols, message_schedule.carry_or_buffer[j] + ); + sum += inner_row[carry_base + limb * 2 - 2].asUInt32() + + (inner_row[carry_base + limb * 2 - 1].asUInt32() << 1); + } + uint32_t carry = (sum - word_to_u16_limb(w_val, limb)) >> 16; + SHA2INNER_WRITE_ROUND( + V, + inner_row, + message_schedule.carry_or_buffer[j][limb * 2], + Fp(carry & 1) + ); + SHA2INNER_WRITE_ROUND( + V, + inner_row, + message_schedule.carry_or_buffer[j][limb * 2 + 1], + Fp((carry >> 1) & 1) + ); + } + } + + SHA2_WRITE_BITS(V, inner_row, Sha2RoundCols, message_schedule.w[j], w_val); + + typename V::Word t1 = + h + sha2::big_sig1(e) + sha2::ch(e, f, g) + V::K(t) + w_val; + typename V::Word t2 = sha2::big_sig0(a) + sha2::maj(a, b, c); + + typename V::Word new_e = d + t1; + typename V::Word new_a = t1 + t2; + + SHA2_WRITE_BITS(V, inner_row, Sha2RoundCols, work_vars.e[j], new_e); + SHA2_WRITE_BITS(V, inner_row, Sha2RoundCols, work_vars.a[j], new_a); + +#pragma unroll + for (int limb = 0; limb < static_cast(V::WORD_U16S); limb++) { + uint32_t t1_limb = word_to_u16_limb(h, limb) + + word_to_u16_limb(sha2::big_sig1(e), limb) + + word_to_u16_limb(sha2::ch(e, f, g), limb) + + word_to_u16_limb(V::K(t), limb) + + word_to_u16_limb(w_val, limb); + uint32_t t2_limb = word_to_u16_limb(sha2::big_sig0(a), limb) + + word_to_u16_limb(sha2::maj(a, b, c), limb); + + uint32_t prev_carry_e = + (limb > 0) ? inner_row[SHA2_COL_INDEX( + V, Sha2RoundCols, work_vars.carry_e[j][limb - 1] + )] + .asUInt32() + : 0; + uint32_t prev_carry_a = + (limb > 0) ? inner_row[SHA2_COL_INDEX( + V, Sha2RoundCols, work_vars.carry_a[j][limb - 1] + )] + .asUInt32() + : 0; + uint32_t e_sum = t1_limb + word_to_u16_limb(d, limb) + prev_carry_e; + uint32_t a_sum = t1_limb + t2_limb + prev_carry_a; + uint32_t c_e = (e_sum - word_to_u16_limb(new_e, limb)) >> 16; + uint32_t c_a = (a_sum - word_to_u16_limb(new_a, limb)) >> 16; + SHA2INNER_WRITE_ROUND(V, inner_row, work_vars.carry_e[j][limb], Fp(c_e)); + SHA2INNER_WRITE_ROUND(V, inner_row, work_vars.carry_a[j][limb], Fp(c_a)); + bitwise_lookup.add_range(c_a, c_e); + } + + if (row_in_block > 0) { + typename V::Word w_4 = w_schedule[t - 4]; + typename V::Word sig0_w3 = sha2::small_sig0(w_schedule[t - 3]); +#pragma unroll + for (int limb = 0; limb < static_cast(V::WORD_U16S); limb++) { + uint32_t val = + word_to_u16_limb(w_4, limb) + word_to_u16_limb(sig0_w3, limb); + SHA2INNER_WRITE_ROUND( + V, inner_row, schedule_helper.intermed_4[j][limb], Fp(val) + ); + } + if (j < V::ROUNDS_PER_ROW - 1) { + typename V::Word w3 = w_schedule[t - 3]; +#pragma unroll + for (int limb = 0; limb < static_cast(V::WORD_U16S); limb++) { + SHA2INNER_WRITE_ROUND( + V, + inner_row, + schedule_helper.w_3[j][limb], + Fp(word_to_u16_limb(w3, limb)) + ); + } + } + } + + h = g; + g = f; + f = e; + e = new_e; + d = c; + c = b; + b = a; + a = new_a; + } + } else { + uint32_t digest_row_idx = V::ROUND_ROWS; + SHA2_WRITE_DIGEST(V, row, request_id, Fp(record_idx)); + + RowSlice inner_row = row.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + helper.write_flags_digest(inner_row, digest_row_idx, global_block_idx + 1); + + for (uint32_t j = 0; j < V::ROUNDS_PER_ROW - 1; j++) { + typename V::Word val = w_schedule[row_in_block * V::ROUNDS_PER_ROW + j - 3]; + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + SHA2INNER_WRITE_DIGEST( + V, + inner_row, + schedule_helper.w_3[j][limb], + Fp(word_to_u16_limb(val, limb)) + ); + } + } + + typename V::Word final_hash[V::HASH_WORDS]; + for (int i = 0; i < static_cast(V::HASH_WORDS); i++) { + typename V::Word work_val = + (i == 0) + ? a + : (i == 1 + ? b + : (i == 2 + ? c + : (i == 3 ? d + : (i == 4 ? e : (i == 5 ? f : (i == 6 ? g : h)))))); + final_hash[i] = prev_hash[i] + work_val; + for (uint32_t limb = 0; limb < V::WORD_U8S; limb++) { + SHA2INNER_WRITE_DIGEST( + V, + inner_row, + final_hash[i][limb], + Fp(word_to_u8_limb(final_hash[i], limb)) + ); + } + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + SHA2INNER_WRITE_DIGEST( + V, + inner_row, + prev_hash[i][limb], + Fp(word_to_u16_limb(prev_hash[i], limb)) + ); + } + + for (uint32_t limb = 0; limb < V::WORD_U8S; limb += 2) { + uint32_t b0 = word_to_u8_limb(final_hash[i], limb); + uint32_t b1 = word_to_u8_limb(final_hash[i], limb + 1); + bitwise_lookup.add_range(b0, b1); + } + } + + for (uint32_t i = 0; i < V::ROUNDS_PER_ROW; i++) { + uint32_t a_idx = V::ROUNDS_PER_ROW - i - 1; + uint32_t e_idx = V::ROUNDS_PER_ROW - i + 3; + SHA2_WRITE_BITS( + V, inner_row, Sha2DigestCols, hash.a[i], next_block_prev_hash[a_idx] + ); + SHA2_WRITE_BITS( + V, inner_row, Sha2DigestCols, hash.e[i], next_block_prev_hash[e_idx] + ); + } + } + } + + for (uint32_t row_in_block = 0; row_in_block < V::ROWS_PER_BLOCK - 1; row_in_block++) { + uint32_t absolute_row = trace_start_row + row_in_block; + RowSlice local_row(trace + absolute_row, trace_height); + RowSlice next_row(trace + absolute_row + 1, trace_height); + RowSlice local_inner = local_row.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + RowSlice next_inner = next_row.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + + if (row_in_block > 0) { + for (uint32_t j = 0; j < V::ROUNDS_PER_ROW; j++) { + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + Fp intermed_4_val = local_inner[SHA2_COL_INDEX( + V, Sha2RoundCols, schedule_helper.intermed_4[j][limb] + )]; + if (row_in_block + 1 == V::ROWS_PER_BLOCK - 1) { + SHA2INNER_WRITE_DIGEST( + V, next_inner, schedule_helper.intermed_8[j][limb], intermed_4_val + ); + } else { + SHA2INNER_WRITE_ROUND( + V, next_inner, schedule_helper.intermed_8[j][limb], intermed_4_val + ); + } + + if (row_in_block >= 2 && row_in_block < V::ROWS_PER_BLOCK - 3) { + Fp intermed_8_val = local_inner[SHA2_COL_INDEX( + V, Sha2RoundCols, schedule_helper.intermed_8[j][limb] + )]; + SHA2INNER_WRITE_ROUND( + V, next_inner, schedule_helper.intermed_12[j][limb], intermed_8_val + ); + } + } + } + } + + if (row_in_block == V::ROWS_PER_BLOCK - 2) { + helper.generate_carry_ae(local_inner, next_inner); + helper.generate_intermed_4(local_inner, next_inner); + } + + if (row_in_block < V::MESSAGE_ROWS - 1) { + helper.generate_intermed_12(local_inner, next_inner); + } + } +} + +template +__global__ void sha2_fill_first_dummy_row(Fp *trace, size_t trace_height, size_t rows_used) { + uint32_t row_idx = rows_used; + + uint32_t digest_row = V::ROUND_ROWS; + if (digest_row >= trace_height) { + return; + } + + RowSlice digest(trace + digest_row, trace_height); + RowSlice digest_inner = digest.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + + typename V::Word prev_hash[V::HASH_WORDS]; + for (uint32_t i = 0; i < V::HASH_WORDS; i++) { + typename V::Word acc = 0; + for (uint32_t limb = 0; limb < V::WORD_U16S; limb++) { + size_t base = SHA2_COL_INDEX(V, Sha2DigestCols, prev_hash[i][limb]); + uint32_t limb_val = digest_inner[base].asUInt32(); + acc |= static_cast(limb_val) << (16 * limb); + } + prev_hash[i] = acc; + } + + RowSlice row(trace + row_idx, trace_height); + uint32_t intermed_4_offset = + SHA2_COL_INDEX(V, Sha2BlockHasherRoundCols, inner.schedule_helper.intermed_4); + uint32_t intermed_8_offset = + SHA2_COL_INDEX(V, Sha2BlockHasherRoundCols, inner.schedule_helper.intermed_8); + row.fill_zero(0, intermed_4_offset); + row.fill_zero(intermed_8_offset, Sha2Layout::WIDTH - intermed_8_offset); + SHA2_WRITE_ROUND(V, row, request_id, Fp::zero()); + RowSlice inner_row = row.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + + Sha2TraceHelper helper; + helper.generate_default_row(inner_row, prev_hash, nullptr, nullptr, trace_height); + + helper.generate_carry_ae(inner_row, inner_row); +} + +template +__global__ void sha2_second_pass_dependencies(Fp *trace, size_t trace_height, size_t total_blocks) { + uint32_t block_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (block_idx >= total_blocks) { + return; + } + + Sha2TraceHelper helper; + helper.generate_missing_cells(trace, trace_height, block_idx); +} + +template +__global__ void sha2_fill_invalid_rows( + Fp *d_trace, + size_t trace_height, + size_t rows_used, + typename V::Word *d_prev_hashes +) { + uint32_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t first_dummy_row_idx = rows_used; + // skip the first dummy row, since it is already filled + uint32_t row_idx = first_dummy_row_idx + thread_idx + 1; + if (row_idx >= trace_height) { + return; + } + + RowSlice first_dummy_row(d_trace + first_dummy_row_idx, trace_height); + RowSlice first_dummy_row_inner = first_dummy_row.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + + Fp *first_dummy_row_carry_a = + &first_dummy_row_inner[SHA2_COL_INDEX(V, Sha2RoundCols, work_vars.carry_a)]; + Fp *first_dummy_row_carry_e = + &first_dummy_row_inner[SHA2_COL_INDEX(V, Sha2RoundCols, work_vars.carry_e)]; + + RowSlice dst(d_trace + row_idx, trace_height); + dst.fill_zero(0, Sha2Layout::WIDTH); + RowSlice dst_inner = dst.slice_from(Sha2Layout::INNER_COLUMN_OFFSET); + + Sha2TraceHelper helper; + helper.generate_default_row( + dst_inner, &d_prev_hashes[0], first_dummy_row_carry_a, first_dummy_row_carry_e, trace_height + ); +} + +// ===== HOST LAUNCHER FUNCTIONS ===== + +template +int launch_sha2_hash_computation( + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + typename V::Word *d_prev_hashes, + uint32_t total_num_blocks +) { + auto [grid_size, block_size] = kernel_launch_params(num_records, 256); + + sha2_hash_computation<<>>( + d_records, num_records, d_record_offsets, d_prev_hashes, total_num_blocks + ); + + return CHECK_KERNEL(); +} + +template +int launch_sha2_first_pass_tracegen( + Fp *d_trace, + size_t trace_height, + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + uint32_t total_num_blocks, + typename V::Word *d_prev_hashes, + uint32_t ptr_max_bits, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t *d_bitwise_lookup, + uint32_t bitwise_num_bits, + uint32_t timestamp_max_bits +) { + auto [grid_size, block_size] = kernel_launch_params(total_num_blocks, 256); + + sha2_first_pass_tracegen<<>>( + d_trace, + trace_height, + d_records, + num_records, + d_record_offsets, + total_num_blocks, + d_prev_hashes, + ptr_max_bits, + d_range_checker, + range_checker_num_bins, + d_bitwise_lookup, + bitwise_num_bits, + timestamp_max_bits + ); + + return CHECK_KERNEL(); +} + +template +int launch_sha2_second_pass_dependencies(Fp *d_trace, size_t trace_height, size_t rows_used) { + size_t total_blocks = rows_used / V::ROWS_PER_BLOCK; + auto [grid_size, block_size] = kernel_launch_params(total_blocks, 256); + sha2_second_pass_dependencies + <<>>(d_trace, trace_height, total_blocks); + return CHECK_KERNEL(); +} + +template +int launch_sha2_fill_invalid_rows( + Fp *d_trace, + size_t trace_height, + size_t rows_used, + typename V::Word *d_prev_hashes +) { + sha2_fill_first_dummy_row<<<1, 1>>>(d_trace, trace_height, rows_used); + if (CHECK_KERNEL() != 0) { + return -1; + } + + auto [grid_size, block_size] = kernel_launch_params(trace_height - rows_used, 256); + sha2_fill_invalid_rows + <<>>(d_trace, trace_height, rows_used, d_prev_hashes); + return CHECK_KERNEL(); +} + +// Explicit instantiations for SHA-256 and SHA-512 +extern "C" { +int launch_sha256_hash_computation( + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + uint32_t *d_prev_hashes, + uint32_t total_num_blocks +) { + return launch_sha2_hash_computation( + d_records, + num_records, + d_record_offsets, + reinterpret_cast(d_prev_hashes), + total_num_blocks + ); +} + +int launch_sha512_hash_computation( + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + uint64_t *d_prev_hashes, + uint32_t total_num_blocks +) { + return launch_sha2_hash_computation( + d_records, num_records, d_record_offsets, d_prev_hashes, total_num_blocks + ); +} + +int launch_sha256_first_pass_tracegen( + Fp *d_trace, + size_t trace_height, + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + uint32_t total_num_blocks, + uint32_t *d_prev_hashes, + uint32_t ptr_max_bits, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t *d_bitwise_lookup, + uint32_t bitwise_num_bits, + uint32_t timestamp_max_bits +) { + return launch_sha2_first_pass_tracegen( + d_trace, + trace_height, + d_records, + num_records, + d_record_offsets, + total_num_blocks, + d_prev_hashes, + ptr_max_bits, + d_range_checker, + range_checker_num_bins, + d_bitwise_lookup, + bitwise_num_bits, + timestamp_max_bits + ); +} + +int launch_sha512_first_pass_tracegen( + Fp *d_trace, + size_t trace_height, + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + uint32_t total_num_blocks, + uint64_t *d_prev_hashes, + uint32_t ptr_max_bits, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t *d_bitwise_lookup, + uint32_t bitwise_num_bits, + uint32_t timestamp_max_bits +) { + return launch_sha2_first_pass_tracegen( + d_trace, + trace_height, + d_records, + num_records, + d_record_offsets, + total_num_blocks, + d_prev_hashes, + ptr_max_bits, + d_range_checker, + range_checker_num_bins, + d_bitwise_lookup, + bitwise_num_bits, + timestamp_max_bits + ); +} + +int launch_sha256_second_pass_dependencies(Fp *d_trace, size_t trace_height, size_t rows_used) { + return launch_sha2_second_pass_dependencies(d_trace, trace_height, rows_used); +} +int launch_sha512_second_pass_dependencies(Fp *d_trace, size_t trace_height, size_t rows_used) { + return launch_sha2_second_pass_dependencies(d_trace, trace_height, rows_used); +} +int launch_sha256_fill_invalid_rows( + Fp *d_trace, + size_t trace_height, + size_t rows_used, + uint32_t *d_prev_hashes +) { + return launch_sha2_fill_invalid_rows( + d_trace, trace_height, rows_used, d_prev_hashes + ); +} +int launch_sha512_fill_invalid_rows( + Fp *d_trace, + size_t trace_height, + size_t rows_used, + uint64_t *d_prev_hashes +) { + return launch_sha2_fill_invalid_rows( + d_trace, trace_height, rows_used, d_prev_hashes + ); +} +} diff --git a/extensions/sha2/circuit/cuda/src/sha2_main.cu b/extensions/sha2/circuit/cuda/src/sha2_main.cu new file mode 100644 index 0000000000..dde74c5826 --- /dev/null +++ b/extensions/sha2/circuit/cuda/src/sha2_main.cu @@ -0,0 +1,219 @@ +#include "block_hasher/variant.cuh" +#include "fp.h" +#include "launcher.cuh" +#include "main/columns.cuh" +#include "main/record.cuh" +#include "primitives/constants.h" +#include "primitives/histogram.cuh" +#include "primitives/trace_access.h" +#include "system/memory/controller.cuh" +#include "system/memory/offline_checker.cuh" +#include +#include +#include + +using namespace riscv; +using namespace sha2; + +// ===== MAIN CHIP KERNEL ===== +template +__global__ void sha2_main_tracegen( + Fp *trace, + size_t trace_height, + uint8_t *records, + size_t num_records, + size_t *record_offsets, + uint32_t ptr_max_bits, + uint32_t *range_checker_ptr, + uint32_t range_checker_num_bins, + uint32_t *bitwise_lookup_ptr, + uint32_t bitwise_num_bits, + uint32_t timestamp_max_bits +) { + uint32_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (row_idx >= trace_height) { + return; + } + + RowSlice row(trace + row_idx, trace_height); + row.fill_zero(0, sha2::Sha2MainLayout::WIDTH); + + if (row_idx >= num_records) { + return; + } + + Sha2RecordMut record(records + record_offsets[row_idx]); + Sha2RecordHeader *header = record.header; + + BitwiseOperationLookup bitwise_lookup(bitwise_lookup_ptr, bitwise_num_bits); + MemoryAuxColsFactory mem_helper( + VariableRangeChecker(range_checker_ptr, range_checker_num_bins), timestamp_max_bits + ); + + // Block cols + SHA2_MAIN_WRITE_BLOCK(V, row, request_id, Fp(row_idx)); + SHA2_MAIN_WRITE_ARRAY_BLOCK(V, row, message_bytes, record.message_bytes); + SHA2_MAIN_WRITE_ARRAY_BLOCK(V, row, prev_state, record.prev_state); + SHA2_MAIN_WRITE_ARRAY_BLOCK(V, row, new_state, record.new_state); + + // Instruction cols + SHA2_MAIN_WRITE_INSTR(V, row, is_enabled, Fp::one()); + SHA2_MAIN_WRITE_INSTR(V, row, from_state.timestamp, header->timestamp); + SHA2_MAIN_WRITE_INSTR(V, row, from_state.pc, header->from_pc); + SHA2_MAIN_WRITE_INSTR(V, row, dst_reg_ptr, header->dst_reg_ptr); + SHA2_MAIN_WRITE_INSTR(V, row, state_reg_ptr, header->state_reg_ptr); + SHA2_MAIN_WRITE_INSTR(V, row, input_reg_ptr, header->input_reg_ptr); + + uint8_t dst_ptr_bytes[RV32_REGISTER_NUM_LIMBS]; + uint8_t state_ptr_bytes[RV32_REGISTER_NUM_LIMBS]; + uint8_t input_ptr_bytes[RV32_REGISTER_NUM_LIMBS]; + memcpy(dst_ptr_bytes, &header->dst_ptr, sizeof(uint32_t)); + memcpy(state_ptr_bytes, &header->state_ptr, sizeof(uint32_t)); + memcpy(input_ptr_bytes, &header->input_ptr, sizeof(uint32_t)); + + SHA2_MAIN_WRITE_ARRAY_INSTR(V, row, dst_ptr_limbs, dst_ptr_bytes); + SHA2_MAIN_WRITE_ARRAY_INSTR(V, row, state_ptr_limbs, state_ptr_bytes); + SHA2_MAIN_WRITE_ARRAY_INSTR(V, row, input_ptr_limbs, input_ptr_bytes); + + // Range checks on top limbs + uint8_t needs_range_check[4] = { + dst_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1], + state_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1], + input_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1], + input_ptr_bytes[RV32_REGISTER_NUM_LIMBS - 1], + }; + uint32_t shift = 1u << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - ptr_max_bits); + for (int i = 0; i < 4; i += 2) { + bitwise_lookup.add_range( + static_cast(needs_range_check[i]) * shift, + static_cast(needs_range_check[i + 1]) * shift + ); + } + + // Memory aux + uint32_t timestamp = header->timestamp; + for (int i = 0; i < static_cast(sha2::SHA2_REGISTER_READS); i++) { + RowSlice reg_aux = SHA2_MAIN_SLICE_MEM(V, row, register_aux[i]); + RowSlice base_slice = reg_aux.slice_from(COL_INDEX(MemoryReadAuxCols, base)); + mem_helper.fill(base_slice, header->register_reads_aux[i].prev_timestamp, timestamp); + timestamp += 1; + } + + for (int i = 0; i < static_cast(V::BLOCK_READS); i++) { + RowSlice input_aux = SHA2_MAIN_SLICE_MEM(V, row, input_reads[i]); + RowSlice base_slice = input_aux.slice_from(COL_INDEX(MemoryReadAuxCols, base)); + mem_helper.fill(base_slice, record.input_reads_aux[i].prev_timestamp, timestamp); + timestamp += 1; + } + + for (int i = 0; i < static_cast(V::STATE_READS); i++) { + RowSlice state_aux = SHA2_MAIN_SLICE_MEM(V, row, state_reads[i]); + RowSlice base_slice = state_aux.slice_from(COL_INDEX(MemoryReadAuxCols, base)); + mem_helper.fill(base_slice, record.state_reads_aux[i].prev_timestamp, timestamp); + timestamp += 1; + } + + for (int i = 0; i < static_cast(V::STATE_WRITES); i++) { + RowSlice write_aux = SHA2_MAIN_SLICE_MEM(V, row, write_aux[i]); + write_aux.write_array( + COL_INDEX(MemoryWriteAuxCols, prev_data), + sha2::SHA2_WRITE_SIZE, + record.write_aux[i].prev_data + ); + RowSlice base_slice = write_aux.slice_from(COL_INDEX(MemoryWriteAuxCols, base)); + mem_helper.fill(base_slice, record.write_aux[i].prev_timestamp, timestamp); + timestamp += 1; + } +} + +// ===== HOST LAUNCHER FUNCTIONS ===== + +template +int launch_sha2_main_tracegen( + Fp *d_trace, + size_t trace_height, + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + uint32_t ptr_max_bits, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t *d_bitwise_lookup, + uint32_t bitwise_num_bits, + uint32_t timestamp_max_bits +) { + auto [grid_size, block_size] = kernel_launch_params(trace_height, 256); + sha2_main_tracegen<<>>( + d_trace, + trace_height, + d_records, + num_records, + d_record_offsets, + ptr_max_bits, + d_range_checker, + range_checker_num_bins, + d_bitwise_lookup, + bitwise_num_bits, + timestamp_max_bits + ); + return CHECK_KERNEL(); +} + +// Explicit instantiations for SHA-256 and SHA-512 +extern "C" { +int launch_sha256_main_tracegen( + Fp *d_trace, + size_t trace_height, + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + uint32_t ptr_max_bits, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t *d_bitwise_lookup, + uint32_t bitwise_num_bits, + uint32_t timestamp_max_bits +) { + return launch_sha2_main_tracegen( + d_trace, + trace_height, + d_records, + num_records, + d_record_offsets, + ptr_max_bits, + d_range_checker, + range_checker_num_bins, + d_bitwise_lookup, + bitwise_num_bits, + timestamp_max_bits + ); +} + +int launch_sha512_main_tracegen( + Fp *d_trace, + size_t trace_height, + uint8_t *d_records, + size_t num_records, + size_t *d_record_offsets, + uint32_t ptr_max_bits, + uint32_t *d_range_checker, + uint32_t range_checker_num_bins, + uint32_t *d_bitwise_lookup, + uint32_t bitwise_num_bits, + uint32_t timestamp_max_bits +) { + return launch_sha2_main_tracegen( + d_trace, + trace_height, + d_records, + num_records, + d_record_offsets, + ptr_max_bits, + d_range_checker, + range_checker_num_bins, + d_bitwise_lookup, + bitwise_num_bits, + timestamp_max_bits + ); +} +} diff --git a/extensions/sha2/circuit/src/cuda/cuda_abi.rs b/extensions/sha2/circuit/src/cuda/cuda_abi.rs new file mode 100644 index 0000000000..a7ba91bc5e --- /dev/null +++ b/extensions/sha2/circuit/src/cuda/cuda_abi.rs @@ -0,0 +1,326 @@ +#![allow(clippy::missing_safety_doc)] + +use openvm_cuda_backend::prelude::F; +use openvm_cuda_common::{d_buffer::DeviceBuffer, error::CudaError}; + +pub mod sha256 { + use super::*; + + extern "C" { + fn launch_sha256_main_tracegen( + d_trace: *mut F, + trace_height: usize, + d_records: *const u8, + num_records: usize, + d_record_offsets: *const usize, + ptr_max_bits: u32, + d_range_checker: *mut u32, + range_checker_num_bins: u32, + d_bitwise_lookup: *mut u32, + bitwise_num_bits: u32, + timestamp_max_bits: u32, + ) -> i32; + + fn launch_sha256_hash_computation( + d_records: *const u8, + num_records: usize, + d_record_offsets: *const usize, + d_prev_hashes: *mut u32, + total_num_blocks: u32, + ) -> i32; + + fn launch_sha256_first_pass_tracegen( + d_trace: *mut F, + trace_height: usize, + d_records: *const u8, + num_records: usize, + d_record_offsets: *const usize, + total_num_blocks: u32, + d_prev_hashes: *const u32, + ptr_max_bits: u32, + d_range_checker: *mut u32, + range_checker_num_bins: u32, + d_bitwise_lookup: *mut u32, + bitwise_num_bits: u32, + timestamp_max_bits: u32, + ) -> i32; + + fn launch_sha256_second_pass_dependencies( + d_trace: *mut F, + trace_height: usize, + rows_used: usize, + ) -> i32; + + fn launch_sha256_fill_invalid_rows( + d_trace: *mut F, + trace_height: usize, + rows_used: usize, + d_prev_hashes: *const u32, + ) -> i32; + } + + #[allow(clippy::too_many_arguments)] + pub unsafe fn sha256_main_tracegen( + d_trace: &DeviceBuffer, + height: usize, + d_records: &DeviceBuffer, + num_records: usize, + d_record_offsets: &DeviceBuffer, + ptr_max_bits: u32, + d_range_checker: &DeviceBuffer, + d_bitwise_lookup: &DeviceBuffer, + bitwise_num_bits: u32, + timestamp_max_bits: u32, + ) -> Result<(), CudaError> { + let result = launch_sha256_main_tracegen( + d_trace.as_mut_ptr(), + height, + d_records.as_ptr(), + num_records, + d_record_offsets.as_ptr(), + ptr_max_bits, + d_range_checker.as_mut_ptr() as *mut u32, + d_range_checker.len() as u32, + d_bitwise_lookup.as_mut_ptr() as *mut u32, + bitwise_num_bits, + timestamp_max_bits, + ); + CudaError::from_result(result) + } + + pub unsafe fn sha256_hash_computation( + d_records: &DeviceBuffer, + num_records: usize, + d_record_offsets: &DeviceBuffer, + d_prev_hashes: &DeviceBuffer, + num_blocks: u32, + ) -> Result<(), CudaError> { + let result = launch_sha256_hash_computation( + d_records.as_ptr(), + num_records, + d_record_offsets.as_ptr(), + d_prev_hashes.as_mut_ptr(), + num_blocks, + ); + CudaError::from_result(result) + } + + #[allow(clippy::too_many_arguments)] + pub unsafe fn sha256_first_pass_tracegen( + d_trace: &DeviceBuffer, + height: usize, + d_records: &DeviceBuffer, + num_records: usize, + d_record_offsets: &DeviceBuffer, + total_num_blocks: u32, + d_prev_hashes: &DeviceBuffer, + ptr_max_bits: u32, + d_range_checker: &DeviceBuffer, + d_bitwise_lookup: &DeviceBuffer, + bitwise_num_bits: u32, + timestamp_max_bits: u32, + ) -> Result<(), CudaError> { + let result = launch_sha256_first_pass_tracegen( + d_trace.as_mut_ptr(), + height, + d_records.as_ptr(), + num_records, + d_record_offsets.as_ptr(), + total_num_blocks, + d_prev_hashes.as_ptr(), + ptr_max_bits, + d_range_checker.as_mut_ptr() as *mut u32, + d_range_checker.len() as u32, + d_bitwise_lookup.as_mut_ptr() as *mut u32, + bitwise_num_bits, + timestamp_max_bits, + ); + CudaError::from_result(result) + } + + pub unsafe fn sha256_second_pass_dependencies( + d_trace: &DeviceBuffer, + height: usize, + rows_used: usize, + ) -> Result<(), CudaError> { + let result = + launch_sha256_second_pass_dependencies(d_trace.as_mut_ptr(), height, rows_used); + CudaError::from_result(result) + } + + pub unsafe fn sha256_fill_invalid_rows( + d_trace: &DeviceBuffer, + height: usize, + rows_used: usize, + d_prev_hashes: &DeviceBuffer, + ) -> Result<(), CudaError> { + let result = launch_sha256_fill_invalid_rows( + d_trace.as_mut_ptr(), + height, + rows_used, + d_prev_hashes.as_ptr(), + ); + CudaError::from_result(result) + } +} + +pub mod sha512 { + use super::*; + + extern "C" { + fn launch_sha512_main_tracegen( + d_trace: *mut F, + trace_height: usize, + d_records: *const u8, + num_records: usize, + d_record_offsets: *const usize, + ptr_max_bits: u32, + d_range_checker: *mut u32, + range_checker_num_bins: u32, + d_bitwise_lookup: *mut u32, + bitwise_num_bits: u32, + timestamp_max_bits: u32, + ) -> i32; + + fn launch_sha512_hash_computation( + d_records: *const u8, + num_records: usize, + d_record_offsets: *const usize, + d_prev_hashes: *mut u64, + total_num_blocks: u32, + ) -> i32; + + fn launch_sha512_first_pass_tracegen( + d_trace: *mut F, + trace_height: usize, + d_records: *const u8, + num_records: usize, + d_record_offsets: *const usize, + total_num_blocks: u32, + d_prev_hashes: *const u64, + ptr_max_bits: u32, + d_range_checker: *mut u32, + range_checker_num_bins: u32, + d_bitwise_lookup: *mut u32, + bitwise_num_bits: u32, + timestamp_max_bits: u32, + ) -> i32; + + fn launch_sha512_second_pass_dependencies( + d_trace: *mut F, + trace_height: usize, + rows_used: usize, + ) -> i32; + + fn launch_sha512_fill_invalid_rows( + d_trace: *mut F, + trace_height: usize, + rows_used: usize, + d_prev_hashes: *const u64, + ) -> i32; + } + + #[allow(clippy::too_many_arguments)] + pub unsafe fn sha512_main_tracegen( + d_trace: &DeviceBuffer, + height: usize, + d_records: &DeviceBuffer, + num_records: usize, + d_record_offsets: &DeviceBuffer, + ptr_max_bits: u32, + d_range_checker: &DeviceBuffer, + d_bitwise_lookup: &DeviceBuffer, + bitwise_num_bits: u32, + timestamp_max_bits: u32, + ) -> Result<(), CudaError> { + let result = launch_sha512_main_tracegen( + d_trace.as_mut_ptr(), + height, + d_records.as_ptr(), + num_records, + d_record_offsets.as_ptr(), + ptr_max_bits, + d_range_checker.as_mut_ptr() as *mut u32, + d_range_checker.len() as u32, + d_bitwise_lookup.as_mut_ptr() as *mut u32, + bitwise_num_bits, + timestamp_max_bits, + ); + CudaError::from_result(result) + } + + pub unsafe fn sha512_hash_computation( + d_records: &DeviceBuffer, + num_records: usize, + d_record_offsets: &DeviceBuffer, + d_prev_hashes: &DeviceBuffer, + num_blocks: u32, + ) -> Result<(), CudaError> { + let result = launch_sha512_hash_computation( + d_records.as_ptr(), + num_records, + d_record_offsets.as_ptr(), + d_prev_hashes.as_mut_ptr(), + num_blocks, + ); + CudaError::from_result(result) + } + + #[allow(clippy::too_many_arguments)] + pub unsafe fn sha512_first_pass_tracegen( + d_trace: &DeviceBuffer, + height: usize, + d_records: &DeviceBuffer, + num_records: usize, + d_record_offsets: &DeviceBuffer, + total_num_blocks: u32, + d_prev_hashes: &DeviceBuffer, + ptr_max_bits: u32, + d_range_checker: &DeviceBuffer, + d_bitwise_lookup: &DeviceBuffer, + bitwise_num_bits: u32, + timestamp_max_bits: u32, + ) -> Result<(), CudaError> { + let result = launch_sha512_first_pass_tracegen( + d_trace.as_mut_ptr(), + height, + d_records.as_ptr(), + num_records, + d_record_offsets.as_ptr(), + total_num_blocks, + d_prev_hashes.as_ptr(), + ptr_max_bits, + d_range_checker.as_mut_ptr() as *mut u32, + d_range_checker.len() as u32, + d_bitwise_lookup.as_mut_ptr() as *mut u32, + bitwise_num_bits, + timestamp_max_bits, + ); + CudaError::from_result(result) + } + + pub unsafe fn sha512_second_pass_dependencies( + d_trace: &DeviceBuffer, + height: usize, + rows_used: usize, + ) -> Result<(), CudaError> { + let result = + launch_sha512_second_pass_dependencies(d_trace.as_mut_ptr(), height, rows_used); + CudaError::from_result(result) + } + + pub unsafe fn sha512_fill_invalid_rows( + d_trace: &DeviceBuffer, + height: usize, + rows_used: usize, + d_prev_hashes: &DeviceBuffer, + ) -> Result<(), CudaError> { + let result = launch_sha512_fill_invalid_rows( + d_trace.as_mut_ptr(), + height, + rows_used, + d_prev_hashes.as_ptr(), + ); + CudaError::from_result(result) + } +} diff --git a/extensions/sha2/circuit/src/cuda/mod.rs b/extensions/sha2/circuit/src/cuda/mod.rs new file mode 100644 index 0000000000..98380e446c --- /dev/null +++ b/extensions/sha2/circuit/src/cuda/mod.rs @@ -0,0 +1,289 @@ +use std::{ + marker::PhantomData, + sync::{Arc, Mutex}, +}; + +use openvm_circuit::{ + arch::{DenseRecordArena, RecordSeeker}, + utils::next_power_of_two_or_zero, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU, +}; +use openvm_cuda_backend::{ + base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend, +}; +use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer}; +use openvm_sha2_air::{Sha256Config, Sha2Variant, Sha512Config}; +use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; + +use crate::{Sha2Config, Sha2RecordLayout, Sha2RecordMut}; + +mod cuda_abi; + +pub struct Sha2SharedRecordsGpu { + d_records: DeviceBuffer, + d_record_offsets: DeviceBuffer, + num_records: usize, +} + +pub struct Sha2MainChipGpu { + records: Arc>>, + range_checker: Arc, + bitwise_lookup: Arc>, + pointer_max_bits: u32, + timestamp_max_bits: u32, + _marker: PhantomData, +} + +impl Sha2MainChipGpu { + pub fn new( + records: Arc>>, + range_checker: Arc, + bitwise_lookup: Arc>, + pointer_max_bits: u32, + timestamp_max_bits: u32, + ) -> Self { + Self { + records, + range_checker, + bitwise_lookup, + pointer_max_bits, + timestamp_max_bits, + _marker: PhantomData, + } + } +} + +impl Chip for Sha2MainChipGpu +where + C: Sha2Config, +{ + fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext { + let records = arena.allocated_mut(); + if records.is_empty() { + return get_empty_air_proving_ctx::(); + } + + let mut record_offsets = Vec::::new(); + let mut offset = 0usize; + while offset < records.len() { + record_offsets.push(offset); + let _record = + RecordSeeker::::get_record_at( + &mut offset, + records, + ); + } + + let num_records = record_offsets.len(); + let trace_height = next_power_of_two_or_zero(num_records); + let trace = DeviceMatrix::::with_capacity(trace_height, C::MAIN_CHIP_WIDTH); + + let d_records = records.to_device().unwrap(); + let d_record_offsets = record_offsets.to_device().unwrap(); + + unsafe { + match C::VARIANT { + Sha2Variant::Sha256 => { + cuda_abi::sha256::sha256_main_tracegen( + trace.buffer(), + trace_height, + &d_records, + num_records, + &d_record_offsets, + self.pointer_max_bits, + &self.range_checker.count, + &self.bitwise_lookup.count, + 8, + self.timestamp_max_bits, + ) + .unwrap(); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + cuda_abi::sha512::sha512_main_tracegen( + trace.buffer(), + trace_height, + &d_records, + num_records, + &d_record_offsets, + self.pointer_max_bits, + &self.range_checker.count, + &self.bitwise_lookup.count, + 8, + self.timestamp_max_bits, + ) + .unwrap(); + } + } + } + + // Pass the records to Sha2BlockHasherChip + *self.records.lock().unwrap() = Some(Sha2SharedRecordsGpu { + d_records, + d_record_offsets, + num_records, + }); + + AirProvingContext::simple_no_pis(trace) + } +} + +/// Generic hybrid GPU wrapper that reuses CPU block-hasher tracegen. +pub struct Sha2BlockHasherChipGpu { + records: Arc>>, + range_checker: Arc, + bitwise_lookup: Arc>, + pointer_max_bits: u32, + timestamp_max_bits: u32, + _marker: PhantomData, +} + +impl Chip for Sha2BlockHasherChipGpu +where + C: Sha2Config, +{ + /// We don't use the record arena associated with this chip. Instead, we will use the record + /// arena provided by the main chip, which will be passed to this chip after the main chip's + /// tracegen is done. + fn generate_proving_ctx(&self, _: R) -> AirProvingContext { + let mut records = self.records.lock().unwrap(); + if records.is_none() { + return get_empty_air_proving_ctx::(); + } + + let Sha2SharedRecordsGpu { + d_records, + d_record_offsets, + num_records, + } = records.take().unwrap(); + + if num_records == 0 { + return get_empty_air_proving_ctx::(); + } + + let rows_used = num_records * C::ROWS_PER_BLOCK; + let trace_height = next_power_of_two_or_zero(rows_used); + let trace = DeviceMatrix::::with_capacity(trace_height, C::BLOCK_HASHER_WIDTH); + + // one record per block, right now + let num_blocks: u32 = num_records as u32; + + // prev_hashes + unsafe { + match C::VARIANT { + Sha2Variant::Sha256 => { + let d_prev_hashes = + DeviceBuffer::::with_capacity(num_blocks as usize * C::HASH_WORDS); + cuda_abi::sha256::sha256_hash_computation( + &d_records, + num_records, + &d_record_offsets, + &d_prev_hashes, + num_blocks, + ) + .unwrap(); + + cuda_abi::sha256::sha256_first_pass_tracegen( + trace.buffer(), + trace_height, + &d_records, + num_records, + &d_record_offsets, + num_blocks, + &d_prev_hashes, + self.pointer_max_bits, + &self.range_checker.count, + &self.bitwise_lookup.count, + 8, + self.timestamp_max_bits, + ) + .unwrap(); + + cuda_abi::sha256::sha256_fill_invalid_rows( + trace.buffer(), + trace_height, + rows_used, + &d_prev_hashes, + ) + .unwrap(); + cuda_abi::sha256::sha256_second_pass_dependencies( + trace.buffer(), + trace_height, + rows_used, + ) + .unwrap(); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let d_prev_hashes = + DeviceBuffer::::with_capacity(num_blocks as usize * C::HASH_WORDS); + cuda_abi::sha512::sha512_hash_computation( + &d_records, + num_records, + &d_record_offsets, + &d_prev_hashes, + num_blocks, + ) + .unwrap(); + + cuda_abi::sha512::sha512_first_pass_tracegen( + trace.buffer(), + trace_height, + &d_records, + num_records, + &d_record_offsets, + num_blocks, + &d_prev_hashes, + self.pointer_max_bits, + &self.range_checker.count, + &self.bitwise_lookup.count, + 8, + self.timestamp_max_bits, + ) + .unwrap(); + + cuda_abi::sha512::sha512_fill_invalid_rows( + trace.buffer(), + trace_height, + rows_used, + &d_prev_hashes, + ) + .unwrap(); + cuda_abi::sha512::sha512_second_pass_dependencies( + trace.buffer(), + trace_height, + rows_used, + ) + .unwrap(); + } + } + } + + AirProvingContext::simple_no_pis(trace) + } +} + +impl Sha2BlockHasherChipGpu { + pub fn new( + records: Arc>>, + range_checker: Arc, + bitwise_lookup: Arc>, + pointer_max_bits: u32, + timestamp_max_bits: u32, + ) -> Self { + Self { + records, + range_checker, + bitwise_lookup, + pointer_max_bits, + timestamp_max_bits, + _marker: PhantomData, + } + } +} + +// Convenience aliases for the common SHA-2 variants. +pub type Sha256VmChipGpu = Sha2MainChipGpu; +pub type Sha256BlockHasherChipGpu = Sha2BlockHasherChipGpu; +pub type Sha512VmChipGpu = Sha2MainChipGpu; +pub type Sha512BlockHasherChipGpu = Sha2BlockHasherChipGpu; diff --git a/extensions/sha2/circuit/src/extension/cuda.rs b/extensions/sha2/circuit/src/extension/cuda.rs new file mode 100644 index 0000000000..972bbf4c2c --- /dev/null +++ b/extensions/sha2/circuit/src/extension/cuda.rs @@ -0,0 +1,117 @@ +use openvm_circuit::{ + arch::{ + AirInventory, ChipInventory, ChipInventoryError, DenseRecordArena, VmBuilder, + VmChipComplex, VmProverExtension, + }, + system::cuda::{ + extensions::{ + get_inventory_range_checker, get_or_create_bitwise_op_lookup, SystemGpuBuilder, + }, + SystemChipInventoryGPU, + }, +}; +use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend}; +use openvm_rv32im_circuit::Rv32ImGpuProverExt; +use openvm_sha2_air::{Sha256Config, Sha512Config}; +use openvm_stark_sdk::{config::baby_bear_poseidon2::BabyBearPoseidon2Config, engine::StarkEngine}; + +use super::*; +use crate::{ + cuda::{Sha2BlockHasherChipGpu, Sha2MainChipGpu}, + Sha2BlockHasherVmAir, Sha2MainAir, +}; + +pub struct Sha2GpuProverExt; + +impl VmProverExtension for Sha2GpuProverExt { + fn extend_prover( + &self, + _: &Sha2, + inventory: &mut ChipInventory, + ) -> Result<(), ChipInventoryError> { + let pointer_max_bits = inventory.airs().pointer_max_bits(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + + let range_checker_gpu = get_inventory_range_checker(inventory); + let bitwise_gpu = get_or_create_bitwise_op_lookup(inventory)?; + + // SHA-256 + inventory.next_air::>()?; + let sha256_shared_records = Arc::new(Mutex::new(None)); + let sha256_block_gpu = Sha2BlockHasherChipGpu::::new( + sha256_shared_records.clone(), + range_checker_gpu.clone(), + bitwise_gpu.clone(), + pointer_max_bits as u32, + timestamp_max_bits as u32, + ); + inventory.add_periphery_chip(sha256_block_gpu); + + inventory.next_air::>()?; + let sha256_main_gpu = Sha2MainChipGpu::::new( + sha256_shared_records, + range_checker_gpu.clone(), + bitwise_gpu.clone(), + pointer_max_bits as u32, + timestamp_max_bits as u32, + ); + inventory.add_executor_chip(sha256_main_gpu); + + // SHA-512 (also covers SHA-384 constraints) + inventory.next_air::>()?; + let sha512_shared_records = Arc::new(Mutex::new(None)); + let sha512_block_gpu = Sha2BlockHasherChipGpu::::new( + sha512_shared_records.clone(), + range_checker_gpu.clone(), + bitwise_gpu.clone(), + pointer_max_bits as u32, + timestamp_max_bits as u32, + ); + inventory.add_periphery_chip(sha512_block_gpu); + + inventory.next_air::>()?; + let sha512_main_gpu = Sha2MainChipGpu::::new( + sha512_shared_records, + range_checker_gpu, + bitwise_gpu, + pointer_max_bits as u32, + timestamp_max_bits as u32, + ); + inventory.add_executor_chip(sha512_main_gpu); + + Ok(()) + } +} + +pub struct Sha2Rv32GpuBuilder; + +type E = GpuBabyBearPoseidon2Engine; + +impl VmBuilder for Sha2Rv32GpuBuilder { + type VmConfig = Sha2Rv32Config; + type SystemChipInventory = SystemChipInventoryGPU; + type RecordArena = DenseRecordArena; + + fn create_chip_complex( + &self, + config: &Sha2Rv32Config, + circuit: AirInventory<::SC>, + ) -> Result< + VmChipComplex< + ::SC, + Self::RecordArena, + ::PB, + Self::SystemChipInventory, + >, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemGpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImGpuProverExt, &config.rv32i, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImGpuProverExt, &config.rv32m, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImGpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover(&Sha2GpuProverExt, &config.sha2, inventory)?; + Ok(chip_complex) + } +} diff --git a/extensions/sha2/circuit/src/extension/mod.rs b/extensions/sha2/circuit/src/extension/mod.rs new file mode 100644 index 0000000000..2429ea74b4 --- /dev/null +++ b/extensions/sha2/circuit/src/extension/mod.rs @@ -0,0 +1,281 @@ +use std::{ + result::Result, + sync::{Arc, Mutex}, +}; + +use derive_more::derive::From; +use openvm_circuit::{ + arch::{ + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, + ExecutorInventoryBuilder, ExecutorInventoryError, InitFileGenerator, MatrixRecordArena, + RowMajorMatrixArena, SystemConfig, VmBuilder, VmChipComplex, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, + }, + system::{memory::SharedMemoryHelper, SystemChipInventory, SystemCpuBuilder, SystemExecutor}, +}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor, VmConfig}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, +}; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, +}; +use openvm_sha2_air::{Sha256Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; +use serde::{Deserialize, Serialize}; + +use crate::{Sha2BlockHasherChip, Sha2BlockHasherVmAir, Sha2MainAir, Sha2MainChip, Sha2VmExecutor}; + +cfg_if::cfg_if! { + if #[cfg(feature = "cuda")] { + mod cuda; + pub use self::cuda::*; + pub use self::cuda::Sha2GpuProverExt as Sha2ProverExt; + pub use self::cuda::Sha2Rv32GpuBuilder as Sha2Rv32Builder; + } else { + pub use self::Sha2CpuProverExt as Sha2ProverExt; + pub use self::Sha2Rv32CpuBuilder as Sha2Rv32Builder; + } +} + +#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] +pub struct Sha2Rv32Config { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub sha2: Sha2, +} + +impl Default for Sha2Rv32Config { + fn default() -> Self { + Self { + system: SystemConfig::default(), + rv32i: Rv32I, + rv32m: Rv32M::default(), + io: Rv32Io, + sha2: Sha2, + } + } +} + +// Default implementation uses no init file +impl InitFileGenerator for Sha2Rv32Config {} + +#[derive(Clone)] +pub struct Sha2Rv32CpuBuilder; + +impl VmBuilder for Sha2Rv32CpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Sha2Rv32Config; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Sha2Rv32Config, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha2, inventory)?; + Ok(chip_complex) + } +} + +// =================================== VM Extension Implementation ================================= +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub struct Sha2; + +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +#[cfg_attr( + feature = "aot", + derive( + openvm_circuit_derive::AotExecutor, + openvm_circuit_derive::AotMeteredExecutor + ) +)] +pub enum Sha2Executor { + Sha256(Sha2VmExecutor), + Sha512(Sha2VmExecutor), +} + +impl VmExecutionExtension for Sha2 { + type Executor = Sha2Executor; + + fn extend_execution( + &self, + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + + let sha256_executor = + Sha2VmExecutor::::new(Rv32Sha2Opcode::CLASS_OFFSET, pointer_max_bits); + inventory.add_executor(sha256_executor, [Rv32Sha2Opcode::SHA256.global_opcode()])?; + + let sha512_executor = + Sha2VmExecutor::::new(Rv32Sha2Opcode::CLASS_OFFSET, pointer_max_bits); + inventory.add_executor(sha512_executor, [Rv32Sha2Opcode::SHA512.global_opcode()])?; + + Ok(()) + } +} + +impl VmCircuitExtension for Sha2 { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let bitwise_lu = { + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } + }; + + // this bus will be used for communication between the block hasher chip and the main chip + let sha2_bus_index = inventory.new_bus_idx(); + // the sha2 subair needs its own bus for self-interactions + let subair_bus_index = inventory.new_bus_idx(); + + // SHA-256 + let sha256_block_hasher_air = + Sha2BlockHasherVmAir::::new(bitwise_lu, subair_bus_index, sha2_bus_index); + inventory.add_air(sha256_block_hasher_air); + + let sha256_main_air = Sha2MainAir::::new( + inventory.system().port(), + bitwise_lu, + inventory.pointer_max_bits(), + sha2_bus_index, + Rv32Sha2Opcode::CLASS_OFFSET, + ); + inventory.add_air(sha256_main_air); + + // SHA-512 + let sha512_block_hasher_air = + Sha2BlockHasherVmAir::::new(bitwise_lu, subair_bus_index, sha2_bus_index); + inventory.add_air(sha512_block_hasher_air); + + let sha512_main_air = Sha2MainAir::::new( + inventory.system().port(), + bitwise_lu, + inventory.pointer_max_bits(), + sha2_bus_index, + Rv32Sha2Opcode::CLASS_OFFSET, + ); + inventory.add_air(sha512_main_air); + + Ok(()) + } +} + +pub struct Sha2CpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for Sha2CpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena> + Send + Sync + 'static, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &Sha2, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + + // We must add each block hasher chip before the main chip to ensure that main chip does its + // tracegen first, because the main chip will pass the records to the block hasher chip + // after its tracegen is done. + + // SHA-256 + inventory.next_air::>()?; + // shared records between the main chip and the block hasher chip + let records = Arc::new(Mutex::new(None)); + let sha256_block_hasher_chip = Sha2BlockHasherChip::, Sha256Config>::new( + bitwise_lu.clone(), + pointer_max_bits, + mem_helper.clone(), + records.clone(), + ); + inventory.add_periphery_chip(sha256_block_hasher_chip); + + inventory.next_air::>()?; + let sha256_main_chip = Sha2MainChip::, Sha256Config>::new( + records, + bitwise_lu.clone(), + pointer_max_bits, + mem_helper.clone(), + ); + inventory.add_executor_chip(sha256_main_chip); + + // SHA-512 + inventory.next_air::>()?; + // shared records between the main chip and the block hasher chip + let records = Arc::new(Mutex::new(None)); + let sha512_block_hasher_chip = Sha2BlockHasherChip::, Sha512Config>::new( + bitwise_lu.clone(), + pointer_max_bits, + mem_helper.clone(), + records.clone(), + ); + inventory.add_periphery_chip(sha512_block_hasher_chip); + + inventory.next_air::>()?; + let sha512_main_chip = Sha2MainChip::, Sha512Config>::new( + records, + bitwise_lu.clone(), + pointer_max_bits, + mem_helper.clone(), + ); + inventory.add_executor_chip(sha512_main_chip); + + Ok(()) + } +} diff --git a/extensions/sha2/circuit/src/lib.rs b/extensions/sha2/circuit/src/lib.rs new file mode 100644 index 0000000000..226cde11e7 --- /dev/null +++ b/extensions/sha2/circuit/src/lib.rs @@ -0,0 +1,15 @@ +#![cfg_attr(feature = "tco", allow(incomplete_features))] +#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] +#![cfg_attr(feature = "tco", allow(internal_features))] +#![cfg_attr(feature = "tco", feature(core_intrinsics))] + +mod sha2_chips; +pub use sha2_chips::*; + +mod extension; +pub use extension::*; + +#[cfg(feature = "cuda")] +mod cuda; +#[cfg(feature = "cuda")] +pub use cuda::*; diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/air.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/air.rs new file mode 100644 index 0000000000..84ae359dfa --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/air.rs @@ -0,0 +1,156 @@ +use openvm_circuit_primitives::{bitwise_op_lookup::BitwiseOperationLookupBus, SubAir}; +use openvm_sha2_air::{compose, Sha2BlockHasherSubAir}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use crate::{ + MessageType, Sha2BlockHasherDigestColsRef, Sha2BlockHasherRoundColsRef, + Sha2BlockHasherVmConfig, INNER_OFFSET, +}; + +pub struct Sha2BlockHasherVmAir { + pub inner: Sha2BlockHasherSubAir, + pub sha2_bus: PermutationCheckBus, +} + +impl Sha2BlockHasherVmAir { + pub fn new( + bitwise_lookup_bus: BitwiseOperationLookupBus, + inner_bus_idx: BusIndex, + sha2_bus_idx: BusIndex, + ) -> Self { + Self { + inner: Sha2BlockHasherSubAir::new(bitwise_lookup_bus, inner_bus_idx), + sha2_bus: PermutationCheckBus::new(sha2_bus_idx), + } + } +} + +impl BaseAirWithPublicValues for Sha2BlockHasherVmAir {} +impl PartitionedBaseAir for Sha2BlockHasherVmAir {} +impl BaseAir for Sha2BlockHasherVmAir { + fn width(&self) -> usize { + C::BLOCK_HASHER_WIDTH + } +} + +impl Air for Sha2BlockHasherVmAir { + fn eval(&self, builder: &mut AB) { + self.inner.eval(builder, INNER_OFFSET); + self.eval_interactions(builder); + self.eval_request_id(builder); + } +} + +impl Sha2BlockHasherVmAir { + fn eval_interactions(&self, builder: &mut AB) { + let main = builder.main(); + let local_slice = main.row_slice(0); + let next_slice = main.row_slice(1); + + let local = Sha2BlockHasherDigestColsRef::::from::( + &local_slice[..C::BLOCK_HASHER_DIGEST_WIDTH], + ); + + // Receive (STATE, request_id, prev_state_as_u16s, new_state) on the sha2 bus + self.sha2_bus.receive( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::State as u8), + (*local.request_id).into(), + ] + .into_iter() + .chain(local.inner.prev_hash.flatten().map(|x| (*x).into())) + .chain(local.inner.final_hash.flatten().map(|x| (*x).into())), + *local.inner.flags.is_digest_row, + ); + + let local = Sha2BlockHasherRoundColsRef::::from::( + &local_slice[..C::BLOCK_HASHER_ROUND_WIDTH], + ); + let next = Sha2BlockHasherRoundColsRef::::from::( + &next_slice[..C::BLOCK_HASHER_ROUND_WIDTH], + ); + + let is_local_first_row = self + .inner + .row_idx_encoder + .contains_flag::(local.inner.flags.row_idx.to_slice().unwrap(), &[0]); + + // Taken from old Sha256VmChip: + // https://github.com/openvm-org/openvm/blob/c2e376e6059c8bbf206736cf01d04cda43dfc42d/extensions/sha256/circuit/src/sha256_chip/air.rs#L310C1-L318C1 + let get_ith_byte = |i: usize, cols: &Sha2BlockHasherRoundColsRef| { + debug_assert!(i < C::WORD_U8S * C::ROUNDS_PER_ROW); + let row_idx = i / C::WORD_U8S; + let word: Vec = cols + .inner + .message_schedule + .w + .row(row_idx) + .into_iter() + .copied() + .collect::>(); + // Need to reverse the byte order to match the endianness of the memory + let byte_idx = C::WORD_U8S - i % C::WORD_U8S - 1; + compose::(&word[byte_idx * 8..(byte_idx + 1) * 8], 1) + }; + + let local_message = (0..C::WORD_U8S * C::ROUNDS_PER_ROW).map(|i| get_ith_byte(i, &local)); + let next_message = (0..C::WORD_U8S * C::ROUNDS_PER_ROW).map(|i| get_ith_byte(i, &next)); + + // Receive (MESSAGE_1, request_id, first_half_of_message) on the sha2 bus + self.sha2_bus.receive( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::Message1 as u8), + (*local.request_id).into(), + ] + .into_iter() + .chain(local_message.clone()) + .chain(next_message.clone()), + is_local_first_row * local.inner.flags.is_not_padding_row(), + ); + + let is_local_third_row = self + .inner + .row_idx_encoder + .contains_flag::(local.inner.flags.row_idx.to_slice().unwrap(), &[2]); + + // Send (MESSAGE_2, request_id, second_half_of_message) to the sha2 bus + self.sha2_bus.receive( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::Message2 as u8), + (*local.request_id).into(), + ] + .into_iter() + .chain(local_message) + .chain(next_message), + is_local_third_row * local.inner.flags.is_not_padding_row(), + ); + } + + fn eval_request_id(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let next = main.row_slice(1); + + // doesn't matter if we use round or digest cols here, since we only access + // request_id and inner.flags.is_last block, which are common to both + // field + let local = + Sha2BlockHasherRoundColsRef::::from::(&local[..C::BLOCK_HASHER_WIDTH]); + let next = + Sha2BlockHasherRoundColsRef::::from::(&next[..C::BLOCK_HASHER_WIDTH]); + + builder + .when_transition() + .when(*local.inner.flags.is_round_row) + .assert_eq(*next.request_id, *local.request_id); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/columns.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/columns.rs new file mode 100644 index 0000000000..13f7a70a76 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/columns.rs @@ -0,0 +1,59 @@ +use openvm_circuit_primitives_derive::ColsRef; +use openvm_sha2_air::{ + Sha2BlockHasherSubairConfig, Sha2DigestCols, Sha2DigestColsRef, Sha2DigestColsRefMut, + Sha2RoundCols, Sha2RoundColsRef, Sha2RoundColsRefMut, +}; + +// offset in the columns struct where the inner column start +pub const INNER_OFFSET: usize = 1; + +// Just adding request_id to both columns structs +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2BlockHasherRoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub request_id: T, + pub inner: Sha2RoundCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2BlockHasherDigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub request_id: T, + pub inner: Sha2DigestCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + HASH_WORDS, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, +} diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/config.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/config.rs new file mode 100644 index 0000000000..14523a06f4 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/config.rs @@ -0,0 +1,52 @@ +use openvm_sha2_air::{Sha256Config, Sha2BlockHasherSubairConfig, Sha384Config, Sha512Config}; + +use crate::{Sha2BlockHasherDigestColsRef, Sha2BlockHasherRoundColsRef}; + +pub trait Sha2BlockHasherVmConfig: Sha2BlockHasherSubairConfig { + /// Width of the Sha2VmRoundCols + const BLOCK_HASHER_ROUND_WIDTH: usize; + /// Width of the Sha2DigestCols + const BLOCK_HASHER_DIGEST_WIDTH: usize; + /// Width of the Sha2BlockHasherCols + const BLOCK_HASHER_WIDTH: usize; +} + +impl Sha2BlockHasherVmConfig for Sha256Config { + const BLOCK_HASHER_ROUND_WIDTH: usize = + Sha2BlockHasherRoundColsRef::::width::(); + const BLOCK_HASHER_DIGEST_WIDTH: usize = + Sha2BlockHasherDigestColsRef::::width::(); + const BLOCK_HASHER_WIDTH: usize = + if Self::BLOCK_HASHER_ROUND_WIDTH > Self::BLOCK_HASHER_DIGEST_WIDTH { + Self::BLOCK_HASHER_ROUND_WIDTH + } else { + Self::BLOCK_HASHER_DIGEST_WIDTH + }; +} + +impl Sha2BlockHasherVmConfig for Sha512Config { + const BLOCK_HASHER_ROUND_WIDTH: usize = + Sha2BlockHasherRoundColsRef::::width::(); + const BLOCK_HASHER_DIGEST_WIDTH: usize = + Sha2BlockHasherDigestColsRef::::width::(); + const BLOCK_HASHER_WIDTH: usize = if ::BLOCK_HASHER_ROUND_WIDTH + > Self::BLOCK_HASHER_DIGEST_WIDTH + { + Self::BLOCK_HASHER_ROUND_WIDTH + } else { + Self::BLOCK_HASHER_DIGEST_WIDTH + }; +} + +impl Sha2BlockHasherVmConfig for Sha384Config { + const BLOCK_HASHER_ROUND_WIDTH: usize = + Sha2BlockHasherRoundColsRef::::width::(); + const BLOCK_HASHER_DIGEST_WIDTH: usize = + Sha2BlockHasherDigestColsRef::::width::(); + const BLOCK_HASHER_WIDTH: usize = + if Self::BLOCK_HASHER_ROUND_WIDTH > Self::BLOCK_HASHER_DIGEST_WIDTH { + Self::BLOCK_HASHER_ROUND_WIDTH + } else { + Self::BLOCK_HASHER_DIGEST_WIDTH + }; +} diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/mod.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/mod.rs new file mode 100644 index 0000000000..0465a2fdac --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/mod.rs @@ -0,0 +1,55 @@ +mod air; +mod columns; + +mod config; +mod trace; + +use std::{ + marker::PhantomData, + sync::{Arc, Mutex}, +}; + +pub use air::*; +pub use columns::*; +pub use config::*; +use openvm_circuit::system::memory::SharedMemoryHelper; +use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_sha2_air::{Sha2BlockHasherFillerHelper, Sha2BlockHasherSubairConfig}; + +pub use super::{config::*, Sha2SharedRecords}; + +pub struct Sha2BlockHasherChip { + pub inner: Sha2BlockHasherFillerHelper, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub pointer_max_bits: usize, + pub mem_helper: SharedMemoryHelper, + // This Arc>> is shared with the main chip (Sha2MainChip). + // When the main chip's tracegen is done, it will set the value of the mutex to Some(records) + // and then the block hasher chip can see the records and use it to generate its trace. + // The arc mutex is not strictly necessary (we could just use a Cell) because tracegen is done + // sequentially over the list of chips (although it is parallelized within each chip), but the + // overhead of using a thread-safe type is negligible since we only access the 'records' field + // twice (once to set the value and once to get the value). + // So, we will just use an arc mutex to avoid overcomplicating things. + pub records: Arc>>>, + _phantom: PhantomData, +} + +impl Sha2BlockHasherChip { + pub fn new( + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, + mem_helper: SharedMemoryHelper, + records: Arc>>>, + ) -> Self { + Self { + inner: Sha2BlockHasherFillerHelper::new(), + bitwise_lookup_chip, + pointer_max_bits, + mem_helper, + records, + _phantom: PhantomData, + } + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/trace.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/trace.rs new file mode 100644 index 0000000000..a9b801731e --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/trace.rs @@ -0,0 +1,271 @@ +use std::{slice, sync::Arc}; + +use openvm_circuit::arch::get_record_from_slice; +use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_sha2_air::{ + be_limbs_into_word, le_limbs_into_word, Sha2BlockHasherFillerHelper, Sha2RoundColsRef, + Sha2RoundColsRefMut, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, + p3_maybe_rayon::prelude::*, + prover::{cpu::CpuBackend, types::AirProvingContext}, + Chip, +}; + +use crate::{ + Sha2BlockHasherChip, Sha2BlockHasherRoundColsRefMut, Sha2BlockHasherVmConfig, Sha2Config, + Sha2Metadata, Sha2RecordLayout, Sha2RecordMut, Sha2SharedRecords, INNER_OFFSET, +}; + +// We don't use the record arena associated with this chip. Instead, we will use the record arena +// provided by the main chip, which will be passed to this chip after the main chip's tracegen is +// done. +impl Chip> for Sha2BlockHasherChip, C> +where + Val: PrimeField32, + SC: StarkGenericConfig, +{ + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { + // SAFETY: the tracegen for Sha2MainChip must be done before this chip's tracegen + let mut records = self.records.lock().unwrap(); + let mut records = records.take().unwrap(); + let rows_used = records.num_records * C::ROWS_PER_BLOCK; + + let height = next_power_of_two_or_zero(rows_used); + let trace = Val::::zero_vec(height * C::BLOCK_HASHER_WIDTH); + let mut trace_matrix = RowMajorMatrix::new(trace, C::BLOCK_HASHER_WIDTH); + + self.fill_trace(&mut trace_matrix, &mut records, rows_used); + + AirProvingContext::simple(Arc::new(trace_matrix), vec![]) + } +} + +impl Sha2BlockHasherChip +where + F: PrimeField32, + C: Sha2BlockHasherVmConfig, +{ + fn fill_trace( + &self, + trace_matrix: &mut RowMajorMatrix, + records: &mut Sha2SharedRecords, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + + let trace = &mut trace_matrix.values[..]; + + // grab all the records + // we need to do this first, so we can pass (this_block.prev_hash, next_block.prev_hash) to + // each block (in the call to fill_block_trace) + let (records, prev_hashes): (Vec<_>, Vec<_>) = records + .matrix + .par_rows_mut() + .take(records.num_records) + .map(|mut record| { + // SAFETY: + // - caller ensures `records` contains a valid record representation that was + // previously written by the executor + // - records contains a valid Sha2RecordMut with the exact layout specified + // - get_record_from_slice will correctly split the buffer into header, input, and + // aux components based on this layout + let record: Sha2RecordMut = unsafe { + get_record_from_slice( + &mut record, + Sha2RecordLayout { + metadata: Sha2Metadata { + variant: C::VARIANT, + }, + }, + ) + }; + + let prev_hash = (0..C::HASH_WORDS) + .map(|i| { + le_limbs_into_word::( + &record.prev_state[i * C::WORD_U8S..(i + 1) * C::WORD_U8S] + .iter() + .map(|x| *x as u32) + .collect::>(), + ) + }) + .collect::>(); + + (record, prev_hash) + }) + .unzip(); + + // zip the prev_hashes with the next block's prev_hash + let prev_hashes_and_next_block_prev_hashes = prev_hashes.par_iter().zip( + prev_hashes[1..] + .par_iter() + .chain(prev_hashes[..1].par_iter()), + ); + + // fill in used rows + trace[..rows_used * C::BLOCK_HASHER_WIDTH] + .par_chunks_exact_mut(C::BLOCK_HASHER_WIDTH * C::ROWS_PER_BLOCK) + .zip( + records + .par_iter() + .zip(prev_hashes_and_next_block_prev_hashes), + ) + .enumerate() + .for_each( + |(block_idx, (block_slice, (record, (prev_hash, next_block_prev_hash))))| { + self.fill_block_trace( + block_slice, + record.message_bytes, + block_idx + 1, // 1-indexed + prev_hash, + next_block_prev_hash, + block_idx, + ); + }, + ); + + // fill in the first dummy row. + // we need to do this first, so we can compute the carries that make the + // constraint_word_addition constraints hold on dummy rows (or more precisely, on rows such + // that the next row is a dummy row). + let first_dummy_row_cols_const = self.fill_first_dummy_row( + &mut trace[rows_used * C::BLOCK_HASHER_WIDTH..(rows_used + 1) * C::BLOCK_HASHER_WIDTH], + &prev_hashes[0], + ); + + // fill in the rest of the dummy rows + trace[(rows_used + 1) * C::BLOCK_HASHER_WIDTH..] + .par_chunks_exact_mut(C::BLOCK_HASHER_WIDTH) + .for_each(|row| { + // copy the carries from the first dummy row into the current dummy row + self.inner.generate_default_row( + &mut Sha2RoundColsRefMut::from::( + &mut row[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH], + ), + &prev_hashes[0], + Some( + first_dummy_row_cols_const + .work_vars + .carry_a + .as_slice() + .unwrap(), + ), + Some( + first_dummy_row_cols_const + .work_vars + .carry_e + .as_slice() + .unwrap(), + ), + ); + }); + + // Do a second pass over the trace to fill in the missing values + // Note, we need to skip the very first row + trace[C::BLOCK_HASHER_WIDTH..] + .par_chunks_mut(C::BLOCK_HASHER_WIDTH * C::ROWS_PER_BLOCK) + .take(rows_used / C::ROWS_PER_BLOCK) + .for_each(|chunk| { + self.inner + .generate_missing_cells(chunk, C::BLOCK_HASHER_WIDTH, INNER_OFFSET); + }); + } + + fn fill_first_dummy_row( + &self, + first_dummy_row_mut: &mut [F], + first_block_prev_hash: &[C::Word], + ) -> Sha2RoundColsRef<'_, F> { + let first_dummy_row_const = + unsafe { slice::from_raw_parts(first_dummy_row_mut.as_ptr(), C::BLOCK_HASHER_WIDTH) }; + let first_dummy_row_cols_const = Sha2RoundColsRef::from::( + &first_dummy_row_const[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH], + ); + + let first_dummy_row_mut = unsafe { + slice::from_raw_parts_mut(first_dummy_row_mut.as_mut_ptr(), C::BLOCK_HASHER_WIDTH) + }; + let mut first_dummy_row_cols_mut: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut first_dummy_row_mut[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH], + ); + + // first, fill in everything but the carries into the first dummy row (i.e. fill in the + // work vars and row_idx) + self.inner.generate_default_row( + &mut first_dummy_row_cols_mut, + first_block_prev_hash, + None, + None, + ); + + // Now, this will fill in the first dummy row with the correct carries. + // This works because we already filled in the work vars into the first dummy row, and + // generate_carry_ae only looks at the work vars. + // Note that these carries will work for any pair of dummy rows, since all dummy rows + // have the same work vars (the first block's prev_hash). + Sha2BlockHasherFillerHelper::::generate_carry_ae( + first_dummy_row_cols_const.clone(), + &mut first_dummy_row_cols_mut, + ); + + first_dummy_row_cols_const + } +} + +impl Sha2BlockHasherChip { + #[allow(clippy::too_many_arguments)] + fn fill_block_trace( + &self, + block_slice: &mut [F], + input: &[u8], + global_block_idx: usize, // 1-indexed + prev_hash: &[C::Word], + next_block_prev_hash: &[C::Word], + request_id: usize, + ) where + F: PrimeField32, + { + debug_assert_eq!(input.len(), C::BLOCK_U8S); + debug_assert_eq!(prev_hash.len(), C::HASH_WORDS); + + // Set request_id + block_slice + .par_chunks_exact_mut(C::BLOCK_HASHER_WIDTH) + .for_each(|row_slice| { + // Set request_id + let cols = Sha2BlockHasherRoundColsRefMut::::from::( + &mut row_slice[..C::BLOCK_HASHER_WIDTH], + ); + *cols.request_id = F::from_canonical_usize(request_id); + }); + + let input_words = (0..C::BLOCK_WORDS) + .map(|i| { + be_limbs_into_word::( + &input[i * C::WORD_U8S..(i + 1) * C::WORD_U8S] + .iter() + .map(|x| *x as u32) + .collect::>(), + ) + }) + .collect::>(); + + // Fill in the inner trace + self.inner.generate_block_trace( + block_slice, + C::BLOCK_HASHER_WIDTH, + INNER_OFFSET, + &input_words, + self.bitwise_lookup_chip.clone(), + prev_hash, + next_block_prev_hash, + global_block_idx as u32, + ); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/config.rs b/extensions/sha2/circuit/src/sha2_chips/config.rs new file mode 100644 index 0000000000..437852315e --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/config.rs @@ -0,0 +1,97 @@ +use openvm_sha2_air::{Sha256Config, Sha384Config, Sha512Config}; +use sha2::{ + compress256, compress512, digest::generic_array::GenericArray, Digest, Sha256, Sha384, Sha512, +}; + +use crate::{Sha2BlockHasherVmConfig, Sha2MainChipConfig}; + +pub const SHA2_REGISTER_READS: usize = 3; +pub const SHA2_READ_SIZE: usize = 4; +pub const SHA2_WRITE_SIZE: usize = 4; + +pub trait Sha2Config: Sha2MainChipConfig + Sha2BlockHasherVmConfig { + /// Number of bits used to store the message length (part of the message padding) + const MESSAGE_LENGTH_BITS: usize; + + /// Number of bytes in the digest + const DIGEST_BYTES: usize; + + // Preconditions: + // - state.len() >= Self::STATE_BYTES + // - input.len() == Self::BLOCK_BYTES + fn compress(state: &mut [u8], input: &[u8]); + + // returns the digest as big-endian words + fn hash(message: &[u8]) -> Vec; +} + +impl Sha2Config for Sha256Config { + const MESSAGE_LENGTH_BITS: usize = 64; + + // the digest is the whole state + const DIGEST_BYTES: usize = Sha256Config::STATE_BYTES; + + fn compress(state: &mut [u8], input: &[u8]) { + debug_assert!(state.len() >= Sha256Config::STATE_BYTES); + debug_assert!(input.len() == Sha256Config::BLOCK_BYTES); + + // SAFETY: + // This is safe because state points to a [u32; 8]. + // The only reason we have a &[u8] instead is that we read it from a record, where + // we store the state as bytes since we don't know the word size at compile time (u32 for + // Sha256, u64 for Sha512) + let state_u32s: &mut [u32; 8] = unsafe { &mut *(state.as_mut_ptr() as *mut [u32; 8]) }; + + let input_array = GenericArray::from_slice(input); + + compress256(state_u32s, &[*input_array]); + } + + // returns the digest as big-endian words + fn hash(message: &[u8]) -> Vec { + Sha256::digest(message).to_vec() + } +} + +impl Sha2Config for Sha512Config { + const MESSAGE_LENGTH_BITS: usize = 128; + + // the digest is the whole state + const DIGEST_BYTES: usize = Sha512Config::STATE_BYTES; + + fn compress(state: &mut [u8], input: &[u8]) { + debug_assert!(state.len() >= Sha512Config::STATE_BYTES); + debug_assert!(input.len() == Sha512Config::BLOCK_BYTES); + + // SAFETY: + // This is safe because state points to a [u64; 8]. + // The only reason we have a &[u8] instead is that we read it from a record, where + // we store the state as bytes since we don't know the word size at compile time (u32 for + // Sha256, u64 for Sha512) + let state_u64s: &mut [u64; 8] = unsafe { &mut *(state.as_mut_ptr() as *mut [u64; 8]) }; + + let input_array = GenericArray::from_slice(input); + + compress512(state_u64s, &[*input_array]); + } + + // returns the digest as big-endian words + fn hash(message: &[u8]) -> Vec { + Sha512::digest(message).to_vec() + } +} + +impl Sha2Config for Sha384Config { + const MESSAGE_LENGTH_BITS: usize = Sha512Config::MESSAGE_LENGTH_BITS; + + // SHA-384 truncates the output to 48 bytes + const DIGEST_BYTES: usize = 48; + + fn compress(state: &mut [u8], input: &[u8]) { + Sha512Config::compress(state, input); + } + + fn hash(message: &[u8]) -> Vec { + Sha384::digest(message).to_vec() + } +} diff --git a/extensions/sha256/circuit/src/sha256_chip/execution.rs b/extensions/sha2/circuit/src/sha2_chips/execution.rs similarity index 53% rename from extensions/sha256/circuit/src/sha256_chip/execution.rs rename to extensions/sha2/circuit/src/sha2_chips/execution.rs index 8fdc91f9d0..d0428ab33f 100644 --- a/extensions/sha256/circuit/src/sha256_chip/execution.rs +++ b/extensions/sha2/circuit/src/sha2_chips/execution.rs @@ -8,62 +8,58 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, LocalOpcode, }; -use openvm_sha256_air::{get_sha256_num_blocks, SHA256_ROWS_PER_BLOCK}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; use openvm_stark_backend::p3_field::PrimeField32; -use super::{sha256_solve, Sha256VmExecutor, SHA256_NUM_READ_ROWS, SHA256_READ_SIZE}; +use super::{Sha2Config, Sha2VmExecutor, SHA2_READ_SIZE}; +use crate::SHA2_WRITE_SIZE; #[derive(AlignedBytesBorrow, Clone)] #[repr(C)] -struct ShaPreCompute { +struct Sha2PreCompute { a: u8, b: u8, c: u8, } -impl InterpreterExecutor for Sha256VmExecutor { - #[cfg(feature = "tco")] - fn handler( +impl InterpreterExecutor for Sha2VmExecutor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[cfg(not(feature = "tco"))] + fn pre_compute( &self, pc: u32, inst: &Instruction, data: &mut [u8], - ) -> Result, StaticProgramError> + ) -> Result, StaticProgramError> where Ctx: ExecutionCtxTrait, { - let data: &mut ShaPreCompute = data.borrow_mut(); + let data: &mut Sha2PreCompute = data.borrow_mut(); self.pre_compute_impl(pc, inst, data)?; - Ok(execute_e1_handler::<_, _>) - } - - fn pre_compute_size(&self) -> usize { - size_of::() + Ok(execute_e1_impl::<_, _, C>) } - #[cfg(not(feature = "tco"))] - fn pre_compute( + #[cfg(feature = "tco")] + fn handler( &self, pc: u32, inst: &Instruction, data: &mut [u8], - ) -> Result, StaticProgramError> + ) -> Result, StaticProgramError> where Ctx: ExecutionCtxTrait, { - let data: &mut ShaPreCompute = data.borrow_mut(); + let data: &mut Sha2PreCompute = data.borrow_mut(); self.pre_compute_impl(pc, inst, data)?; - Ok(execute_e1_impl::<_, _>) + Ok(execute_e1_handler::<_, _, C>) } } -#[cfg(feature = "aot")] -impl AotExecutor for Sha256VmExecutor {} - -impl InterpreterMeteredExecutor for Sha256VmExecutor { +impl InterpreterMeteredExecutor for Sha2VmExecutor { fn metered_pre_compute_size(&self) -> usize { - size_of::>() + size_of::>() } #[cfg(not(feature = "tco"))] @@ -77,10 +73,10 @@ impl InterpreterMeteredExecutor for Sha256VmExecutor { where Ctx: MeteredExecutionCtxTrait, { - let data: &mut E2PreCompute = data.borrow_mut(); + let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; self.pre_compute_impl(pc, inst, &mut data.data)?; - Ok(execute_e2_impl::<_, _>) + Ok(execute_e2_impl::<_, _, C>) } #[cfg(feature = "tco")] @@ -94,90 +90,101 @@ impl InterpreterMeteredExecutor for Sha256VmExecutor { where Ctx: MeteredExecutionCtxTrait, { - let data: &mut E2PreCompute = data.borrow_mut(); + let data: &mut E2PreCompute = data.borrow_mut(); data.chip_idx = chip_idx as u32; self.pre_compute_impl(pc, inst, &mut data.data)?; - Ok(execute_e2_handler::<_, _>) + Ok(execute_e2_handler::<_, _, C>) } } -#[cfg(feature = "aot")] -impl AotMeteredExecutor for Sha256VmExecutor {} - #[inline(always)] -unsafe fn execute_e12_impl( - pre_compute: &ShaPreCompute, +unsafe fn execute_e12_impl< + F: PrimeField32, + C: Sha2Config, + CTX: ExecutionCtxTrait, + const IS_E1: bool, +>( + pre_compute: &Sha2PreCompute, exec_state: &mut VmExecState, ) -> u32 { let dst = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32); - let src = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); - let len = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let state = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let input = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); let dst_u32 = u32::from_le_bytes(dst); - let src_u32 = u32::from_le_bytes(src); - let len_u32 = u32::from_le_bytes(len); - - let (output, height) = if IS_E1 { - // SAFETY: RV32_MEMORY_AS is memory address space of type u8 - let message = exec_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize); - let output = sha256_solve(message); - (output, 0) - } else { - let num_blocks = get_sha256_num_blocks(len_u32); - let mut message = Vec::with_capacity(len_u32 as usize); - for block_idx in 0..num_blocks as usize { - // Reads happen on the first 4 rows of each block - for row in 0..SHA256_NUM_READ_ROWS { - let read_idx = block_idx * SHA256_NUM_READ_ROWS + row; - let row_input: [u8; SHA256_READ_SIZE] = exec_state.vm_read( - RV32_MEMORY_AS, - src_u32 + (read_idx * SHA256_READ_SIZE) as u32, - ); - message.extend_from_slice(&row_input); - } - } - let output = sha256_solve(&message[..len_u32 as usize]); - let height = num_blocks * SHA256_ROWS_PER_BLOCK as u32; - (output, height) - }; - exec_state.vm_write(RV32_MEMORY_AS, dst_u32, &output); + let state_u32 = u32::from_le_bytes(state); + let input_u32 = u32::from_le_bytes(input); + + // state is in 4-byte little-endian words + let mut state_data = Vec::with_capacity(C::STATE_BYTES); + for i in 0..C::STATE_READS { + state_data.extend_from_slice(&exec_state.vm_read::( + RV32_MEMORY_AS, + state_u32 + (i * SHA2_READ_SIZE) as u32, + )); + } + let mut input_block = Vec::with_capacity(C::BLOCK_BYTES); + for i in 0..C::BLOCK_READS { + input_block.extend_from_slice(&exec_state.vm_read::( + RV32_MEMORY_AS, + input_u32 + (i * SHA2_READ_SIZE) as u32, + )); + } + + C::compress(&mut state_data, &input_block); + + for i in 0..C::STATE_WRITES { + exec_state.vm_write::( + RV32_MEMORY_AS, + dst_u32 + (i * SHA2_WRITE_SIZE) as u32, + &state_data[i * SHA2_WRITE_SIZE..(i + 1) * SHA2_WRITE_SIZE] + .try_into() + .unwrap(), + ); + } let pc = exec_state.pc(); exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP)); - height + 1 // height delta } #[create_handler] #[inline(always)] -unsafe fn execute_e1_impl( +unsafe fn execute_e1_impl( pre_compute: *const u8, exec_state: &mut VmExecState, ) { - let pre_compute: &ShaPreCompute = - std::slice::from_raw_parts(pre_compute, size_of::()).borrow(); - execute_e12_impl::(pre_compute, exec_state); + let pre_compute: &Sha2PreCompute = + std::slice::from_raw_parts(pre_compute, size_of::()).borrow(); + execute_e12_impl::(pre_compute, exec_state); } #[create_handler] #[inline(always)] -unsafe fn execute_e2_impl( +unsafe fn execute_e2_impl( pre_compute: *const u8, exec_state: &mut VmExecState, ) { - let pre_compute: &E2PreCompute = - std::slice::from_raw_parts(pre_compute, size_of::>()).borrow(); - let height = execute_e12_impl::(&pre_compute.data, exec_state); + let pre_compute: &E2PreCompute = + std::slice::from_raw_parts(pre_compute, size_of::>()).borrow(); + let height = execute_e12_impl::(&pre_compute.data, exec_state); exec_state .ctx .on_height_change(pre_compute.chip_idx as usize, height); } -impl Sha256VmExecutor { +#[cfg(feature = "aot")] +impl AotExecutor for Sha2VmExecutor {} + +#[cfg(feature = "aot")] +impl AotMeteredExecutor for Sha2VmExecutor {} + +impl Sha2VmExecutor { fn pre_compute_impl( &self, pc: u32, inst: &Instruction, - data: &mut ShaPreCompute, + data: &mut Sha2PreCompute, ) -> Result<(), StaticProgramError> { let Instruction { opcode, @@ -192,12 +199,12 @@ impl Sha256VmExecutor { if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { return Err(StaticProgramError::InvalidInstruction(pc)); } - *data = ShaPreCompute { + *data = Sha2PreCompute { a: a.as_canonical_u32() as u8, b: b.as_canonical_u32() as u8, c: c.as_canonical_u32() as u8, }; - assert_eq!(&Rv32Sha256Opcode::SHA256.global_opcode(), opcode); + assert_eq!(&C::OPCODE.global_opcode(), opcode); Ok(()) } } diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/air.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/air.rs new file mode 100644 index 0000000000..faff9e49cd --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/air.rs @@ -0,0 +1,326 @@ +use std::marker::PhantomData; + +use itertools::izip; +use ndarray::s; +use openvm_circuit::{ + arch::ExecutionBridge, + system::{ + memory::{offline_checker::MemoryBridge, MemoryAddress}, + SystemPort, + }, +}; +use openvm_circuit_primitives::{bitwise_op_lookup::BitwiseOperationLookupBus, utils::compose}; +use openvm_instructions::riscv::{ + RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS, +}; +use openvm_sha2_air::Sha2BlockHasherSubairConfig; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::FieldAlgebra, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use super::config::Sha2MainChipConfig; +use crate::{MessageType, Sha2ColsRef, SHA2_READ_SIZE, SHA2_WRITE_SIZE}; + +#[derive(Clone, Debug)] +pub struct Sha2MainAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + pub sha2_bus: PermutationCheckBus, + /// Maximum number of bits allowed for an address pointer + /// Must be at least 24 + pub ptr_max_bits: usize, + pub offset: usize, + _phantom: PhantomData, +} + +impl Sha2MainAir { + pub fn new( + SystemPort { + execution_bus, + program_bus, + memory_bridge, + }: SystemPort, + bitwise_lookup_bus: BitwiseOperationLookupBus, + ptr_max_bits: usize, + self_bus_idx: BusIndex, + offset: usize, + ) -> Self { + Self { + execution_bridge: ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lookup_bus, + sha2_bus: PermutationCheckBus::new(self_bus_idx), + ptr_max_bits, + offset, + _phantom: PhantomData, + } + } +} + +impl BaseAirWithPublicValues for Sha2MainAir {} +impl PartitionedBaseAir for Sha2MainAir {} +impl BaseAir for Sha2MainAir { + fn width(&self) -> usize { + C::MAIN_CHIP_WIDTH + } +} + +impl Air + for Sha2MainAir +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + + let local: Sha2ColsRef = Sha2ColsRef::from::(&local[..C::MAIN_CHIP_WIDTH]); + let next: Sha2ColsRef = Sha2ColsRef::from::(&next[..C::MAIN_CHIP_WIDTH]); + + let mut timestamp_delta = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + local.instruction.from_state.timestamp + + AB::F::from_canonical_usize(timestamp_delta - 1) + }; + + self.eval_block(builder, &local, &next); + self.eval_instruction(builder, &local, &mut timestamp_pp); + self.eval_reads(builder, &local, &mut timestamp_pp); + self.eval_writes(builder, &local, &mut timestamp_pp); + } +} + +impl Sha2MainAir { + pub fn eval_block( + &self, + builder: &mut AB, + local: &Sha2ColsRef, + next: &Sha2ColsRef, + ) { + builder + .when_first_row() + .when(*local.instruction.is_enabled) + .assert_zero(*local.block.request_id); + + builder + .when_transition() + .when(*next.instruction.is_enabled) + .assert_eq( + *next.block.request_id, + *local.block.request_id + AB::Expr::ONE, + ); + + let prev_state_as_u16s: Vec = local + .block + .prev_state + .exact_chunks(C::WORD_U8S) + .into_iter() + .flat_map(|word| { + word.as_slice() + .unwrap() + .chunks_exact(2) + .map(|x| x[1] * AB::F::from_canonical_u64(1 << 8) + x[0]) + .collect::>() + }) + .collect(); + + // Send (STATE, request_id, prev_state_as_u16s, new_state) to the sha2 bus + self.sha2_bus.send( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::State as u8), + (*local.block.request_id).into(), + ] + .into_iter() + .chain(prev_state_as_u16s) + .chain(local.block.new_state.into_iter().copied().map(|x| x.into())), + *local.instruction.is_enabled, + ); + + // Send (MESSAGE_1, request_id, first_half_of_message) to the sha2 bus + self.sha2_bus.send( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::Message1 as u8), + (*local.block.request_id).into(), + ] + .into_iter() + .chain( + local + .block + .message_bytes + .iter() + .take(C::BLOCK_BYTES / 2) + .map(|x| (*x).into()), + ), + *local.instruction.is_enabled, + ); + + // Send (MESSAGE_2, request_id, second_half_of_message) to the sha2 bus + self.sha2_bus.send( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::Message2 as u8), + (*local.block.request_id).into(), + ] + .into_iter() + .chain( + local + .block + .message_bytes + .iter() + .skip(C::BLOCK_BYTES / 2) + .map(|x| (*x).into()), + ), + *local.instruction.is_enabled, + ); + } + + pub fn eval_instruction( + &self, + builder: &mut AB, + local: &Sha2ColsRef, + timestamp_pp: &mut impl FnMut() -> AB::Expr, + ) { + for (&ptr, val, aux) in izip!( + [ + local.instruction.dst_reg_ptr, + local.instruction.state_reg_ptr, + local.instruction.input_reg_ptr + ], + [ + local.instruction.dst_ptr_limbs, + local.instruction.state_ptr_limbs, + local.instruction.input_ptr_limbs + ], + &local.mem.register_aux, + ) { + self.memory_bridge + .read::<_, _, SHA2_READ_SIZE>( + MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_REGISTER_AS), ptr), + val.to_vec().try_into().unwrap_or_else(|_| panic!()), // can't unwrap because AB::Var doesn't impl Debug + timestamp_pp(), + aux, + ) + .eval(builder, *local.instruction.is_enabled); + } + + // range check the memory pointers + let shift = AB::Expr::from_canonical_usize( + 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), + ); + let needs_range_check = [ + local.instruction.dst_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + local.instruction.state_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + local.instruction.input_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + local.instruction.input_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + ]; + for pair in needs_range_check.chunks_exact(2) { + self.bitwise_lookup_bus + .send_range(pair[0] * shift.clone(), pair[1] * shift.clone()) + .eval(builder, *local.instruction.is_enabled); + } + + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(C::OPCODE as usize + self.offset), + [ + (*local.instruction.dst_reg_ptr).into(), + (*local.instruction.state_reg_ptr).into(), + (*local.instruction.input_reg_ptr).into(), + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + ], + *local.instruction.from_state, + AB::F::from_canonical_usize(C::TIMESTAMP_DELTA), + ) + .eval(builder, *local.instruction.is_enabled); + } + + pub fn eval_reads( + &self, + builder: &mut AB, + local: &Sha2ColsRef, + timestamp_pp: &mut impl FnMut() -> AB::Expr, + ) { + let input_ptr_val = compose(&local.instruction.input_ptr_limbs.to_vec(), RV32_CELL_BITS); + for i in 0..C::BLOCK_READS { + self.memory_bridge + .read::<_, _, SHA2_READ_SIZE>( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + input_ptr_val.clone() + AB::F::from_canonical_usize(i * SHA2_READ_SIZE), + ), + local + .block + .message_bytes + .slice(s![i * SHA2_READ_SIZE..(i + 1) * SHA2_READ_SIZE]) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("message bytes is not the correct size"); + }), + timestamp_pp(), + &local.mem.input_reads[i], + ) + .eval(builder, *local.instruction.is_enabled); + } + + let state_ptr_val = compose(&local.instruction.state_ptr_limbs.to_vec(), RV32_CELL_BITS); + for i in 0..C::STATE_READS { + self.memory_bridge + .read::<_, _, SHA2_READ_SIZE>( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + state_ptr_val.clone() + AB::F::from_canonical_usize(i * SHA2_READ_SIZE), + ), + local + .block + .prev_state + .slice(s![i * SHA2_READ_SIZE..(i + 1) * SHA2_READ_SIZE]) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("prev state is not the correct size"); + }), + timestamp_pp(), + &local.mem.state_reads[i], + ) + .eval(builder, *local.instruction.is_enabled); + } + } + + pub fn eval_writes( + &self, + builder: &mut AB, + local: &Sha2ColsRef, + timestamp_pp: &mut impl FnMut() -> AB::Expr, + ) { + let dst_ptr_val = compose(&local.instruction.dst_ptr_limbs.to_vec(), RV32_CELL_BITS); + for i in 0..C::STATE_READS { + self.memory_bridge + .write::<_, _, SHA2_WRITE_SIZE>( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val.clone() + AB::F::from_canonical_usize(i * SHA2_READ_SIZE), + ), + local + .block + .new_state + .slice(s![i * SHA2_READ_SIZE..(i + 1) * SHA2_READ_SIZE]) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("new state is not the correct size"); + }), + timestamp_pp(), + &local.mem.write_aux[i], + ) + .eval(builder, *local.instruction.is_enabled); + } + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/columns.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/columns.rs new file mode 100644 index 0000000000..9c2151580f --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/columns.rs @@ -0,0 +1,74 @@ +use openvm_circuit::{ + arch::ExecutionState, + system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, +}; +use openvm_circuit_primitives::ColsRef; +use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; + +use crate::{Sha2MainChipConfig, SHA2_REGISTER_READS, SHA2_WRITE_SIZE}; + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2MainChipConfig)] +pub struct Sha2Cols { + pub block: Sha2BlockCols, + pub instruction: Sha2InstructionCols, + pub mem: Sha2MemoryCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2MainChipConfig)] +pub struct Sha2BlockCols { + /// Identifier of this row in the interactions between the two chips + pub request_id: T, + /// Input bytes for this block + pub message_bytes: [T; BLOCK_BYTES], + /// Previous state of the SHA-2 hasher object, as little-endian words + pub prev_state: [T; STATE_BYTES], + /// New state of the SHA-2 hasher object after processing this block, as little-endian words + pub new_state: [T; STATE_BYTES], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2MainChipConfig)] +pub struct Sha2InstructionCols { + /// True for all rows that are part of opcode execution. + /// False on dummy rows only used to pad the height. + pub is_enabled: T, + #[aligned_borrow] + pub from_state: ExecutionState, + /// Pointer to address space 1 `dst` register + pub dst_reg_ptr: T, + /// Pointer to address space 1 `state` register + pub state_reg_ptr: T, + /// Pointer to address space 1 `input` register + pub input_reg_ptr: T, + // Register values + /// dst_ptr_limbs <- \[dst_reg_ptr:4\]_1 + pub dst_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS], + /// state_ptr_limbs <- \[state_reg_ptr:4\]_1 + pub state_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS], + /// input_ptr_limbs <- \[input_reg_ptr:4\]_1 + pub input_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2MainChipConfig)] +pub struct Sha2MemoryCols< + T, + const BLOCK_READS: usize, + const STATE_READS: usize, + const STATE_WRITES: usize, +> { + #[aligned_borrow] + pub register_aux: [MemoryReadAuxCols; SHA2_REGISTER_READS], + #[aligned_borrow] + pub input_reads: [MemoryReadAuxCols; BLOCK_READS], + #[aligned_borrow] + pub state_reads: [MemoryReadAuxCols; STATE_READS], + #[aligned_borrow] + pub write_aux: [MemoryWriteAuxCols; STATE_WRITES], +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/config.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/config.rs new file mode 100644 index 0000000000..5aec638971 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/config.rs @@ -0,0 +1,42 @@ +use openvm_sha2_air::{Sha256Config, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; + +use crate::{Sha2ColsRef, SHA2_READ_SIZE, SHA2_REGISTER_READS, SHA2_WRITE_SIZE}; + +pub trait Sha2MainChipConfig: Send + Sync + Clone { + // --- Required --- + /// Number of bytes in a SHA block (sometimes referred to as message bytes in the code) + const BLOCK_BYTES: usize; + /// Number of bytes in a SHA state + const STATE_BYTES: usize; + /// OpenVM Opcode for the instruction + const OPCODE: Rv32Sha2Opcode; + + // --- Provided --- + const BLOCK_READS: usize = Self::BLOCK_BYTES / SHA2_READ_SIZE; + const STATE_READS: usize = Self::STATE_BYTES / SHA2_READ_SIZE; + const STATE_WRITES: usize = Self::STATE_BYTES / SHA2_WRITE_SIZE; + + const TIMESTAMP_DELTA: usize = + Self::BLOCK_READS + Self::STATE_READS + Self::STATE_WRITES + SHA2_REGISTER_READS; + + const MAIN_CHIP_WIDTH: usize = Sha2ColsRef::::width::(); +} + +impl Sha2MainChipConfig for Sha256Config { + const BLOCK_BYTES: usize = 64; + const STATE_BYTES: usize = 32; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA256; +} + +impl Sha2MainChipConfig for Sha512Config { + const BLOCK_BYTES: usize = 128; + const STATE_BYTES: usize = 64; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA512; +} + +impl Sha2MainChipConfig for Sha384Config { + const BLOCK_BYTES: usize = Sha512Config::BLOCK_BYTES; + const STATE_BYTES: usize = Sha512Config::STATE_BYTES; + const OPCODE: Rv32Sha2Opcode = Sha512Config::OPCODE; +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/mod.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/mod.rs new file mode 100644 index 0000000000..c72b48fe3e --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/mod.rs @@ -0,0 +1,58 @@ +mod air; +mod columns; +mod config; +mod trace; + +use std::{ + marker::PhantomData, + sync::{Arc, Mutex}, +}; + +pub use air::*; +pub use columns::*; +pub use config::*; +use openvm_circuit::system::memory::SharedMemoryHelper; +use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; +use openvm_stark_backend::p3_matrix::dense::RowMajorMatrix; + +use crate::Sha2Config; + +// Record struct for sharing between the main chip and the block hasher chip +pub struct Sha2SharedRecords { + // note that we can't just do matrix.height() because the height is padded to the next power of + // two by MatrixRecordArena::into_matrix() + pub num_records: usize, + pub matrix: RowMajorMatrix, +} +pub struct Sha2MainChip { + // This Arc>> is shared with the block hasher chip (Sha2BlockHasherChip). + // When the main chip's tracegen is done, it will set the value of the mutex to Some(records) + // and then the block hasher chip can see the records and use them to generate its trace. + // The arc mutex is not strictly necessary (we could just use a Cell) because tracegen is done + // sequentially over the list of chips (although it is parallelized within each chip), but the + // overhead of using a thread-safe type is negligible since we only access the 'records' field + // twice (once to set the value and once to get the value). + // So, we will just use an arc mutex to avoid overcomplicating things. + pub records: Arc>>>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + pub pointer_max_bits: usize, + pub mem_helper: SharedMemoryHelper, + _phantom: PhantomData, +} + +impl Sha2MainChip { + pub fn new( + records: Arc>>>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + pointer_max_bits: usize, + mem_helper: SharedMemoryHelper, + ) -> Self { + Self { + records, + bitwise_lookup_chip, + pointer_max_bits, + mem_helper, + _phantom: PhantomData, + } + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/trace.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/trace.rs new file mode 100644 index 0000000000..33a1d2868d --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/trace.rs @@ -0,0 +1,224 @@ +use std::sync::Arc; + +use openvm_circuit::{ + arch::*, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + MemoryAuxColsFactory, + }, + utils::next_power_of_two_or_zero, +}; +use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use openvm_sha2_air::set_arrayview_from_u8_slice; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::*, + prover::{cpu::CpuBackend, types::AirProvingContext}, + Chip, +}; + +use crate::{ + Sha2ColsRefMut, Sha2Config, Sha2MainChip, Sha2Metadata, Sha2RecordLayout, Sha2RecordMut, + Sha2SharedRecords, SHA2_WRITE_SIZE, +}; + +// We will allocate a new trace matrix instead of using the record arena directly, +// because we want to pass the record arena to Sha2BlockHasherChip when we are done. +impl Chip> + for Sha2MainChip, C> +where + Val: PrimeField32, + SC: StarkGenericConfig, + RA: RowMajorMatrixArena> + Send + Sync, +{ + fn generate_proving_ctx(&self, arena: RA) -> AirProvingContext> { + // Since Sha2Metadata::get_num_rows() = 1, the number of rows used is equal to the number of + // SHA-2 instructions executed. + let rows_used = arena.trace_offset() / arena.width(); + + // We will fill the trace into a separate buffer, because we want to pass the arena to the + // Sha2BlockHasherChip when we are done. + // Sha2MainChip uses 1 row per instruction, we allocate rows_used * arena.width() space for + // the trace. + let height = next_power_of_two_or_zero(rows_used); + let trace = Val::::zero_vec(height * arena.width()); + let mut trace_matrix = RowMajorMatrix::new(trace, arena.width()); + let mem_helper = self.mem_helper.as_borrowed(); + + let mut records = arena.into_matrix(); + + self.fill_trace(&mem_helper, &mut trace_matrix, rows_used, &mut records); + + // Pass the records to Sha2BlockHasherChip + *self.records.lock().unwrap() = Some(Sha2SharedRecords { + num_records: rows_used, + matrix: records, + }); + + AirProvingContext::simple(Arc::new(trace_matrix), vec![]) + } +} + +// Note: we would like to just impl TraceFiller here, but we can't because we need to pass the +// records and row_idx to the tracegen functions. +impl Sha2MainChip { + // Preconditions: + // - trace should be a matrix with width = Sha2MainAir::width() and height = rows_used + // - trace should be filled with all zeros + // - records should be a matrix with height = rows_used, where each row stores a record + pub fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + records: &mut RowMajorMatrix, + ) { + let width = trace.width(); + trace.values[..rows_used * width] + .par_chunks_exact_mut(width) + .zip(records.par_rows_mut()) + .enumerate() + .for_each(|(row_idx, (row_slice, record))| { + self.fill_trace_row_with_row_idx(mem_helper, row_slice, row_idx, record); + }); + } + + // Same as TraceFiller::fill_trace_row, except we also take the row index as a parameter. + // + // Note: the only reason the record parameter is mutable is that get_record_from_slice + // requires a &mut &mut [F] slice. This parameter type is useful in other places where + // get_record_from_slice is used, to circumvent the borrow checker. Here, we don't actually need + // this workaround (we could duplicate get_record_from_slice and modify it to take a &mut + // [F] slice), but we just use the existing function for simplicity. + fn fill_trace_row_with_row_idx( + &self, + mem_helper: &MemoryAuxColsFactory, + row_slice: &mut [F], + row_idx: usize, + mut record: &mut [F], + ) where + F: Clone, + { + // SAFETY: + // - caller ensures `record` contains a valid record representation that was previously + // written by the executor + // - record contains a valid Sha2RecordMut with the exact layout specified + // - get_record_from_slice will correctly split the buffer into header and other components + // based on this layout. + let record: Sha2RecordMut = unsafe { + get_record_from_slice( + &mut record, + Sha2RecordLayout::new(Sha2Metadata { + variant: C::VARIANT, + }), + ) + }; + + // save all the components of the record on the stack so that we don't overwrite them when + // filling in the trace matrix. + let vm_record = record.inner.clone(); + + let mut message_bytes = Vec::with_capacity(C::BLOCK_BYTES); + message_bytes.extend_from_slice(record.message_bytes); + + let mut prev_state = Vec::with_capacity(C::STATE_BYTES); + prev_state.extend_from_slice(record.prev_state); + + let mut new_state = Vec::with_capacity(C::STATE_BYTES); + new_state.extend_from_slice(record.new_state); + + let mut input_reads_aux = + Vec::with_capacity(C::BLOCK_READS * size_of::()); + input_reads_aux.extend_from_slice(record.input_reads_aux); + + let mut state_reads_aux = + Vec::with_capacity(C::STATE_READS * size_of::()); + state_reads_aux.extend_from_slice(record.state_reads_aux); + + let mut write_aux = Vec::with_capacity( + C::STATE_WRITES * size_of::>(), + ); + write_aux.extend_from_slice(record.write_aux); + + let mut cols = Sha2ColsRefMut::from::(row_slice); + + *cols.block.request_id = F::from_canonical_usize(row_idx); + set_arrayview_from_u8_slice(&mut cols.block.message_bytes, message_bytes); + set_arrayview_from_u8_slice(&mut cols.block.prev_state, prev_state); + set_arrayview_from_u8_slice(&mut cols.block.new_state, new_state); + + *cols.instruction.is_enabled = F::ONE; + cols.instruction.from_state.timestamp = F::from_canonical_u32(vm_record.timestamp); + cols.instruction.from_state.pc = F::from_canonical_u32(vm_record.from_pc); + *cols.instruction.dst_reg_ptr = F::from_canonical_u32(vm_record.dst_reg_ptr); + *cols.instruction.state_reg_ptr = F::from_canonical_u32(vm_record.state_reg_ptr); + *cols.instruction.input_reg_ptr = F::from_canonical_u32(vm_record.input_reg_ptr); + + let dst_ptr_limbs = vm_record.dst_ptr.to_le_bytes(); + let state_ptr_limbs = vm_record.state_ptr.to_le_bytes(); + let input_ptr_limbs = vm_record.input_ptr.to_le_bytes(); + set_arrayview_from_u8_slice(&mut cols.instruction.dst_ptr_limbs, dst_ptr_limbs); + set_arrayview_from_u8_slice(&mut cols.instruction.state_ptr_limbs, state_ptr_limbs); + set_arrayview_from_u8_slice(&mut cols.instruction.input_ptr_limbs, input_ptr_limbs); + let needs_range_check = [ + dst_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + state_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + input_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + input_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + ]; + let shift: u32 = 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits); + for pair in needs_range_check.chunks_exact(2) { + self.bitwise_lookup_chip + .request_range(pair[0] as u32 * shift, pair[1] as u32 * shift); + } + + // fill in the register reads aux + let mut timestamp = vm_record.timestamp; + for (cols, vm_record) in cols + .mem + .register_aux + .iter_mut() + .zip(vm_record.register_reads_aux.iter()) + { + mem_helper.fill(vm_record.prev_timestamp, timestamp, cols.as_mut()); + timestamp += 1; + } + + input_reads_aux.iter().zip(cols.mem.input_reads).for_each( + |(read_aux_record, read_aux_cols)| { + mem_helper.fill( + read_aux_record.prev_timestamp, + timestamp, + read_aux_cols.as_mut(), + ); + timestamp += 1; + }, + ); + + state_reads_aux.iter().zip(cols.mem.state_reads).for_each( + |(state_aux_record, state_aux_cols)| { + mem_helper.fill( + state_aux_record.prev_timestamp, + timestamp, + state_aux_cols.as_mut(), + ); + timestamp += 1; + }, + ); + + write_aux + .iter() + .zip(cols.mem.write_aux) + .for_each(|(write_aux_record, write_aux_cols)| { + write_aux_cols.set_prev_data(write_aux_record.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + write_aux_record.prev_timestamp, + timestamp, + write_aux_cols.as_mut(), + ); + timestamp += 1; + }); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/mod.rs b/extensions/sha2/circuit/src/sha2_chips/mod.rs new file mode 100644 index 0000000000..f1a6a9211d --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/mod.rs @@ -0,0 +1,33 @@ +mod block_hasher_chip; +mod config; +mod execution; +mod main_chip; +mod trace; + +use std::marker::PhantomData; + +pub use block_hasher_chip::*; +pub use main_chip::*; +pub use trace::*; + +#[cfg(test)] +mod test_utils; +#[cfg(test)] +mod tests; +#[cfg(test)] +pub use test_utils::*; + +#[derive(derive_new::new, Clone)] +pub struct Sha2VmExecutor { + pub offset: usize, + pub pointer_max_bits: usize, + _phantom: PhantomData, +} + +// Indicates the message type of the interactions on the sha bus +#[repr(u8)] +pub enum MessageType { + State, + Message1, + Message2, +} diff --git a/extensions/sha2/circuit/src/sha2_chips/test_utils.rs b/extensions/sha2/circuit/src/sha2_chips/test_utils.rs new file mode 100644 index 0000000000..af4e4dce2c --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/test_utils.rs @@ -0,0 +1,77 @@ +use itertools::Itertools; +use openvm_circuit::arch::testing::TestBuilder; +use openvm_instructions::riscv::RV32_MEMORY_AS; +use openvm_sha2_air::Sha2Variant; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{Sha2Config, SHA2_READ_SIZE, SHA2_WRITE_SIZE}; + +// See https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf for the padding algorithm +pub fn add_padding_to_message(mut message: Vec) -> Vec { + // Length of the message in bits + let message_len = message.len() * 8; + + // For SHA-256, + // l + 1 + k = 448 mod 512 + // <=> l + 1 + k + 8 = 0 mod 512 + // <=> k = -(l + 1 + 8) mod 512 + // <=> k = (512 - (l + 1 + 8)) mod 512 + // The other variants are similar. + let padding_len_bits = (C::BLOCK_BITS + - ((message_len + 1 + C::MESSAGE_LENGTH_BITS) % C::BLOCK_BITS)) + % C::BLOCK_BITS; + message.push(0x80); + + let padding_len_bytes = padding_len_bits / 8; + message.extend(std::iter::repeat_n(0x00, padding_len_bytes)); + + match C::VARIANT { + Sha2Variant::Sha256 => { + message.extend_from_slice(&((message_len as u64).to_be_bytes())); + } + Sha2Variant::Sha512 => { + message.extend_from_slice(&((message_len as u128).to_be_bytes())); + } + Sha2Variant::Sha384 => { + message.extend_from_slice(&((message_len as u128).to_be_bytes())); + } + }; + + assert_eq!(message.len() % C::BLOCK_BYTES, 0); + + message +} + +pub fn write_slice_to_memory( + tester: &mut impl TestBuilder, + data: &[u8], + ptr: usize, +) { + data.chunks_exact(4).enumerate().for_each(|(i, chunk)| { + tester.write::( + RV32_MEMORY_AS as usize, + ptr + i * 4, + chunk + .iter() + .cloned() + .map(F::from_canonical_u8) + .collect_vec() + .try_into() + .unwrap(), + ); + }); +} + +pub fn read_slice_from_memory( + tester: &mut impl TestBuilder, + ptr: usize, + len: usize, +) -> Vec { + let mut data = Vec::new(); + for i in 0..(len / SHA2_READ_SIZE) { + data.extend_from_slice( + &tester.read::(RV32_MEMORY_AS as usize, ptr + i * SHA2_READ_SIZE), + ); + } + data +} diff --git a/extensions/sha2/circuit/src/sha2_chips/tests.rs b/extensions/sha2/circuit/src/sha2_chips/tests.rs new file mode 100644 index 0000000000..00baa1b4d4 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/tests.rs @@ -0,0 +1,864 @@ +use std::sync::{Arc, Mutex}; + +use hex::FromHex; +use itertools::Itertools; +use openvm_circuit::{ + arch::{ + testing::{ + memory::gen_pointer, TestBuilder, TestChipHarness, VmChipTestBuilder, + BITWISE_OP_LOOKUP_BUS, + }, + Arena, MatrixRecordArena, PreflightExecutor, + }, + system::{memory::SharedMemoryHelper, SystemPort}, + utils::get_random_message, +}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, +}; +use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_sha2_air::{word_into_u8_limbs, Sha256Config, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; +use openvm_stark_backend::{ + interaction::BusIndex, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +#[cfg(feature = "cuda")] +use { + crate::{Sha2BlockHasherChipGpu, Sha2MainChipGpu, Sha2RecordMut}, + openvm_circuit::arch::testing::{ + default_bitwise_lookup_bus, GpuChipTestBuilder, GpuTestChipHarness, + }, + openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupChipGPU, +}; + +use crate::{ + add_padding_to_message, read_slice_from_memory, write_slice_to_memory, Sha2BlockHasherChip, + Sha2BlockHasherDigestColsRefMut, Sha2BlockHasherVmAir, Sha2Config, Sha2MainAir, Sha2MainChip, + Sha2VmExecutor, +}; + +const SHA2_BUS_IDX: BusIndex = 28; +const SUBAIR_BUS_IDX: BusIndex = 29; +type F = BabyBear; +const MAX_INS_CAPACITY: usize = 4096; +type Harness = TestChipHarness, Sha2MainAir, Sha2MainChip, RA>; + +fn create_harness_fields( + system_port: SystemPort, + bitwise_chip: SharedBitwiseOperationLookupChip, + memory_helper: SharedMemoryHelper, + pointer_max_bits: usize, +) -> (Sha2MainAir, Sha2VmExecutor, Sha2MainChip) { + let executor = Sha2VmExecutor::::new(Rv32Sha2Opcode::CLASS_OFFSET, pointer_max_bits); + let empty_records = Arc::new(Mutex::new(None)); + let main_chip = Sha2MainChip::new( + empty_records.clone(), + bitwise_chip.clone(), + pointer_max_bits, + memory_helper, + ); + let main_air = Sha2MainAir::new( + system_port, + bitwise_chip.bus(), + pointer_max_bits, + SHA2_BUS_IDX, + Rv32Sha2Opcode::CLASS_OFFSET, + ); + (main_air, executor, main_chip) +} + +struct TestHarness { + harness: Harness, + bitwise: ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), + block_hasher: (Sha2BlockHasherVmAir, Sha2BlockHasherChip), +} + +fn create_test_harness( + tester: &mut VmChipTestBuilder, +) -> TestHarness { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let (air, executor, main_chip) = create_harness_fields( + tester.system_port(), + bitwise_chip.clone(), + tester.memory_helper(), + tester.address_bits(), + ); + + let shared_records = main_chip.records.clone(); + + let harness = Harness::::with_capacity(executor, air, main_chip, MAX_INS_CAPACITY); + + let block_hasher_air = + Sha2BlockHasherVmAir::new(bitwise_chip.bus(), SUBAIR_BUS_IDX, SHA2_BUS_IDX); + let block_hasher_chip = Sha2BlockHasherChip::new( + bitwise_chip.clone(), + tester.address_bits(), + tester.memory_helper(), + shared_records, + ); + + TestHarness { + harness, + bitwise: (bitwise_chip.air, bitwise_chip), + block_hasher: (block_hasher_air, block_hasher_chip), + } +} + +// execute one SHA2_UPDATE instruction +#[allow(clippy::too_many_arguments)] +fn set_and_execute_single_block>( + tester: &mut impl TestBuilder, + executor: &mut E, + arena: &mut RA, + rng: &mut StdRng, + opcode: Rv32Sha2Opcode, + message: Option<&[u8]>, + prev_state: Option<&[u8]>, +) { + let rd = gen_pointer(rng, 4); + let rs1 = gen_pointer(rng, 4); + let rs2 = gen_pointer(rng, 4); + + let dst_ptr = gen_pointer(rng, 4); + let state_ptr = gen_pointer(rng, 4); + let input_ptr = gen_pointer(rng, 4); + tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs1, state_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs2, input_ptr.to_le_bytes().map(F::from_canonical_u8)); + + let default_message = get_random_message(rng, C::BLOCK_U8S); + let message = message.unwrap_or(&default_message); + assert!(message.len() == C::BLOCK_U8S); + write_slice_to_memory(tester, message, input_ptr); + + let default_prev_state = get_random_message(rng, C::STATE_BYTES); + let prev_state = prev_state.unwrap_or(&default_prev_state); + assert!(prev_state.len() == C::STATE_BYTES); + write_slice_to_memory(tester, prev_state, state_ptr); + + tester.execute( + executor, + arena, + &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), + ); + + let mut state = prev_state.to_vec(); + C::compress(&mut state, message); + let expected_output = state + .iter() + .cloned() + .map(F::from_canonical_u8) + .collect_vec(); + + assert_eq!( + expected_output, + read_slice_from_memory(tester, dst_ptr, C::STATE_BYTES) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS - Single Block Hash +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// +// Test a single block hash +fn rand_sha2_single_block_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let TestHarness { + mut harness, + bitwise, + block_hasher, + } = create_test_harness::, C>(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute_single_block::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + None, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(block_hasher) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn rand_sha256_single_block_test() { + rand_sha2_single_block_test::(); +} + +#[test] +fn rand_sha512_single_block_test() { + rand_sha2_single_block_test::(); +} + +#[test] +fn rand_sha384_single_block_test() { + rand_sha2_single_block_test::(); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS - Multi Block Hash +/// +/// Execute multiple SHA2_UPDATE instructions to hash an entire message +/////////////////////////////////////////////////////////////////////////////////////// +#[allow(clippy::too_many_arguments)] +fn set_and_execute_full_message>( + tester: &mut impl TestBuilder, + executor: &mut E, + arena: &mut RA, + rng: &mut StdRng, + opcode: Rv32Sha2Opcode, + message: Option<&[u8]>, + len: Option, +) { + let rd = gen_pointer(rng, 4); + let rs1 = gen_pointer(rng, 4); + let rs2 = gen_pointer(rng, 4); + + let state_ptr = gen_pointer(rng, 4); + let dst_ptr = state_ptr; + let input_ptr = gen_pointer(rng, 4); + tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs1, state_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs2, input_ptr.to_le_bytes().map(F::from_canonical_u8)); + + // initial state as little-endian words + let initial_state: Vec = C::get_h() + .iter() + .cloned() + .flat_map(|x| word_into_u8_limbs::(x).into_iter()) + .map(|x| x.try_into().unwrap()) + .collect_vec(); + + assert!(initial_state.len() == C::STATE_BYTES); + write_slice_to_memory(tester, &initial_state, state_ptr); + + let len = len.unwrap_or(rng.gen_range(1..3000)); + let default_message = get_random_message(rng, len); + let message = message.map(|x| x.to_vec()).unwrap_or(default_message); + + // C::hash() returns big-endian words. + // We want little-endian words so we can compare to our final state (which is in little-endian + // words) + let expected_output = C::hash(&message) + .chunks_exact(C::WORD_U8S) + .flat_map(|word| word.iter().rev().copied()) + .collect_vec(); + + let padded_message = add_padding_to_message::(message); + + // run SHA2_UPDATE as many times as needed to hash the entire message + padded_message + .chunks_exact(C::BLOCK_BYTES) + .for_each(|block| { + write_slice_to_memory(tester, block, input_ptr); + + tester.execute( + executor, + arena, + &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), + ); + }); + + let output = read_slice_from_memory(tester, dst_ptr, C::DIGEST_BYTES) + .into_iter() + .map(|x| x.as_canonical_u32() as u8) + .collect_vec(); + + assert_eq!(expected_output, output); +} + +// Test a single block hash +fn rand_sha2_multi_block_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let TestHarness { + mut harness, + bitwise, + block_hasher, + } = create_test_harness::<_, C>(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + None, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(block_hasher) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn rand_sha256_multi_block_test() { + rand_sha2_multi_block_test::(); +} + +#[test] +fn rand_sha512_multi_block_test() { + rand_sha2_multi_block_test::(); +} + +// Note that this test is distinct from rand_sha512_multi_block_test() because this one uses the +// initial hash state for SHA384 instead of SHA512. +#[test] +fn rand_sha384_multi_block_test() { + rand_sha2_multi_block_test::(); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// EDGE TESTS - Edge Case Input Lengths +/// +/// Test the hash function with various input lengths. +/////////////////////////////////////////////////////////////////////////////////////// +fn sha2_edge_test_lengths() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let TestHarness { + mut harness, + bitwise, + block_hasher, + } = create_test_harness::<_, C>(&mut tester); + + // inputs of various number of blocks + const TEST_VECTORS: [&str; 4] = [ + "", + "98c1c0bdb7d5fea9a88859f06c6c439f", + "5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", + "9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", + ]; + + for input in TEST_VECTORS.iter() { + let input = Vec::from_hex(input).unwrap(); + + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + Some(&input), + None, + ); + } + + // check every possible input length modulo block size + for i in (C::BLOCK_BYTES + 1)..=(2 * C::BLOCK_BYTES) { + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + Some(i), + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(block_hasher) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn sha256_edge_test_lengths() { + sha2_edge_test_lengths::(); +} + +#[test] +fn sha512_edge_test_lengths() { + sha2_edge_test_lengths::(); +} + +#[test] +fn sha384_edge_test_lengths() { + sha2_edge_test_lengths::(); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let TestHarness { mut harness, .. } = + create_test_harness::, C>(&mut tester); + + // let num_tests: usize = 10; + let num_tests: usize = 1; + for _ in 0..num_tests { + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + // None, + Some(0), + ); + } +} + +#[test] +fn execute_roundtrip_sanity_test_sha256() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn execute_roundtrip_sanity_test_sha512() { + execute_roundtrip_sanity_test::(); +} +#[test] +fn execute_roundtrip_sanity_test_sha384() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha256_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = Sha256Config::hash(input); + let expected: [u8; 32] = [ + 99, 196, 61, 185, 226, 212, 131, 80, 154, 248, 97, 108, 157, 55, 200, 226, 160, 73, 207, + 46, 245, 169, 94, 255, 42, 136, 193, 15, 40, 133, 173, 22, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha512_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = Sha512Config::hash(input); + let expected: [u8; 64] = [ + 0, 8, 195, 142, 70, 71, 97, 208, 132, 132, 243, 53, 179, 186, 8, 162, 71, 75, 126, 21, 130, + 203, 245, 126, 207, 65, 119, 60, 64, 79, 200, 2, 194, 17, 189, 137, 164, 213, 107, 197, + 152, 11, 242, 165, 146, 80, 96, 105, 249, 27, 139, 14, 244, 21, 118, 31, 94, 87, 32, 145, + 149, 98, 235, 75, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha384_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = Sha384Config::hash(input); + let expected: [u8; 48] = [ + 134, 227, 167, 229, 35, 110, 115, 174, 10, 27, 197, 116, 56, 144, 150, 36, 152, 190, 212, + 120, 26, 243, 125, 4, 2, 60, 164, 195, 218, 219, 255, 143, 240, 75, 158, 126, 102, 105, 8, + 202, 142, 240, 230, 161, 162, 152, 111, 71, + ]; + assert_eq!(output, expected); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// NEGATIVE TESTS +/// +/// This tests a soundness bug that was found at one point in our implementation. +/////////////////////////////////////////////////////////////////////////////////////// +fn negative_sha2_test_bad_final_hash() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let TestHarness { + mut harness, + bitwise, + block_hasher, + } = create_test_harness::, C>(&mut tester); + + let num_ops: usize = 1; + for _ in 0..num_ops { + set_and_execute_single_block::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + None, + ); + } + + // Set the final_hash to all zeros + let modify_trace = |trace: &mut RowMajorMatrix| { + trace.row_chunks_exact_mut(1).for_each(|row| { + let mut row_slice = row.row_slice(0).to_vec(); + let mut cols = Sha2BlockHasherDigestColsRefMut::from::( + &mut row_slice[..C::BLOCK_HASHER_DIGEST_WIDTH], + ); + if cols.inner.flags.is_digest_row.is_one() { + for i in 0..C::HASH_WORDS { + for j in 0..C::WORD_U8S { + cols.inner.final_hash[[i, j]] = F::ZERO; + } + } + row.values.copy_from_slice(&row_slice); + } + }); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load(harness) + .load_periphery_and_prank_trace(block_hasher, modify_trace) + .load_periphery(bitwise) + .finalize(); + tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); +} + +#[test] +fn negative_sha256_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +#[test] +fn negative_sha512_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +#[test] +fn negative_sha384_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +// //////////////////////////////////////////////////////////////////////////////////// +// CUDA TESTS +// +// Ensure GPU tracegen is equivalent to CPU tracegen +// //////////////////////////////////////////////////////////////////////////////////// + +#[cfg(feature = "cuda")] +type Sha2GpuTestChip = GpuTestChipHarness< + F, + Sha2VmExecutor, + Sha2MainAir, + Sha2MainChipGpu, + Sha2MainChip, +>; + +#[cfg(feature = "cuda")] +struct GpuHarness { + pub main: Sha2GpuTestChip, + block_air: Sha2BlockHasherVmAir, + block_gpu: Sha2BlockHasherChipGpu, + block_cpu: Sha2BlockHasherChip, + bitwise_air: BitwiseOperationLookupAir, + bitwise_gpu: Arc>, +} + +#[cfg(feature = "cuda")] +fn create_cuda_harness(tester: &GpuChipTestBuilder) -> GpuHarness { + const GPU_MAX_INS_CAPACITY: usize = 8192; + let bitwise_bus = default_bitwise_lookup_bus(); + let dummy_bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let (main_air, main_executor, main_chip) = create_harness_fields( + tester.system_port(), + dummy_bitwise_chip.clone(), + tester.dummy_memory_helper(), + tester.address_bits(), + ); + + let block_hasher_air = Sha2BlockHasherVmAir::new(bitwise_bus, SUBAIR_BUS_IDX, SHA2_BUS_IDX); + let block_hasher_chip = Sha2BlockHasherChip::new( + dummy_bitwise_chip.clone(), + tester.address_bits(), + tester.dummy_memory_helper(), + main_chip.records.clone(), + ); + + let shared_records_gpu = Arc::new(Mutex::new(None)); + let main_gpu_chip = Sha2MainChipGpu::new( + shared_records_gpu.clone(), + tester.range_checker(), + tester.bitwise_op_lookup(), + tester.address_bits() as u32, + tester.timestamp_max_bits() as u32, + ); + + let block_gpu_chip = Sha2BlockHasherChipGpu::new( + shared_records_gpu.clone(), + tester.range_checker(), + tester.bitwise_op_lookup(), + tester.address_bits() as u32, + tester.timestamp_max_bits() as u32, + ); + + let bitwise_gpu = tester.bitwise_op_lookup(); + let bitwise_air = BitwiseOperationLookupAir::new(bitwise_bus); + + GpuHarness { + main: GpuTestChipHarness::with_capacity( + main_executor, + main_air, + main_gpu_chip, + main_chip, + GPU_MAX_INS_CAPACITY, + ), + block_air: block_hasher_air, + block_gpu: block_gpu_chip, + block_cpu: block_hasher_chip, + bitwise_air, + bitwise_gpu, + } +} + +#[cfg(feature = "cuda")] +fn test_cuda_rand_sha2_multi_block() { + let mut rng = create_seeded_rng(); + let mut tester = + GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); + + let mut harness = create_cuda_harness::(&tester); + + let num_ops = 70; + for _ in 1..=num_ops { + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.main.executor, + &mut harness.main.dense_arena, + &mut rng, + C::OPCODE, + None, + None, + ); + } + + harness + .main + .dense_arena + .get_record_seeker::() + .transfer_to_matrix_arena(&mut harness.main.matrix_arena); + + let mut tester = tester.build(); + tester = tester.load_gpu_harness(harness.main); + tester = tester.load_and_compare( + harness.block_air, + harness.block_gpu, + (), + harness.block_cpu, + (), + ); + tester = tester.load_periphery(harness.bitwise_air, harness.bitwise_gpu); + tester.finalize().simple_test().unwrap(); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_rand_sha256_multi_block() { + test_cuda_rand_sha2_multi_block::(); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_rand_sha512_multi_block() { + test_cuda_rand_sha2_multi_block::(); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_rand_sha384_multi_block() { + test_cuda_rand_sha2_multi_block::(); +} + +#[cfg(feature = "cuda")] +fn test_cuda_sha2_known_vectors(test_vectors: &[(&str, &str)]) { + let mut rng = create_seeded_rng(); + let mut tester = + GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); + + let mut harness = create_cuda_harness::(&tester); + + for (input, expected_hex) in test_vectors.iter() { + let input = Vec::from_hex(input).unwrap(); + let expected = Vec::from_hex(expected_hex).unwrap(); + // Sanity-check the expected digest matches the config’s hash. + assert_eq!(C::hash(&input).as_slice(), expected.as_slice()); + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.main.executor, + &mut harness.main.dense_arena, + &mut rng, + C::OPCODE, + Some(&input), + Some(input.len()), + ); + } + // No block-hasher arena needed; GPU block chip ignores the arena input. + harness + .main + .dense_arena + .get_record_seeker::() + .transfer_to_matrix_arena(&mut harness.main.matrix_arena); + + let mut tester = tester.build(); + tester = tester.load_gpu_harness(harness.main); + tester = tester.load_and_compare( + harness.block_air, + harness.block_gpu, + (), + harness.block_cpu, + (), + ); + tester = tester.load_periphery(harness.bitwise_air, harness.bitwise_gpu); + tester.finalize().simple_test().unwrap(); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_sha256_known_vectors() { + let test_vectors = [ + ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), + ( + "98c1c0bdb7d5fea9a88859f06c6c439f", + "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", + ), + ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), + ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") + ]; + test_cuda_sha2_known_vectors::(&test_vectors); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_sha512_known_vectors() { + let test_vectors = [ + ( + "", + "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", + ), + ( + "98c1c0bdb7d5fea9a88859f06c6c439f", + "eb576959c531f116842c0cc915a29c8f71d7a285c894c349b83469002ef093d51f9f14ce4248488bff143025e47ed27c12badb9cd43779cb147408eea062d583" + ), + ( + "9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", + "8d215ee6dc26757c210db0dd00c1c6ed16cc34dbd4bb0fa10c1edb6b62d5ab16aea88c881001b173d270676daf2d6381b5eab8711fa2f5589c477c1d4b84774f" + ), + ]; + test_cuda_sha2_known_vectors::(&test_vectors); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_sha384_known_vectors() { + let test_vectors = [ + ( + "", + "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", + ), + ( + "98c1c0bdb7d5fea9a88859f06c6c439f", + "63e3061aab01f335ea3a4e617b9d14af9b63a5240229164ee962f6d5335ff25f0f0bf8e46723e83c41b9d17413b6a3c7", + ), + ( + "9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", + "904a90010d772a904a35572fdd4bdf1dd253742e47872c8a18e2255f66fa889e44781e65487a043f435daa53c496a53e", + ), + ]; + test_cuda_sha2_known_vectors::(&test_vectors); +} + +// GPU edge-case length tests mirroring the CPU suite. +#[cfg(feature = "cuda")] +fn cuda_sha2_edge_test_lengths() { + let mut rng = create_seeded_rng(); + let mut tester = + GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); + + let mut harness = create_cuda_harness::(&tester); + + // check every possible input length modulo block size + for i in (C::BLOCK_BYTES + 1)..=(2 * C::BLOCK_BYTES) { + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.main.executor, + &mut harness.main.dense_arena, + &mut rng, + C::OPCODE, + None, + Some(i), + ); + } + + harness + .main + .dense_arena + .get_record_seeker::() + .transfer_to_matrix_arena(&mut harness.main.matrix_arena); + + let mut tester = tester.build(); + tester = tester.load_gpu_harness(harness.main); + tester = tester.load_and_compare( + harness.block_air, + harness.block_gpu, + (), + harness.block_cpu, + (), + ); + tester = tester.load_periphery(harness.bitwise_air, harness.bitwise_gpu); + tester.finalize().simple_test().unwrap(); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_sha256_edge_test_lengths() { + cuda_sha2_edge_test_lengths::(); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_sha512_edge_test_lengths() { + cuda_sha2_edge_test_lengths::(); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_sha384_edge_test_lengths() { + cuda_sha2_edge_test_lengths::(); +} diff --git a/extensions/sha2/circuit/src/sha2_chips/trace.rs b/extensions/sha2/circuit/src/sha2_chips/trace.rs new file mode 100644 index 0000000000..4749a839ef --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/trace.rs @@ -0,0 +1,320 @@ +use std::{ + borrow::BorrowMut, + mem::transmute, + slice::{from_raw_parts, from_raw_parts_mut}, +}; + +use openvm_circuit::{ + arch::{ + CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor, + RecordArena, SizedRecord, VmStateMut, + }, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + }, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{tracing_read, tracing_write}; +use openvm_sha2_air::{Sha256Config, Sha2Variant, Sha384Config, Sha512Config}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + Sha2Config, Sha2MainChipConfig, Sha2VmExecutor, SHA2_READ_SIZE, SHA2_REGISTER_READS, + SHA2_WRITE_SIZE, +}; + +#[derive(Clone, Copy)] +pub struct Sha2Metadata { + pub variant: Sha2Variant, +} + +impl MultiRowMetadata for Sha2Metadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + // The size of the record arena will be height * Sha2MainAir::width() * num_rows. + // We will not use the record arena's buffer for either chip's trace, so we just + // need to ensure that the record arena is large enough to store all the records. + // The size of Sha2RecordMut (in bytes) is less than Sha2MainAir::width() * size_of::(), + // for all SHA-2 variants. Therefore, we can set num_rows = 1. + 1 + } +} + +pub(crate) type Sha2RecordLayout = MultiRowLayout; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Sha2RecordHeader { + pub variant: Sha2Variant, + pub from_pc: u32, + pub timestamp: u32, + pub dst_reg_ptr: u32, + pub state_reg_ptr: u32, + pub input_reg_ptr: u32, + pub dst_ptr: u32, + pub state_ptr: u32, + pub input_ptr: u32, + + pub register_reads_aux: [MemoryReadAuxRecord; SHA2_REGISTER_READS], +} + +pub struct Sha2RecordMut<'a> { + pub inner: &'a mut Sha2RecordHeader, + + pub message_bytes: &'a mut [u8], + pub prev_state: &'a mut [u8], // little-endian words + pub new_state: &'a mut [u8], // little-endian words + + pub input_reads_aux: &'a mut [MemoryReadAuxRecord], + pub state_reads_aux: &'a mut [MemoryReadAuxRecord], + pub write_aux: &'a mut [MemoryWriteBytesAuxRecord], +} + +impl<'a> CustomBorrow<'a, Sha2RecordMut<'a>, Sha2RecordLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: Sha2RecordLayout) -> Sha2RecordMut<'a> { + // SAFETY: + // - Caller guarantees through the layout that self has sufficient length for all splits and + // constants are guaranteed <= self.len() by layout precondition + + let (header_slice, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let record_header: &mut Sha2RecordHeader = header_slice.borrow_mut(); + + let dims = Sha2PreComputeDims::new(layout.metadata.variant); + + let (message_bytes, rest) = unsafe { rest.split_at_mut_unchecked(dims.input_size) }; + let (prev_state, rest) = unsafe { rest.split_at_mut_unchecked(dims.state_size) }; + let (new_state, rest) = unsafe { rest.split_at_mut_unchecked(dims.state_size) }; + + let (input_reads_aux, rest) = unsafe { align_to_mut_at(rest, dims.input_reads) }; + let (state_reads_aux, rest) = unsafe { align_to_mut_at(rest, dims.state_reads) }; + let (write_aux, _) = unsafe { align_to_mut_at(rest, dims.state_writes) }; + + Sha2RecordMut { + inner: record_header, + message_bytes, + prev_state, + new_state, + input_reads_aux, + state_reads_aux, + write_aux, + } + } + + unsafe fn extract_layout(&self) -> Sha2RecordLayout { + let (variant, _) = unsafe { align_to_at(self, 1) }; + let variant = variant[0]; + Sha2RecordLayout { + metadata: Sha2Metadata { variant }, + } + } +} + +unsafe fn align_to_mut_at(slice: &mut [u8], offset: usize) -> (&mut [T], &mut [u8]) { + let (_, items, rest) = unsafe { slice.align_to_mut::() }; + let (items, items_rest) = unsafe { items.split_at_mut_unchecked(offset) }; + let rest = unsafe { + let items_rest: &mut [u8] = transmute(items_rest); + from_raw_parts_mut( + items_rest.as_mut_ptr(), + items_rest.len() * size_of::() + rest.len(), + ) + }; + (items, rest) +} + +unsafe fn align_to_at(slice: &[u8], offset: usize) -> (&[T], &[u8]) { + let (_, items, rest) = unsafe { slice.align_to::() }; + let (items, items_rest) = unsafe { items.split_at_unchecked(offset) }; + let rest = unsafe { + let items_rest: &[u8] = transmute(items_rest); + from_raw_parts( + items_rest.as_ptr(), + items_rest.len() * size_of::() + rest.len(), + ) + }; + (items, rest) +} + +impl SizedRecord for Sha2RecordMut<'_> { + fn size(layout: &Sha2RecordLayout) -> usize { + let header_size = size_of::(); + let dims = Sha2PreComputeDims::new(layout.metadata.variant); + let mut total_len = header_size + + dims.input_size // input + + dims.state_size // prev_state + + dims.state_size; // new_state + + total_len = total_len.next_multiple_of(align_of::()); + total_len += dims.input_reads * size_of::(); + + total_len = total_len.next_multiple_of(align_of::()); + total_len += dims.state_reads * size_of::(); + + total_len = + total_len.next_multiple_of(align_of::>()); + total_len += dims.state_writes * size_of::>(); + + total_len + } + + fn alignment(_layout: &Sha2RecordLayout) -> usize { + align_of::() // 4-byte alignment + } +} + +// This is needed in CustomBorrow trait to convert the Sha2Variant that we read from the buffer +// into appropriate dimensions for the record. +struct Sha2PreComputeDims { + state_size: usize, + input_size: usize, + input_reads: usize, + state_reads: usize, + state_writes: usize, +} + +impl Sha2PreComputeDims { + fn new(variant: Sha2Variant) -> Self { + match variant { + Sha2Variant::Sha256 => Self { + state_size: Sha256Config::STATE_BYTES, + input_size: Sha256Config::BLOCK_BYTES, + input_reads: Sha256Config::BLOCK_READS, + state_reads: Sha256Config::STATE_READS, + state_writes: Sha256Config::STATE_WRITES, + }, + Sha2Variant::Sha512 => Self { + state_size: Sha512Config::STATE_BYTES, + input_size: Sha512Config::BLOCK_BYTES, + input_reads: Sha512Config::BLOCK_READS, + state_reads: Sha512Config::STATE_READS, + state_writes: Sha512Config::STATE_WRITES, + }, + Sha2Variant::Sha384 => Self { + state_size: Sha384Config::STATE_BYTES, + input_size: Sha384Config::BLOCK_BYTES, + input_reads: Sha384Config::BLOCK_READS, + state_reads: Sha384Config::STATE_READS, + state_writes: Sha384Config::STATE_WRITES, + }, + } + } +} + +impl PreflightExecutor for Sha2VmExecutor +where + F: PrimeField32, + // for<'buf> RA: RecordArena<'buf, Sha2RecordLayout, Sha2RecordMut<'buf>>, + for<'buf> RA: RecordArena<'buf, Sha2RecordLayout, Sha2RecordMut<'buf>>, +{ + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", C::OPCODE) + } + + fn execute( + &self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + debug_assert_eq!(opcode, C::OPCODE.global_opcode()); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + let record = state.ctx.alloc(Sha2RecordLayout::new(Sha2Metadata { + variant: C::VARIANT, + })); + + record.inner.variant = C::VARIANT; + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp(); + record.inner.dst_reg_ptr = a.as_canonical_u32(); + record.inner.state_reg_ptr = b.as_canonical_u32(); + record.inner.input_reg_ptr = c.as_canonical_u32(); + + record.inner.dst_ptr = u32::from_le_bytes(tracing_read::( + state.memory, + RV32_REGISTER_AS, + record.inner.dst_reg_ptr, + &mut record.inner.register_reads_aux[0].prev_timestamp, + )); + record.inner.state_ptr = u32::from_le_bytes(tracing_read::( + state.memory, + RV32_REGISTER_AS, + record.inner.state_reg_ptr, + &mut record.inner.register_reads_aux[1].prev_timestamp, + )); + record.inner.input_ptr = u32::from_le_bytes(tracing_read::( + state.memory, + RV32_REGISTER_AS, + record.inner.input_reg_ptr, + &mut record.inner.register_reads_aux[2].prev_timestamp, + )); + + debug_assert!( + record.inner.dst_ptr as usize + C::STATE_BYTES <= (1 << self.pointer_max_bits) + ); + debug_assert!( + record.inner.state_ptr as usize + C::STATE_BYTES <= (1 << self.pointer_max_bits) + ); + debug_assert!( + record.inner.input_ptr as usize + C::BLOCK_BYTES <= (1 << self.pointer_max_bits) + ); + + for idx in 0..C::BLOCK_READS { + let read = tracing_read::( + state.memory, + RV32_MEMORY_AS, + record.inner.input_ptr + (idx * SHA2_READ_SIZE) as u32, + &mut record.input_reads_aux[idx].prev_timestamp, + ); + record.message_bytes[idx * SHA2_READ_SIZE..(idx + 1) * SHA2_READ_SIZE] + .copy_from_slice(&read); + } + + for idx in 0..C::STATE_READS { + let read = tracing_read::( + state.memory, + RV32_MEMORY_AS, + record.inner.state_ptr + (idx * SHA2_READ_SIZE) as u32, + &mut record.state_reads_aux[idx].prev_timestamp, + ); + record.prev_state[idx * SHA2_READ_SIZE..(idx + 1) * SHA2_READ_SIZE] + .copy_from_slice(&read); + } + + record.new_state.copy_from_slice(record.prev_state); + C::compress(record.new_state, record.message_bytes); + + for idx in 0..C::STATE_WRITES { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (idx * SHA2_WRITE_SIZE) as u32, + record.new_state[idx * SHA2_WRITE_SIZE..(idx + 1) * SHA2_WRITE_SIZE] + .try_into() + .unwrap(), + &mut record.write_aux[idx].prev_timestamp, + &mut record.write_aux[idx].prev_data, + ); + } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } +} diff --git a/extensions/sha256/guest/Cargo.toml b/extensions/sha2/guest/Cargo.toml similarity index 69% rename from extensions/sha256/guest/Cargo.toml rename to extensions/sha2/guest/Cargo.toml index e9d28292b8..1c6503002e 100644 --- a/extensions/sha256/guest/Cargo.toml +++ b/extensions/sha2/guest/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "openvm-sha256-guest" +name = "openvm-sha2-guest" version.workspace = true authors.workspace = true edition.workspace = true -description = "Guest extension for Sha256" +description = "Guest extension for SHA-2" [dependencies] openvm-platform = { workspace = true } diff --git a/extensions/sha2/guest/src/lib.rs b/extensions/sha2/guest/src/lib.rs new file mode 100644 index 0000000000..d2b9c1f585 --- /dev/null +++ b/extensions/sha2/guest/src/lib.rs @@ -0,0 +1,155 @@ +#![no_std] + +#[cfg(target_os = "zkvm")] +use openvm_platform::alloc::AlignedBuf; + +/// This is custom-0 defined in RISC-V spec document +pub const OPCODE: u8 = 0x0b; +pub const SHA2_FUNCT3: u8 = 0b100; + +// There is no Sha384 enum variant because the SHA-384 compression function is +// the same as the SHA-512 compression function. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum Sha2BaseFunct7 { + Sha256 = 0x1, + Sha512 = 0x2, +} + +/// zkvm native implementation of sha256 compression function +/// # Safety +/// +/// The VM accepts the previous hash state and the next block of input, and writes the +/// new hash state. +/// - `state` must point to a buffer of at least 32 bytes, storing the previous hash state as 8 +/// 32-bit words in little-endian order +/// - `input` must point to a buffer of at least 64 bytes +/// - `output` must point to a buffer of at least 32 bytes. It will be filled with the new hash +/// state as 8 32-bit words in little-endian order +/// +/// [`sha2-256`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha256_impl(state: *const u8, input: *const u8, output: *mut u8) { + // SAFETY: we handle all cases where `prev_state`, `input`, or `output` are not aligned to 4 + // bytes. + + // The minimum alignment required for the buffers + const MIN_ALIGN: usize = 4; + unsafe { + let state_is_aligned = state as usize % MIN_ALIGN == 0; + let input_is_aligned = input as usize % MIN_ALIGN == 0; + let output_is_aligned = output as usize % MIN_ALIGN == 0; + + let state_ptr = if state_is_aligned { + state + } else { + AlignedBuf::new(state, 32, MIN_ALIGN).ptr + }; + + let input_ptr = if input_is_aligned { + input + } else { + AlignedBuf::new(input, 64, MIN_ALIGN).ptr + }; + + let output_ptr = if output_is_aligned { + output + } else { + AlignedBuf::uninit(32, MIN_ALIGN).ptr + }; + + __native_sha256_compress(state_ptr, input_ptr, output_ptr); + + if !output_is_aligned { + core::ptr::copy_nonoverlapping(output_ptr, output, 32); + } + } +} + +/// zkvm native implementation of sha512 compression function +/// # Safety +/// +/// The VM accepts the previous hash state and the next block of input, and writes the +/// new hash state. +/// - `state` must point to a buffer of at least 64 bytes, storing the previous hash state as 8 +/// 64-bit words in little-endian order +/// - `input` must point to a buffer of at least 128 bytes +/// - `output` must point to a buffer of at least 64 bytes. It will be filled with the new hash +/// state as 8 64-bit words in little-endian order +/// +/// [`sha2-512`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha512_impl(state: *const u8, input: *const u8, output: *mut u8) { + // SAFETY: we handle all cases where `prev_state`, `input`, or `output` are not aligned to 4 + // bytes. + + // The minimum alignment required for the buffers + const MIN_ALIGN: usize = 4; + unsafe { + let state_is_aligned = state as usize % MIN_ALIGN == 0; + let input_is_aligned = input as usize % MIN_ALIGN == 0; + let output_is_aligned = output as usize % MIN_ALIGN == 0; + + let state_ptr = if state_is_aligned { + state + } else { + AlignedBuf::new(state, 64, MIN_ALIGN).ptr + }; + + let input_ptr = if input_is_aligned { + input + } else { + AlignedBuf::new(input, 128, MIN_ALIGN).ptr + }; + + let output_ptr = if output_is_aligned { + output + } else { + AlignedBuf::uninit(64, MIN_ALIGN).ptr + }; + + __native_sha512_compress(state_ptr, input_ptr, output_ptr); + + if !output_is_aligned { + core::ptr::copy_nonoverlapping(output_ptr, output, 64); + } + } +} + +/// sha256 compression function intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the previous hash state and the next block of input, and writes the +/// 32-byte hash. +/// - `prev_state` must point to a buffer of at least 32 bytes, storing the previous hash state as 8 +/// 32-bit words in little-endian order +/// - `input` must point to a buffer of at least 64 bytes +/// - `output` must point to a buffer of at least 32 bytes. It will be filled with the new hash +/// state as 8 32-bit words in little-endian order +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha256_compress(prev_state: *const u8, input: *const u8, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha256 as u8, rd = In output, rs1 = In prev_state, rs2 = In input); +} + +/// sha512 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the previous hash state and the next block of input, and writes the +/// 64-byte hash. +/// - `prev_state` must point to a buffer of at least 32 bytes, storing the previous hash state as 8 +/// 64-bit words in little-endian order +/// - `input` must point to a buffer of at least 128 bytes +/// - `output` must point to a buffer of at least 64 bytes. It will be filled with the new hash +/// state as 8 64-bit words in little-endian order +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha512_compress(prev_state: *const u8, input: *const u8, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha512 as u8, rd = In output, rs1 = In prev_state, rs2 = In input); +} diff --git a/extensions/sha256/transpiler/Cargo.toml b/extensions/sha2/transpiler/Cargo.toml similarity index 73% rename from extensions/sha256/transpiler/Cargo.toml rename to extensions/sha2/transpiler/Cargo.toml index 933859f3a8..9eff76a3db 100644 --- a/extensions/sha256/transpiler/Cargo.toml +++ b/extensions/sha2/transpiler/Cargo.toml @@ -1,15 +1,15 @@ [package] -name = "openvm-sha256-transpiler" +name = "openvm-sha2-transpiler" version.workspace = true authors.workspace = true edition.workspace = true -description = "Transpiler extension for sha256" +description = "Transpiler extension for SHA-2" [dependencies] openvm-stark-backend = { workspace = true } openvm-instructions = { workspace = true } openvm-transpiler = { workspace = true } rrs-lib = { workspace = true } -openvm-sha256-guest = { workspace = true } +openvm-sha2-guest = { workspace = true } openvm-instructions-derive = { workspace = true } strum = { workspace = true } diff --git a/extensions/sha2/transpiler/src/lib.rs b/extensions/sha2/transpiler/src/lib.rs new file mode 100644 index 0000000000..0110d869c3 --- /dev/null +++ b/extensions/sha2/transpiler/src/lib.rs @@ -0,0 +1,58 @@ +use openvm_instructions::{riscv::RV32_MEMORY_AS, LocalOpcode}; +use openvm_instructions_derive::LocalOpcode; +use openvm_sha2_guest::{Sha2BaseFunct7, OPCODE, SHA2_FUNCT3}; +use openvm_stark_backend::p3_field::PrimeField32; +use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput}; +use rrs_lib::instruction_formats::RType; +use strum::{EnumCount, EnumIter, FromRepr}; + +// There is no SHA384 opcode because the SHA-384 compression function is +// the same as the SHA-512 compression function. +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x320] +#[repr(usize)] +pub enum Rv32Sha2Opcode { + SHA256, + SHA512, +} + +#[derive(Default)] +pub struct Sha2TranspilerExtension; + +impl TranspilerExtension for Sha2TranspilerExtension { + fn process_custom(&self, instruction_stream: &[u32]) -> Option> { + if instruction_stream.is_empty() { + return None; + } + let instruction_u32 = instruction_stream[0]; + let opcode = (instruction_u32 & 0x7f) as u8; + let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; + + if (opcode, funct3) != (OPCODE, SHA2_FUNCT3) { + return None; + } + let dec_insn = RType::new(instruction_u32); + + if dec_insn.funct7 == Sha2BaseFunct7::Sha256 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA256.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else if dec_insn.funct7 == Sha2BaseFunct7::Sha512 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA512.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else { + None + } + } +} diff --git a/extensions/sha256/circuit/README.md b/extensions/sha256/circuit/README.md deleted file mode 100644 index 1e794cd35c..0000000000 --- a/extensions/sha256/circuit/README.md +++ /dev/null @@ -1,93 +0,0 @@ -# SHA256 VM Extension - -This crate contains the circuit for the SHA256 VM extension. - -## SHA-256 Algorithm Summary - -See the [FIPS standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf), in particular, section 6.2 for reference. - -In short the SHA-256 algorithm works as follows. -1. Pad the message to 512 bits and split it into 512-bit 'blocks'. -2. Initialize a hash state consisting of eight 32-bit words. -3. For each block, - 1. split the message into 16 32-bit words and produce 48 more 'message schedule' words based on them. - 2. apply 64 'rounds' to update the hash state based on the message schedule. - 3. add the previous block's final hash state to the current hash state (modulo `2^32`). -4. The output is the final hash state - -## Design Overview - -This chip produces an AIR that consists of 17 rows for each block (512 bits) in the message, and no more rows. -The first 16 rows of each block are called 'round rows', and each of them represents four rounds of the SHA-256 algorithm. -Each row constrains updates to the working variables on each round, and it also constrains the message schedule words based on previous rounds. -The final row is called a 'digest row' and it produces a final hash for the block, computed as the sum of the working variables and the previous block's final hash. - -Note that this chip only supports messages of length less than `2^29` bytes. - -### Storing working variables - -One optimization is that we only keep track of the `a` and `e` working variables. -It turns out that if we have their values over four consecutive rounds, we can reconstruct all eight variables at the end of the four rounds. -This is because there is overlap between the values of the working variables in adjacent rounds. -If the state is visualized as an array, `s_0 = [a, b, c, d, e, f, g, h]`, then the new state, `s_1`, after one round is produced by a right-shift and an addition. -More formally, -``` -s_1 = (s_0 >> 1) + [T_1 + T_2, 0, 0, 0, T_1, 0, 0, 0] - = [0, a, b, c, d, e, f, g] + [T_1 + T_2, 0, 0, 0, T_1, 0, 0, 0] - = [T_1 + T_2, a, b, c, d + T_1, e, f, g] -``` -where `T_1` and `T_2` are certain functions of the working variables and message data (see the FIPS spec). -So if `a_i` and `e_i` denote the values of `a` and `e` after the `i`th round, for `0 <= i < 4`, then the state `s_3` after the fourth round can be written as `s_3 = [a_3, a_2, a_1, a_0, e_3, e_2, e_1, e_0]`. - -### Message schedule constraints - -The algorithm for computing the message schedule involves message schedule words from 16 rounds ago. -Since we can only constrain two rows at a time, we cannot access data from more than four rounds ago for the first round in each row. -So, we maintain intermediate values that we call `intermed_4`, `intermed_8` and `intermed_12`, where `intermed_i = w_i + sig_0(w_{i+1})` where `w_i` is the value of `w` from `i` rounds ago and `sig_0` denotes the `sigma_0` function from the FIPS spec. -Since we can reliably constrain values from four rounds ago, we can build up `intermed_16` from these values, which is needed for computing the message schedule. - -### Note about `is_last_block` - -The last block of every message should have the `is_last_block` flag set to `1`. -Note that `is_last_block` is not constrained to be true for the last block of every message, instead it *defines* what the last block of a message is. -For instance, if we produce an air with 10 blocks and only the last block has `is_last_block = 1` then the constraints will interpret it as a single message of length 10 blocks. -If, however, we set `is_last_block` to true for the 6th block, the trace will be interpreted as hashing two messages, each of length 5 blocks. - -Note that we do constrain, however, that the very last block of the trace has `is_last_block = 1`. - -### Dummy values - -Some constraints have degree three, and so we cannot restrict them to particular rows due to the limitation of the maximum constraint degree. -We must enforce them on all rows, and in order to ensure they hold on the remaining rows we must fill in some cells with appropriate dummy values. -We use this trick in several places in this chip. - -### Block index counter variables - -There are two "block index" counter variables in each row of the air named `global_block_idx` and `local_block_idx`. -Both of these variables take on the same value on all 17 rows in a block. - -The `global_block_idx` is the index of the block in the entire trace. -The very first 17 rows in the trace will have `global_block_idx = 1` and the counter will increment by 1 between blocks. -The padding rows will all have `global_block_idx = 0`. -The `global_block_idx` is used in interaction constraints to constrain the value of `hash` between blocks. - -The `local_block_idx` is the index of the block in the current message. -It starts at 0 for the first block of each message and increments by 1 for each block. -The `local_block_idx` is reset to 0 after each message. -The padding rows will all have `local_block_idx = 0`. -The `local_block_idx` is used to calculate the length of the message processed so far when the first padding row is encountered. - -### VM air vs SubAir - -The SHA-256 VM extension chip uses the `Sha256Air` SubAir to help constrain the SHA-256 hash. -The VM extension air constrains the correctness of the SHA message padding, while the SubAir adds all other constraints related to the hash algorithm. -The VM extension air also constrains memory reads and writes. - -### A gotcha about padding rows - -There are two senses of the word padding used in the context of this chip and this can be confusing. -First, we use padding to refer to the extra bits added to the message that is input to the SHA-256 algorithm in order to make the input's length a multiple of 512 bits. -So, we may use the term 'padding rows' to refer to round rows that correspond to the padded bits of a message (as in `Sha256VmAir::eval_padding_row`). -Second, the dummy rows that are added to the trace to make the trace height a power of 2 are also called padding rows (see the `is_padding_row` flag). -In the SubAir, padding row probably means dummy row. -In the VM air, it probably refers to SHA-256 padding. \ No newline at end of file diff --git a/extensions/sha256/circuit/cuda/src/sha256.cu b/extensions/sha256/circuit/cuda/src/sha256.cu deleted file mode 100644 index d1939a8c9a..0000000000 --- a/extensions/sha256/circuit/cuda/src/sha256.cu +++ /dev/null @@ -1,447 +0,0 @@ -#include "launcher.cuh" -#include "primitives/constants.h" -#include "primitives/trace_access.h" -#include "sha256-air/columns.cuh" -#include "sha256-air/tracegen.cuh" -#include "sha256-air/utils.cuh" -#include "system/memory/controller.cuh" -#include "system/memory/offline_checker.cuh" -#include - -using namespace riscv; -using namespace sha256; - -__device__ inline void write_round_padding_flags_encoder( - RowSlice row, - const Encoder &padding_encoder, - uint32_t flag_idx -) { - RowSlice pad_flags = row.slice_from(COL_INDEX(Sha256VmRoundCols, control.pad_flags)); - padding_encoder.write_flag_pt(pad_flags, flag_idx); -} - -__device__ inline void write_digest_padding_flags_encoder( - RowSlice row, - const Encoder &padding_encoder, - uint32_t flag_idx -) { - RowSlice pad_flags = row.slice_from(COL_INDEX(Sha256VmDigestCols, control.pad_flags)); - padding_encoder.write_flag_pt(pad_flags, flag_idx); -} - -// ===== MAIN KERNEL FUNCTIONS ===== -__global__ void sha256_hash_computation( - uint8_t *records, - size_t num_records, - size_t *record_offsets, - uint32_t *block_offsets, - uint32_t *prev_hashes, - uint32_t total_num_blocks -) { - uint32_t record_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (record_idx >= num_records) { - return; - } - - uint32_t offset = record_offsets[record_idx]; - Sha256VmRecordMut record(records + offset); - - uint32_t len = record.header->len; - uint8_t *input = record.input; - - uint32_t start_block = block_offsets[record_idx]; - uint32_t num_blocks = get_sha256_num_blocks(len); - - uint32_t current_hash[SHA256_HASH_WORDS] = { - 0x6a09e667, - 0xbb67ae85, - 0x3c6ef372, - 0xa54ff53a, - 0x510e527f, - 0x9b05688c, - 0x1f83d9ab, - 0x5be0cd19 - }; - - uint8_t block_input[SHA256_BLOCK_U8S]; - for (uint32_t local_block = 0; local_block < num_blocks; local_block++) { - { - uint32_t offset = SHA256_HASH_WORDS * (start_block + local_block); - memcpy(prev_hashes + offset, current_hash, SHA256_HASH_WORDS * sizeof(uint32_t)); - } - if (local_block < num_blocks - 1) { - // Since local_block < num_blocks - 1, we know that block_offset < len by the definition of num_blocks - uint32_t block_offset = local_block * SHA256_BLOCK_U8S; - if (block_offset + SHA256_BLOCK_U8S > len) { - uint32_t remaining_bytes = len - block_offset; - memcpy(block_input, input + block_offset, remaining_bytes); - block_input[remaining_bytes] = 0x80; - memset( - block_input + remaining_bytes + 1, 0, SHA256_BLOCK_U8S - remaining_bytes - 1 - ); - } else { - memcpy(block_input, input + block_offset, SHA256_BLOCK_U8S); - } - - get_block_hash(current_hash, block_input); - } - } -} - -__global__ __noinline__ void sha256_first_pass_tracegen( - Fp *trace, - size_t trace_height, - uint8_t *records, - size_t num_records, - size_t *record_offsets, - uint32_t *block_offsets, - uint32_t *block_to_record_idx, - uint32_t total_num_blocks, - uint32_t *prev_hashes, - uint32_t ptr_max_bits, - uint32_t *range_checker_ptr, - uint32_t range_checker_num_bins, - uint32_t *bitwise_lookup_ptr, - uint32_t bitwise_num_bits, - uint32_t timestamp_max_bits -) { - uint32_t global_block_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (global_block_idx >= total_num_blocks) { - return; - } - - uint32_t record_idx = block_to_record_idx[global_block_idx]; - if (record_idx >= num_records) { - return; - } - - uint32_t offset = record_offsets[record_idx]; - Sha256VmRecordMut record(records + offset); - Sha256VmRecordHeader &vm_record = *record.header; - - auto len = vm_record.len; - auto input = record.input; - - auto local_block_idx = global_block_idx - block_offsets[record_idx]; - auto prev_hash = &prev_hashes[global_block_idx * SHA256_HASH_WORDS]; - auto block_offset = local_block_idx * SHA256_BLOCK_U8S; - - uint32_t num_blocks_for_record = get_sha256_num_blocks(len); - bool is_last_block = (local_block_idx == num_blocks_for_record - 1); - - uint32_t input_words[SHA256_BLOCK_WORDS]; - { - uint8_t padded_input[SHA256_BLOCK_U8S]; - if (block_offset <= len) { - if (block_offset + SHA256_BLOCK_U8S > len) { - uint32_t remaining_bytes = len - block_offset; - memcpy(padded_input, input + block_offset, remaining_bytes); - padded_input[remaining_bytes] = 0x80; - memset( - padded_input + remaining_bytes + 1, 0, SHA256_BLOCK_U8S - remaining_bytes - 1 - ); - } else { - memcpy(padded_input, input + block_offset, SHA256_BLOCK_U8S); - } - } else { - memset(padded_input, 0, SHA256_BLOCK_U8S); - } - - for (uint32_t i = 0; i < SHA256_BLOCK_WORDS; i++) { - input_words[i] = u32_from_bytes_be(padded_input + i * 4); - } - - if (is_last_block) { - input_words[SHA256_BLOCK_WORDS - 1] = len << 3; - } - } - - uint32_t trace_start_row = global_block_idx * SHA256_ROWS_PER_BLOCK; - - uint32_t read_cells = (SHA256_BLOCK_U8S * local_block_idx); - uint32_t block_start_read_ptr = vm_record.src_ptr + read_cells; - uint32_t message_left = len - read_cells; - - MemoryAuxColsFactory mem_helper( - VariableRangeChecker(range_checker_ptr, range_checker_num_bins), timestamp_max_bits - ); - - BitwiseOperationLookup bitwise_lookup(bitwise_lookup_ptr, bitwise_num_bits); - - Encoder padding_encoder(static_cast(PaddingFlags::COUNT), 2, false); - - int32_t first_padding_row; - if (len < read_cells) { - first_padding_row = -1; - } else if (message_left < SHA256_BLOCK_U8S) { - first_padding_row = message_left / SHA256_READ_SIZE; - } else { - first_padding_row = 18; - } - - auto start_timestamp = - vm_record.timestamp + (SHA256_REGISTER_READS + SHA256_NUM_READ_ROWS * local_block_idx); - - for (uint32_t row_in_block = 0; row_in_block < SHA256_ROWS_PER_BLOCK; row_in_block++) { - uint32_t absolute_row = trace_start_row + row_in_block; - - if (absolute_row >= trace_height) { - return; - } - - RowSlice row(trace + absolute_row, trace_height); - - if (row_in_block == SHA256_ROWS_PER_BLOCK - 1) { - SHA256_WRITE_DIGEST(row, from_state.timestamp, vm_record.timestamp); - SHA256_WRITE_DIGEST(row, from_state.pc, vm_record.from_pc); - SHA256_WRITE_DIGEST(row, rd_ptr, vm_record.rd_ptr); - SHA256_WRITE_DIGEST(row, rs1_ptr, vm_record.rs1_ptr); - SHA256_WRITE_DIGEST(row, rs2_ptr, vm_record.rs2_ptr); - - { - uint8_t *dst_bytes = reinterpret_cast(&vm_record.dst_ptr); - uint8_t *src_bytes = reinterpret_cast(&vm_record.src_ptr); - uint8_t *len_bytes = reinterpret_cast(&len); - - SHA256_WRITE_ARRAY_DIGEST(row, dst_ptr, dst_bytes); - SHA256_WRITE_ARRAY_DIGEST(row, src_ptr, src_bytes); - SHA256_WRITE_ARRAY_DIGEST(row, len_data, len_bytes); - } - - if (is_last_block) { - for (int i = 0; i < SHA256_REGISTER_READS; i++) { - mem_helper.fill( - SHA256_SLICE_DIGEST(row, register_reads_aux[i]), - vm_record.register_reads_aux[i].prev_timestamp, - vm_record.timestamp + i - ); - } - - SHA256_WRITE_ARRAY_DIGEST(row, writes_aux.prev_data, vm_record.write_aux.prev_data); - - mem_helper.fill( - SHA256_SLICE_DIGEST(row, writes_aux), - vm_record.write_aux.prev_timestamp, - start_timestamp + SHA256_NUM_READ_ROWS - ); - - uint32_t msl_rshift = ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); - uint32_t msl_lshift = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - ptr_max_bits); - bitwise_lookup.add_range( - (vm_record.dst_ptr >> msl_rshift) << msl_lshift, - (vm_record.src_ptr >> msl_rshift) << msl_lshift - ); - } else { - for (int i = 0; i < SHA256_REGISTER_READS; i++) { - mem_helper.fill_zero(SHA256_SLICE_DIGEST(row, register_reads_aux[i])); - } - - row.fill_zero( - COL_INDEX(Sha256VmDigestCols, writes_aux.prev_data), SHA256_WRITE_SIZE - ); - mem_helper.fill_zero(SHA256_SLICE_DIGEST(row, writes_aux)); - } - SHA256_WRITE_DIGEST(row, inner.flags.is_last_block, is_last_block); - SHA256_WRITE_DIGEST(row, inner.flags.is_digest_row, Fp::one()); - row.fill_zero(SHA256VM_DIGEST_WIDTH, SHA256VM_WIDTH - SHA256VM_DIGEST_WIDTH); - } else { - if (row_in_block < SHA256_NUM_READ_ROWS) { - - uint32_t data_offset = block_offset + row_in_block * SHA256_READ_SIZE; - SHA256_WRITE_ARRAY_ROUND( - row, inner.message_schedule.carry_or_buffer, input + data_offset - ); - - MemoryReadAuxRecord *read_aux_record = - get_read_aux_record(&record, local_block_idx, row_in_block); - mem_helper.fill( - SHA256_SLICE_ROUND(row, read_aux), - read_aux_record->prev_timestamp, - start_timestamp + row_in_block - ); - } else { - mem_helper.fill_zero(SHA256_SLICE_ROUND(row, read_aux)); - } - } - - SHA256_WRITE_ROUND(row, control.len, len); - - { - uint32_t control_timestamp = - start_timestamp + min(row_in_block, (uint32_t)SHA256_NUM_READ_ROWS); - uint32_t control_read_ptr = - block_start_read_ptr + - (SHA256_READ_SIZE * min(row_in_block, (uint32_t)SHA256_NUM_READ_ROWS)); - - SHA256_WRITE_ROUND(row, control.cur_timestamp, control_timestamp); - SHA256_WRITE_ROUND(row, control.read_ptr, control_read_ptr); - } - - if (row_in_block < SHA256_NUM_READ_ROWS) { - if ((int32_t)row_in_block < first_padding_row) { - write_round_padding_flags_encoder( - row, padding_encoder, static_cast(PaddingFlags::NotPadding) - ); - } else if ((int32_t)row_in_block == first_padding_row) { - { - uint32_t len = message_left - row_in_block * SHA256_READ_SIZE; - uint32_t flag_idx; - if (row_in_block == 3 && is_last_block) { - flag_idx = static_cast(PaddingFlags::FirstPadding0_LastRow) + len; - if (flag_idx >= static_cast(PaddingFlags::COUNT)) { - flag_idx = static_cast(PaddingFlags::EntirePaddingLastRow); - } - } else { - flag_idx = static_cast(PaddingFlags::FirstPadding0) + len; - if (flag_idx >= static_cast(PaddingFlags::COUNT)) { - flag_idx = static_cast(PaddingFlags::EntirePadding); - } - } - write_round_padding_flags_encoder(row, padding_encoder, flag_idx); - } - } else { - { - uint32_t flag_idx; - if (row_in_block == 3 && is_last_block) { - flag_idx = static_cast(PaddingFlags::EntirePaddingLastRow); - } else { - flag_idx = static_cast(PaddingFlags::EntirePadding); - } - write_round_padding_flags_encoder(row, padding_encoder, flag_idx); - } - } - } else { - write_round_padding_flags_encoder( - row, padding_encoder, static_cast(PaddingFlags::NotConsidered) - ); - } - - if (is_last_block && row_in_block == SHA256_ROWS_PER_BLOCK - 1) { - SHA256_WRITE_ROUND(row, control.padding_occurred, Fp::zero()); - } else { - SHA256_WRITE_ROUND( - row, control.padding_occurred, (int32_t)row_in_block >= first_padding_row - ); - } - } - - Fp *inner_trace_start = trace + (SHA256_INNER_COLUMN_OFFSET * trace_height) + trace_start_row; - generate_block_trace( - inner_trace_start, - trace_height, - input_words, - bitwise_lookup_ptr, - bitwise_num_bits, - prev_hash, - is_last_block, - global_block_idx + 1, - local_block_idx - ); -} - -__global__ void sha256_fill_invalid_rows(Fp *d_trace, size_t trace_height, size_t rows_used) { - uint32_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - uint32_t row_idx = rows_used + thread_idx; - if (row_idx >= trace_height) { - return; - } - - RowSlice row(d_trace + row_idx, trace_height); - row.fill_zero(0, SHA256VM_WIDTH); - - RowSlice inner_row = row.slice_from(SHA256_INNER_COLUMN_OFFSET); - generate_default_row(inner_row); -} - -// ===== HOST LAUNCHER FUNCTIONS ===== -extern "C" int launch_sha256_hash_computation( - uint8_t *d_records, - size_t num_records, - size_t *d_record_offsets, - uint32_t *d_block_offsets, - uint32_t *d_prev_hashes, - uint32_t total_num_blocks -) { - auto [grid_size, block_size] = kernel_launch_params(num_records, 256); - - sha256_hash_computation<<>>( - d_records, num_records, d_record_offsets, d_block_offsets, d_prev_hashes, total_num_blocks - ); - - return CHECK_KERNEL(); -} - -extern "C" int launch_sha256_first_pass_tracegen( - Fp *d_trace, - size_t trace_height, - uint8_t *d_records, - size_t num_records, - size_t *d_record_offsets, - uint32_t *d_block_offsets, - uint32_t *d_block_to_record_idx, - uint32_t total_num_blocks, - uint32_t *d_prev_hashes, - uint32_t ptr_max_bits, - uint32_t *d_range_checker, - uint32_t range_checker_num_bins, - uint32_t *d_bitwise_lookup, - uint32_t bitwise_num_bits, - uint32_t timestamp_max_bits -) { - // Validate that trace_height is a power of two - assert((trace_height & (trace_height - 1)) == 0); - assert(trace_height >= total_num_blocks * SHA256_ROWS_PER_BLOCK); - - auto [grid_size, block_size] = kernel_launch_params(total_num_blocks, 256); - - sha256_first_pass_tracegen<<>>( - d_trace, - trace_height, - d_records, - num_records, - d_record_offsets, - d_block_offsets, - d_block_to_record_idx, - total_num_blocks, - d_prev_hashes, - ptr_max_bits, - d_range_checker, - range_checker_num_bins, - d_bitwise_lookup, - bitwise_num_bits, - timestamp_max_bits - ); - - return CHECK_KERNEL(); -} - -extern "C" int launch_sha256_second_pass_dependencies( - Fp *d_trace, - size_t trace_height, - size_t rows_used -) { - Fp *inner_trace_start = d_trace + (SHA256_INNER_COLUMN_OFFSET * trace_height); - uint32_t total_sha256_blocks = rows_used / SHA256_ROWS_PER_BLOCK; - - auto [grid_size, block_size] = kernel_launch_params(total_sha256_blocks, 256); - - sha256_second_pass_dependencies<<>>( - inner_trace_start, trace_height, total_sha256_blocks - ); - - return CHECK_KERNEL(); -} - -extern "C" int launch_sha256_fill_invalid_rows( - Fp *d_trace, - size_t trace_height, - size_t rows_used -) { - uint32_t invalid_rows = trace_height - rows_used; - auto [grid_size, block_size] = kernel_launch_params(invalid_rows, 256); - sha256_fill_invalid_rows<<>>(d_trace, trace_height, rows_used); - - return CHECK_KERNEL(); -} \ No newline at end of file diff --git a/extensions/sha256/circuit/src/cuda_abi.rs b/extensions/sha256/circuit/src/cuda_abi.rs deleted file mode 100644 index 77df3a782e..0000000000 --- a/extensions/sha256/circuit/src/cuda_abi.rs +++ /dev/null @@ -1,124 +0,0 @@ -#![allow(clippy::missing_safety_doc)] - -use openvm_cuda_backend::prelude::F; -use openvm_cuda_common::{d_buffer::DeviceBuffer, error::CudaError}; - -pub mod sha256 { - use super::*; - - extern "C" { - fn launch_sha256_hash_computation( - d_records: *const u8, - num_records: usize, - d_record_offsets: *const usize, - d_block_offsets: *const u32, - d_prev_hashes: *mut u32, - total_num_blocks: u32, - ) -> i32; - - fn launch_sha256_first_pass_tracegen( - d_trace: *mut F, - trace_height: usize, - d_records: *const u8, - num_records: usize, - d_record_offsets: *const usize, - d_block_offsets: *const u32, - d_block_to_record_idx: *const u32, - total_num_blocks: u32, - d_prev_hashes: *const u32, - ptr_max_bits: u32, - d_range_checker: *mut u32, - range_checker_num_bins: u32, - d_bitwise_lookup: *mut u32, - bitwise_num_bits: u32, - timestamp_max_bits: u32, - ) -> i32; - - fn launch_sha256_second_pass_dependencies( - d_trace: *mut F, - trace_height: usize, - rows_used: usize, - ) -> i32; - - fn launch_sha256_fill_invalid_rows( - d_trace: *mut F, - trace_height: usize, - rows_used: usize, - ) -> i32; - } - - pub unsafe fn sha256_hash_computation( - d_records: &DeviceBuffer, - num_records: usize, - d_record_offsets: &DeviceBuffer, - d_block_offsets: &DeviceBuffer, - d_prev_hashes: &DeviceBuffer, - num_blocks: u32, - ) -> Result<(), CudaError> { - let result = launch_sha256_hash_computation( - d_records.as_ptr(), - num_records, - d_record_offsets.as_ptr(), - d_block_offsets.as_ptr(), - d_prev_hashes.as_mut_ptr(), - num_blocks, - ); - CudaError::from_result(result) - } - - #[allow(clippy::too_many_arguments)] - pub unsafe fn sha256_first_pass_tracegen( - d_trace: &DeviceBuffer, - height: usize, - d_records: &DeviceBuffer, - num_records: usize, - d_record_offsets: &DeviceBuffer, - d_block_offsets: &DeviceBuffer, - d_block_to_record_idx: &DeviceBuffer, - total_num_blocks: u32, - d_prev_hashes: &DeviceBuffer, - ptr_max_bits: u32, - d_range_checker: &DeviceBuffer, - d_bitwise_lookup: &DeviceBuffer, - bitwise_num_bits: u32, - timestamp_max_bits: u32, - ) -> Result<(), CudaError> { - let result = launch_sha256_first_pass_tracegen( - d_trace.as_mut_ptr(), - height, - d_records.as_ptr(), - num_records, - d_record_offsets.as_ptr(), - d_block_offsets.as_ptr(), - d_block_to_record_idx.as_ptr(), - total_num_blocks, - d_prev_hashes.as_ptr(), - ptr_max_bits, - d_range_checker.as_mut_ptr() as *mut u32, - d_range_checker.len() as u32, - d_bitwise_lookup.as_mut_ptr() as *mut u32, - bitwise_num_bits, - timestamp_max_bits, - ); - CudaError::from_result(result) - } - - pub unsafe fn sha256_second_pass_dependencies( - d_trace: &DeviceBuffer, - height: usize, - rows_used: usize, - ) -> Result<(), CudaError> { - let result = - launch_sha256_second_pass_dependencies(d_trace.as_mut_ptr(), height, rows_used); - CudaError::from_result(result) - } - - pub unsafe fn sha256_fill_invalid_rows( - d_trace: &DeviceBuffer, - height: usize, - rows_used: usize, - ) -> Result<(), CudaError> { - let result = launch_sha256_fill_invalid_rows(d_trace.as_mut_ptr(), height, rows_used); - CudaError::from_result(result) - } -} diff --git a/extensions/sha256/circuit/src/extension/cuda.rs b/extensions/sha256/circuit/src/extension/cuda.rs deleted file mode 100644 index d646f71d50..0000000000 --- a/extensions/sha256/circuit/src/extension/cuda.rs +++ /dev/null @@ -1,39 +0,0 @@ -use openvm_circuit::{ - arch::DenseRecordArena, - system::cuda::extensions::{get_inventory_range_checker, get_or_create_bitwise_op_lookup}, -}; -use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend}; -use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; - -use super::*; - -pub struct Sha256GpuProverExt; - -impl VmProverExtension - for Sha256GpuProverExt -{ - fn extend_prover( - &self, - _: &Sha256, - inventory: &mut ChipInventory, - ) -> Result<(), ChipInventoryError> { - let pointer_max_bits = inventory.airs().pointer_max_bits(); - let timestamp_max_bits = inventory.timestamp_max_bits(); - - let range_checker = get_inventory_range_checker(inventory); - let bitwise_lu = get_or_create_bitwise_op_lookup(inventory)?; - - // These calls to next_air are not strictly necessary to construct the chips, but provide a - // safeguard to ensure that chip construction matches the circuit definition - inventory.next_air::()?; - let sha256 = Sha256VmChipGpu::new( - range_checker.clone(), - bitwise_lu, - pointer_max_bits as u32, - timestamp_max_bits as u32, - ); - inventory.add_executor_chip(sha256); - - Ok(()) - } -} diff --git a/extensions/sha256/circuit/src/extension/mod.rs b/extensions/sha256/circuit/src/extension/mod.rs deleted file mode 100644 index d8f18028a6..0000000000 --- a/extensions/sha256/circuit/src/extension/mod.rs +++ /dev/null @@ -1,145 +0,0 @@ -use std::{result::Result, sync::Arc}; - -use derive_more::derive::From; -use openvm_circuit::{ - arch::{ - AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, - ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, - VmExecutionExtension, VmProverExtension, - }, - system::memory::SharedMemoryHelper, -}; -use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, - SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::*; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::PrimeField32, - prover::cpu::{CpuBackend, CpuDevice}, -}; -use openvm_stark_sdk::engine::StarkEngine; -use serde::{Deserialize, Serialize}; -use strum::IntoEnumIterator; - -use crate::*; - -cfg_if::cfg_if! { - if #[cfg(feature = "cuda")] { - mod cuda; - pub use self::cuda::*; - pub use self::cuda::Sha256GpuProverExt as Sha256ProverExt; - } else { - pub use self::Sha2CpuProverExt as Sha256ProverExt; - } -} - -// =================================== VM Extension Implementation ================================= -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -pub struct Sha256; - -#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] -#[cfg_attr( - feature = "aot", - derive( - openvm_circuit_derive::AotExecutor, - openvm_circuit_derive::AotMeteredExecutor - ) -)] -pub enum Sha256Executor { - Sha256(Sha256VmExecutor), -} - -impl VmExecutionExtension for Sha256 { - type Executor = Sha256Executor; - - fn extend_execution( - &self, - inventory: &mut ExecutorInventoryBuilder, - ) -> Result<(), ExecutorInventoryError> { - let pointer_max_bits = inventory.pointer_max_bits(); - let sha256_step = Sha256VmExecutor::new(Rv32Sha256Opcode::CLASS_OFFSET, pointer_max_bits); - inventory.add_executor( - sha256_step, - Rv32Sha256Opcode::iter().map(|x| x.global_opcode()), - )?; - - Ok(()) - } -} - -impl VmCircuitExtension for Sha256 { - fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { - let pointer_max_bits = inventory.pointer_max_bits(); - - let bitwise_lu = { - let existing_air = inventory.find_air::>().next(); - if let Some(air) = existing_air { - air.bus - } else { - let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); - let air = BitwiseOperationLookupAir::<8>::new(bus); - inventory.add_air(air); - air.bus - } - }; - - let sha256 = Sha256VmAir::new( - inventory.system().port(), - bitwise_lu, - pointer_max_bits, - inventory.new_bus_idx(), - ); - inventory.add_air(sha256); - - Ok(()) - } -} - -pub struct Sha2CpuProverExt; -// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, -// BitwiseOperationLookupChip) are specific to CpuBackend. -impl VmProverExtension for Sha2CpuProverExt -where - SC: StarkGenericConfig, - E: StarkEngine, PD = CpuDevice>, - RA: RowMajorMatrixArena>, - Val: PrimeField32, -{ - fn extend_prover( - &self, - _: &Sha256, - inventory: &mut ChipInventory>, - ) -> Result<(), ChipInventoryError> { - let range_checker = inventory.range_checker()?.clone(); - let timestamp_max_bits = inventory.timestamp_max_bits(); - let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); - let pointer_max_bits = inventory.airs().pointer_max_bits(); - - let bitwise_lu = { - let existing_chip = inventory - .find_chip::>() - .next(); - if let Some(chip) = existing_chip { - chip.clone() - } else { - let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; - let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); - inventory.add_periphery_chip(chip.clone()); - chip - } - }; - - inventory.next_air::()?; - let sha256 = Sha256VmChip::new( - Sha256VmFiller::new(bitwise_lu, pointer_max_bits), - mem_helper, - ); - inventory.add_executor_chip(sha256); - - Ok(()) - } -} diff --git a/extensions/sha256/circuit/src/lib.rs b/extensions/sha256/circuit/src/lib.rs deleted file mode 100644 index c726ba9ef7..0000000000 --- a/extensions/sha256/circuit/src/lib.rs +++ /dev/null @@ -1,159 +0,0 @@ -#![cfg_attr(feature = "tco", allow(incomplete_features))] -#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] -#![cfg_attr(feature = "tco", allow(internal_features))] -#![cfg_attr(feature = "tco", feature(core_intrinsics))] - -use std::result::Result; - -use openvm_circuit::{ - arch::{ - AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, - VmBuilder, VmChipComplex, VmProverExtension, - }, - system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, -}; -use openvm_circuit_derive::VmConfig; -use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, -}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::PrimeField32, - prover::cpu::{CpuBackend, CpuDevice}, -}; -use openvm_stark_sdk::engine::StarkEngine; -use serde::{Deserialize, Serialize}; - -mod sha256_chip; -pub use sha256_chip::*; - -mod extension; -pub use extension::*; - -cfg_if::cfg_if! { - if #[cfg(feature = "cuda")] { - use openvm_circuit::arch::DenseRecordArena; - use openvm_circuit::system::cuda::{extensions::SystemGpuBuilder, SystemChipInventoryGPU}; - use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend}; - use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; - use openvm_rv32im_circuit::Rv32ImGpuProverExt; - pub(crate) mod cuda_abi; - pub use Sha256Rv32GpuBuilder as Sha256Rv32Builder; - } else { - pub use Sha256Rv32CpuBuilder as Sha256Rv32Builder; - } -} - -#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Sha256Rv32Config { - #[config(executor = "SystemExecutor")] - pub system: SystemConfig, - #[extension] - pub rv32i: Rv32I, - #[extension] - pub rv32m: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub sha256: Sha256, -} - -impl Default for Sha256Rv32Config { - fn default() -> Self { - Self { - system: SystemConfig::default(), - rv32i: Rv32I, - rv32m: Rv32M::default(), - io: Rv32Io, - sha256: Sha256, - } - } -} - -// Default implementation uses no init file -impl InitFileGenerator for Sha256Rv32Config {} - -#[derive(Clone)] -pub struct Sha256Rv32CpuBuilder; - -impl VmBuilder for Sha256Rv32CpuBuilder -where - SC: StarkGenericConfig, - E: StarkEngine, PD = CpuDevice>, - Val: PrimeField32, -{ - type VmConfig = Sha256Rv32Config; - type SystemChipInventory = SystemChipInventory; - type RecordArena = MatrixRecordArena>; - - fn create_chip_complex( - &self, - config: &Sha256Rv32Config, - circuit: AirInventory, - ) -> Result< - VmChipComplex, - ChipInventoryError, - > { - let mut chip_complex = - VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; - let inventory = &mut chip_complex.inventory; - VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; - VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; - VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; - VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; - Ok(chip_complex) - } -} - -#[cfg(feature = "cuda")] -#[derive(Clone)] -pub struct Sha256Rv32GpuBuilder; - -#[cfg(feature = "cuda")] -impl VmBuilder for Sha256Rv32GpuBuilder { - type VmConfig = Sha256Rv32Config; - type SystemChipInventory = SystemChipInventoryGPU; - type RecordArena = DenseRecordArena; - - fn create_chip_complex( - &self, - config: &Sha256Rv32Config, - circuit: AirInventory, - ) -> Result< - VmChipComplex< - BabyBearPoseidon2Config, - Self::RecordArena, - GpuBackend, - Self::SystemChipInventory, - >, - ChipInventoryError, - > { - let mut chip_complex = VmBuilder::::create_chip_complex( - &SystemGpuBuilder, - &config.system, - circuit, - )?; - let inventory = &mut chip_complex.inventory; - VmProverExtension::::extend_prover( - &Rv32ImGpuProverExt, - &config.rv32i, - inventory, - )?; - VmProverExtension::::extend_prover( - &Rv32ImGpuProverExt, - &config.rv32m, - inventory, - )?; - VmProverExtension::::extend_prover( - &Rv32ImGpuProverExt, - &config.io, - inventory, - )?; - VmProverExtension::::extend_prover( - &Sha256GpuProverExt, - &config.sha256, - inventory, - )?; - Ok(chip_complex) - } -} diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs deleted file mode 100644 index 2fe1cb26c0..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ /dev/null @@ -1,624 +0,0 @@ -use std::{array, borrow::Borrow, cmp::min}; - -use openvm_circuit::{ - arch::ExecutionBridge, - system::{ - memory::{offline_checker::MemoryBridge, MemoryAddress}, - SystemPort, - }, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir, -}; -use openvm_instructions::{ - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, - LocalOpcode, -}; -use openvm_sha256_air::{ - compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder}, - p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra}, - p3_matrix::Matrix, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, -}; - -use super::{ - Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH, - SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE, -}; - -/// Sha256VmAir does all constraints related to message padding and -/// the Sha256Air subair constrains the actual hash -#[derive(Clone, Debug)] -pub struct Sha256VmAir { - pub execution_bridge: ExecutionBridge, - pub memory_bridge: MemoryBridge, - /// Bus to send byte checks to - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - /// Maximum number of bits allowed for an address pointer - /// Must be at least 24 - pub ptr_max_bits: usize, - pub(super) sha256_subair: Sha256Air, - pub(super) padding_encoder: Encoder, -} - -impl Sha256VmAir { - pub fn new( - SystemPort { - execution_bus, - program_bus, - memory_bridge, - }: SystemPort, - bitwise_lookup_bus: BitwiseOperationLookupBus, - ptr_max_bits: usize, - self_bus_idx: BusIndex, - ) -> Self { - Self { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_bus, - ptr_max_bits, - sha256_subair: Sha256Air::new(bitwise_lookup_bus, self_bus_idx), - padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), - } - } -} - -impl BaseAirWithPublicValues for Sha256VmAir {} -impl PartitionedBaseAir for Sha256VmAir {} -impl BaseAir for Sha256VmAir { - fn width(&self) -> usize { - SHA256VM_WIDTH - } -} - -impl Air for Sha256VmAir { - fn eval(&self, builder: &mut AB) { - self.eval_padding(builder); - self.eval_transitions(builder); - self.eval_reads(builder); - self.eval_last_row(builder); - - self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH); - } -} - -#[allow(dead_code, non_camel_case_types)] -pub(super) enum PaddingFlags { - /// Not considered for padding - W's are not constrained - NotConsidered, - /// Not padding - W's should be equal to the message - NotPadding, - /// FIRST_PADDING_i: it is the first row with padding and there are i cells of non-padding - FirstPadding0, - FirstPadding1, - FirstPadding2, - FirstPadding3, - FirstPadding4, - FirstPadding5, - FirstPadding6, - FirstPadding7, - FirstPadding8, - FirstPadding9, - FirstPadding10, - FirstPadding11, - FirstPadding12, - FirstPadding13, - FirstPadding14, - FirstPadding15, - /// FIRST_PADDING_i_LastRow: it is the first row with padding and there are i cells of - /// non-padding AND it is the last reading row of the message - /// NOTE: if the Last row has padding it has to be at least 9 cells since the last 8 cells are - /// padded with the message length - FirstPadding0_LastRow, - FirstPadding1_LastRow, - FirstPadding2_LastRow, - FirstPadding3_LastRow, - FirstPadding4_LastRow, - FirstPadding5_LastRow, - FirstPadding6_LastRow, - FirstPadding7_LastRow, - /// The entire row is padding AND it is not the first row with padding - /// AND it is the 4th row of the last block of the message - EntirePaddingLastRow, - /// The entire row is padding AND it is not the first row with padding - EntirePadding, -} - -impl PaddingFlags { - /// The number of padding flags (including NotConsidered) - pub const COUNT: usize = EntirePadding as usize + 1; -} - -use PaddingFlags::*; -impl Sha256VmAir { - /// Implement all necessary constraints for the padding - fn eval_padding(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - let next_cols: &Sha256VmRoundCols = next[..SHA256VM_ROUND_WIDTH].borrow(); - - // Constrain the sanity of the padding flags - self.padding_encoder - .eval(builder, &local_cols.control.pad_flags); - - builder.assert_one(self.padding_encoder.contains_flag_range::( - &local_cols.control.pad_flags, - NotConsidered as usize..=EntirePadding as usize, - )); - - Self::eval_padding_transitions(self, builder, local_cols, next_cols); - Self::eval_padding_row(self, builder, local_cols); - } - - fn eval_padding_transitions( - &self, - builder: &mut AB, - local: &Sha256VmRoundCols, - next: &Sha256VmRoundCols, - ) { - let next_is_last_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block; - - // Constrain that `padding_occured` is 1 on a suffix of rows in each message, excluding the - // last digest row, and 0 everywhere else. Furthermore, the suffix starts in the - // first 4 rows of some block. - - builder.assert_bool(local.control.padding_occurred); - // Last round row in the last block has padding_occurred = 1 - // This is the end of the suffix - builder - .when(next_is_last_row.clone()) - .assert_one(local.control.padding_occurred); - - // Digest row in the last block has padding_occurred = 0 - builder - .when(next_is_last_row.clone()) - .assert_zero(next.control.padding_occurred); - - // If padding_occurred = 1 in the current row, then padding_occurred = 1 in the next row, - // unless next is the last digest row - builder - .when(local.control.padding_occurred - next_is_last_row.clone()) - .assert_one(next.control.padding_occurred); - - // If next row is not first 4 rows of a block, then next.padding_occurred = - // local.padding_occurred. So padding_occurred only changes in the first 4 rows of a - // block. - builder - .when_transition() - .when(not(next.inner.flags.is_first_4_rows) - next_is_last_row) - .assert_eq( - next.control.padding_occurred, - local.control.padding_occurred, - ); - - // Constrain the that the start of the padding is correct - let next_is_first_padding_row = - next.control.padding_occurred - local.control.padding_occurred; - // Row index if its between 0..4, else 0 - let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::( - &next.inner.flags.row_idx, - &(0..4).map(|x| (x, x)).collect::>(), - ); - // How many non-padding cells there are in the next row. - // Will be 0 on non-padding rows. - let next_padding_offset = self.padding_encoder.flag_with_val::( - &next.control.pad_flags, - &(0..16) - .map(|i| (FirstPadding0 as usize + i, i)) - .collect::>(), - ) + self.padding_encoder.flag_with_val::( - &next.control.pad_flags, - &(0..8) - .map(|i| (FirstPadding0_LastRow as usize + i, i)) - .collect::>(), - ); - - // Will be 0 on last digest row since: - // - padding_occurred = 0 is constrained above - // - next_row_idx = 0 since row_idx is not in 0..4 - // - and next_padding_offset = 0 since `pad_flags = NotConsidered` - let expected_len = next.inner.flags.local_block_idx - * next.control.padding_occurred - * AB::Expr::from_canonical_usize(SHA256_BLOCK_U8S) - + next_row_idx * AB::Expr::from_canonical_usize(SHA256_READ_SIZE) - + next_padding_offset; - - // Note: `next_is_first_padding_row` is either -1,0,1 - // If 1, then this constrains the length of message - // If -1, then `next` must be the last digest row and so this constraint will be 0 == 0 - builder.when(next_is_first_padding_row).assert_eq( - expected_len, - next.control.len * next.control.padding_occurred, - ); - - // Constrain the padding flags are of correct type (eg is not padding or first padding) - let is_next_first_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - FirstPadding0 as usize..=FirstPadding7_LastRow as usize, - ); - - let is_next_last_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, - ); - - let is_next_entire_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - EntirePaddingLastRow as usize..=EntirePadding as usize, - ); - - let is_next_not_considered = self - .padding_encoder - .contains_flag::(&next.control.pad_flags, &[NotConsidered as usize]); - - let is_next_not_padding = self - .padding_encoder - .contains_flag::(&next.control.pad_flags, &[NotPadding as usize]); - - let is_next_4th_row = self - .sha256_subair - .row_idx_encoder - .contains_flag::(&next.inner.flags.row_idx, &[3]); - - // `pad_flags` is `NotConsidered` on all rows except the first 4 rows of a block - builder.assert_eq( - not(next.inner.flags.is_first_4_rows), - is_next_not_considered, - ); - - // `pad_flags` is `EntirePadding` if the previous row is padding - builder.when(next.inner.flags.is_first_4_rows).assert_eq( - local.control.padding_occurred * next.control.padding_occurred, - is_next_entire_padding, - ); - - // `pad_flags` is `FirstPadding*` if current row is padding and the previous row is not - // padding - builder.when(next.inner.flags.is_first_4_rows).assert_eq( - not(local.control.padding_occurred) * next.control.padding_occurred, - is_next_first_padding, - ); - - // `pad_flags` is `NotPadding` if current row is not padding - builder - .when(next.inner.flags.is_first_4_rows) - .assert_eq(not(next.control.padding_occurred), is_next_not_padding); - - // `pad_flags` is `*LastRow` on the row that contains the last four words of the message - builder - .when(next.inner.flags.is_last_block) - .assert_eq(is_next_4th_row, is_next_last_padding); - } - - fn eval_padding_row( - &self, - builder: &mut AB, - local: &Sha256VmRoundCols, - ) { - let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| { - local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)] - [i % (SHA256_WORD_U8S)] - }); - - let get_ith_byte = |i: usize| { - let word_idx = i / SHA256_ROUNDS_PER_ROW; - let word = local.inner.message_schedule.w[word_idx].map(|x| x.into()); - // Need to reverse the byte order to match the endianness of the memory - let byte_idx = 4 - i % 4 - 1; - compose::(&word[byte_idx * 8..(byte_idx + 1) * 8], 1) - }; - - let is_not_padding = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[NotPadding as usize]); - - // Check the `w`s on case by case basis - for (i, message_byte) in message.iter().enumerate() { - let w = get_ith_byte(i); - let should_be_message = is_not_padding.clone() - + if i < 15 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0 as usize + i + 1..=FirstPadding15 as usize, - ) - } else { - AB::Expr::ZERO - } - + if i < 7 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize, - ) - } else { - AB::Expr::ZERO - }; - builder - .when(should_be_message) - .assert_eq(w.clone(), *message_byte); - - let should_be_zero = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[EntirePadding as usize]) - + if i < 12 { - self.padding_encoder.contains_flag::( - &local.control.pad_flags, - &[EntirePaddingLastRow as usize], - ) + if i > 0 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize - ..=min( - FirstPadding0_LastRow as usize + i - 1, - FirstPadding7_LastRow as usize, - ), - ) - } else { - AB::Expr::ZERO - } - } else { - AB::Expr::ZERO - } - + if i > 0 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0 as usize..=FirstPadding0 as usize + i - 1, - ) - } else { - AB::Expr::ZERO - }; - builder.when(should_be_zero).assert_zero(w.clone()); - - // Assumes bit-length of message is a multiple of 8 (message is bytes) - // This is true because the message is given as &[u8] - let should_be_128 = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[FirstPadding0 as usize + i]) - + if i < 8 { - self.padding_encoder.contains_flag::( - &local.control.pad_flags, - &[FirstPadding0_LastRow as usize + i], - ) - } else { - AB::Expr::ZERO - }; - - builder - .when(should_be_128) - .assert_eq(AB::Expr::from_canonical_u32(1 << 7), w); - - // should be len is handled outside of the loop - } - let appended_len = compose::( - &[ - get_ith_byte(15), - get_ith_byte(14), - get_ith_byte(13), - get_ith_byte(12), - ], - RV32_CELL_BITS, - ); - - let actual_len = local.control.len; - - let is_last_padding_row = self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, - ); - - builder.when(is_last_padding_row.clone()).assert_eq( - appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), // bit to byte conversion - actual_len, - ); - - // We constrain that the appended length is in bytes - builder.when(is_last_padding_row.clone()).assert_zero( - local.inner.message_schedule.w[3][0] - + local.inner.message_schedule.w[3][1] - + local.inner.message_schedule.w[3][2], - ); - - // We can't support messages longer than 2^30 bytes because the length has to fit in a field - // element. So, constrain that the first 4 bytes of the length are 0. - // Thus, the bit-length is < 2^32 so the message is < 2^29 bytes. - for i in 8..12 { - builder - .when(is_last_padding_row.clone()) - .assert_zero(get_ith_byte(i)); - } - } - /// Implement constraints on `len`, `read_ptr` and `cur_timestamp` - fn eval_transitions(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - let next_cols: &Sha256VmRoundCols = next[..SHA256VM_ROUND_WIDTH].borrow(); - - let is_last_row = - local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row; - - // Len should be the same for the entire message - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq(next_cols.control.len, local_cols.control.len); - - // Read ptr should increment by [SHA256_READ_SIZE] for the first 4 rows and stay the same - // otherwise - let read_ptr_delta = local_cols.inner.flags.is_first_4_rows - * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq( - next_cols.control.read_ptr, - local_cols.control.read_ptr + read_ptr_delta, - ); - - // Timestamp should increment by 1 for the first 4 rows and stay the same otherwise - let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE; - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq( - next_cols.control.cur_timestamp, - local_cols.control.cur_timestamp + timestamp_delta, - ); - } - - /// Implement the reads for the first 4 rows of a block - fn eval_reads(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - - let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| { - local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)] - [i % (SHA256_WORD_U16S * 2)] - }); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - local_cols.control.read_ptr, - ), - message, - local_cols.control.cur_timestamp, - &local_cols.read_aux, - ) - .eval(builder, local_cols.inner.flags.is_first_4_rows); - } - /// Implement the constraints for the last row of a message - fn eval_last_row(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_cols: &Sha256VmDigestCols = local[..SHA256VM_DIGEST_WIDTH].borrow(); - - let timestamp: AB::Var = local_cols.from_state.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) - }; - - let is_last_row = - local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row; - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rd_ptr, - ), - local_cols.dst_ptr, - timestamp_pp(), - &local_cols.register_reads_aux[0], - ) - .eval(builder, is_last_row.clone()); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rs1_ptr, - ), - local_cols.src_ptr, - timestamp_pp(), - &local_cols.register_reads_aux[1], - ) - .eval(builder, is_last_row.clone()); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rs2_ptr, - ), - local_cols.len_data, - timestamp_pp(), - &local_cols.register_reads_aux[2], - ) - .eval(builder, is_last_row.clone()); - - // range check that the memory pointers don't overflow - // Note: no need to range check the length since we read from memory step by step and - // the memory bus will catch any memory accesses beyond ptr_max_bits - let shift = AB::Expr::from_canonical_usize( - 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), - ); - // This only works if self.ptr_max_bits >= 24 which is typically the case - self.bitwise_lookup_bus - .send_range( - // It is fine to shift like this since we already know that dst_ptr and src_ptr - // have [RV32_CELL_BITS] bits - local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), - local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), - ) - .eval(builder, is_last_row.clone()); - - // the number of reads that happened to read the entire message: we do 4 reads per block - let time_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE) - * AB::Expr::from_canonical_usize(4); - // Every time we read the message we increment the read pointer by SHA256_READ_SIZE - let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); - - let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| { - // The limbs are written in big endian order to the memory so need to be reversed - local_cols.inner.final_hash[i / SHA256_WORD_U8S] - [SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1] - }); - - let dst_ptr_val = - compose::(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS); - - // Note: revisit in the future to do 2 block writes of 16 cells instead of 1 block write of - // 32 cells This could be beneficial as the output is often an input for - // another hash - self.memory_bridge - .write( - MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), dst_ptr_val), - result, - timestamp_pp() + time_delta.clone(), - &local_cols.writes_aux, - ) - .eval(builder, is_last_row.clone()); - - self.execution_bridge - .execute_and_increment_pc( - AB::Expr::from_canonical_usize(Rv32Sha256Opcode::SHA256.global_opcode().as_usize()), - [ - local_cols.rd_ptr.into(), - local_cols.rs1_ptr.into(), - local_cols.rs2_ptr.into(), - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - ], - local_cols.from_state, - AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(), - ) - .eval(builder, is_last_row.clone()); - - // Assert that we read the correct length of the message - let len_val = compose::(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS); - builder - .when(is_last_row.clone()) - .assert_eq(local_cols.control.len, len_val); - // Assert that we started reading from the correct pointer initially - let src_val = compose::(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS); - builder - .when(is_last_row.clone()) - .assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta); - // Assert that we started reading from the correct timestamp - builder.when(is_last_row.clone()).assert_eq( - local_cols.control.cur_timestamp, - local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta, - ); - } -} diff --git a/extensions/sha256/circuit/src/sha256_chip/columns.rs b/extensions/sha256/circuit/src/sha256_chip/columns.rs deleted file mode 100644 index 38c13a0f73..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/columns.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! WARNING: the order of fields in the structs is important, do not change it - -use openvm_circuit::{ - arch::ExecutionState, - system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; -use openvm_circuit_primitives::AlignedBorrow; -use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; -use openvm_sha256_air::{Sha256DigestCols, Sha256RoundCols}; - -use super::{SHA256_REGISTER_READS, SHA256_WRITE_SIZE}; - -/// the first 16 rows of every SHA256 block will be of type Sha256VmRoundCols and the last row will -/// be of type Sha256VmDigestCols -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmRoundCols { - pub control: Sha256VmControlCols, - pub inner: Sha256RoundCols, - pub read_aux: MemoryReadAuxCols, -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmDigestCols { - pub control: Sha256VmControlCols, - pub inner: Sha256DigestCols, - - pub from_state: ExecutionState, - /// It is counter intuitive, but we will constrain the register reads on the very last row of - /// every message - pub rd_ptr: T, - pub rs1_ptr: T, - pub rs2_ptr: T, - pub dst_ptr: [T; RV32_REGISTER_NUM_LIMBS], - pub src_ptr: [T; RV32_REGISTER_NUM_LIMBS], - pub len_data: [T; RV32_REGISTER_NUM_LIMBS], - pub register_reads_aux: [MemoryReadAuxCols; SHA256_REGISTER_READS], - pub writes_aux: MemoryWriteAuxCols, -} - -/// These are the columns that are used on both round and digest rows -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmControlCols { - /// Note: We will use the buffer in `inner.message_schedule` as the message data - /// This is the length of the entire message in bytes - pub len: T, - /// Need to keep timestamp and read_ptr since block reads don't have the necessary information - pub cur_timestamp: T, - pub read_ptr: T, - /// Padding flags which will be used to encode the the number of non-padding cells in the - /// current row - pub pad_flags: [T; 6], - /// A boolean flag that indicates whether a padding already occurred - pub padding_occurred: T, -} - -/// Width of the Sha256VmControlCols -pub const SHA256VM_CONTROL_WIDTH: usize = Sha256VmControlCols::::width(); -/// Width of the Sha256VmRoundCols -pub const SHA256VM_ROUND_WIDTH: usize = Sha256VmRoundCols::::width(); -/// Width of the Sha256VmDigestCols -pub const SHA256VM_DIGEST_WIDTH: usize = Sha256VmDigestCols::::width(); -/// Width of the Sha256Cols -pub const SHA256VM_WIDTH: usize = if SHA256VM_ROUND_WIDTH > SHA256VM_DIGEST_WIDTH { - SHA256VM_ROUND_WIDTH -} else { - SHA256VM_DIGEST_WIDTH -}; diff --git a/extensions/sha256/circuit/src/sha256_chip/cuda.rs b/extensions/sha256/circuit/src/sha256_chip/cuda.rs deleted file mode 100644 index 9a673dc795..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/cuda.rs +++ /dev/null @@ -1,129 +0,0 @@ -// crates/tracegen/src/extensions/sha256/mod.rs - -use std::{iter::repeat_n, sync::Arc}; - -use derive_new::new; -use openvm_circuit::{ - arch::{DenseRecordArena, MultiRowLayout, RecordSeeker}, - utils::next_power_of_two_or_zero, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU, -}; -use openvm_cuda_backend::{ - base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend, -}; -use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer}; -use openvm_instructions::riscv::RV32_CELL_BITS; -use openvm_sha256_air::{get_sha256_num_blocks, SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK}; -use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; - -use crate::{ - cuda_abi::sha256::{ - sha256_fill_invalid_rows, sha256_first_pass_tracegen, sha256_hash_computation, - sha256_second_pass_dependencies, - }, - Sha256VmMetadata, Sha256VmRecordMut, SHA256VM_WIDTH, -}; - -// ===== SHA256 GPU CHIP IMPLEMENTATION ===== -#[derive(new)] -pub struct Sha256VmChipGpu { - pub range_checker: Arc, - pub bitwise_lookup: Arc>, - pub ptr_max_bits: u32, - pub timestamp_max_bits: u32, -} - -impl Chip for Sha256VmChipGpu { - fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext { - let records = arena.allocated_mut(); - if records.is_empty() { - return get_empty_air_proving_ctx::(); - } - - let mut record_offsets = Vec::::new(); - let mut block_to_record_idx = Vec::::new(); - let mut block_offsets = Vec::::new(); - let mut offset_so_far = 0; - let mut num_blocks_so_far: u32 = 0; - - while offset_so_far < records.len() { - record_offsets.push(offset_so_far); - block_offsets.push(num_blocks_so_far); - - let record = RecordSeeker::< - DenseRecordArena, - Sha256VmRecordMut, - MultiRowLayout, - >::get_record_at(&mut offset_so_far, records); - - let num_blocks = get_sha256_num_blocks(record.inner.len); - let record_idx = record_offsets.len() - 1; - - block_to_record_idx.extend(repeat_n(record_idx as u32, num_blocks as usize)); - num_blocks_so_far += num_blocks; - } - - assert_eq!(num_blocks_so_far as usize, block_to_record_idx.len()); - assert_eq!(offset_so_far, records.len()); - assert_eq!(block_offsets.len(), record_offsets.len()); - - let d_records = records.to_device().unwrap(); - let d_record_offsets = record_offsets.to_device().unwrap(); - let d_block_offsets = block_offsets.to_device().unwrap(); - let d_block_to_record_idx = block_to_record_idx.to_device().unwrap(); - - let d_prev_hashes = DeviceBuffer::::with_capacity( - num_blocks_so_far as usize * SHA256_HASH_WORDS, // 8 words per SHA256 hash block - ); - - unsafe { - sha256_hash_computation( - &d_records, - record_offsets.len(), - &d_record_offsets, - &d_block_offsets, - &d_prev_hashes, - num_blocks_so_far, - ) - .expect("Hash computation kernel failed"); - } - - let rows_used = num_blocks_so_far as usize * SHA256_ROWS_PER_BLOCK; - let trace_height = next_power_of_two_or_zero(rows_used); - let d_trace = DeviceMatrix::::with_capacity(trace_height, SHA256VM_WIDTH); - - unsafe { - sha256_first_pass_tracegen( - d_trace.buffer(), - trace_height, - &d_records, - record_offsets.len(), - &d_record_offsets, - &d_block_offsets, - &d_block_to_record_idx, - num_blocks_so_far, - &d_prev_hashes, - self.ptr_max_bits, - &self.range_checker.count, - &self.bitwise_lookup.count, - RV32_CELL_BITS as u32, - self.timestamp_max_bits, - ) - .expect("First pass trace generation failed"); - } - - unsafe { - sha256_fill_invalid_rows(d_trace.buffer(), trace_height, rows_used) - .expect("Invalid rows filling failed"); - } - - unsafe { - sha256_second_pass_dependencies(d_trace.buffer(), trace_height, rows_used) - .expect("Second pass trace generation failed"); - } - - AirProvingContext::simple_no_pis(d_trace) - } -} diff --git a/extensions/sha256/circuit/src/sha256_chip/mod.rs b/extensions/sha256/circuit/src/sha256_chip/mod.rs deleted file mode 100644 index eb027bd6d1..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/mod.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Sha256 hasher. Handles full sha256 hashing with padding. -//! variable length inputs read from VM memory. - -use openvm_circuit::arch::*; -use openvm_circuit_primitives::{ - bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, -}; -use openvm_instructions::riscv::RV32_CELL_BITS; -use openvm_sha256_air::{Sha256FillerHelper, SHA256_BLOCK_BITS}; -use sha2::{Digest, Sha256}; - -mod air; -mod columns; -mod execution; -mod trace; - -pub use air::*; -pub use columns::*; -pub use trace::*; - -#[cfg(feature = "cuda")] -mod cuda; -#[cfg(feature = "cuda")] -pub use cuda::*; - -#[cfg(test)] -mod tests; - -// ==== Constants for register/memory adapter ==== -/// Register reads to get dst, src, len -const SHA256_REGISTER_READS: usize = 3; -/// Number of cells to read in a single memory access -const SHA256_READ_SIZE: usize = 16; -/// Number of cells to write in a single memory access -const SHA256_WRITE_SIZE: usize = 32; -/// Number of rv32 cells read in a SHA256 block -pub const SHA256_BLOCK_CELLS: usize = SHA256_BLOCK_BITS / RV32_CELL_BITS; -/// Number of rows we will do a read on for each SHA256 block -pub const SHA256_NUM_READ_ROWS: usize = SHA256_BLOCK_CELLS / SHA256_READ_SIZE; -/// Maximum message length that this chip supports in bytes -pub const SHA256_MAX_MESSAGE_LEN: usize = 1 << 29; - -pub type Sha256VmChip = VmChipWrapper; - -#[derive(derive_new::new, Clone)] -pub struct Sha256VmExecutor { - pub offset: usize, - pub pointer_max_bits: usize, -} - -pub struct Sha256VmFiller { - pub inner: Sha256FillerHelper, - pub padding_encoder: Encoder, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - pub pointer_max_bits: usize, -} - -impl Sha256VmFiller { - pub fn new( - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - pointer_max_bits: usize, - ) -> Self { - Self { - inner: Sha256FillerHelper::new(), - padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), - bitwise_lookup_chip, - pointer_max_bits, - } - } -} - -pub fn sha256_solve(input_message: &[u8]) -> [u8; SHA256_WRITE_SIZE] { - let mut hasher = Sha256::new(); - hasher.update(input_message); - let mut output = [0u8; SHA256_WRITE_SIZE]; - output.copy_from_slice(hasher.finalize().as_ref()); - output -} diff --git a/extensions/sha256/circuit/src/sha256_chip/tests.rs b/extensions/sha256/circuit/src/sha256_chip/tests.rs deleted file mode 100644 index 4f72ffd333..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/tests.rs +++ /dev/null @@ -1,376 +0,0 @@ -use std::{array, sync::Arc}; - -use hex::FromHex; -use openvm_circuit::{ - arch::{ - testing::{ - memory::gen_pointer, TestBuilder, TestChipHarness, VmChipTestBuilder, - BITWISE_OP_LOOKUP_BUS, - }, - Arena, MatrixRecordArena, PreflightExecutor, - }, - system::{memory::SharedMemoryHelper, SystemPort}, - utils::get_random_message, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, - SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::{ - instruction::Instruction, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS}, - LocalOpcode, -}; -use openvm_sha256_air::{get_sha256_num_blocks, SHA256_BLOCK_U8S}; -use openvm_sha256_transpiler::Rv32Sha256Opcode::{self, *}; -use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; -use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::{rngs::StdRng, Rng}; -#[cfg(feature = "cuda")] -use { - crate::{Sha256VmChipGpu, Sha256VmRecordMut}, - openvm_circuit::arch::testing::{ - default_bitwise_lookup_bus, GpuChipTestBuilder, GpuTestChipHarness, - }, -}; - -use super::{Sha256VmAir, Sha256VmChip, Sha256VmExecutor}; -use crate::{sha256_solve, Sha256VmDigestCols, Sha256VmFiller, Sha256VmRoundCols}; - -type F = BabyBear; -const SELF_BUS_IDX: BusIndex = 28; -const MAX_INS_CAPACITY: usize = 4096; -type Harness = TestChipHarness, RA>; - -fn create_harness_fields( - system_port: SystemPort, - bitwise_chip: Arc>, - memory_helper: SharedMemoryHelper, - address_bits: usize, -) -> (Sha256VmAir, Sha256VmExecutor, Sha256VmChip) { - let air = Sha256VmAir::new(system_port, bitwise_chip.bus(), address_bits, SELF_BUS_IDX); - let executor = Sha256VmExecutor::new(Rv32Sha256Opcode::CLASS_OFFSET, address_bits); - let chip = Sha256VmChip::new( - Sha256VmFiller::new(bitwise_chip, address_bits), - memory_helper, - ); - (air, executor, chip) -} - -fn create_harness( - tester: &mut VmChipTestBuilder, -) -> ( - Harness, - ( - BitwiseOperationLookupAir, - SharedBitwiseOperationLookupChip, - ), -) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - let (air, executor, chip) = create_harness_fields( - tester.system_port(), - bitwise_chip.clone(), - tester.memory_helper(), - tester.address_bits(), - ); - let harness = Harness::::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - (harness, (bitwise_chip.air, bitwise_chip)) -} - -fn set_and_execute>( - tester: &mut impl TestBuilder, - executor: &mut E, - arena: &mut RA, - rng: &mut StdRng, - opcode: Rv32Sha256Opcode, - message: Option<&[u8]>, - len: Option, -) { - let len = len.unwrap_or(rng.gen_range(1..3000)); - let tmp = get_random_message(rng, len); - let message: &[u8] = message.unwrap_or(&tmp); - let len = message.len(); - - let rd = gen_pointer(rng, 4); - let rs1 = gen_pointer(rng, 4); - let rs2 = gen_pointer(rng, 4); - - let dst_ptr = gen_pointer(rng, 4); - let src_ptr = gen_pointer(rng, 4); - tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); - tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); - tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); - - // Adding random memory after the message - let num_blocks = get_sha256_num_blocks(len as u32) as usize; - for offset in (0..num_blocks * SHA256_BLOCK_U8S).step_by(4) { - let chunk: [F; 4] = array::from_fn(|i| { - if offset + i < message.len() { - F::from_canonical_u8(message[offset + i]) - } else { - F::from_canonical_u8(rng.gen()) - } - }); - - tester.write(RV32_MEMORY_AS as usize, src_ptr + offset, chunk); - } - - tester.execute( - executor, - arena, - &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), - ); - - let output = sha256_solve(message); - assert_eq!( - output.map(F::from_canonical_u8), - tester.read::<32>(RV32_MEMORY_AS as usize, dst_ptr) - ); -} - -/////////////////////////////////////////////////////////////////////////////////////// -/// POSITIVE TESTS -/// -/// Randomly generate computations and execute, ensuring that the generated trace -/// passes all constraints. -/////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_sha256_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let (mut harness, bitwise) = create_harness(&mut tester); - - let num_ops: usize = 10; - for _ in 0..num_ops { - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.arena, - &mut rng, - SHA256, - None, - None, - ); - } - - let tester = tester - .build() - .load(harness) - .load_periphery(bitwise) - .finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn sha256_edge_test_lengths() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let (mut harness, bitwise) = create_harness(&mut tester); - - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ( - "98c1c0bdb7d5fea9a88859f06c6c439f", - "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", - ), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - ]; - - for (input, _) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.arena, - &mut rng, - SHA256, - Some(&input), - None, - ); - } - - // check every possible input length modulo 64 - for i in 65..=128 { - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.arena, - &mut rng, - SHA256, - None, - Some(i), - ); - } - - let tester = tester - .build() - .load(harness) - .load_periphery(bitwise) - .finalize(); - tester.simple_test().expect("Verification failed"); -} - -/////////////////////////////////////////////////////////////////////////////////////// -/// SANITY TESTS -/// -/// Ensure that solve functions produce the correct results. -/////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let (mut harness, _) = create_harness::>(&mut tester); - - println!( - "Sha256VmDigestCols::width(): {}", - Sha256VmDigestCols::::width() - ); - println!( - "Sha256VmRoundCols::width(): {}", - Sha256VmRoundCols::::width() - ); - let num_tests: usize = 1; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.arena, - &mut rng, - SHA256, - None, - None, - ); - } -} - -#[test] -fn sha256_solve_sanity_check() { - let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; - let output = sha256_solve(input); - let expected: [u8; 32] = [ - 99, 196, 61, 185, 226, 212, 131, 80, 154, 248, 97, 108, 157, 55, 200, 226, 160, 73, 207, - 46, 245, 169, 94, 255, 42, 136, 193, 15, 40, 133, 173, 22, - ]; - assert_eq!(output, expected); -} - -// //////////////////////////////////////////////////////////////////////////////////// -// CUDA TESTS -// -// Ensure GPU tracegen is equivalent to CPU tracegen -// //////////////////////////////////////////////////////////////////////////////////// - -#[cfg(feature = "cuda")] -type GpuHarness = - GpuTestChipHarness>; - -#[cfg(feature = "cuda")] -fn create_cuda_harness(tester: &GpuChipTestBuilder) -> GpuHarness { - const GPU_MAX_INS_CAPACITY: usize = 8192; - - let bitwise_bus = default_bitwise_lookup_bus(); - let dummy_bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - - let (air, executor, cpu_chip) = create_harness_fields( - tester.system_port(), - dummy_bitwise_chip, - tester.dummy_memory_helper(), - tester.address_bits(), - ); - let gpu_chip = Sha256VmChipGpu::new( - tester.range_checker(), - tester.bitwise_op_lookup(), - tester.address_bits() as u32, - tester.timestamp_max_bits() as u32, - ); - - GpuTestChipHarness::with_capacity(executor, air, gpu_chip, cpu_chip, GPU_MAX_INS_CAPACITY) -} - -#[cfg(feature = "cuda")] -#[test] -fn test_cuda_sha256_tracegen() { - let mut rng = create_seeded_rng(); - let mut tester = - GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); - - let mut harness = create_cuda_harness(&tester); - - let num_ops = 70; - for i in 1..=num_ops { - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.dense_arena, - &mut rng, - SHA256, - None, - Some(i), - ); - } - - harness - .dense_arena - .get_record_seeker::() - .transfer_to_matrix_arena(&mut harness.matrix_arena); - - tester - .build() - .load_gpu_harness(harness) - .finalize() - .simple_test() - .unwrap(); -} - -#[cfg(feature = "cuda")] -#[test] -fn test_cuda_sha256_known_vectors() { - let mut rng = create_seeded_rng(); - let mut tester = - GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); - - let mut harness = create_cuda_harness(&tester); - - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ( - "98c1c0bdb7d5fea9a88859f06c6c439f", - "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", - ), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - ]; - - for (input, _) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.dense_arena, - &mut rng, - SHA256, - Some(&input), - None, - ); - } - - harness - .dense_arena - .get_record_seeker::() - .transfer_to_matrix_arena(&mut harness.matrix_arena); - - tester - .build() - .load_gpu_harness(harness) - .finalize() - .simple_test() - .unwrap(); -} diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs deleted file mode 100644 index 7fc5c7062c..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ /dev/null @@ -1,624 +0,0 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, - cmp::min, -}; - -use openvm_circuit::{ - arch::*, - system::memory::{ - offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, - online::TracingMemory, - MemoryAuxColsFactory, - }, -}; -use openvm_circuit_primitives::AlignedBytesBorrow; -use openvm_instructions::{ - instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, - LocalOpcode, -}; -use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; -use openvm_sha256_air::{ - get_flag_pt_array, get_sha256_num_blocks, Sha256FillerHelper, SHA256_BLOCK_BITS, SHA256_H, - SHA256_ROWS_PER_BLOCK, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - p3_field::PrimeField32, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, -}; - -use super::{ - Sha256VmDigestCols, Sha256VmExecutor, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, - SHA256VM_DIGEST_WIDTH, -}; -use crate::{ - sha256_chip::{PaddingFlags, SHA256_READ_SIZE, SHA256_REGISTER_READS, SHA256_WRITE_SIZE}, - sha256_solve, Sha256VmControlCols, Sha256VmFiller, SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, - SHA256_BLOCK_CELLS, SHA256_MAX_MESSAGE_LEN, SHA256_NUM_READ_ROWS, -}; - -#[derive(Clone, Copy)] -pub struct Sha256VmMetadata { - pub num_blocks: u32, -} - -impl MultiRowMetadata for Sha256VmMetadata { - #[inline(always)] - fn get_num_rows(&self) -> usize { - self.num_blocks as usize * SHA256_ROWS_PER_BLOCK - } -} - -pub(crate) type Sha256VmRecordLayout = MultiRowLayout; - -#[repr(C)] -#[derive(AlignedBytesBorrow, Debug, Clone)] -pub struct Sha256VmRecordHeader { - pub from_pc: u32, - pub timestamp: u32, - pub rd_ptr: u32, - pub rs1_ptr: u32, - pub rs2_ptr: u32, - pub dst_ptr: u32, - pub src_ptr: u32, - pub len: u32, - - pub register_reads_aux: [MemoryReadAuxRecord; SHA256_REGISTER_READS], - pub write_aux: MemoryWriteBytesAuxRecord, -} - -pub struct Sha256VmRecordMut<'a> { - pub inner: &'a mut Sha256VmRecordHeader, - // Having a continuous slice of the input is useful for fast hashing in `execute` - pub input: &'a mut [u8], - pub read_aux: &'a mut [MemoryReadAuxRecord], -} - -/// Custom borrowing that splits the buffer into a fixed `Sha256VmRecord` header -/// followed by a slice of `u8`'s of length `SHA256_BLOCK_CELLS * num_blocks` where `num_blocks` is -/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length -/// `SHA256_NUM_READ_ROWS * num_blocks`. Uses `align_to_mut()` to make sure the slice is properly -/// aligned to `MemoryReadAuxRecord`. Has debug assertions that check the size and alignment of the -/// slices. -impl<'a> CustomBorrow<'a, Sha256VmRecordMut<'a>, Sha256VmRecordLayout> for [u8] { - fn custom_borrow(&'a mut self, layout: Sha256VmRecordLayout) -> Sha256VmRecordMut<'a> { - // SAFETY: - // - Caller guarantees through the layout that self has sufficient length for all splits and - // constants are guaranteed <= self.len() by layout precondition - let (header_buf, rest) = - unsafe { self.split_at_mut_unchecked(size_of::()) }; - - // SAFETY: - // - layout guarantees rest has sufficient length for input data - // - The layout size calculation includes num_blocks * SHA256_BLOCK_CELLS bytes after header - // - num_blocks is derived from the message length ensuring correct sizing - // - SHA256_BLOCK_CELLS is a compile-time constant (64 bytes per block) - let (input, rest) = unsafe { - rest.split_at_mut_unchecked((layout.metadata.num_blocks as usize) * SHA256_BLOCK_CELLS) - }; - - // SAFETY: - // - rest is a valid mutable slice from the previous split - // - align_to_mut guarantees the middle slice is properly aligned for MemoryReadAuxRecord - // - The subslice operation [..num_blocks * SHA256_NUM_READ_ROWS] validates sufficient - // capacity - // - Layout calculation ensures space for alignment padding plus required aux records - let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; - Sha256VmRecordMut { - inner: header_buf.borrow_mut(), - input, - read_aux: &mut read_aux_buf - [..(layout.metadata.num_blocks as usize) * SHA256_NUM_READ_ROWS], - } - } - - unsafe fn extract_layout(&self) -> Sha256VmRecordLayout { - let header: &Sha256VmRecordHeader = self.borrow(); - Sha256VmRecordLayout { - metadata: Sha256VmMetadata { - num_blocks: get_sha256_num_blocks(header.len), - }, - } - } -} - -impl SizedRecord for Sha256VmRecordMut<'_> { - fn size(layout: &Sha256VmRecordLayout) -> usize { - let mut total_len = size_of::(); - total_len += layout.metadata.num_blocks as usize * SHA256_BLOCK_CELLS; - // Align the pointer to the alignment of `MemoryReadAuxRecord` - total_len = total_len.next_multiple_of(align_of::()); - total_len += layout.metadata.num_blocks as usize - * SHA256_NUM_READ_ROWS - * size_of::(); - total_len - } - - fn alignment(_layout: &Sha256VmRecordLayout) -> usize { - align_of::() - } -} - -impl PreflightExecutor for Sha256VmExecutor -where - F: PrimeField32, - for<'buf> RA: RecordArena<'buf, Sha256VmRecordLayout, Sha256VmRecordMut<'buf>>, -{ - fn get_opcode_name(&self, _: usize) -> String { - format!("{:?}", Rv32Sha256Opcode::SHA256) - } - - fn execute( - &self, - state: VmStateMut, - instruction: &Instruction, - ) -> Result<(), ExecutionError> { - let Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = instruction; - debug_assert_eq!(*opcode, Rv32Sha256Opcode::SHA256.global_opcode()); - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - - // Reading the length first to allocate a record of correct size - let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()); - - let num_blocks = get_sha256_num_blocks(len); - let record = state.ctx.alloc(MultiRowLayout { - metadata: Sha256VmMetadata { num_blocks }, - }); - - record.inner.from_pc = *state.pc; - record.inner.timestamp = state.memory.timestamp(); - record.inner.rd_ptr = a.as_canonical_u32(); - record.inner.rs1_ptr = b.as_canonical_u32(); - record.inner.rs2_ptr = c.as_canonical_u32(); - - record.inner.dst_ptr = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rd_ptr, - &mut record.inner.register_reads_aux[0].prev_timestamp, - )); - record.inner.src_ptr = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rs1_ptr, - &mut record.inner.register_reads_aux[1].prev_timestamp, - )); - record.inner.len = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rs2_ptr, - &mut record.inner.register_reads_aux[2].prev_timestamp, - )); - - // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used - debug_assert!( - record.inner.src_ptr as usize + num_blocks as usize * SHA256_BLOCK_CELLS - <= (1 << self.pointer_max_bits) - ); - debug_assert!( - record.inner.dst_ptr as usize + SHA256_WRITE_SIZE <= (1 << self.pointer_max_bits) - ); - // We don't support messages longer than 2^29 bytes - debug_assert!(record.inner.len < SHA256_MAX_MESSAGE_LEN as u32); - - for block_idx in 0..num_blocks as usize { - // Reads happen on the first 4 rows of each block - for row in 0..SHA256_NUM_READ_ROWS { - let read_idx = block_idx * SHA256_NUM_READ_ROWS + row; - let row_input: [u8; SHA256_READ_SIZE] = tracing_read( - state.memory, - RV32_MEMORY_AS, - record.inner.src_ptr + (read_idx * SHA256_READ_SIZE) as u32, - &mut record.read_aux[read_idx].prev_timestamp, - ); - record.input[read_idx * SHA256_READ_SIZE..(read_idx + 1) * SHA256_READ_SIZE] - .copy_from_slice(&row_input); - } - } - - let output = sha256_solve(&record.input[..len as usize]); - tracing_write( - state.memory, - RV32_MEMORY_AS, - record.inner.dst_ptr, - output, - &mut record.inner.write_aux.prev_timestamp, - &mut record.inner.write_aux.prev_data, - ); - - *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - - Ok(()) - } -} - -impl TraceFiller for Sha256VmFiller { - fn fill_trace( - &self, - mem_helper: &MemoryAuxColsFactory, - trace_matrix: &mut RowMajorMatrix, - rows_used: usize, - ) { - if rows_used == 0 { - return; - } - - let mut chunks = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK); - let mut sizes = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK); - let mut trace = &mut trace_matrix.values[..]; - let mut num_blocks_so_far = 0; - - // First pass over the trace to get the number of blocks for each instruction - // and divide the matrix into chunks of needed sizes - loop { - if num_blocks_so_far * SHA256_ROWS_PER_BLOCK >= rows_used { - // Push all the padding rows as a single chunk and break - chunks.push(trace); - sizes.push((0, num_blocks_so_far)); - break; - } else { - // SAFETY: - // - caller ensures `trace` contains a valid record representation that was - // previously written by the executor - // - header is the first element of the record - let record: &Sha256VmRecordHeader = - unsafe { get_record_from_slice(&mut trace, ()) }; - let num_blocks = ((record.len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS); - let (chunk, rest) = - trace.split_at_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK * num_blocks); - chunks.push(chunk); - sizes.push((num_blocks, num_blocks_so_far)); - num_blocks_so_far += num_blocks; - trace = rest; - } - } - - // During the first pass we will fill out most of the matrix - // But there are some cells that can't be generated by the first pass so we will do a second - // pass over the matrix later - chunks.par_iter_mut().zip(sizes.par_iter()).for_each( - |(slice, (num_blocks, global_block_offset))| { - if global_block_offset * SHA256_ROWS_PER_BLOCK >= rows_used { - // Fill in the invalid rows - slice.par_chunks_mut(SHA256VM_WIDTH).for_each(|row| { - // Need to get rid of the accidental garbage data that might overflow the - // F's prime field. Unfortunately, there is no good way around this - // SAFETY: - // - row has exactly SHA256VM_WIDTH elements - // - We're zeroing all SHA256VM_WIDTH elements to clear any garbage data - // that might overflow the field - // - Casting F* to u8* preserves validity for write_bytes operation - // - SHA256VM_WIDTH * size_of::() correctly calculates total bytes to - // zero - unsafe { - std::ptr::write_bytes( - row.as_mut_ptr() as *mut u8, - 0, - SHA256VM_WIDTH * size_of::(), - ); - } - let cols: &mut Sha256VmRoundCols = - row[..SHA256VM_ROUND_WIDTH].borrow_mut(); - self.inner.generate_default_row(&mut cols.inner); - }); - return; - } - - // SAFETY: - // - caller ensures `trace` contains a valid record representation that was - // previously written by the executor - // - slice contains a valid Sha256VmRecord with the exact layout specified - // - get_record_from_slice will correctly split the buffer into header, input, and - // aux components based on this layout - let record: Sha256VmRecordMut = unsafe { - get_record_from_slice( - slice, - Sha256VmRecordLayout { - metadata: Sha256VmMetadata { - num_blocks: *num_blocks as u32, - }, - }, - ) - }; - - let mut input: Vec = Vec::with_capacity(SHA256_BLOCK_CELLS * num_blocks); - input.extend_from_slice(record.input); - let mut padded_input = input.clone(); - let len = record.inner.len as usize; - let padded_input_len = padded_input.len(); - padded_input[len] = 1 << (RV32_CELL_BITS - 1); - padded_input[len + 1..padded_input_len - 4].fill(0); - padded_input[padded_input_len - 4..] - .copy_from_slice(&((len as u32) << 3).to_be_bytes()); - - let mut prev_hashes = Vec::with_capacity(*num_blocks); - prev_hashes.push(SHA256_H); - for i in 0..*num_blocks - 1 { - prev_hashes.push(Sha256FillerHelper::get_block_hash( - &prev_hashes[i], - padded_input[i * SHA256_BLOCK_CELLS..(i + 1) * SHA256_BLOCK_CELLS] - .try_into() - .unwrap(), - )); - } - // Copy the read aux records and input to another place to safely fill in the trace - // matrix without overwriting the record - let mut read_aux_records = Vec::with_capacity(SHA256_NUM_READ_ROWS * num_blocks); - read_aux_records.extend_from_slice(record.read_aux); - let vm_record = record.inner.clone(); - - slice - .par_chunks_exact_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK) - .enumerate() - .for_each(|(block_idx, block_slice)| { - // Need to get rid of the accidental garbage data that might overflow the - // F's prime field. Unfortunately, there is no good way around this - // SAFETY: - // - block_slice comes from par_chunks_exact_mut with exact size guarantee - // - Length is SHA256_ROWS_PER_BLOCK * SHA256VM_WIDTH * size_of::() bytes - // - Zeroing entire blocks prevents using garbage data - // - The subsequent trace filling will overwrite with valid values - unsafe { - std::ptr::write_bytes( - block_slice.as_mut_ptr() as *mut u8, - 0, - SHA256_ROWS_PER_BLOCK * SHA256VM_WIDTH * size_of::(), - ); - } - self.fill_block_trace::( - block_slice, - &vm_record, - &read_aux_records[block_idx * SHA256_NUM_READ_ROWS - ..(block_idx + 1) * SHA256_NUM_READ_ROWS], - &input[block_idx * SHA256_BLOCK_CELLS - ..(block_idx + 1) * SHA256_BLOCK_CELLS], - &padded_input[block_idx * SHA256_BLOCK_CELLS - ..(block_idx + 1) * SHA256_BLOCK_CELLS], - block_idx == *num_blocks - 1, - *global_block_offset + block_idx, - block_idx, - prev_hashes[block_idx], - mem_helper, - ); - }); - }, - ); - - // Do a second pass over the trace to fill in the missing values - // Note, we need to skip the very first row - trace_matrix.values[SHA256VM_WIDTH..] - .par_chunks_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK) - .take(rows_used / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - self.inner - .generate_missing_cells(chunk, SHA256VM_WIDTH, SHA256VM_CONTROL_WIDTH); - }); - } -} - -impl Sha256VmFiller { - #[allow(clippy::too_many_arguments)] - fn fill_block_trace( - &self, - block_slice: &mut [F], - record: &Sha256VmRecordHeader, - read_aux_records: &[MemoryReadAuxRecord], - input: &[u8], - padded_input: &[u8], - is_last_block: bool, - global_block_idx: usize, - local_block_idx: usize, - prev_hash: [u32; 8], - mem_helper: &MemoryAuxColsFactory, - ) { - debug_assert_eq!(input.len(), SHA256_BLOCK_CELLS); - debug_assert_eq!(padded_input.len(), SHA256_BLOCK_CELLS); - debug_assert_eq!(read_aux_records.len(), SHA256_NUM_READ_ROWS); - - let padded_input = array::from_fn(|i| { - u32::from_be_bytes(padded_input[i * 4..(i + 1) * 4].try_into().unwrap()) - }); - - let block_start_timestamp = record.timestamp - + (SHA256_REGISTER_READS + SHA256_NUM_READ_ROWS * local_block_idx) as u32; - - let read_cells = (SHA256_BLOCK_CELLS * local_block_idx) as u32; - let block_start_read_ptr = record.src_ptr + read_cells; - - let message_left = if record.len <= read_cells { - 0 - } else { - (record.len - read_cells) as usize - }; - - // -1 means that padding occurred before the start of the block - // 18 means that no padding occurred on this block - let first_padding_row = if record.len < read_cells { - -1 - } else if message_left < SHA256_BLOCK_CELLS { - (message_left / SHA256_READ_SIZE) as i32 - } else { - 18 - }; - - // Fill in the VM columns first because the inner `carry_or_buffer` needs to be filled in - block_slice - .par_chunks_exact_mut(SHA256VM_WIDTH) - .enumerate() - .for_each(|(row_idx, row_slice)| { - // Handle round rows and digest row separately - if row_idx == SHA256_ROWS_PER_BLOCK - 1 { - // This is a digest row - let digest_cols: &mut Sha256VmDigestCols = - row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut(); - digest_cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); - digest_cols.from_state.pc = F::from_canonical_u32(record.from_pc); - digest_cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); - digest_cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); - digest_cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); - digest_cols.dst_ptr = record.dst_ptr.to_le_bytes().map(F::from_canonical_u8); - digest_cols.src_ptr = record.src_ptr.to_le_bytes().map(F::from_canonical_u8); - digest_cols.len_data = record.len.to_le_bytes().map(F::from_canonical_u8); - if is_last_block { - digest_cols - .register_reads_aux - .iter_mut() - .zip(record.register_reads_aux.iter()) - .enumerate() - .for_each(|(idx, (cols_read, record_read))| { - mem_helper.fill( - record_read.prev_timestamp, - record.timestamp + idx as u32, - cols_read.as_mut(), - ); - }); - digest_cols - .writes_aux - .set_prev_data(record.write_aux.prev_data.map(F::from_canonical_u8)); - // In the last block we do `SHA256_NUM_READ_ROWS` reads and then write the - // result thus the timestamp of the write is - // `block_start_timestamp + SHA256_NUM_READ_ROWS` - mem_helper.fill( - record.write_aux.prev_timestamp, - block_start_timestamp + SHA256_NUM_READ_ROWS as u32, - digest_cols.writes_aux.as_mut(), - ); - // Need to range check the destination and source pointers - let msl_rshift: u32 = - ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; - let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - - self.pointer_max_bits) - as u32; - self.bitwise_lookup_chip.request_range( - (record.dst_ptr >> msl_rshift) << msl_lshift, - (record.src_ptr >> msl_rshift) << msl_lshift, - ); - } else { - // Filling in zeros to make sure the accidental garbage data doesn't - // overflow the prime - digest_cols.register_reads_aux.iter_mut().for_each(|aux| { - mem_helper.fill_zero(aux.as_mut()); - }); - digest_cols - .writes_aux - .set_prev_data([F::ZERO; SHA256_WRITE_SIZE]); - mem_helper.fill_zero(digest_cols.writes_aux.as_mut()); - } - digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); - digest_cols.inner.flags.is_digest_row = F::from_bool(true); - } else { - // This is a round row - let round_cols: &mut Sha256VmRoundCols = - row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); - // Take care of the first 4 round rows (aka read rows) - if row_idx < SHA256_NUM_READ_ROWS { - round_cols - .inner - .message_schedule - .carry_or_buffer - .as_flattened_mut() - .iter_mut() - .zip( - input[row_idx * SHA256_READ_SIZE..(row_idx + 1) * SHA256_READ_SIZE] - .iter(), - ) - .for_each(|(cell, data)| { - *cell = F::from_canonical_u8(*data); - }); - mem_helper.fill( - read_aux_records[row_idx].prev_timestamp, - block_start_timestamp + row_idx as u32, - round_cols.read_aux.as_mut(), - ); - } else { - mem_helper.fill_zero(round_cols.read_aux.as_mut()); - } - } - // Fill in the control cols, doesn't matter if it is a round or digest row - let control_cols: &mut Sha256VmControlCols = - row_slice[..SHA256VM_CONTROL_WIDTH].borrow_mut(); - control_cols.len = F::from_canonical_u32(record.len); - // Only the first `SHA256_NUM_READ_ROWS` rows increment the timestamp and read ptr - control_cols.cur_timestamp = F::from_canonical_u32( - block_start_timestamp + min(row_idx, SHA256_NUM_READ_ROWS) as u32, - ); - control_cols.read_ptr = F::from_canonical_u32( - block_start_read_ptr - + (SHA256_READ_SIZE * min(row_idx, SHA256_NUM_READ_ROWS)) as u32, - ); - - // Fill in the padding flags - if row_idx < SHA256_NUM_READ_ROWS { - #[allow(clippy::comparison_chain)] - if (row_idx as i32) < first_padding_row { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - PaddingFlags::NotPadding as usize, - ) - .map(F::from_canonical_u32); - } else if row_idx as i32 == first_padding_row { - let len = message_left - row_idx * SHA256_READ_SIZE; - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - if row_idx == 3 && is_last_block { - PaddingFlags::FirstPadding0_LastRow - } else { - PaddingFlags::FirstPadding0 - } as usize - + len, - ) - .map(F::from_canonical_u32); - } else { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - if row_idx == 3 && is_last_block { - PaddingFlags::EntirePaddingLastRow - } else { - PaddingFlags::EntirePadding - } as usize, - ) - .map(F::from_canonical_u32); - } - } else { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - PaddingFlags::NotConsidered as usize, - ) - .map(F::from_canonical_u32); - } - if is_last_block && row_idx == SHA256_ROWS_PER_BLOCK - 1 { - // If last digest row, then we set padding_occurred = 0 - control_cols.padding_occurred = F::ZERO; - } else { - control_cols.padding_occurred = - F::from_bool((row_idx as i32) >= first_padding_row); - } - }); - - // Fill in the inner trace when the `buffer_or_carry` is filled in - self.inner.generate_block_trace::( - block_slice, - SHA256VM_WIDTH, - SHA256VM_CONTROL_WIDTH, - &padded_input, - self.bitwise_lookup_chip.as_ref(), - &prev_hash, - is_last_block, - global_block_idx as u32 + 1, // global block index is 1-indexed - local_block_idx as u32, - ); - } -} diff --git a/extensions/sha256/guest/src/lib.rs b/extensions/sha256/guest/src/lib.rs deleted file mode 100644 index 8f7c072f4a..0000000000 --- a/extensions/sha256/guest/src/lib.rs +++ /dev/null @@ -1,69 +0,0 @@ -#![no_std] - -#[cfg(target_os = "zkvm")] -use openvm_platform::alloc::AlignedBuf; - -/// This is custom-0 defined in RISC-V spec document -pub const OPCODE: u8 = 0x0b; -pub const SHA256_FUNCT3: u8 = 0b100; -pub const SHA256_FUNCT7: u8 = 0x1; - -/// Native hook for sha256 -/// -/// # Safety -/// -/// The VM accepts the preimage by pointer and length, and writes the -/// 32-byte hash. -/// - `bytes` must point to an input buffer at least `len` long. -/// - `output` must point to a buffer that is at least 32-bytes long. -/// -/// [`sha2`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf -#[cfg(target_os = "zkvm")] -#[inline(always)] -#[no_mangle] -pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { - // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or - // `output` are not aligned to 4 bytes. - // The minimum alignment required for the input and output buffers - const MIN_ALIGN: usize = 4; - // The preferred alignment for the input buffer, since the input is read in chunks of 16 bytes - const INPUT_ALIGN: usize = 16; - // The preferred alignment for the output buffer, since the output is written in chunks of 32 - // bytes - const OUTPUT_ALIGN: usize = 32; - unsafe { - if bytes as usize % MIN_ALIGN != 0 { - let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); - if output as usize % MIN_ALIGN != 0 { - let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); - __native_sha256(aligned_buff.ptr, len, aligned_out.ptr); - core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); - } else { - __native_sha256(aligned_buff.ptr, len, output); - } - } else { - if output as usize % MIN_ALIGN != 0 { - let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); - __native_sha256(bytes, len, aligned_out.ptr); - core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); - } else { - __native_sha256(bytes, len, output); - } - }; - } -} - -/// sha256 intrinsic binding -/// -/// # Safety -/// -/// The VM accepts the preimage by pointer and length, and writes the -/// 32-byte hash. -/// - `bytes` must point to an input buffer at least `len` long. -/// - `output` must point to a buffer that is at least 32-bytes long. -/// - `bytes` and `output` must be 4-byte aligned. -#[cfg(target_os = "zkvm")] -#[inline(always)] -fn __native_sha256(bytes: *const u8, len: usize, output: *mut u8) { - openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA256_FUNCT3, funct7 = SHA256_FUNCT7, rd = In output, rs1 = In bytes, rs2 = In len); -} diff --git a/extensions/sha256/transpiler/src/lib.rs b/extensions/sha256/transpiler/src/lib.rs deleted file mode 100644 index 6b13efe055..0000000000 --- a/extensions/sha256/transpiler/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -use openvm_instructions::{riscv::RV32_MEMORY_AS, LocalOpcode}; -use openvm_instructions_derive::LocalOpcode; -use openvm_sha256_guest::{OPCODE, SHA256_FUNCT3, SHA256_FUNCT7}; -use openvm_stark_backend::p3_field::PrimeField32; -use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput}; -use rrs_lib::instruction_formats::RType; -use strum::{EnumCount, EnumIter, FromRepr}; - -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x320] -#[repr(usize)] -pub enum Rv32Sha256Opcode { - SHA256, -} - -#[derive(Default)] -pub struct Sha256TranspilerExtension; - -impl TranspilerExtension for Sha256TranspilerExtension { - fn process_custom(&self, instruction_stream: &[u32]) -> Option> { - if instruction_stream.is_empty() { - return None; - } - let instruction_u32 = instruction_stream[0]; - let opcode = (instruction_u32 & 0x7f) as u8; - let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; - - if (opcode, funct3) != (OPCODE, SHA256_FUNCT3) { - return None; - } - let dec_insn = RType::new(instruction_u32); - - if dec_insn.funct7 != SHA256_FUNCT7 as u32 { - return None; - } - let instruction = from_r_type( - Rv32Sha256Opcode::SHA256.global_opcode().as_usize(), - RV32_MEMORY_AS as usize, - &dec_insn, - true, - ); - Some(TranspilerOutput::one_to_one(instruction)) - } -} diff --git a/guest-libs/k256/Cargo.toml b/guest-libs/k256/Cargo.toml index 11b2abccfa..bde91433c3 100644 --- a/guest-libs/k256/Cargo.toml +++ b/guest-libs/k256/Cargo.toml @@ -35,8 +35,8 @@ openvm-transpiler.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true @@ -71,14 +71,14 @@ pkcs8 = ["ecdsa-core/pkcs8", "elliptic-curve/pkcs8"] precomputed-tables = ["arithmetic", "once_cell"] schnorr = ["arithmetic", "signature"] serde = ["ecdsa-core/serde", "elliptic-curve/serde"] -sha256 = [] +sha2 = [] test-vectors = [] # Internal feature for testing only. cuda = [ "openvm-circuit/cuda", "openvm-ecc-circuit/cuda", - "openvm-sha256-circuit/cuda", + "openvm-sha2-circuit/cuda", ] tco = ["openvm-circuit/tco"] aot = ["openvm-circuit/aot"] diff --git a/guest-libs/k256/src/ecdsa.rs b/guest-libs/k256/src/ecdsa.rs index 69128c0384..29b10f882a 100644 --- a/guest-libs/k256/src/ecdsa.rs +++ b/guest-libs/k256/src/ecdsa.rs @@ -3,7 +3,7 @@ // Use these types instead of unpatched k256::ecdsa::{Signature, VerifyingKey} // because those are type aliases that use non-zkvm implementations -#[cfg(any(feature = "ecdsa", feature = "sha256"))] +#[cfg(any(feature = "ecdsa", feature = "sha2"))] pub use ecdsa_core::hazmat; pub use ecdsa_core::{ signature::{self, Error}, diff --git a/guest-libs/k256/tests/lib.rs b/guest-libs/k256/tests/lib.rs index 59eb42dde3..7573d7a49d 100644 --- a/guest-libs/k256/tests/lib.rs +++ b/guest-libs/k256/tests/lib.rs @@ -13,7 +13,7 @@ mod guest_tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -99,7 +99,7 @@ mod guest_tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, Rv32WeierstrassConfigExecutor, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256ProverExt}; + use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2ProverExt}; use serde::{Deserialize, Serialize}; #[cfg(feature = "cuda")] use { @@ -128,14 +128,14 @@ mod guest_tests { #[config(generics = true)] pub weierstrass: Rv32WeierstrassConfig, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } impl EcdsaConfig { pub fn new(curves: Vec) -> Self { Self { weierstrass: Rv32WeierstrassConfig::new(curves), - sha256: Default::default(), + sha2: Default::default(), } } } @@ -179,8 +179,8 @@ mod guest_tests { )?; let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover( - &Sha256ProverExt, - &config.sha256, + &Sha2ProverExt, + &config.sha2, inventory, )?; Ok(chip_complex) @@ -214,8 +214,8 @@ mod guest_tests { )?; let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover( - &Sha256ProverExt, - &config.sha256, + &Sha2ProverExt, + &config.sha2, inventory, )?; Ok(chip_complex) @@ -237,7 +237,7 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(EcdsaBuilder, config, openvm_exe); Ok(()) diff --git a/guest-libs/k256/tests/programs/examples/ecdsa.rs b/guest-libs/k256/tests/programs/examples/ecdsa.rs index d886e80ed0..f66fdf7622 100644 --- a/guest-libs/k256/tests/programs/examples/ecdsa.rs +++ b/guest-libs/k256/tests/programs/examples/ecdsa.rs @@ -3,13 +3,15 @@ extern crate alloc; +use core::hint::black_box; + use ecdsa::RecoveryId; use hex_literal::hex; use openvm_k256::ecdsa::{Signature, VerifyingKey}; // clippy thinks this is unused, but it's used in the init! macro #[allow(unused)] use openvm_k256::Secp256k1Point; -use openvm_sha2::sha256; +use openvm_sha2::{Digest, Sha256}; openvm::init!("openvm_init_ecdsa.rs"); @@ -49,10 +51,12 @@ const RECOVERY_TEST_VECTORS: &[RecoveryTestVector] = &[ // Test public key recovery fn main() { for vector in RECOVERY_TEST_VECTORS { - let digest = sha256(vector.msg); + let mut hasher = Sha256::new(); + hasher.update(black_box(&vector.msg)); + let digest = hasher.finalize(); let sig = Signature::try_from(vector.sig.as_slice()).unwrap(); let recid = vector.recid; - let pk = VerifyingKey::recover_from_prehash(digest.as_slice(), &sig, recid).unwrap(); + let pk = VerifyingKey::recover_from_prehash(&digest, &sig, recid).unwrap(); assert_eq!(&vector.pk[..], &pk.to_sec1_bytes(true)); } } diff --git a/guest-libs/p256/Cargo.toml b/guest-libs/p256/Cargo.toml index 5bcb547846..c591114a23 100644 --- a/guest-libs/p256/Cargo.toml +++ b/guest-libs/p256/Cargo.toml @@ -32,8 +32,8 @@ openvm-transpiler.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true @@ -60,7 +60,7 @@ jwk = ["elliptic-curve/jwk"] pem = ["elliptic-curve/pem", "ecdsa-core/pem", "pkcs8"] pkcs8 = ["ecdsa-core?/pkcs8", "elliptic-curve/pkcs8"] serde = ["ecdsa-core?/serde", "elliptic-curve/serde"] -sha256 = [] +sha2 = [] test-vectors = [] voprf = ["elliptic-curve/voprf"] @@ -68,7 +68,7 @@ voprf = ["elliptic-curve/voprf"] cuda = [ "openvm-circuit/cuda", "openvm-ecc-circuit/cuda", - "openvm-sha256-circuit/cuda", + "openvm-sha2-circuit/cuda", ] tco = ["openvm-circuit/tco"] aot = ["openvm-circuit/aot"] diff --git a/guest-libs/p256/tests/lib.rs b/guest-libs/p256/tests/lib.rs index 9eaf2b2c74..19f8eb464e 100644 --- a/guest-libs/p256/tests/lib.rs +++ b/guest-libs/p256/tests/lib.rs @@ -13,7 +13,7 @@ mod guest_tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -99,7 +99,7 @@ mod guest_tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, Rv32WeierstrassConfigExecutor, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256ProverExt}; + use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2ProverExt}; use serde::{Deserialize, Serialize}; #[cfg(feature = "cuda")] use { @@ -128,14 +128,14 @@ mod guest_tests { #[config(generics = true)] pub weierstrass: Rv32WeierstrassConfig, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } impl EcdsaConfig { pub fn new(curves: Vec) -> Self { Self { weierstrass: Rv32WeierstrassConfig::new(curves), - sha256: Default::default(), + sha2: Default::default(), } } } @@ -179,8 +179,8 @@ mod guest_tests { )?; let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover( - &Sha256ProverExt, - &config.sha256, + &Sha2ProverExt, + &config.sha2, inventory, )?; Ok(chip_complex) @@ -214,8 +214,8 @@ mod guest_tests { )?; let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover( - &Sha256ProverExt, - &config.sha256, + &Sha2ProverExt, + &config.sha2, inventory, )?; Ok(chip_complex) @@ -237,7 +237,7 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(EcdsaBuilder, config, openvm_exe); Ok(()) diff --git a/guest-libs/pairing/src/bls12_381/pairing.rs b/guest-libs/pairing/src/bls12_381/pairing.rs index db13c785e1..8ed9df2f68 100644 --- a/guest-libs/pairing/src/bls12_381/pairing.rs +++ b/guest-libs/pairing/src/bls12_381/pairing.rs @@ -25,7 +25,7 @@ use { openvm_pairing_guest::{PairingBaseFunct7, OPCODE, PAIRING_FUNCT3}, openvm_platform::custom_insn_r, openvm_rv32im_guest, - openvm_rv32im_guest::hint_buffer_u32, + openvm_rv32im_guest::hint_buffer_chunked, }; use super::{Bls12_381, Fp, Fp12, Fp2}; @@ -280,7 +280,7 @@ impl PairingCheck for Bls12_381 { } #[cfg(target_os = "zkvm")] { - let hint = MaybeUninit::<(Fp12, Fp12)>::uninit(); + let mut hint = MaybeUninit::<(Fp12, Fp12)>::uninit(); // We do not rely on the slice P's memory layout since rust does not guarantee it across // compiler versions. let p_fat_ptr = (P.as_ptr() as u32, P.len() as u32); @@ -294,8 +294,8 @@ impl PairingCheck for Bls12_381 { rs1 = In &p_fat_ptr, rs2 = In &q_fat_ptr ); - let ptr = hint.as_ptr() as *const u8; - hint_buffer_u32!(ptr, (48 * 12 * 2) / 4); + let ptr = hint.as_mut_ptr() as *mut u8; + hint_buffer_chunked(ptr, (48 * 12 * 2) / 4 as usize); hint.assume_init() } } diff --git a/guest-libs/pairing/src/bn254/pairing.rs b/guest-libs/pairing/src/bn254/pairing.rs index c0f1cc35f2..9fb160511b 100644 --- a/guest-libs/pairing/src/bn254/pairing.rs +++ b/guest-libs/pairing/src/bn254/pairing.rs @@ -21,7 +21,7 @@ use { core::mem::MaybeUninit, openvm_pairing_guest::{PairingBaseFunct7, OPCODE, PAIRING_FUNCT3}, openvm_platform::custom_insn_r, - openvm_rv32im_guest::hint_buffer_u32, + openvm_rv32im_guest::hint_buffer_chunked, }; use super::{Bn254, Fp, Fp12, Fp2}; @@ -314,7 +314,7 @@ impl PairingCheck for Bn254 { } #[cfg(target_os = "zkvm")] { - let hint = MaybeUninit::<(Fp12, Fp12)>::uninit(); + let mut hint = MaybeUninit::<(Fp12, Fp12)>::uninit(); // We do not rely on the slice P's memory layout since rust does not guarantee it across // compiler versions. let p_fat_ptr = (P.as_ptr() as u32, P.len() as u32); @@ -328,8 +328,8 @@ impl PairingCheck for Bn254 { rs1 = In &p_fat_ptr, rs2 = In &q_fat_ptr ); - let ptr = hint.as_ptr() as *const u8; - hint_buffer_u32!(ptr, (32 * 12 * 2) / 4); + let ptr = hint.as_mut_ptr() as *mut u8; + hint_buffer_chunked(ptr, (32 * 12 * 2) / 4 as usize); hint.assume_init() } } diff --git a/guest-libs/sha2/Cargo.toml b/guest-libs/sha2/Cargo.toml index 573930affb..0ae715d364 100644 --- a/guest-libs/sha2/Cargo.toml +++ b/guest-libs/sha2/Cargo.toml @@ -10,22 +10,21 @@ repository.workspace = true license.workspace = true [dependencies] -openvm-sha256-guest = { workspace = true } +openvm = { workspace = true } +openvm-sha2-guest = { workspace = true } +sha2 = { workspace = true, default-features = false } [dev-dependencies] openvm-instructions = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils", "parallel"] } openvm-transpiler = { workspace = true } -openvm-sha256-transpiler = { workspace = true } -openvm-sha256-circuit = { workspace = true } +openvm-sha2-transpiler = { workspace = true } +openvm-sha2-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } -[target.'cfg(not(target_os = "zkvm"))'.dependencies] -sha2 = { workspace = true } - [features] # Internal feature for testing only. -cuda = ["openvm-sha256-circuit/cuda"] \ No newline at end of file +cuda = ["openvm-sha2-circuit/cuda"] \ No newline at end of file diff --git a/guest-libs/sha2/src/host_impl.rs b/guest-libs/sha2/src/host_impl.rs new file mode 100644 index 0000000000..da74208f91 --- /dev/null +++ b/guest-libs/sha2/src/host_impl.rs @@ -0,0 +1,3 @@ +// On a host execution environment, the zkvm impl's input buffering is not necessary, and we can +// use the sha2 crate directly. +pub use sha2::{Sha256, Sha384, Sha512}; diff --git a/guest-libs/sha2/src/lib.rs b/guest-libs/sha2/src/lib.rs index 43d90ba822..2aaeabbfe0 100644 --- a/guest-libs/sha2/src/lib.rs +++ b/guest-libs/sha2/src/lib.rs @@ -1,28 +1,13 @@ #![no_std] -/// The sha256 cryptographic hash function. -#[inline(always)] -pub fn sha256(input: &[u8]) -> [u8; 32] { - let mut output = [0u8; 32]; - set_sha256(input, &mut output); - output -} +pub use sha2::Digest; -/// Sets `output` to the sha256 hash of `input`. -pub fn set_sha256(input: &[u8], output: &mut [u8; 32]) { - #[cfg(not(target_os = "zkvm"))] - { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(input); - output.copy_from_slice(hasher.finalize().as_ref()); - } - #[cfg(target_os = "zkvm")] - { - openvm_sha256_guest::zkvm_sha256_impl( - input.as_ptr(), - input.len(), - output.as_mut_ptr() as *mut u8, - ); - } -} +#[cfg(not(target_os = "zkvm"))] +mod host_impl; +#[cfg(target_os = "zkvm")] +mod zkvm_impl; + +#[cfg(not(target_os = "zkvm"))] +pub use host_impl::*; +#[cfg(target_os = "zkvm")] +pub use zkvm_impl::*; diff --git a/guest-libs/sha2/src/zkvm_impl.rs b/guest-libs/sha2/src/zkvm_impl.rs new file mode 100644 index 0000000000..1f32ba0a70 --- /dev/null +++ b/guest-libs/sha2/src/zkvm_impl.rs @@ -0,0 +1,292 @@ +use core::cmp::min; + +use sha2::digest::{ + consts::{U32, U64}, + FixedOutput, HashMarker, Output, OutputSizeUser, Update, +}; + +// We store static padding bytes here so that we don't need to allocate a vector when padding in +// finalize(). +// Padding always consists of a single 0x80 byte, followed by zeros, (and then the length of the +// message in bits but we don't include that here because it's not static). +// Length of this array is chosen to be the maximum block size between SHA-256 and SHA-512, since +// padding can be at most BLOCK_BYTES bytes. +const PADDING_BYTES: [u8; SHA512_BLOCK_BYTES] = { + let mut padding_bytes = [0u8; SHA512_BLOCK_BYTES]; + padding_bytes[0] = 0x80; + padding_bytes +}; + +const SHA256_STATE_WORDS: usize = 8; +const SHA256_BLOCK_BYTES: usize = 64; +const SHA256_DIGEST_BYTES: usize = 32; + +// Initial state for SHA-256 in 32-bit words +const SHA256_H: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +#[derive(Debug, Clone, Copy)] +pub struct Sha256 { + // the current hasher state, in 32-bit words + state: [u32; SHA256_STATE_WORDS], + // the next block of input + buffer: [u8; SHA256_BLOCK_BYTES], + // idx of next byte to write to buffer (equal to len mod SHA256_BLOCK_BYTES) + idx: usize, + // accumulated length of the input data, in bytes + len: usize, +} + +impl Default for Sha256 { + fn default() -> Self { + Self::new() + } +} + +impl Sha256 { + pub fn new() -> Self { + Self { + state: SHA256_H, + buffer: [0; SHA256_BLOCK_BYTES], + idx: 0, + len: 0, + } + } + + fn update(&mut self, mut input: &[u8]) { + self.len += input.len(); + while !input.is_empty() { + let to_copy = min(input.len(), SHA256_BLOCK_BYTES - self.idx); + self.buffer[self.idx..self.idx + to_copy].copy_from_slice(&input[..to_copy]); + self.idx += to_copy; + if self.idx == SHA256_BLOCK_BYTES { + self.idx = 0; + self.compress(); + } + input = &input[to_copy..]; + } + } + + fn finalize(mut self) -> [u8; SHA256_DIGEST_BYTES] { + // pad until length in bytes is 56 mod 64 (leave 8 bytes for the message length) + let num_bytes_of_padding = SHA256_BLOCK_BYTES - 8 - self.idx; + // ensure num_bytes_of_padding is positive + let num_bytes_of_padding = (num_bytes_of_padding + SHA256_BLOCK_BYTES) % SHA256_BLOCK_BYTES; + let message_len_in_bits = self.len * 8; + self.update(&PADDING_BYTES[..num_bytes_of_padding]); + self.update(&(message_len_in_bits as u64).to_be_bytes()); + let mut output = [0u8; SHA256_DIGEST_BYTES]; + output + .chunks_exact_mut(4) + .zip(self.state.iter()) + .for_each(|(chunk, x)| { + chunk.copy_from_slice(&x.to_be_bytes()); + }); + output + } + + fn compress(&mut self) { + openvm_sha2_guest::zkvm_sha256_impl( + self.state.as_ptr() as *const u8, + self.buffer.as_ptr(), + self.state.as_mut_ptr() as *mut u8, + ); + } +} + +// We will implement FixedOutput, Default, Update, and HashMarker for Sha256 so that +// the blanket implementation of sha2::Digest is available. +// See: https://docs.rs/sha2/latest/sha2/trait.Digest.html#impl-Digest-for-D +impl Update for Sha256 { + fn update(&mut self, input: &[u8]) { + self.update(input); + } +} + +// OutputSizeUser is required for FixedOutput +// See: https://docs.rs/digest/0.10.7/digest/trait.FixedOutput.html +impl OutputSizeUser for Sha256 { + type OutputSize = U32; +} + +impl FixedOutput for Sha256 { + fn finalize_into(self, out: &mut Output) { + out.copy_from_slice(&self.finalize()); + } +} + +impl HashMarker for Sha256 {} + +const SHA512_STATE_WORDS: usize = 8; +const SHA512_BLOCK_BYTES: usize = 128; +const SHA512_DIGEST_BYTES: usize = 64; + +// Initial state for SHA-512 in 64-bit words +pub const SHA512_H: [u64; 8] = [ + 0x6a09e667f3bcc908, + 0xbb67ae8584caa73b, + 0x3c6ef372fe94f82b, + 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, + 0x9b05688c2b3e6c1f, + 0x1f83d9abfb41bd6b, + 0x5be0cd19137e2179, +]; + +#[derive(Debug, Clone, Copy)] +pub struct Sha512 { + // the current hasher state + state: [u64; SHA512_STATE_WORDS], + // the next block of input + buffer: [u8; SHA512_BLOCK_BYTES], + // idx of next byte to write to buffer + idx: usize, + // accumulated length of the input data, in bytes + len: usize, +} + +impl Default for Sha512 { + fn default() -> Self { + Self::new() + } +} + +impl Sha512 { + pub fn new() -> Self { + Self { + state: SHA512_H, + buffer: [0; SHA512_BLOCK_BYTES], + idx: 0, + len: 0, + } + } + + fn update(&mut self, mut input: &[u8]) { + self.len += input.len(); + while !input.is_empty() { + let to_copy = min(input.len(), SHA512_BLOCK_BYTES - self.idx); + self.buffer[self.idx..self.idx + to_copy].copy_from_slice(&input[..to_copy]); + self.idx += to_copy; + if self.idx == SHA512_BLOCK_BYTES { + self.idx = 0; + self.compress(); + } + input = &input[to_copy..]; + } + } + + fn finalize(mut self) -> [u8; SHA512_DIGEST_BYTES] { + // pad until length in bytes is 112 mod 128 (leave 16 bytes for the message length) + let num_bytes_of_padding = SHA512_BLOCK_BYTES - 16 - self.idx; + // ensure num_bytes_of_padding is positive + let num_bytes_of_padding = (num_bytes_of_padding + SHA512_BLOCK_BYTES) % SHA512_BLOCK_BYTES; + let message_len_in_bits = self.len * 8; + self.update(&PADDING_BYTES[..num_bytes_of_padding]); + self.update(&(message_len_in_bits as u128).to_be_bytes()); + let mut output = [0u8; SHA512_DIGEST_BYTES]; + output + .chunks_exact_mut(8) + .zip(self.state.iter()) + .for_each(|(chunk, x)| { + chunk.copy_from_slice(&x.to_be_bytes()); + }); + output + } + + fn compress(&mut self) { + openvm_sha2_guest::zkvm_sha512_impl( + self.state.as_ptr() as *const u8, + self.buffer.as_ptr(), + self.state.as_mut_ptr() as *mut u8, + ); + } +} + +// We will implement FixedOutput, Default, Update, and HashMarker for Sha512 so that +// the blanket implementation of sha2::Digest is available. +// See: https://docs.rs/sha2/latest/sha2/trait.Digest.html#impl-Digest-for-D +impl Update for Sha512 { + fn update(&mut self, input: &[u8]) { + self.update(input); + } +} + +// OutputSizeUser is required for FixedOutput +// See: https://docs.rs/digest/0.10.7/digest/trait.FixedOutput.html +impl OutputSizeUser for Sha512 { + type OutputSize = U64; +} + +impl FixedOutput for Sha512 { + fn finalize_into(self, out: &mut Output) { + out.copy_from_slice(&self.finalize()); + } +} + +impl HashMarker for Sha512 {} + +const SHA384_DIGEST_BYTES: usize = 48; + +// Initial state for SHA-384 in 64-bit words +pub const SHA384_H: [u64; 8] = [ + 0xcbbb9d5dc1059ed8, + 0x629a292a367cd507, + 0x9159015a3070dd17, + 0x152fecd8f70e5939, + 0x67332667ffc00b31, + 0x8eb44a8768581511, + 0xdb0c2e0d64f98fa7, + 0x47b5481dbefa4fa4, +]; + +#[derive(Debug, Clone, Copy)] +pub struct Sha384 { + inner: Sha512, +} + +impl Default for Sha384 { + fn default() -> Self { + Self::new() + } +} + +impl Sha384 { + pub fn new() -> Self { + let mut inner = Sha512::new(); + inner.state = SHA384_H; + Self { inner } + } + + pub fn update(&mut self, input: &[u8]) { + self.inner.update(input); + } + + pub fn finalize(self) -> [u8; SHA384_DIGEST_BYTES] { + let digest = self.inner.finalize(); + digest[..SHA384_DIGEST_BYTES].try_into().unwrap() + } +} + +// We will implement FixedOutput, Default, Update, and HashMarker for Sha384 so that +// the blanket implementation of sha2::Digest is available. +// See: https://docs.rs/sha2/latest/sha2/trait.Digest.html#impl-Digest-for-D +impl Update for Sha384 { + fn update(&mut self, input: &[u8]) { + self.update(input); + } +} + +// OutputSizeUser is required for FixedOutput +// See: https://docs.rs/digest/0.10.7/digest/trait.FixedOutput.html +impl OutputSizeUser for Sha384 { + type OutputSize = U64; +} + +impl FixedOutput for Sha384 { + fn finalize_into(self, out: &mut Output) { + out.copy_from_slice(&self.finalize()); + } +} + +impl HashMarker for Sha384 {} diff --git a/guest-libs/sha2/tests/lib.rs b/guest-libs/sha2/tests/lib.rs index adfae8e764..26319f7e4d 100644 --- a/guest-libs/sha2/tests/lib.rs +++ b/guest-libs/sha2/tests/lib.rs @@ -6,8 +6,8 @@ mod tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_circuit::{Sha256Rv32Builder, Sha256Rv32Config}; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_circuit::{Sha2Rv32Builder, Sha2Rv32Config}; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -15,19 +15,19 @@ mod tests { type F = BabyBear; #[test] - fn test_sha256() -> Result<()> { - let config = Sha256Rv32Config::default(); + fn test_sha2() -> Result<()> { + let config = Sha2Rv32Config::default(); let elf = - build_example_program_at_path(get_programs_dir!("tests/programs"), "sha", &config)?; + build_example_program_at_path(get_programs_dir!("tests/programs"), "sha2", &config)?; let openvm_exe = VmExe::from_elf( elf, Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; - air_test(Sha256Rv32Builder, config, openvm_exe); + air_test(Sha2Rv32Builder, config, openvm_exe); Ok(()) } } diff --git a/guest-libs/sha2/tests/programs/Cargo.toml b/guest-libs/sha2/tests/programs/Cargo.toml index df13f8dfc7..60618893c9 100644 --- a/guest-libs/sha2/tests/programs/Cargo.toml +++ b/guest-libs/sha2/tests/programs/Cargo.toml @@ -6,18 +6,12 @@ edition = "2021" [dependencies] openvm = { path = "../../../../crates/toolchain/openvm" } -openvm-platform = { path = "../../../../crates/toolchain/platform" } openvm-sha2 = { path = "../../" } - hex = { version = "0.4.3", default-features = false, features = ["alloc"] } -serde = { version = "1.0", default-features = false, features = [ - "alloc", - "derive", -] } [features] default = [] -std = ["serde/std", "openvm/std"] +std = ["openvm/std"] [profile.release] panic = "abort" diff --git a/guest-libs/sha2/tests/programs/examples/sha.rs b/guest-libs/sha2/tests/programs/examples/sha.rs deleted file mode 100644 index ebfd50cbee..0000000000 --- a/guest-libs/sha2/tests/programs/examples/sha.rs +++ /dev/null @@ -1,29 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_main)] -#![cfg_attr(not(feature = "std"), no_std)] - -extern crate alloc; - -use alloc::vec::Vec; -use core::hint::black_box; - -use hex::FromHex; -use openvm_sha2::sha256; - -openvm::entry!(main); - -pub fn main() { - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ("98c1c0bdb7d5fea9a88859f06c6c439f", "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05"), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - ]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = sha256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} diff --git a/guest-libs/sha2/tests/programs/examples/sha2.rs b/guest-libs/sha2/tests/programs/examples/sha2.rs new file mode 100644 index 0000000000..49c0dafeb4 --- /dev/null +++ b/guest-libs/sha2/tests/programs/examples/sha2.rs @@ -0,0 +1,95 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +extern crate alloc; + +use alloc::vec::Vec; +use core::hint::black_box; + +use hex::FromHex; +use openvm_sha2::{Digest, Sha256, Sha384, Sha512}; + +openvm::entry!(main); + +struct ShaTestVector { + input: &'static str, + expected_output_sha256: &'static str, + expected_output_sha512: &'static str, + expected_output_sha384: &'static str, +} + +pub fn main() { + let test_vectors = [ + ShaTestVector { + input: "", + expected_output_sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expected_output_sha512: "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", + expected_output_sha384: "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", + }, + ShaTestVector { + input: "98c1c0bdb7d5fea9a88859f06c6c439f", + expected_output_sha256: "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", + expected_output_sha512: "eb576959c531f116842c0cc915a29c8f71d7a285c894c349b83469002ef093d51f9f14ce4248488bff143025e47ed27c12badb9cd43779cb147408eea062d583", + expected_output_sha384: "63e3061aab01f335ea3a4e617b9d14af9b63a5240229164ee962f6d5335ff25f0f0bf8e46723e83c41b9d17413b6a3c7", + }, + ShaTestVector { + input: "5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", + expected_output_sha256: "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7", + expected_output_sha512: "a20d5fb14814d045a7d2861e80d2b688f1cd1daaba69e6bb1cc5233f514141ea4623b3373af702e78e3ec5dc8c1b716a37a9a2f5fbc9493b9df7043f5e99a8da", + expected_output_sha384: "eac4b72b0540486bc088834860873338e31e9e4062532bf509191ef63b9298c67db5654a28fe6f07e4cc6ff466d1be24", + }, + ShaTestVector { + input: "9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", + expected_output_sha256: "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23", + expected_output_sha512: "8d215ee6dc26757c210db0dd00c1c6ed16cc34dbd4bb0fa10c1edb6b62d5ab16aea88c881001b173d270676daf2d6381b5eab8711fa2f5589c477c1d4b84774f", + expected_output_sha384: "904a90010d772a904a35572fdd4bdf1dd253742e47872c8a18e2255f66fa889e44781e65487a043f435daa53c496a53e", + } + ]; + + for ( + i, + ShaTestVector { + input, + expected_output_sha256, + expected_output_sha512, + expected_output_sha384, + }, + ) in test_vectors.iter().enumerate() + { + let input = Vec::from_hex(input).unwrap(); + let expected_output_sha256 = Vec::from_hex(expected_output_sha256).unwrap(); + let mut hasher = Sha256::new(); + hasher.update(black_box(&input)); + let output = hasher.finalize(); + if output.as_slice() != expected_output_sha256.as_slice() { + panic!( + "sha256 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, + expected_output_sha256, + output.as_slice() + ); + } + let expected_output_sha512 = Vec::from_hex(expected_output_sha512).unwrap(); + let mut hasher = Sha512::new(); + hasher.update(black_box(&input)); + let output = hasher.finalize(); + if output.as_slice() != expected_output_sha512.as_slice() { + panic!( + "sha512 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, + expected_output_sha512, + output.as_slice() + ); + } + let expected_output_sha384 = Vec::from_hex(expected_output_sha384).unwrap(); + let mut hasher = Sha384::new(); + hasher.update(black_box(&input)); + let output = hasher.finalize(); + if output.as_slice() != expected_output_sha384.as_slice() { + panic!( + "sha384 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha384, output + ); + } + } +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 25c9b640da..a6a8cdcacb 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -2,4 +2,4 @@ channel = "1.90.0" # To use the "tco" feature, switch to Rust nightly: # channel = "nightly-2025-08-19" -components = ["clippy", "rustfmt"] +components = ["clippy", "rustfmt", "rust-analyzer"]