From 41d177a6c8193b61f69c5c0b73b6acab547508bf Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Mon, 5 Jan 2026 14:20:06 +0800 Subject: [PATCH 01/37] block-multiplier: bench WASI compatible. --- skyscraper/block-multiplier/Cargo.toml | 4 +- skyscraper/block-multiplier/benches/bench.rs | 222 ++++++++++--------- skyscraper/block-multiplier/src/lib.rs | 1 + skyscraper/block-multiplier/src/scalar.rs | 1 + tooling/provekit-bench/Cargo.toml | 2 +- 5 files changed, 126 insertions(+), 104 deletions(-) diff --git a/skyscraper/block-multiplier/Cargo.toml b/skyscraper/block-multiplier/Cargo.toml index ab66b0aa..3960da90 100644 --- a/skyscraper/block-multiplier/Cargo.toml +++ b/skyscraper/block-multiplier/Cargo.toml @@ -24,9 +24,11 @@ ark-ff.workspace = true # 3rd party divan.workspace = true primitive-types.workspace = true -proptest.workspace = true rand.workspace = true +[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] +proptest.workspace = true + [build-dependencies] # Workspace crates block-multiplier-codegen.workspace = true diff --git a/skyscraper/block-multiplier/benches/bench.rs b/skyscraper/block-multiplier/benches/bench.rs index 3e5c6f17..bda9be3a 100644 --- a/skyscraper/block-multiplier/benches/bench.rs +++ b/skyscraper/block-multiplier/benches/bench.rs @@ -1,9 +1,7 @@ #![feature(portable_simd)] use { - core::{array, simd::u64x2}, divan::Bencher, - fp_rounding::with_rounding_mode, rand::{rng, Rng}, }; @@ -33,69 +31,78 @@ mod mul { .bench_local_values(|(a, b)| a * b); } - #[divan::bench] - fn simd_mul(bencher: Bencher) { - bencher - //.counter(ItemsCount::new(2usize)) - .with_inputs(|| rng().random()) - .bench_local_values(|(a, b, c, d)| block_multiplier::simd_mul(a, b, c, d)); - } + #[cfg(target_arch = "aarch64")] + mod aarch64 { + use { + super::*, + core::{array, simd::u64x2}, + fp_rounding::with_rounding_mode, + }; - #[divan::bench] - fn block_mul(bencher: Bencher) { - let bencher = bencher - //.counter(ItemsCount::new(3usize)) - .with_inputs(|| rng().random()); - unsafe { - with_rounding_mode((), |guard, _| { - bencher.bench_local_values(|(a, b, c, d, e, f)| { - block_multiplier::block_mul(guard, a, b, c, d, e, f) + #[divan::bench] + fn simd_mul(bencher: Bencher) { + bencher + //.counter(ItemsCount::new(2usize)) + .with_inputs(|| rng().random()) + .bench_local_values(|(a, b, c, d)| block_multiplier::simd_mul(a, b, c, d)); + } + + #[divan::bench] + fn block_mul(bencher: Bencher) { + let bencher = bencher + //.counter(ItemsCount::new(3usize)) + .with_inputs(|| rng().random()); + unsafe { + with_rounding_mode((), |guard, _| { + bencher.bench_local_values(|(a, b, c, d, e, f)| { + block_multiplier::block_mul(guard, a, b, c, d, e, f) + }); }); - }); + } } - } - #[divan::bench] - fn montgomery_interleaved_3(bencher: Bencher) { - let bencher = bencher - //.counter(ItemsCount::new(3usize)) - .with_inputs(|| { - ( - rng().random(), - rng().random(), - array::from_fn(|_| u64x2::from_array(rng().random())), - array::from_fn(|_| u64x2::from_array(rng().random())), - ) - }); - unsafe { - with_rounding_mode((), |mode_guard, _| { - bencher.bench_local_values(|(a, b, c, d)| { - block_multiplier::montgomery_interleaved_3(mode_guard, a, b, c, d) + #[divan::bench] + fn montgomery_interleaved_3(bencher: Bencher) { + let bencher = bencher + //.counter(ItemsCount::new(3usize)) + .with_inputs(|| { + ( + rng().random(), + rng().random(), + array::from_fn(|_| u64x2::from_array(rng().random())), + array::from_fn(|_| u64x2::from_array(rng().random())), + ) }); - }); + unsafe { + with_rounding_mode((), |mode_guard, _| { + bencher.bench_local_values(|(a, b, c, d)| { + block_multiplier::montgomery_interleaved_3(mode_guard, a, b, c, d) + }); + }); + } } - } - #[divan::bench] - fn montgomery_interleaved_4(bencher: Bencher) { - let bencher = bencher - //.counter(ItemsCount::new(4usize)) - .with_inputs(|| { - ( - rng().random(), - rng().random(), - rng().random(), - rng().random(), - array::from_fn(|_| u64x2::from_array(rng().random())), - array::from_fn(|_| u64x2::from_array(rng().random())), - ) - }); - unsafe { - with_rounding_mode((), |mode_guard, _| { - bencher.bench_local_values(|(a, b, c, d, e, f)| { - block_multiplier::montgomery_interleaved_4(mode_guard, a, b, c, d, e, f) + #[divan::bench] + fn montgomery_interleaved_4(bencher: Bencher) { + let bencher = bencher + //.counter(ItemsCount::new(4usize)) + .with_inputs(|| { + ( + rng().random(), + rng().random(), + rng().random(), + rng().random(), + array::from_fn(|_| u64x2::from_array(rng().random())), + array::from_fn(|_| u64x2::from_array(rng().random())), + ) }); - }); + unsafe { + with_rounding_mode((), |mode_guard, _| { + bencher.bench_local_values(|(a, b, c, d, e, f)| { + block_multiplier::montgomery_interleaved_4(mode_guard, a, b, c, d, e, f) + }); + }); + } } } } @@ -121,38 +128,47 @@ mod sqr { .bench_local_values(|a: Fr| a.square()); } - #[divan::bench] - fn montgomery_square_log_interleaved_3(bencher: Bencher) { - let bencher = bencher.with_inputs(|| { - ( - rng().random(), - array::from_fn(|_| u64x2::from_array(rng().random())), - ) - }); - unsafe { - with_rounding_mode((), |mode_guard, _| { - bencher.bench_local_values(|(a, b)| { - block_multiplier::montgomery_square_log_interleaved_3(mode_guard, a, b) - }); + #[cfg(target_arch = "aarch64")] + mod aarch64 { + use { + super::*, + core::{array, simd::u64x2}, + fp_rounding::with_rounding_mode, + }; + + #[divan::bench] + fn montgomery_square_log_interleaved_3(bencher: Bencher) { + let bencher = bencher.with_inputs(|| { + ( + rng().random(), + array::from_fn(|_| u64x2::from_array(rng().random())), + ) }); + unsafe { + with_rounding_mode((), |mode_guard, _| { + bencher.bench_local_values(|(a, b)| { + block_multiplier::montgomery_square_log_interleaved_3(mode_guard, a, b) + }); + }); + } } - } - #[divan::bench] - fn montgomery_square_log_interleaved_4(bencher: Bencher) { - let bencher = bencher.with_inputs(|| { - ( - rng().random(), - rng().random(), - array::from_fn(|_| u64x2::from_array(rng().random())), - ) - }); - unsafe { - with_rounding_mode((), |mode_guard, _| { - bencher.bench_local_values(|(a, b, c)| { - block_multiplier::montgomery_square_log_interleaved_4(mode_guard, a, b, c) - }); + #[divan::bench] + fn montgomery_square_log_interleaved_4(bencher: Bencher) { + let bencher = bencher.with_inputs(|| { + ( + rng().random(), + rng().random(), + array::from_fn(|_| u64x2::from_array(rng().random())), + ) }); + unsafe { + with_rounding_mode((), |mode_guard, _| { + bencher.bench_local_values(|(a, b, c)| { + block_multiplier::montgomery_square_log_interleaved_4(mode_guard, a, b, c) + }); + }); + } } #[divan::bench] @@ -189,25 +205,27 @@ mod sqr { }); } } - } - #[divan::bench] - fn simd_sqr(bencher: Bencher) { - bencher - //.counter(ItemsCount::new(2usize)) - .with_inputs(|| rng().random()) - .bench_local_values(|(a, b)| block_multiplier::simd_sqr(a, b)); - } + #[divan::bench] + fn simd_sqr(bencher: Bencher) { + bencher + //.counter(ItemsCount::new(2usize)) + .with_inputs(|| rng().random()) + .bench_local_values(|(a, b)| block_multiplier::simd_sqr(a, b)); + } - #[divan::bench] - fn block_sqr(bencher: Bencher) { - let bencher = bencher - //.counter(ItemsCount::new(3usize)) - .with_inputs(|| rng().random()); - unsafe { - with_rounding_mode((), |guard, _| { - bencher.bench_local_values(|(a, b, c)| block_multiplier::block_sqr(guard, a, b, c)); - }); + #[divan::bench] + fn block_sqr(bencher: Bencher) { + let bencher = bencher + //.counter(ItemsCount::new(3usize)) + .with_inputs(|| rng().random()); + unsafe { + with_rounding_mode((), |guard, _| { + bencher.bench_local_values(|(a, b, c)| { + block_multiplier::block_sqr(guard, a, b, c) + }); + }); + } } } } diff --git a/skyscraper/block-multiplier/src/lib.rs b/skyscraper/block-multiplier/src/lib.rs index fe54fa53..f18ad733 100644 --- a/skyscraper/block-multiplier/src/lib.rs +++ b/skyscraper/block-multiplier/src/lib.rs @@ -17,6 +17,7 @@ mod simd_utils; pub mod constants; mod scalar; +#[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI mod test_utils; mod utils; diff --git a/skyscraper/block-multiplier/src/scalar.rs b/skyscraper/block-multiplier/src/scalar.rs index ff7250ec..93bb5c48 100644 --- a/skyscraper/block-multiplier/src/scalar.rs +++ b/skyscraper/block-multiplier/src/scalar.rs @@ -131,6 +131,7 @@ pub fn scalar_mul(a: [u64; 4], b: [u64; 4]) -> [u64; 4] { reduce_ct(subarray!(addv(s, mp), 1, 4)) } +#[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI #[cfg(test)] mod tests { use { diff --git a/tooling/provekit-bench/Cargo.toml b/tooling/provekit-bench/Cargo.toml index 5c6aaddc..b90f5c9a 100644 --- a/tooling/provekit-bench/Cargo.toml +++ b/tooling/provekit-bench/Cargo.toml @@ -34,4 +34,4 @@ workspace = true [[bench]] name = "bench" -harness = false \ No newline at end of file +harness = false From 4be79b37e61d9768eefecb9081ef4373e2492cb4 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Mon, 5 Jan 2026 16:11:50 +0800 Subject: [PATCH 02/37] divan: codspeed only on CI, use regular to build with WASI --- .cargo/config.toml | 6 ++++++ .github/workflows/benchmark.yml | 4 ++++ Cargo.toml | 4 +++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index e757e115..2aa77d57 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,3 +1,9 @@ # This enables KaTex in docs, but requires running `cargo doc --no-deps`. [build] rustdocflags = "--html-in-header .cargo/katex-header.html" + +[target.wasm32-wasip2] +runner = "wasmtime run --dir . " + +[target.wasm32-wasip1] +runner = "wasmtime run --dir . " diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index c9c4bf6a..a7a18c56 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -18,6 +18,10 @@ jobs: steps: - uses: actions/checkout@v4 + - name: Replace divan with codspeed-divan-compat + run: | + sed -i 's/^divan = .*/divan = { package = "codspeed-divan-compat", version = "3.0.1" }/' Cargo.toml + - name: Setup Rust toolchain, cache and cargo-codspeed binary uses: moonrepo/setup-rust@v1 with: diff --git a/Cargo.toml b/Cargo.toml index 97664360..9c51196c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,7 +94,9 @@ axum = "0.8.4" base64 = "0.22.1" bytes = "1.10.1" chrono = "0.4.41" -divan = { package = "codspeed-divan-compat", version = "3.0.1" } +# On CI divan get replaced by divan = { package = "codspeed-divan-compat", version = "3.0.1" } for benchmark tracking. +# This is a workaround because different package selection based on target does not mix well with workspace dependencies. +divan = "0.1.21" hex = "0.4.3" itertools = "0.14.0" paste = "1.0.15" From 11b03662eaf471010eb3c8facb9231859ea78729 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Mon, 5 Jan 2026 16:55:14 +0800 Subject: [PATCH 03/37] block-multiplier: widening mul optimised for WASM --- skyscraper/block-multiplier/src/utils.rs | 29 ++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/skyscraper/block-multiplier/src/utils.rs b/skyscraper/block-multiplier/src/utils.rs index b4e92777..88a14022 100644 --- a/skyscraper/block-multiplier/src/utils.rs +++ b/skyscraper/block-multiplier/src/utils.rs @@ -68,7 +68,32 @@ pub fn sub(a: [u64; N], b: [u64; N]) -> [u64; N] { } #[inline(always)] -pub fn carrying_mul_add(a: u64, b: u64, add: u64, carry: u64) -> (u64, u64) { - let c: u128 = a as u128 * b as u128 + carry as u128 + add as u128; +// Based on ark-ff +// On WASM first doing a widening on the operands will cause __multi3 called +// which is u128xu128 -> u128 causing unnecessary multiplications +pub const fn widening_mul(a: u64, b: u64) -> u128 { + #[cfg(not(target_family = "wasm"))] + { + a as u128 * b as u128 + } + #[cfg(target_family = "wasm")] + { + let a0 = a as u32 as u64; + let a1 = a >> 32; + let b0 = b as u32 as u64; + let b1 = b >> 32; + + let c00 = (a0 * b0) as u128; + let c01 = (a0 * b1) as u128; + let c10 = (a1 * b0) as u128; + let cxx = (c01 + c10) << 32; + let c11 = ((a1 * b1) as u128) << 64; + (c00 | c11) + cxx + } +} + +#[inline(always)] +pub const fn carrying_mul_add(a: u64, b: u64, add: u64, carry: u64) -> (u64, u64) { + let c: u128 = widening_mul(a, b) + carry as u128 + add as u128; (c as u64, (c >> 64) as u64) } From be45a0981d87fb51c2d0f3f0bb92022b1563ca27 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 6 Jan 2026 15:10:02 +0800 Subject: [PATCH 04/37] wasi runners: enable relaxed simd --- .cargo/config.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 2aa77d57..1bcde2a1 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -3,7 +3,8 @@ rustdocflags = "--html-in-header .cargo/katex-header.html" [target.wasm32-wasip2] -runner = "wasmtime run --dir . " +rustflags = ["-C", "target-feature=+relaxed-simd"] [target.wasm32-wasip1] runner = "wasmtime run --dir . " +rustflags = ["-C", "target-feature=+relaxed-simd"] From 2d42c76cf41a4fe3b006aaaf6bd1e9eb6c6ddb2d Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 6 Jan 2026 15:10:39 +0800 Subject: [PATCH 05/37] wasm: bench portable_simd on wasm --- skyscraper/block-multiplier/benches/bench.rs | 10 ++++++++++ skyscraper/block-multiplier/src/lib.rs | 3 +++ 2 files changed, 13 insertions(+) diff --git a/skyscraper/block-multiplier/benches/bench.rs b/skyscraper/block-multiplier/benches/bench.rs index bda9be3a..338a9446 100644 --- a/skyscraper/block-multiplier/benches/bench.rs +++ b/skyscraper/block-multiplier/benches/bench.rs @@ -31,6 +31,16 @@ mod mul { .bench_local_values(|(a, b)| a * b); } + #[divan::bench] + fn simd_mul(bencher: Bencher) { + bencher + //.counter(ItemsCount::new(2usize)) + .with_inputs(|| rng().random()) + .bench_local_values(|(a, b, c, d)| { + block_multiplier::portable_simd_wasm::simd_mul(a, b, c, d) + }); + } + #[cfg(target_arch = "aarch64")] mod aarch64 { use { diff --git a/skyscraper/block-multiplier/src/lib.rs b/skyscraper/block-multiplier/src/lib.rs index f18ad733..dbe70504 100644 --- a/skyscraper/block-multiplier/src/lib.rs +++ b/skyscraper/block-multiplier/src/lib.rs @@ -15,8 +15,11 @@ mod portable_simd; #[cfg(target_arch = "aarch64")] mod simd_utils; +// pub mod block_simd_wasm; pub mod constants; +pub mod portable_simd_wasm; mod scalar; +mod simd_utils_wasm; #[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI mod test_utils; mod utils; From 813b59270c714b61d4204e25ec72af12aff257bc Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 6 Jan 2026 17:57:38 +0800 Subject: [PATCH 06/37] wasm: Add simd flags --- .cargo/config.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 1bcde2a1..262a07a0 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -3,8 +3,8 @@ rustdocflags = "--html-in-header .cargo/katex-header.html" [target.wasm32-wasip2] -rustflags = ["-C", "target-feature=+relaxed-simd"] +rustflags = ["-C", "target-feature=+simd128,+relaxed-simd"] [target.wasm32-wasip1] runner = "wasmtime run --dir . " -rustflags = ["-C", "target-feature=+relaxed-simd"] +rustflags = ["-C", "target-feature=+simd128,+relaxed-simd"] From 1a94a3e5a11913b8529f3c98c724644b5a535e74 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 6 Jan 2026 17:57:46 +0800 Subject: [PATCH 07/37] wasm: Add test to portable_simd --- .../block-multiplier/src/portable_simd.rs | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/skyscraper/block-multiplier/src/portable_simd.rs b/skyscraper/block-multiplier/src/portable_simd.rs index 39ca34f2..5881d8bf 100644 --- a/skyscraper/block-multiplier/src/portable_simd.rs +++ b/skyscraper/block-multiplier/src/portable_simd.rs @@ -377,3 +377,36 @@ pub fn simd_mul( let v = transpose_simd_to_u256(u256_result); (v[0], v[1]) } + +#[cfg(test)] +mod tests { + use { + super::*, + crate::test_utils::{ark_ff_reference, safe_bn254_montgomery_input}, + ark_bn254::Fr, + ark_ff::BigInt, + fp_rounding::{with_rounding_mode, Zero}, + proptest::proptest, + }; + + #[test] + fn test_simd_mul() { + proptest!(|( + a in safe_bn254_montgomery_input(), + b in safe_bn254_montgomery_input(), + c in safe_bn254_montgomery_input(), + )| { + unsafe { + with_rounding_mode((), |rtz : &fp_rounding::RoundingGuard, _| { + + let (ab, bc) = simd_mul(a, b, b,c); + let ab_ref = ark_ff_reference(a, b); + let bc_ref = ark_ff_reference(b, c); + let ab = Fr::new(BigInt(ab)); + let bc = Fr::new(BigInt(bc)); + assert_eq!(ab_ref, ab); + assert_eq!(bc_ref, bc); + });} + }); + } +} From 0143939936c796ab587da1583b14863239768cd4 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 6 Jan 2026 17:59:43 +0800 Subject: [PATCH 08/37] wasm: add portable_simd_wasm --- .../src/portable_simd_wasm.rs | 411 ++++++++++++++++++ 1 file changed, 411 insertions(+) create mode 100644 skyscraper/block-multiplier/src/portable_simd_wasm.rs diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs new file mode 100644 index 00000000..35b7f18b --- /dev/null +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -0,0 +1,411 @@ +use { + crate::{ + constants::*, + simd_utils_wasm::{ + addv_simd, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, + transpose_u256_to_simd, u256_to_u260_shl2_simd, u260_to_u256_simd, + }, + }, + core::{ + ops::BitAnd, + simd::{num::SimdFloat, Simd}, + }, + std::simd::{num::SimdUint, StdFloat}, +}; + +#[inline] +pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { + let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); + + let mut t: [Simd; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + + let avi: Simd = v0_a[0].cast(); + let bvj: Simd = v0_a[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1] += p_hi.to_bits(); + t[0] += p_lo.to_bits(); + let bvj: Simd = v0_a[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1] += p_hi.to_bits(); + t[1] += p_lo.to_bits(); + let bvj: Simd = v0_a[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1] += p_hi.to_bits(); + t[2] += p_lo.to_bits(); + let bvj: Simd = v0_a[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1] += p_hi.to_bits(); + t[3] += p_lo.to_bits(); + let bvj: Simd = v0_a[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1] += p_hi.to_bits(); + t[4] += p_lo.to_bits(); + + let avi: Simd = v0_a[1].cast(); + let bvj: Simd = v0_a[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1] += p_hi.to_bits(); + t[1] += p_lo.to_bits(); + let bvj: Simd = v0_a[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1 + 1] += p_hi.to_bits(); + t[1 + 1] += p_lo.to_bits(); + let bvj: Simd = v0_a[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 2 + 1] += p_hi.to_bits(); + t[1 + 2] += p_lo.to_bits(); + let bvj: Simd = v0_a[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 3 + 1] += p_hi.to_bits(); + t[1 + 3] += p_lo.to_bits(); + let bvj: Simd = v0_a[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 4 + 1] += p_hi.to_bits(); + t[1 + 4] += p_lo.to_bits(); + + let avi: Simd = v0_a[2].cast(); + let bvj: Simd = v0_a[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1] += p_hi.to_bits(); + t[2] += p_lo.to_bits(); + let bvj: Simd = v0_a[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1 + 1] += p_hi.to_bits(); + t[2 + 1] += p_lo.to_bits(); + let bvj: Simd = v0_a[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 2 + 1] += p_hi.to_bits(); + t[2 + 2] += p_lo.to_bits(); + let bvj: Simd = v0_a[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 3 + 1] += p_hi.to_bits(); + t[2 + 3] += p_lo.to_bits(); + let bvj: Simd = v0_a[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 4 + 1] += p_hi.to_bits(); + t[2 + 4] += p_lo.to_bits(); + + let avi: Simd = v0_a[3].cast(); + let bvj: Simd = v0_a[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1] += p_hi.to_bits(); + t[3] += p_lo.to_bits(); + let bvj: Simd = v0_a[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1 + 1] += p_hi.to_bits(); + t[3 + 1] += p_lo.to_bits(); + let bvj: Simd = v0_a[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 2 + 1] += p_hi.to_bits(); + t[3 + 2] += p_lo.to_bits(); + let bvj: Simd = v0_a[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 3 + 1] += p_hi.to_bits(); + t[3 + 3] += p_lo.to_bits(); + let bvj: Simd = v0_a[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 4 + 1] += p_hi.to_bits(); + t[3 + 4] += p_lo.to_bits(); + + let avi: Simd = v0_a[4].cast(); + let bvj: Simd = v0_a[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1] += p_hi.to_bits(); + t[4] += p_lo.to_bits(); + let bvj: Simd = v0_a[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1 + 1] += p_hi.to_bits(); + t[4 + 1] += p_lo.to_bits(); + let bvj: Simd = v0_a[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 2 + 1] += p_hi.to_bits(); + t[4 + 2] += p_lo.to_bits(); + let bvj: Simd = v0_a[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 3 + 1] += p_hi.to_bits(); + t[4 + 3] += p_lo.to_bits(); + let bvj: Simd = v0_a[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 4 + 1] += p_hi.to_bits(); + t[4 + 4] += p_lo.to_bits(); + + t[1] += t[0] >> 52; + t[2] += t[1] >> 52; + t[3] += t[2] >> 52; + t[4] += t[3] >> 52; + + let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK52)), RHO_4); + let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK52)), RHO_3); + let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); + let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); + + let s = [ + r0[0] + r1[0] + r2[0] + r3[0] + t[4], + r0[1] + r1[1] + r2[1] + r3[1] + t[5], + r0[2] + r1[2] + r2[2] + r3[2] + t[6], + r0[3] + r1[3] + r2[3] + r3[3] + t[7], + r0[4] + r1[4] + r2[4] + r3[4] + t[8], + r0[5] + r1[5] + r2[5] + r3[5] + t[9], + ]; + + let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); + let mp = smult_noinit_simd(m, U52_P); + + let reduced = reduce_ct_simd(addv_simd(s, mp)); + let u256_result = u260_to_u256_simd(reduced); + let v = transpose_simd_to_u256(u256_result); + (v[0], v[1]) +} + +#[inline] +pub fn simd_mul( + v0_a: [u64; 4], + v0_b: [u64; 4], + v1_a: [u64; 4], + v1_b: [u64; 4], +) -> ([u64; 4], [u64; 4]) { + let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); + let v0_b = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_b, v1_b])); + + let mut t: [Simd; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + + let avi: Simd = v0_a[0].cast(); + let bvj: Simd = v0_b[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1] += p_hi.to_bits(); + t[0] += p_lo.to_bits(); + let bvj: Simd = v0_b[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1] += p_hi.to_bits(); + t[1] += p_lo.to_bits(); + let bvj: Simd = v0_b[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1] += p_hi.to_bits(); + t[2] += p_lo.to_bits(); + let bvj: Simd = v0_b[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1] += p_hi.to_bits(); + t[3] += p_lo.to_bits(); + let bvj: Simd = v0_b[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1] += p_hi.to_bits(); + t[4] += p_lo.to_bits(); + + let avi: Simd = v0_a[1].cast(); + let bvj: Simd = v0_b[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1] += p_hi.to_bits(); + t[1] += p_lo.to_bits(); + let bvj: Simd = v0_b[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 1 + 1] += p_hi.to_bits(); + t[1 + 1] += p_lo.to_bits(); + let bvj: Simd = v0_b[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 2 + 1] += p_hi.to_bits(); + t[1 + 2] += p_lo.to_bits(); + let bvj: Simd = v0_b[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 3 + 1] += p_hi.to_bits(); + t[1 + 3] += p_lo.to_bits(); + let bvj: Simd = v0_b[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[1 + 4 + 1] += p_hi.to_bits(); + t[1 + 4] += p_lo.to_bits(); + + let avi: Simd = v0_a[2].cast(); + let bvj: Simd = v0_b[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1] += p_hi.to_bits(); + t[2] += p_lo.to_bits(); + let bvj: Simd = v0_b[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 1 + 1] += p_hi.to_bits(); + t[2 + 1] += p_lo.to_bits(); + let bvj: Simd = v0_b[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 2 + 1] += p_hi.to_bits(); + t[2 + 2] += p_lo.to_bits(); + let bvj: Simd = v0_b[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 3 + 1] += p_hi.to_bits(); + t[2 + 3] += p_lo.to_bits(); + let bvj: Simd = v0_b[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[2 + 4 + 1] += p_hi.to_bits(); + t[2 + 4] += p_lo.to_bits(); + + let avi: Simd = v0_a[3].cast(); + let bvj: Simd = v0_b[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1] += p_hi.to_bits(); + t[3] += p_lo.to_bits(); + let bvj: Simd = v0_b[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 1 + 1] += p_hi.to_bits(); + t[3 + 1] += p_lo.to_bits(); + let bvj: Simd = v0_b[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 2 + 1] += p_hi.to_bits(); + t[3 + 2] += p_lo.to_bits(); + let bvj: Simd = v0_b[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 3 + 1] += p_hi.to_bits(); + t[3 + 3] += p_lo.to_bits(); + let bvj: Simd = v0_b[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[3 + 4 + 1] += p_hi.to_bits(); + t[3 + 4] += p_lo.to_bits(); + + let avi: Simd = v0_a[4].cast(); + let bvj: Simd = v0_b[0].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1] += p_hi.to_bits(); + t[4] += p_lo.to_bits(); + let bvj: Simd = v0_b[1].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 1 + 1] += p_hi.to_bits(); + t[4 + 1] += p_lo.to_bits(); + let bvj: Simd = v0_b[2].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 2 + 1] += p_hi.to_bits(); + t[4 + 2] += p_lo.to_bits(); + let bvj: Simd = v0_b[3].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 3 + 1] += p_hi.to_bits(); + t[4 + 3] += p_lo.to_bits(); + let bvj: Simd = v0_b[4].cast(); + let p_hi = avi.mul_add(bvj, Simd::splat(C1)); + let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + t[4 + 4 + 1] += p_hi.to_bits(); + t[4 + 4] += p_lo.to_bits(); + + t[1] += t[0] >> 52; + t[2] += t[1] >> 52; + t[3] += t[2] >> 52; + t[4] += t[3] >> 52; + + let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK52)), RHO_4); + let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK52)), RHO_3); + let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); + let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); + + let s = [ + r0[0] + r1[0] + r2[0] + r3[0] + t[4], + r0[1] + r1[1] + r2[1] + r3[1] + t[5], + r0[2] + r1[2] + r2[2] + r3[2] + t[6], + r0[3] + r1[3] + r2[3] + r3[3] + t[7], + r0[4] + r1[4] + r2[4] + r3[4] + t[8], + r0[5] + r1[5] + r2[5] + r3[5] + t[9], + ]; + + let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); + let mp = smult_noinit_simd(m, U52_P); + + let reduced = reduce_ct_simd(addv_simd(s, mp)); + let u256_result = u260_to_u256_simd(reduced); + let v = transpose_simd_to_u256(u256_result); + (v[0], v[1]) +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::test_utils::{ark_ff_reference, safe_bn254_montgomery_input}, + ark_bn254::Fr, + ark_ff::BigInt, + fp_rounding::{with_rounding_mode, Zero}, + proptest::proptest, + }; + + #[test] + fn test_simd_mul() { + proptest!(|( + a in safe_bn254_montgomery_input(), + b in safe_bn254_montgomery_input(), + c in safe_bn254_montgomery_input(), + )| { + unsafe { + with_rounding_mode((), |rtz : &fp_rounding::RoundingGuard, _| { + + let (ab, bc) = simd_mul(a, b, b,c); + let ab_ref = ark_ff_reference(a, b); + let bc_ref = ark_ff_reference(b, c); + let ab = Fr::new(BigInt(ab)); + let bc = Fr::new(BigInt(bc)); + assert_eq!(ab_ref, ab); + assert_eq!(bc_ref, bc); + });} + }); + } +} From ceee4a2397123d8d4aa0029d5eaf897890aa5978 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 6 Jan 2026 19:45:47 +0800 Subject: [PATCH 09/37] wasm: optimising 52 bit - not final --- .../src/portable_simd_wasm.rs | 346 +++++------------- .../block-multiplier/src/simd_utils_wasm.rs | 158 ++++++++ 2 files changed, 242 insertions(+), 262 deletions(-) create mode 100644 skyscraper/block-multiplier/src/simd_utils_wasm.rs diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 35b7f18b..6283d00e 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -2,196 +2,17 @@ use { crate::{ constants::*, simd_utils_wasm::{ - addv_simd, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, - transpose_u256_to_simd, u256_to_u260_shl2_simd, u260_to_u256_simd, + addv_simd, fma, i2f, make_initial, reduce_ct_simd, smult_noinit_simd, + transpose_simd_to_u256, transpose_u256_to_simd, u256_to_u260_shl2_simd, + u260_to_u256_simd, }, }, core::{ ops::BitAnd, simd::{num::SimdFloat, Simd}, }, - std::simd::{num::SimdUint, StdFloat}, }; -#[inline] -pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { - let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); - - let mut t: [Simd; 10] = [Simd::splat(0); 10]; - t[0] = Simd::splat(make_initial(1, 0)); - t[9] = Simd::splat(make_initial(0, 6)); - t[1] = Simd::splat(make_initial(2, 1)); - t[8] = Simd::splat(make_initial(6, 7)); - t[2] = Simd::splat(make_initial(3, 2)); - t[7] = Simd::splat(make_initial(7, 8)); - t[3] = Simd::splat(make_initial(4, 3)); - t[6] = Simd::splat(make_initial(8, 9)); - t[4] = Simd::splat(make_initial(10, 4)); - t[5] = Simd::splat(make_initial(9, 10)); - - let avi: Simd = v0_a[0].cast(); - let bvj: Simd = v0_a[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[1] += p_hi.to_bits(); - t[0] += p_lo.to_bits(); - let bvj: Simd = v0_a[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += p_hi.to_bits(); - t[1] += p_lo.to_bits(); - let bvj: Simd = v0_a[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += p_hi.to_bits(); - t[2] += p_lo.to_bits(); - let bvj: Simd = v0_a[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += p_hi.to_bits(); - t[3] += p_lo.to_bits(); - let bvj: Simd = v0_a[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += p_hi.to_bits(); - t[4] += p_lo.to_bits(); - - let avi: Simd = v0_a[1].cast(); - let bvj: Simd = v0_a[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += p_hi.to_bits(); - t[1] += p_lo.to_bits(); - let bvj: Simd = v0_a[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[1 + 1 + 1] += p_hi.to_bits(); - t[1 + 1] += p_lo.to_bits(); - let bvj: Simd = v0_a[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[1 + 2 + 1] += p_hi.to_bits(); - t[1 + 2] += p_lo.to_bits(); - let bvj: Simd = v0_a[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[1 + 3 + 1] += p_hi.to_bits(); - t[1 + 3] += p_lo.to_bits(); - let bvj: Simd = v0_a[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[1 + 4 + 1] += p_hi.to_bits(); - t[1 + 4] += p_lo.to_bits(); - - let avi: Simd = v0_a[2].cast(); - let bvj: Simd = v0_a[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += p_hi.to_bits(); - t[2] += p_lo.to_bits(); - let bvj: Simd = v0_a[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[2 + 1 + 1] += p_hi.to_bits(); - t[2 + 1] += p_lo.to_bits(); - let bvj: Simd = v0_a[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[2 + 2 + 1] += p_hi.to_bits(); - t[2 + 2] += p_lo.to_bits(); - let bvj: Simd = v0_a[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[2 + 3 + 1] += p_hi.to_bits(); - t[2 + 3] += p_lo.to_bits(); - let bvj: Simd = v0_a[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[2 + 4 + 1] += p_hi.to_bits(); - t[2 + 4] += p_lo.to_bits(); - - let avi: Simd = v0_a[3].cast(); - let bvj: Simd = v0_a[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += p_hi.to_bits(); - t[3] += p_lo.to_bits(); - let bvj: Simd = v0_a[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[3 + 1 + 1] += p_hi.to_bits(); - t[3 + 1] += p_lo.to_bits(); - let bvj: Simd = v0_a[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[3 + 2 + 1] += p_hi.to_bits(); - t[3 + 2] += p_lo.to_bits(); - let bvj: Simd = v0_a[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[3 + 3 + 1] += p_hi.to_bits(); - t[3 + 3] += p_lo.to_bits(); - let bvj: Simd = v0_a[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[3 + 4 + 1] += p_hi.to_bits(); - t[3 + 4] += p_lo.to_bits(); - - let avi: Simd = v0_a[4].cast(); - let bvj: Simd = v0_a[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += p_hi.to_bits(); - t[4] += p_lo.to_bits(); - let bvj: Simd = v0_a[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[4 + 1 + 1] += p_hi.to_bits(); - t[4 + 1] += p_lo.to_bits(); - let bvj: Simd = v0_a[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[4 + 2 + 1] += p_hi.to_bits(); - t[4 + 2] += p_lo.to_bits(); - let bvj: Simd = v0_a[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[4 + 3 + 1] += p_hi.to_bits(); - t[4 + 3] += p_lo.to_bits(); - let bvj: Simd = v0_a[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); - t[4 + 4 + 1] += p_hi.to_bits(); - t[4 + 4] += p_lo.to_bits(); - - t[1] += t[0] >> 52; - t[2] += t[1] >> 52; - t[3] += t[2] >> 52; - t[4] += t[3] >> 52; - - let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK52)), RHO_4); - let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK52)), RHO_3); - let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); - let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); - - let s = [ - r0[0] + r1[0] + r2[0] + r3[0] + t[4], - r0[1] + r1[1] + r2[1] + r3[1] + t[5], - r0[2] + r1[2] + r2[2] + r3[2] + t[6], - r0[3] + r1[3] + r2[3] + r3[3] + t[7], - r0[4] + r1[4] + r2[4] + r3[4] + t[8], - r0[5] + r1[5] + r2[5] + r3[5] + t[9], - ]; - - let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); - let mp = smult_noinit_simd(m, U52_P); - - let reduced = reduce_ct_simd(addv_simd(s, mp)); - let u256_result = u260_to_u256_simd(reduced); - let v = transpose_simd_to_u256(u256_result); - (v[0], v[1]) -} - #[inline] pub fn simd_mul( v0_a: [u64; 4], @@ -214,138 +35,138 @@ pub fn simd_mul( t[4] = Simd::splat(make_initial(10, 4)); t[5] = Simd::splat(make_initial(9, 10)); - let avi: Simd = v0_a[0].cast(); - let bvj: Simd = v0_b[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let avi: Simd = i2f(v0_a[0]); + let bvj: Simd = i2f(v0_b[0]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[1] += p_hi.to_bits(); t[0] += p_lo.to_bits(); - let bvj: Simd = v0_b[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[1]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[1 + 1] += p_hi.to_bits(); t[1] += p_lo.to_bits(); - let bvj: Simd = v0_b[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[2]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[2 + 1] += p_hi.to_bits(); t[2] += p_lo.to_bits(); - let bvj: Simd = v0_b[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[3 + 1] += p_hi.to_bits(); t[3] += p_lo.to_bits(); - let bvj: Simd = v0_b[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[4 + 1] += p_hi.to_bits(); t[4] += p_lo.to_bits(); - let avi: Simd = v0_a[1].cast(); - let bvj: Simd = v0_b[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let avi: Simd = i2f(v0_a[1]); + let bvj: Simd = i2f(v0_b[0]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[1 + 1] += p_hi.to_bits(); t[1] += p_lo.to_bits(); - let bvj: Simd = v0_b[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[1]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[1 + 1 + 1] += p_hi.to_bits(); t[1 + 1] += p_lo.to_bits(); - let bvj: Simd = v0_b[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[2]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[1 + 2 + 1] += p_hi.to_bits(); t[1 + 2] += p_lo.to_bits(); - let bvj: Simd = v0_b[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[1 + 3 + 1] += p_hi.to_bits(); t[1 + 3] += p_lo.to_bits(); - let bvj: Simd = v0_b[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[1 + 4 + 1] += p_hi.to_bits(); t[1 + 4] += p_lo.to_bits(); - let avi: Simd = v0_a[2].cast(); - let bvj: Simd = v0_b[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let avi: Simd = i2f(v0_a[2]); + let bvj: Simd = i2f(v0_b[0]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[2 + 1] += p_hi.to_bits(); t[2] += p_lo.to_bits(); - let bvj: Simd = v0_b[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[1]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[2 + 1 + 1] += p_hi.to_bits(); t[2 + 1] += p_lo.to_bits(); - let bvj: Simd = v0_b[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[2]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[2 + 2 + 1] += p_hi.to_bits(); t[2 + 2] += p_lo.to_bits(); - let bvj: Simd = v0_b[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[2 + 3 + 1] += p_hi.to_bits(); t[2 + 3] += p_lo.to_bits(); - let bvj: Simd = v0_b[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[2 + 4 + 1] += p_hi.to_bits(); t[2 + 4] += p_lo.to_bits(); - let avi: Simd = v0_a[3].cast(); - let bvj: Simd = v0_b[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let avi: Simd = i2f(v0_a[3]); + let bvj: Simd = i2f(v0_b[0]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[3 + 1] += p_hi.to_bits(); t[3] += p_lo.to_bits(); - let bvj: Simd = v0_b[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[1]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[3 + 1 + 1] += p_hi.to_bits(); t[3 + 1] += p_lo.to_bits(); - let bvj: Simd = v0_b[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[2]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[3 + 2 + 1] += p_hi.to_bits(); t[3 + 2] += p_lo.to_bits(); - let bvj: Simd = v0_b[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[3 + 3 + 1] += p_hi.to_bits(); t[3 + 3] += p_lo.to_bits(); - let bvj: Simd = v0_b[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[3 + 4 + 1] += p_hi.to_bits(); t[3 + 4] += p_lo.to_bits(); - let avi: Simd = v0_a[4].cast(); - let bvj: Simd = v0_b[0].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let avi: Simd = i2f(v0_a[4]); + let bvj: Simd = i2f(v0_b[0]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[4 + 1] += p_hi.to_bits(); t[4] += p_lo.to_bits(); - let bvj: Simd = v0_b[1].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[1]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[4 + 1 + 1] += p_hi.to_bits(); t[4 + 1] += p_lo.to_bits(); - let bvj: Simd = v0_b[2].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[2]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[4 + 2 + 1] += p_hi.to_bits(); t[4 + 2] += p_lo.to_bits(); - let bvj: Simd = v0_b[3].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[4 + 3 + 1] += p_hi.to_bits(); t[4 + 3] += p_lo.to_bits(); - let bvj: Simd = v0_b[4].cast(); - let p_hi = avi.mul_add(bvj, Simd::splat(C1)); - let p_lo = avi.mul_add(bvj, Simd::splat(C2) - p_hi); + let bvj: Simd = i2f(v0_b[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); t[4 + 4 + 1] += p_hi.to_bits(); t[4 + 4] += p_lo.to_bits(); @@ -377,6 +198,7 @@ pub fn simd_mul( (v[0], v[1]) } +#[cfg(not(target_arch = "wasm32"))] #[cfg(test)] mod tests { use { diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs new file mode 100644 index 00000000..eade332a --- /dev/null +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -0,0 +1,158 @@ +use { + crate::constants::{C1, C2, MASK52, U52_2P}, + core::{ + array, + ops::BitAnd, + simd::{ + cmp::SimdPartialEq, + num::{SimdFloat, SimdInt, SimdUint}, + Simd, + }, + }, +}; + +// -- [SIMD UTILS] +// --------------------------------------------------------------------------------- +#[inline(always)] +// 52 bit conversion does not have to go through and expensive +pub fn i2f(a: Simd) -> Simd { + unsafe { core::mem::transmute(a) } + // TODO: add the addition for proper conversion +} + +#[inline(always)] +pub fn fma(a: Simd, b: Simd, c: Simd) -> Simd { + #[cfg(not(target_arch = "wasm32"))] + { + use std::simd::StdFloat; + + a.mul_add(b, c) + } + #[cfg(target_arch = "wasm32")] + { + use core::arch::wasm32::*; + f64x2_relaxed_madd(a.into(), b.into(), c.into()).into() + } +} + +#[inline(always)] +pub const fn make_initial(low_count: usize, high_count: usize) -> u64 { + let val = high_count * 0x467 + low_count * 0x433; + -((val as i64 & 0xfff) << 52) as u64 +} + +#[inline(always)] +pub fn transpose_u256_to_simd(limbs: [[u64; 4]; 2]) -> [Simd; 4] { + // This does not issue multiple ldp and zip which might be marginally faster. + [ + Simd::from_array([limbs[0][0], limbs[1][0]]), + Simd::from_array([limbs[0][1], limbs[1][1]]), + Simd::from_array([limbs[0][2], limbs[1][2]]), + Simd::from_array([limbs[0][3], limbs[1][3]]), + ] +} + +#[inline(always)] +pub fn transpose_simd_to_u256(limbs: [Simd; 4]) -> [[u64; 4]; 2] { + let tmp0 = limbs[0].to_array(); + let tmp1 = limbs[1].to_array(); + let tmp2 = limbs[2].to_array(); + let tmp3 = limbs[3].to_array(); + [[tmp0[0], tmp1[0], tmp2[0], tmp3[0]], [ + tmp0[1], tmp1[1], tmp2[1], tmp3[1], + ]] +} + +#[inline(always)] +pub fn u256_to_u260_shl2_simd(limbs: [Simd; 4]) -> [Simd; 5] { + let [l0, l1, l2, l3] = limbs; + [ + (l0 << 2) & Simd::splat(MASK52), + ((l0 >> 50) | (l1 << 14)) & Simd::splat(MASK52), + ((l1 >> 38) | (l2 << 26)) & Simd::splat(MASK52), + ((l2 >> 26) | (l3 << 38)) & Simd::splat(MASK52), + l3 >> 14, + ] +} + +#[inline(always)] +pub fn u260_to_u256_simd(limbs: [Simd; 5]) -> [Simd; 4] { + let [l0, l1, l2, l3, l4] = limbs; + [ + l0 | (l1 << 52), + (l1 >> 12) | (l2 << 40), + (l2 >> 24) | (l3 << 28), + (l3 >> 36) | (l4 << 16), + ] +} + +#[inline(always)] +pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { + let mut t = [Simd::splat(0); 6]; + let s: Simd = i2f(s); + + let p_hi_0 = fma(s, Simd::splat(v[0] as f64), Simd::splat(C1)); + let p_lo_0 = fma(s, Simd::splat(v[0] as f64), Simd::splat(C2) - p_hi_0); + t[1] += p_hi_0.to_bits(); + t[0] += p_lo_0.to_bits(); + + let p_hi_1 = fma(s, Simd::splat(v[1] as f64), Simd::splat(C1)); + let p_lo_1 = fma(s, Simd::splat(v[1] as f64), Simd::splat(C2) - p_hi_1); + t[2] += p_hi_1.to_bits(); + t[1] += p_lo_1.to_bits(); + + let p_hi_2 = fma(s, Simd::splat(v[2] as f64), Simd::splat(C1)); + let p_lo_2 = fma(s, Simd::splat(v[2] as f64), Simd::splat(C2) - p_hi_2); + t[3] += p_hi_2.to_bits(); + t[2] += p_lo_2.to_bits(); + + let p_hi_3 = fma(s, Simd::splat(v[3] as f64), Simd::splat(C1)); + let p_lo_3 = fma(s, Simd::splat(v[3] as f64), Simd::splat(C2) - p_hi_3); + t[4] += p_hi_3.to_bits(); + t[3] += p_lo_3.to_bits(); + + let p_hi_4 = fma(s, Simd::splat(v[4] as f64), Simd::splat(C1)); + let p_lo_4 = fma(s, Simd::splat(v[4] as f64), Simd::splat(C2) - p_hi_4); + t[5] += p_hi_4.to_bits(); + t[4] += p_lo_4.to_bits(); + + t +} + +#[inline(always)] +/// Resolve the carry bits in the upper parts 12b and reduce the result to +/// within < 3p +pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { + // The lowest limb contains carries that still need to be applied. + let mut borrow: Simd = (red[0] >> 52).cast(); + let a = [red[1], red[2], red[3], red[4], red[5]]; + + // To reduce Check whether the most significant bit is set + let mask = (a[4] >> 47).bitand(Simd::splat(1)).simd_eq(Simd::splat(0)); + + // Select values based on the mask: if mask lane is true, use zeros, else use + // U52_2P + let zeros = [Simd::splat(0); 5]; + let twop = U52_2P.map(Simd::splat); + let b: [_; 5] = array::from_fn(|i| mask.select(zeros[i], twop[i])); + + let mut c = [Simd::splat(0); 5]; + for i in 0..c.len() { + let tmp: Simd = a[i].cast::() - b[i].cast() + borrow; + c[i] = tmp.cast().bitand(Simd::splat(MASK52)); + borrow = tmp >> 52 + } + + c +} + +#[inline(always)] +pub fn addv_simd( + mut va: [Simd; N], + vb: [Simd; N], +) -> [Simd; N] { + for i in 0..va.len() { + va[i] += vb[i]; + } + va +} From 493367c63234153538efc46d9cb6d219be56f837 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Wed, 7 Jan 2026 12:57:55 +0800 Subject: [PATCH 10/37] wasm: optimised 52/51-bit integer-to-float conversion --- skyscraper/block-multiplier/src/lib.rs | 2 +- .../block-multiplier/src/simd_utils_wasm.rs | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/skyscraper/block-multiplier/src/lib.rs b/skyscraper/block-multiplier/src/lib.rs index dbe70504..7fea383e 100644 --- a/skyscraper/block-multiplier/src/lib.rs +++ b/skyscraper/block-multiplier/src/lib.rs @@ -19,7 +19,7 @@ mod simd_utils; pub mod constants; pub mod portable_simd_wasm; mod scalar; -mod simd_utils_wasm; +pub mod simd_utils_wasm; #[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI mod test_utils; mod utils; diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index eade332a..bc620bb6 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -14,10 +14,21 @@ use { // -- [SIMD UTILS] // --------------------------------------------------------------------------------- #[inline(always)] -// 52 bit conversion does not have to go through and expensive +/// On WASSM there is no single specialised instruction to cast an integer to a +/// float. Since we are only interested in 52 bits, we can emulate it with fewer +/// instructions. pub fn i2f(a: Simd) -> Simd { - unsafe { core::mem::transmute(a) } - // TODO: add the addition for proper conversion + // This function has not target gating as we want to verify this function with + // kani and proptest on a different platform than wasm + + // By adding 2^52 represented as float (0x1p52) -> 0x433 << 52, we align the + // 52bit number fully in the mantissa. This can be done with a simple or. Then + // to convert a to it's floating point number we subtract this again. This way + // we only pay for the conversion of the lower bits and not the full 64 bits. + let exponent = Simd::splat(0x433 << 52); + let a: Simd = unsafe { core::mem::transmute(a | exponent) }; + let b: Simd = unsafe { core::mem::transmute(exponent) }; + a - b } #[inline(always)] From 74cc61cb7b6d9c5849514cb7277fd7c772d2a705 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Mon, 12 Jan 2026 10:13:27 +0800 Subject: [PATCH 11/37] b51: add constants --- skyscraper/block-multiplier/src/constants.rs | 2 ++ skyscraper/block-multiplier/src/simd_utils_wasm.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/skyscraper/block-multiplier/src/constants.rs b/skyscraper/block-multiplier/src/constants.rs index 171273f5..f9b8d82b 100644 --- a/skyscraper/block-multiplier/src/constants.rs +++ b/skyscraper/block-multiplier/src/constants.rs @@ -133,6 +133,8 @@ pub const C1: f64 = pow_2(104); // 2.0^104 pub const C2: f64 = pow_2(104) + pow_2(52); // 2.0^104 + 2.0^52 // const C3: f64 = pow_2(52); // 2.0^52 // ------------------------------------------------------------------------------------------------- +pub const C1F51: f64 = pow_2(103); +pub const C2F51: f64 = pow_2(103) + pow_2(52) + pow_2(51); const fn pow_2(n: u32) -> f64 { // Unfortunately we can't use f64::powi in const fn yet diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index bc620bb6..aba10796 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -49,7 +49,7 @@ pub fn fma(a: Simd, b: Simd, c: Simd) -> Simd { #[inline(always)] pub const fn make_initial(low_count: usize, high_count: usize) -> u64 { let val = high_count * 0x467 + low_count * 0x433; - -((val as i64 & 0xfff) << 52) as u64 + -((val as i64) << 52) as u64 } #[inline(always)] From 9500a7b5f3cc88ca044a94be0062cb8c8f9106a3 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 20 Jan 2026 10:19:51 +0800 Subject: [PATCH 12/37] Montgomery table: use correct prime and add 51bit --- .../src/aarch64/generate_montgomery_table.py | 146 ++++++++++++------ 1 file changed, 102 insertions(+), 44 deletions(-) diff --git a/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py b/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py index bf8d78d3..2e3b2695 100644 --- a/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py +++ b/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py @@ -1,19 +1,21 @@ -p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +from math import log2 + +p = 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001 U52_i1 = [ - 0x82e644ee4c3d2, - 0xf93893c98b1de, - 0xd46fe04d0a4c7, - 0x8f0aad55e2a1f, - 0x005ed0447de83, + 0x82E644EE4C3D2, + 0xF93893C98B1DE, + 0xD46FE04D0A4C7, + 0x8F0AAD55E2A1F, + 0x005ED0447DE83, ] U52_i2 = [ - 0x74eccce9a797a, - 0x16ddcc30bd8a4, - 0x49ecd3539499e, - 0xb23a6fcc592b8, - 0x00e3bd49f6ee5, + 0x74ECCCE9A797A, + 0x16DDCC30BD8A4, + 0x49ECD3539499E, + 0xB23A6FCC592B8, + 0x00E3BD49F6EE5, ] U52_i3 = [ @@ -33,17 +35,17 @@ ] U64_I1 = [ - 0x2d3e8053e396ee4d, - 0xca478dbeab3c92cd, - 0xb2d8f06f77f52a93, - 0x24d6ba07f7aa8f04, + 0x2D3E8053E396EE4D, + 0xCA478DBEAB3C92CD, + 0xB2D8F06F77F52A93, + 0x24D6BA07F7AA8F04, ] U64_I2 = [ - 0x18ee753c76f9dc6f, - 0x54ad7e14a329e70f, - 0x2b16366f4f7684df, - 0x133100d71fdf3579, + 0x18EE753C76F9DC6F, + 0x54AD7E14A329E70F, + 0x2B16366F4F7684DF, + 0x133100D71FDF3579, ] U64_I3 = [ @@ -53,13 +55,37 @@ 0x2B062AAA49F80C7D, ] + +U51_i1 = pow( + 2**51, + -1, + 21888242871839275222246405745257275088548364400416034343698204186575808495617, +) +U51_i2 = pow( + 2**51, + -2, + 21888242871839275222246405745257275088548364400416034343698204186575808495617, +) +U51_i3 = pow( + 2**51, + -3, + 21888242871839275222246405745257275088548364400416034343698204186575808495617, +) +U51_i4 = pow( + 2**51, + -4, + 21888242871839275222246405745257275088548364400416034343698204186575808495617, +) + + def limbs_to_int(size, xs): total = 0 - for (i, x) in enumerate(xs): - total += x << (size*i) + for i, x in enumerate(xs): + total += x << (size * i) return total + u64_i1 = limbs_to_int(64, U64_I1) u64_i2 = limbs_to_int(64, U64_I2) u64_i3 = limbs_to_int(64, U64_I3) @@ -69,44 +95,76 @@ def limbs_to_int(size, xs): u52_i3 = limbs_to_int(52, U52_i3) u52_i4 = limbs_to_int(52, U52_i4) - -def log_jump(single_input_bound): +def log_jump(single_input_bound): product_bound = single_input_bound**2 - first_round = (product_bound>>2*64) + u64_i2 * (2**128-1) - second_round = (first_round >> 64) + u64_i1 * (2**64-1) - mont_round = second_round + p*(2**64-1) + first_round = (product_bound >> 2 * 64) + u64_i2 * (2**128 - 1) + second_round = (first_round >> 64) + u64_i1 * (2**64 - 1) + mont_round = second_round + p * (2**64 - 1) final = mont_round >> 64 return final -def single_step(single_input_bound): + +def single_step(single_input_bound): product_bound = single_input_bound**2 - first_round = (product_bound>>3*64) + (u64_i3 + u64_i2 + u64_i1) * (2**64-1) - mont_round = first_round + p*(2**64-1) + first_round = (product_bound >> 3 * 64) + (u64_i3 + u64_i2 + u64_i1) * (2**64 - 1) + mont_round = first_round + p * (2**64 - 1) final = mont_round >> 64 + # print(log2(final)) + return final -def single_step_simd(single_input_bound): - product_bound = (single_input_bound<<2)**2 - first_round = (product_bound>>4*52) + (u52_i4 + u52_i3 + u52_i2 + u52_i1) * (2**52-1) - mont_round = first_round + p*(2**52-1) +def single_step_simd(single_input_bound): + product_bound = (single_input_bound << 2) ** 2 + + first_round = (product_bound >> 4 * 52) + (u52_i4 + u52_i3 + u52_i2 + u52_i1) * ( + 2**52 - 1 + ) + mont_round = first_round + p * (2**52 - 1) final = mont_round >> 52 + # print(log2(final)) return final + +def single_step_simd_wasm(single_input_bound): + product_bound = (single_input_bound) ** 2 + + first_round = (product_bound >> 4 * 51) + (U51_i1 + U51_i2 + U51_i3 + U51_i4) * ( + 2**51 - 1 + ) + mont_round = first_round + p * (2**51 - 1) + final = mont_round >> 51 + # print(log2(final)) + # print(log2(final + p)) + + reduced = (final + p) >> 1 if final & 1 else final >> 1 + # print(log2(reduced)) + return reduced + + if __name__ == "__main__": # Test bounds for different input sizes - test_bounds = [("p", p),("2p", 2*p), ("3p", 3*p), ("2ˆ256-2p",2**256-2*p)] - print("Input Size | single_step | single_step_simd | log_jump") - print("-----------|-------------|------------------|---------") + test_bounds = [ + ("p", p), + ("2p", 2 * p), + ("2ˆ255", 2**255), + ("3p", 3 * p), + ("2ˆ256-2p", 2**256 - 2 * p), + ] + print("Input Size | single_step | single_step_simd | log_jump| single_step_wasm ") + print("-----------|-------------|------------------|---------|-----------------|") for name, bound in test_bounds: - single = single_step(bound)/p - simd = single_step_simd(bound)/p - log = log_jump(bound)/p - single_space = (2**256-1-single_step(bound))/p - simd_space = (2**256-1-single_step_simd(bound))/p - log_space = (2**256-1-log_jump(bound))/p - print(f"{name:10} | {single:4.2f} [{single_space:4.2f}] | {simd:7.2f} [{simd_space:.4f}] | {log:4.2f} [{log_space:.2f}]") - + single = single_step(bound) / p + simd = single_step_simd(bound) / p + simd_wasm = single_step_simd_wasm(bound) / p + log = log_jump(bound) / p + single_space = (2**256 - 1 - single_step(bound)) / p + simd_space = (2**256 - 1 - single_step_simd(bound)) / p + simd_wasm_space = (2**256 - 1 - single_step_simd_wasm(bound)) / p + log_space = (2**256 - 1 - log_jump(bound)) / p + print( + f"{name:10} | {single:4.2f} [{single_space:4.2f}] | {simd:7.2f} [{simd_space:.4f}] | {log:4.2f} [{log_space:.2f}] | {simd_wasm:4.2f} [{simd_wasm_space:.2f}]" + ) From f309c499a93eb1bf8e6a43190efe70fb90e68cd0 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 20 Jan 2026 10:21:04 +0800 Subject: [PATCH 13/37] start 51 bit conversion --- .../block-multiplier/src/constants_wasm.rs | 148 ++++++++++++++++++ skyscraper/block-multiplier/src/lib.rs | 1 + .../src/portable_simd_wasm.rs | 25 ++- .../block-multiplier/src/simd_utils_wasm.rs | 26 +-- 4 files changed, 172 insertions(+), 28 deletions(-) create mode 100644 skyscraper/block-multiplier/src/constants_wasm.rs diff --git a/skyscraper/block-multiplier/src/constants_wasm.rs b/skyscraper/block-multiplier/src/constants_wasm.rs new file mode 100644 index 00000000..54a3084a --- /dev/null +++ b/skyscraper/block-multiplier/src/constants_wasm.rs @@ -0,0 +1,148 @@ +pub const U64_NP0: u64 = 0xc2e1f593efffffff; + +pub const U64_P: [u64; 4] = [ + 0x43e1f593f0000001, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, +]; + +pub const U64_2P: [u64; 4] = [ + 0x87c3eb27e0000002, + 0x5067d090f372e122, + 0x70a08b6d0302b0ba, + 0x60c89ce5c2634053, +]; + +// R mod P +pub const U64_R: [u64; 4] = [ + 0xac96341c4ffffffb, + 0x36fc76959f60cd29, + 0x666ea36f7879462e, + 0x0e0a77c19a07df2f, +]; + +// R^2 mod P +pub const U64_R2: [u64; 4] = [ + 0x1bb8e645ae216da7, + 0x53fe3ab1e35c59e3, + 0x8c49833d53bb8085, + 0x0216d0b17f4e44a5, +]; + +// R^-1 mod P +pub const U64_R_INV: [u64; 4] = [ + 0xdc5ba0056db1194e, + 0x090ef5a9e111ec87, + 0xc8260de4aeb85d5d, + 0x15ebf95182c5551c, +]; + +pub const U52_NP0: u64 = 0x1f593efffffff; +pub const U52_R2: [u64; 5] = [ + 0x0b852d16da6f5, + 0xc621620cddce3, + 0xaf1b95343ffb6, + 0xc3c15e103e7c2, + 0x00281528fa122, +]; + +pub const U52_P: [u64; 5] = [ + 0x1f593f0000001, + 0x4879b9709143e, + 0x181585d2833e8, + 0xa029b85045b68, + 0x030644e72e131, +]; + +pub const U52_2P: [u64; 5] = [ + 0x3eb27e0000002, + 0x90f372e12287c, + 0x302b0ba5067d0, + 0x405370a08b6d0, + 0x060c89ce5c263, +]; + +pub const F52_P: [f64; 5] = [ + 0x1f593f0000001_u64 as f64, + 0x4879b9709143e_u64 as f64, + 0x181585d2833e8_u64 as f64, + 0xa029b85045b68_u64 as f64, + 0x030644e72e131_u64 as f64, +]; + +pub const MASK51: u64 = 2_u64.pow(51) - 1; + +pub const U64_I1: [u64; 4] = [ + 0x2d3e8053e396ee4d, + 0xca478dbeab3c92cd, + 0xb2d8f06f77f52a93, + 0x24d6ba07f7aa8f04, +]; +pub const U64_I2: [u64; 4] = [ + 0x18ee753c76f9dc6f, + 0x54ad7e14a329e70f, + 0x2b16366f4f7684df, + 0x133100d71fdf3579, +]; + +pub const U64_I3: [u64; 4] = [ + 0x9bacb016127cbe4e, + 0x0b2051fa31944124, + 0xb064eea46091c76c, + 0x2b062aaa49f80c7d, +]; +pub const U64_MU0: u64 = 0xc2e1f593efffffff; + +// -- [FP SIMD CONSTANTS] +// -------------------------------------------------------------------------- +pub const RHO_1: [u64; 5] = [ + 0x82e644ee4c3d2, + 0xf93893c98b1de, + 0xd46fe04d0a4c7, + 0x8f0aad55e2a1f, + 0x005ed0447de83, +]; + +pub const RHO_2: [u64; 5] = [ + 0x74eccce9a797a, + 0x16ddcc30bd8a4, + 0x49ecd3539499e, + 0xb23a6fcc592b8, + 0x00e3bd49f6ee5, +]; + +pub const RHO_3: [u64; 5] = [ + 0x0e8c656567d77, + 0x430d05713ae61, + 0xea3ba6b167128, + 0xa7dae55c5a296, + 0x01b4afd513572, +]; + +pub const RHO_4: [u64; 5] = [ + 0x22e2400e2f27d, + 0x323b46ea19686, + 0xe6c43f0df672d, + 0x7824014c39e8b, + 0x00c6b48afe1b8, +]; + +pub const C1: f64 = pow_2(103); +pub const C2: f64 = pow_2(103) + pow_2(52) + pow_2(51); + +const fn pow_2(n: u32) -> f64 { + // Unfortunately we can't use f64::powi in const fn yet + // This is a workaround that creates the bit pattern directly + let exp = ((n as u64 + 1023) & 0x7ff) << 52; + f64::from_bits(exp) +} + +// BOUNDS +/// Upper bound of 2**256-2p +pub const OUTPUT_MAX: [u64; 4] = [ + 0x783c14d81ffffffe, + 0xaf982f6f0c8d1edd, + 0x8f5f7492fcfd4f45, + 0x9f37631a3d9cbfac, +]; diff --git a/skyscraper/block-multiplier/src/lib.rs b/skyscraper/block-multiplier/src/lib.rs index 7fea383e..b1a19da3 100644 --- a/skyscraper/block-multiplier/src/lib.rs +++ b/skyscraper/block-multiplier/src/lib.rs @@ -17,6 +17,7 @@ mod simd_utils; // pub mod block_simd_wasm; pub mod constants; +pub mod constants_wasm; pub mod portable_simd_wasm; mod scalar; pub mod simd_utils_wasm; diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 6283d00e..0825afd6 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -1,10 +1,9 @@ use { crate::{ - constants::*, + constants_wasm::*, simd_utils_wasm::{ addv_simd, fma, i2f, make_initial, reduce_ct_simd, smult_noinit_simd, - transpose_simd_to_u256, transpose_u256_to_simd, u256_to_u260_shl2_simd, - u260_to_u256_simd, + transpose_simd_to_u256, transpose_u256_to_simd, u255_to_u256_simd, u256_to_u255_simd, }, }, core::{ @@ -20,8 +19,8 @@ pub fn simd_mul( v1_a: [u64; 4], v1_b: [u64; 4], ) -> ([u64; 4], [u64; 4]) { - let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); - let v0_b = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_b, v1_b])); + let v0_a = u256_to_u255_simd(transpose_u256_to_simd([v0_a, v1_a])); + let v0_b = u256_to_u255_simd(transpose_u256_to_simd([v0_b, v1_b])); let mut t: [Simd; 10] = [Simd::splat(0); 10]; t[0] = Simd::splat(make_initial(1, 0)); @@ -175,10 +174,10 @@ pub fn simd_mul( t[3] += t[2] >> 52; t[4] += t[3] >> 52; - let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK52)), RHO_4); - let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK52)), RHO_3); - let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK52)), RHO_2); - let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK52)), RHO_1); + let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK51)), RHO_4); + let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK51)), RHO_3); + let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK51)), RHO_2); + let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK51)), RHO_1); let s = [ r0[0] + r1[0] + r2[0] + r3[0] + t[4], @@ -189,11 +188,11 @@ pub fn simd_mul( r0[5] + r1[5] + r2[5] + r3[5] + t[9], ]; - let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK52)); + let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK51)); let mp = smult_noinit_simd(m, U52_P); let reduced = reduce_ct_simd(addv_simd(s, mp)); - let u256_result = u260_to_u256_simd(reduced); + let u256_result = u255_to_u256_simd(reduced); let v = transpose_simd_to_u256(u256_result); (v[0], v[1]) } @@ -206,7 +205,6 @@ mod tests { crate::test_utils::{ark_ff_reference, safe_bn254_montgomery_input}, ark_bn254::Fr, ark_ff::BigInt, - fp_rounding::{with_rounding_mode, Zero}, proptest::proptest, }; @@ -217,8 +215,6 @@ mod tests { b in safe_bn254_montgomery_input(), c in safe_bn254_montgomery_input(), )| { - unsafe { - with_rounding_mode((), |rtz : &fp_rounding::RoundingGuard, _| { let (ab, bc) = simd_mul(a, b, b,c); let ab_ref = ark_ff_reference(a, b); @@ -227,7 +223,6 @@ mod tests { let bc = Fr::new(BigInt(bc)); assert_eq!(ab_ref, ab); assert_eq!(bc_ref, bc); - });} }); } } diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index aba10796..75929534 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -1,5 +1,5 @@ use { - crate::constants::{C1, C2, MASK52, U52_2P}, + crate::constants_wasm::{C1, C2, MASK51, U52_2P}, core::{ array, ops::BitAnd, @@ -75,25 +75,25 @@ pub fn transpose_simd_to_u256(limbs: [Simd; 4]) -> [[u64; 4]; 2] { } #[inline(always)] -pub fn u256_to_u260_shl2_simd(limbs: [Simd; 4]) -> [Simd; 5] { +pub fn u256_to_u255_simd(limbs: [Simd; 4]) -> [Simd; 5] { let [l0, l1, l2, l3] = limbs; [ - (l0 << 2) & Simd::splat(MASK52), - ((l0 >> 50) | (l1 << 14)) & Simd::splat(MASK52), - ((l1 >> 38) | (l2 << 26)) & Simd::splat(MASK52), - ((l2 >> 26) | (l3 << 38)) & Simd::splat(MASK52), - l3 >> 14, + (l0) & Simd::splat(MASK51), + ((l0 >> 51) | (l1 << 13)) & Simd::splat(MASK51), + ((l1 >> 38) | (l2 << 26)) & Simd::splat(MASK51), + ((l2 >> 25) | (l3 << 39)) & Simd::splat(MASK51), + l3 >> 12, ] } #[inline(always)] -pub fn u260_to_u256_simd(limbs: [Simd; 5]) -> [Simd; 4] { +pub fn u255_to_u256_simd(limbs: [Simd; 5]) -> [Simd; 4] { let [l0, l1, l2, l3, l4] = limbs; [ - l0 | (l1 << 52), - (l1 >> 12) | (l2 << 40), - (l2 >> 24) | (l3 << 28), - (l3 >> 36) | (l4 << 16), + l0 | (l1 << 51), + (l1 >> 13) | (l2 << 38), + (l2 >> 26) | (l3 << 25), + (l3 >> 39) | (l4 << 12), ] } @@ -150,7 +150,7 @@ pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { let mut c = [Simd::splat(0); 5]; for i in 0..c.len() { let tmp: Simd = a[i].cast::() - b[i].cast() + borrow; - c[i] = tmp.cast().bitand(Simd::splat(MASK52)); + c[i] = tmp.cast().bitand(Simd::splat(MASK51)); borrow = tmp >> 52 } From 3e82bffa5d5fbcac874e5ad1eed50610ce2238b6 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 20 Jan 2026 11:16:02 +0800 Subject: [PATCH 14/37] kani: check conversion with kani --- .../src/portable_simd_wasm.rs | 32 +++++++-------- .../block-multiplier/src/simd_utils_wasm.rs | 41 +++++++++++++++++-- skyscraper/block-multiplier/src/test_utils.rs | 2 +- 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 0825afd6..1033f825 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -208,21 +208,21 @@ mod tests { proptest::proptest, }; - #[test] - fn test_simd_mul() { - proptest!(|( - a in safe_bn254_montgomery_input(), - b in safe_bn254_montgomery_input(), - c in safe_bn254_montgomery_input(), - )| { + // #[test] + // fn test_simd_mul() { + // proptest!(|( + // a in safe_bn254_montgomery_input(), + // b in safe_bn254_montgomery_input(), + // c in safe_bn254_montgomery_input(), + // )| { - let (ab, bc) = simd_mul(a, b, b,c); - let ab_ref = ark_ff_reference(a, b); - let bc_ref = ark_ff_reference(b, c); - let ab = Fr::new(BigInt(ab)); - let bc = Fr::new(BigInt(bc)); - assert_eq!(ab_ref, ab); - assert_eq!(bc_ref, bc); - }); - } + // let (ab, bc) = simd_mul(a, b, b,c); + // let ab_ref = ark_ff_reference(a, b); + // let bc_ref = ark_ff_reference(b, c); + // let ab = Fr::new(BigInt(ab)); + // let bc = Fr::new(BigInt(bc)); + // assert_eq!(ab_ref, ab); + // assert_eq!(bc_ref, bc); + // }); + // } } diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index 75929534..259cc24b 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -9,6 +9,7 @@ use { Simd, }, }, + std::simd::{LaneCount, SupportedLaneCount}, }; // -- [SIMD UTILS] @@ -75,19 +76,30 @@ pub fn transpose_simd_to_u256(limbs: [Simd; 4]) -> [[u64; 4]; 2] { } #[inline(always)] -pub fn u256_to_u255_simd(limbs: [Simd; 4]) -> [Simd; 5] { +/// Safety: If the input is too large for the conversion the top bit will be +/// discarded. In debug mode it will throw an error. +pub fn u256_to_u255_simd(limbs: [Simd; 4]) -> [Simd; 5] +where + LaneCount: SupportedLaneCount, +{ let [l0, l1, l2, l3] = limbs; + // Check whether the remainder of l3 fits in 51 bits -> does the input fit in + // 255 bits. + debug_assert_eq!(l3 >> 12 & Simd::splat(MASK51), l3 >> 12); [ (l0) & Simd::splat(MASK51), ((l0 >> 51) | (l1 << 13)) & Simd::splat(MASK51), ((l1 >> 38) | (l2 << 26)) & Simd::splat(MASK51), ((l2 >> 25) | (l3 << 39)) & Simd::splat(MASK51), - l3 >> 12, + l3 >> 12 & Simd::splat(MASK51), ] } #[inline(always)] -pub fn u255_to_u256_simd(limbs: [Simd; 5]) -> [Simd; 4] { +pub fn u255_to_u256_simd(limbs: [Simd; 5]) -> [Simd; 4] +where + LaneCount: SupportedLaneCount, +{ let [l0, l1, l2, l3, l4] = limbs; [ l0 | (l1 << 51), @@ -167,3 +179,26 @@ pub fn addv_simd( } va } + +#[cfg(kani)] +mod tests { + use std::simd::Simd; + + fn u255_to_u256(u: [u64; 5]) -> [u64; 4] { + crate::simd_utils_wasm::u255_to_u256_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) + } + fn u256_to_u255(u: [u64; 4]) -> [u64; 5] { + crate::simd_utils_wasm::u256_to_u255_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) + } + + #[kani::proof] + fn u256_to_u255_kani_roundtrip() { + let u: [u64; 4] = [ + kani::any(), + kani::any(), + kani::any(), + kani::any::() & 0x7fffffffffffffff, + ]; + assert_eq!(u, u255_to_u256(u256_to_u255(u))) + } +} diff --git a/skyscraper/block-multiplier/src/test_utils.rs b/skyscraper/block-multiplier/src/test_utils.rs index e46b3f25..bfbdaab3 100644 --- a/skyscraper/block-multiplier/src/test_utils.rs +++ b/skyscraper/block-multiplier/src/test_utils.rs @@ -13,7 +13,7 @@ use { /// Given a multiprecision integer in little-endian format, returns a /// `Strategy` that generates values uniformly in the range `0..=max`. -fn max_multiprecision(max: Vec) -> impl Strategy> { +pub fn max_multiprecision(max: Vec) -> impl Strategy> { // Takes ownership of a vector rather to deal with the 'static // requirement of boxed() let size = max.len(); From 55f02686b98327f21e14644ff6ce0885db134130 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 20 Jan 2026 12:00:53 +0800 Subject: [PATCH 15/37] b51: generate RHO values --- .../src/aarch64/generate_montgomery_table.py | 22 +++- .../block-multiplier/src/constants_wasm.rs | 112 ++++-------------- .../src/portable_simd_wasm.rs | 8 +- 3 files changed, 45 insertions(+), 97 deletions(-) diff --git a/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py b/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py index 2e3b2695..850b2a08 100644 --- a/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py +++ b/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py @@ -59,25 +59,39 @@ U51_i1 = pow( 2**51, -1, - 21888242871839275222246405745257275088548364400416034343698204186575808495617, + p, ) U51_i2 = pow( 2**51, -2, - 21888242871839275222246405745257275088548364400416034343698204186575808495617, + p, ) U51_i3 = pow( 2**51, -3, - 21888242871839275222246405745257275088548364400416034343698204186575808495617, + p, ) U51_i4 = pow( 2**51, -4, - 21888242871839275222246405745257275088548364400416034343698204186575808495617, + p, ) +def int_to_limbs(size, i): + mask = 2**size - 1 + limbs = [] + while i != 0: + limbs.append(i & mask) + i = i >> size + + return limbs + + +def format_limbs(limbs): + return map(lambda x: hex(x), limbs) + + def limbs_to_int(size, xs): total = 0 for i, x in enumerate(xs): diff --git a/skyscraper/block-multiplier/src/constants_wasm.rs b/skyscraper/block-multiplier/src/constants_wasm.rs index 54a3084a..78b66a8c 100644 --- a/skyscraper/block-multiplier/src/constants_wasm.rs +++ b/skyscraper/block-multiplier/src/constants_wasm.rs @@ -1,51 +1,4 @@ -pub const U64_NP0: u64 = 0xc2e1f593efffffff; - -pub const U64_P: [u64; 4] = [ - 0x43e1f593f0000001, - 0x2833e84879b97091, - 0xb85045b68181585d, - 0x30644e72e131a029, -]; - -pub const U64_2P: [u64; 4] = [ - 0x87c3eb27e0000002, - 0x5067d090f372e122, - 0x70a08b6d0302b0ba, - 0x60c89ce5c2634053, -]; - -// R mod P -pub const U64_R: [u64; 4] = [ - 0xac96341c4ffffffb, - 0x36fc76959f60cd29, - 0x666ea36f7879462e, - 0x0e0a77c19a07df2f, -]; - -// R^2 mod P -pub const U64_R2: [u64; 4] = [ - 0x1bb8e645ae216da7, - 0x53fe3ab1e35c59e3, - 0x8c49833d53bb8085, - 0x0216d0b17f4e44a5, -]; - -// R^-1 mod P -pub const U64_R_INV: [u64; 4] = [ - 0xdc5ba0056db1194e, - 0x090ef5a9e111ec87, - 0xc8260de4aeb85d5d, - 0x15ebf95182c5551c, -]; - pub const U52_NP0: u64 = 0x1f593efffffff; -pub const U52_R2: [u64; 5] = [ - 0x0b852d16da6f5, - 0xc621620cddce3, - 0xaf1b95343ffb6, - 0xc3c15e103e7c2, - 0x00281528fa122, -]; pub const U52_P: [u64; 5] = [ 0x1f593f0000001, @@ -73,68 +26,49 @@ pub const F52_P: [f64; 5] = [ pub const MASK51: u64 = 2_u64.pow(51) - 1; -pub const U64_I1: [u64; 4] = [ - 0x2d3e8053e396ee4d, - 0xca478dbeab3c92cd, - 0xb2d8f06f77f52a93, - 0x24d6ba07f7aa8f04, -]; -pub const U64_I2: [u64; 4] = [ - 0x18ee753c76f9dc6f, - 0x54ad7e14a329e70f, - 0x2b16366f4f7684df, - 0x133100d71fdf3579, -]; - -pub const U64_I3: [u64; 4] = [ - 0x9bacb016127cbe4e, - 0x0b2051fa31944124, - 0xb064eea46091c76c, - 0x2b062aaa49f80c7d, -]; -pub const U64_MU0: u64 = 0xc2e1f593efffffff; - // -- [FP SIMD CONSTANTS] // -------------------------------------------------------------------------- + pub const RHO_1: [u64; 5] = [ - 0x82e644ee4c3d2, - 0xf93893c98b1de, - 0xd46fe04d0a4c7, - 0x8f0aad55e2a1f, - 0x005ed0447de83, + 0x05cc89dc987a4, + 0x64e24f262c77a, + 0x237f02685263f, + 0x70aad55e2a1fd, + 0x0bda088fbd071, ]; pub const RHO_2: [u64; 5] = [ - 0x74eccce9a797a, - 0x16ddcc30bd8a4, - 0x49ecd3539499e, - 0xb23a6fcc592b8, - 0x00e3bd49f6ee5, + 0x3459f4a69e5e7, + 0x25faeea4c9ca7, + 0x3e771def3ca40, + 0x46003708f7bc8, + 0x088b040ada652, ]; pub const RHO_3: [u64; 5] = [ - 0x0e8c656567d77, - 0x430d05713ae61, - 0xea3ba6b167128, - 0xa7dae55c5a296, - 0x01b4afd513572, + 0x76fe2f2b3ebb4, + 0x6d028b8f2441f, + 0x461c7904ae683, + 0x71824d0dd38b7, + 0x18c6b0be26ceb, ]; pub const RHO_4: [u64; 5] = [ - 0x22e2400e2f27d, - 0x323b46ea19686, - 0xe6c43f0df672d, - 0x7824014c39e8b, - 0x00c6b48afe1b8, + 0x30bf04e2f27cc, + 0x039b11bea2ed3, + 0x2fb7665568cc8, + 0x0cc99c143d8f0, + 0x0523513296c10, ]; pub const C1: f64 = pow_2(103); pub const C2: f64 = pow_2(103) + pow_2(52) + pow_2(51); const fn pow_2(n: u32) -> f64 { + assert!(n <= 1023); // Unfortunately we can't use f64::powi in const fn yet // This is a workaround that creates the bit pattern directly - let exp = ((n as u64 + 1023) & 0x7ff) << 52; + let exp = (n as u64 + 1023) << 52; f64::from_bits(exp) } diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 1033f825..53619591 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -169,10 +169,10 @@ pub fn simd_mul( t[4 + 4 + 1] += p_hi.to_bits(); t[4 + 4] += p_lo.to_bits(); - t[1] += t[0] >> 52; - t[2] += t[1] >> 52; - t[3] += t[2] >> 52; - t[4] += t[3] >> 52; + t[1] += t[0] >> 51; + t[2] += t[1] >> 51; + t[3] += t[2] >> 51; + t[4] += t[3] >> 51; let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK51)), RHO_4); let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK51)), RHO_3); From 1f090453f0384946c1a8ebbf8b035eacc4a2d272 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 20 Jan 2026 13:56:34 +0800 Subject: [PATCH 16/37] b51: reducer from i64 -> u64 --- .../block-multiplier/src/constants_wasm.rs | 19 +++------ .../src/portable_simd_wasm.rs | 2 +- .../block-multiplier/src/simd_utils_wasm.rs | 40 ++++++++++++------- 3 files changed, 32 insertions(+), 29 deletions(-) diff --git a/skyscraper/block-multiplier/src/constants_wasm.rs b/skyscraper/block-multiplier/src/constants_wasm.rs index 78b66a8c..6acda447 100644 --- a/skyscraper/block-multiplier/src/constants_wasm.rs +++ b/skyscraper/block-multiplier/src/constants_wasm.rs @@ -1,19 +1,12 @@ +// Double check if this is still correct pub const U52_NP0: u64 = 0x1f593efffffff; -pub const U52_P: [u64; 5] = [ +pub const U51_P: [u64; 5] = [ 0x1f593f0000001, - 0x4879b9709143e, - 0x181585d2833e8, - 0xa029b85045b68, - 0x030644e72e131, -]; - -pub const U52_2P: [u64; 5] = [ - 0x3eb27e0000002, - 0x90f372e12287c, - 0x302b0ba5067d0, - 0x405370a08b6d0, - 0x060c89ce5c263, + 0x10f372e12287c, + 0x6056174a0cfa1, + 0x014dc2822db40, + 0x30644e72e131a, ]; pub const F52_P: [f64; 5] = [ diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 53619591..f381fe77 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -189,7 +189,7 @@ pub fn simd_mul( ]; let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK51)); - let mp = smult_noinit_simd(m, U52_P); + let mp = smult_noinit_simd(m, U51_P); let reduced = reduce_ct_simd(addv_simd(s, mp)); let u256_result = u255_to_u256_simd(reduced); diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index 259cc24b..e13646f9 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -1,5 +1,5 @@ use { - crate::constants_wasm::{C1, C2, MASK51, U52_2P}, + crate::constants_wasm::{C1, C2, MASK51, U51_P}, core::{ array, ops::BitAnd, @@ -143,27 +143,37 @@ pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { } #[inline(always)] -/// Resolve the carry bits in the upper parts 12b and reduce the result to -/// within < 3p -pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { +/// Resolve the carry bits in the upper parts 13b and prepare result for final +/// shift by adding p if the result is odd. +/// The final division will be taken care off by the bit packing +/// technically converts from a i64 representation to a u64 representation +/// drops off the lowest limb which got zerood out, but it still contains +/// carries as it is in redundant form +pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { // The lowest limb contains carries that still need to be applied. - let mut borrow: Simd = (red[0] >> 52).cast(); + let mut borrow = red[0] >> 51; let a = [red[1], red[2], red[3], red[4], red[5]]; - // To reduce Check whether the most significant bit is set - let mask = (a[4] >> 47).bitand(Simd::splat(1)).simd_eq(Simd::splat(0)); + let mut c = [Simd::splat(0); 5]; + let tmp = a[0] + borrow; + + // To reduce Check whether the least significant bit is set + let mask = (tmp).bitand(Simd::splat(1)).simd_eq(Simd::splat(1)); - // Select values based on the mask: if mask lane is true, use zeros, else use - // U52_2P + // Select values based on the mask: if mask lane is true, add p, else add + // zero let zeros = [Simd::splat(0); 5]; - let twop = U52_2P.map(Simd::splat); - let b: [_; 5] = array::from_fn(|i| mask.select(zeros[i], twop[i])); + let p = U51_P.map(Simd::splat); + let b: [_; 5] = array::from_fn(|i| mask.select(p[i], zeros[i])); + + let tmp: Simd = tmp + b[0].cast(); + c[0] = tmp.bitand(Simd::splat(MASK51 as i64)).cast(); + borrow = tmp >> 51; - let mut c = [Simd::splat(0); 5]; for i in 0..c.len() { - let tmp: Simd = a[i].cast::() - b[i].cast() + borrow; - c[i] = tmp.cast().bitand(Simd::splat(MASK51)); - borrow = tmp >> 52 + let tmp: Simd = a[i] + b[i].cast() + borrow; + c[i] = tmp.bitand(Simd::splat(MASK51 as i64)).cast(); + borrow = tmp >> 51 } c From 419c8e2c2fc949dd924e3f72a02ff10550f16a09 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Wed, 21 Jan 2026 10:24:22 +0800 Subject: [PATCH 17/37] b51 checkpoint: conversion from b52 to b51 (NON WORKING) --- .../src/aarch64/generate_montgomery_table.py | 1 + .../block-multiplier/src/constants_wasm.rs | 3 +- .../src/portable_simd_wasm.rs | 215 ++++++++++-------- .../block-multiplier/src/simd_utils_wasm.rs | 64 ++++-- 4 files changed, 164 insertions(+), 119 deletions(-) diff --git a/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py b/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py index 850b2a08..1e066e69 100644 --- a/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py +++ b/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py @@ -160,6 +160,7 @@ def single_step_simd_wasm(single_input_bound): if __name__ == "__main__": + print(hex(pow(-p, -1, 2**51))) # Test bounds for different input sizes test_bounds = [ ("p", p), diff --git a/skyscraper/block-multiplier/src/constants_wasm.rs b/skyscraper/block-multiplier/src/constants_wasm.rs index 6acda447..d9677662 100644 --- a/skyscraper/block-multiplier/src/constants_wasm.rs +++ b/skyscraper/block-multiplier/src/constants_wasm.rs @@ -1,5 +1,5 @@ // Double check if this is still correct -pub const U52_NP0: u64 = 0x1f593efffffff; +pub const U51_NP0: u64 = 0x1f593efffffff; pub const U51_P: [u64; 5] = [ 0x1f593f0000001, @@ -56,6 +56,7 @@ pub const RHO_4: [u64; 5] = [ pub const C1: f64 = pow_2(103); pub const C2: f64 = pow_2(103) + pow_2(52) + pow_2(51); +pub const C3: f64 = pow_2(52) + pow_2(51); const fn pow_2(n: u32) -> f64 { assert!(n <= 1023); diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index f381fe77..dfe2b293 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -3,181 +3,195 @@ use { constants_wasm::*, simd_utils_wasm::{ addv_simd, fma, i2f, make_initial, reduce_ct_simd, smult_noinit_simd, - transpose_simd_to_u256, transpose_u256_to_simd, u255_to_u256_simd, u256_to_u255_simd, + transpose_simd_to_u256, transpose_u256_to_simd, u255_to_u256_shr_1_simd, + u255_to_u256_simd, u256_to_u255_simd, }, }, core::{ ops::BitAnd, simd::{num::SimdFloat, Simd}, }, + std::simd::num::{SimdInt, SimdUint}, }; -#[inline] -pub fn simd_mul( - v0_a: [u64; 4], - v0_b: [u64; 4], - v1_a: [u64; 4], - v1_b: [u64; 4], -) -> ([u64; 4], [u64; 4]) { - let v0_a = u256_to_u255_simd(transpose_u256_to_simd([v0_a, v1_a])); - let v0_b = u256_to_u255_simd(transpose_u256_to_simd([v0_b, v1_b])); - - let mut t: [Simd; 10] = [Simd::splat(0); 10]; - t[0] = Simd::splat(make_initial(1, 0)); - t[9] = Simd::splat(make_initial(0, 6)); - t[1] = Simd::splat(make_initial(2, 1)); - t[8] = Simd::splat(make_initial(6, 7)); - t[2] = Simd::splat(make_initial(3, 2)); - t[7] = Simd::splat(make_initial(7, 8)); - t[3] = Simd::splat(make_initial(4, 3)); - t[6] = Simd::splat(make_initial(8, 9)); - t[4] = Simd::splat(make_initial(10, 4)); - t[5] = Simd::splat(make_initial(9, 10)); - +#[inline(always)] +/// i64 signifies redundant carry form +/// t initialise with right for multiplication test +/// compare with school multiplication on 51 bits. This does not require having +/// to move over carries +fn multimul(t: &mut [Simd; 10], v0_a: [Simd; 5], v0_b: [Simd; 5]) { let avi: Simd = i2f(v0_a[0]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1] += p_hi.to_bits(); - t[0] += p_lo.to_bits(); + t[1] += p_hi.to_bits().cast(); + t[0] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += p_hi.to_bits(); - t[1] += p_lo.to_bits(); + t[1 + 1] += p_hi.to_bits().cast(); + t[1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += p_hi.to_bits(); - t[2] += p_lo.to_bits(); + t[2 + 1] += p_hi.to_bits().cast(); + t[2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += p_hi.to_bits(); - t[3] += p_lo.to_bits(); + t[3 + 1] += p_hi.to_bits().cast(); + t[3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += p_hi.to_bits(); - t[4] += p_lo.to_bits(); + t[4 + 1] += p_hi.to_bits().cast(); + t[4] += p_lo.to_bits().cast(); let avi: Simd = i2f(v0_a[1]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += p_hi.to_bits(); - t[1] += p_lo.to_bits(); + t[1 + 1] += p_hi.to_bits().cast(); + t[1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1 + 1] += p_hi.to_bits(); - t[1 + 1] += p_lo.to_bits(); + t[1 + 1 + 1] += p_hi.to_bits().cast(); + t[1 + 1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 2 + 1] += p_hi.to_bits(); - t[1 + 2] += p_lo.to_bits(); + t[1 + 2 + 1] += p_hi.to_bits().cast(); + t[1 + 2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 3 + 1] += p_hi.to_bits(); - t[1 + 3] += p_lo.to_bits(); + t[1 + 3 + 1] += p_hi.to_bits().cast(); + t[1 + 3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 4 + 1] += p_hi.to_bits(); - t[1 + 4] += p_lo.to_bits(); + t[1 + 4 + 1] += p_hi.to_bits().cast(); + t[1 + 4] += p_lo.to_bits().cast(); let avi: Simd = i2f(v0_a[2]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += p_hi.to_bits(); - t[2] += p_lo.to_bits(); + t[2 + 1] += p_hi.to_bits().cast(); + t[2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1 + 1] += p_hi.to_bits(); - t[2 + 1] += p_lo.to_bits(); + t[2 + 1 + 1] += p_hi.to_bits().cast(); + t[2 + 1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 2 + 1] += p_hi.to_bits(); - t[2 + 2] += p_lo.to_bits(); + t[2 + 2 + 1] += p_hi.to_bits().cast(); + t[2 + 2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 3 + 1] += p_hi.to_bits(); - t[2 + 3] += p_lo.to_bits(); + t[2 + 3 + 1] += p_hi.to_bits().cast(); + t[2 + 3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 4 + 1] += p_hi.to_bits(); - t[2 + 4] += p_lo.to_bits(); + t[2 + 4 + 1] += p_hi.to_bits().cast(); + t[2 + 4] += p_lo.to_bits().cast(); let avi: Simd = i2f(v0_a[3]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += p_hi.to_bits(); - t[3] += p_lo.to_bits(); + t[3 + 1] += p_hi.to_bits().cast(); + t[3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1 + 1] += p_hi.to_bits(); - t[3 + 1] += p_lo.to_bits(); + t[3 + 1 + 1] += p_hi.to_bits().cast(); + t[3 + 1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 2 + 1] += p_hi.to_bits(); - t[3 + 2] += p_lo.to_bits(); + t[3 + 2 + 1] += p_hi.to_bits().cast(); + t[3 + 2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 3 + 1] += p_hi.to_bits(); - t[3 + 3] += p_lo.to_bits(); + t[3 + 3 + 1] += p_hi.to_bits().cast(); + t[3 + 3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 4 + 1] += p_hi.to_bits(); - t[3 + 4] += p_lo.to_bits(); + t[3 + 4 + 1] += p_hi.to_bits().cast(); + t[3 + 4] += p_lo.to_bits().cast(); let avi: Simd = i2f(v0_a[4]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += p_hi.to_bits(); - t[4] += p_lo.to_bits(); + t[4 + 1] += p_hi.to_bits().cast(); + t[4] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1 + 1] += p_hi.to_bits(); - t[4 + 1] += p_lo.to_bits(); + t[4 + 1 + 1] += p_hi.to_bits().cast(); + t[4 + 1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 2 + 1] += p_hi.to_bits(); - t[4 + 2] += p_lo.to_bits(); + t[4 + 2 + 1] += p_hi.to_bits().cast(); + t[4 + 2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 3 + 1] += p_hi.to_bits(); - t[4 + 3] += p_lo.to_bits(); + t[4 + 3 + 1] += p_hi.to_bits().cast(); + t[4 + 3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 4 + 1] += p_hi.to_bits(); - t[4 + 4] += p_lo.to_bits(); + t[4 + 4 + 1] += p_hi.to_bits().cast(); + t[4 + 4] += p_lo.to_bits().cast(); +} +#[inline(always)] +pub fn simd_mul( + v0_a: [u64; 4], + v0_b: [u64; 4], + v1_a: [u64; 4], + v1_b: [u64; 4], +) -> ([u64; 4], [u64; 4]) { + let v0_a = u256_to_u255_simd(transpose_u256_to_simd([v0_a, v1_a])); + let v0_b = u256_to_u255_simd(transpose_u256_to_simd([v0_b, v1_b])); + + let mut t: [Simd<_, 2>; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + + multimul(&mut t, v0_a, v0_b); + + // sign extend redundant carries t[1] += t[0] >> 51; t[2] += t[1] >> 51; t[3] += t[2] >> 51; t[4] += t[3] >> 51; - let r0 = smult_noinit_simd(t[0].bitand(Simd::splat(MASK51)), RHO_4); - let r1 = smult_noinit_simd(t[1].bitand(Simd::splat(MASK51)), RHO_3); - let r2 = smult_noinit_simd(t[2].bitand(Simd::splat(MASK51)), RHO_2); - let r3 = smult_noinit_simd(t[3].bitand(Simd::splat(MASK51)), RHO_1); + // lower 51 bits will have the right value as the carry part is either 0 or a + // multiple of -2^51 -> which prevents carry bits to leak into the lower part. + let r0 = smult_noinit_simd(t[0].cast().bitand(Simd::splat(MASK51)), RHO_4); + let r1 = smult_noinit_simd(t[1].cast().bitand(Simd::splat(MASK51)), RHO_3); + let r2 = smult_noinit_simd(t[2].cast().bitand(Simd::splat(MASK51)), RHO_2); + let r3 = smult_noinit_simd(t[3].cast().bitand(Simd::splat(MASK51)), RHO_1); let s = [ r0[0] + r1[0] + r2[0] + r3[0] + t[4], @@ -188,11 +202,13 @@ pub fn simd_mul( r0[5] + r1[5] + r2[5] + r3[5] + t[9], ]; - let m = (s[0] * Simd::splat(U52_NP0)).bitand(Simd::splat(MASK51)); - let mp = smult_noinit_simd(m, U51_P); + // The upper bits of s will not affect the lower 51 bits of the product so we + // defer the and'ing. + let m = s[0] * Simd::splat(U51_NP0 as i64); + let mp = smult_noinit_simd(m.cast().bitand(Simd::splat(MASK51)), U51_P); let reduced = reduce_ct_simd(addv_simd(s, mp)); - let u256_result = u255_to_u256_simd(reduced); + let u256_result = u255_to_u256_shr_1_simd(reduced); let v = transpose_simd_to_u256(u256_result); (v[0], v[1]) } @@ -205,24 +221,27 @@ mod tests { crate::test_utils::{ark_ff_reference, safe_bn254_montgomery_input}, ark_bn254::Fr, ark_ff::BigInt, - proptest::proptest, + proptest::{prop_assert_eq, proptest}, }; - // #[test] - // fn test_simd_mul() { - // proptest!(|( - // a in safe_bn254_montgomery_input(), - // b in safe_bn254_montgomery_input(), - // c in safe_bn254_montgomery_input(), - // )| { - - // let (ab, bc) = simd_mul(a, b, b,c); - // let ab_ref = ark_ff_reference(a, b); - // let bc_ref = ark_ff_reference(b, c); - // let ab = Fr::new(BigInt(ab)); - // let bc = Fr::new(BigInt(bc)); - // assert_eq!(ab_ref, ab); - // assert_eq!(bc_ref, bc); - // }); - // } + #[test] + fn test_simd_mul() { + proptest!(|( + mut a in safe_bn254_montgomery_input(), + mut b in safe_bn254_montgomery_input(), + mut c in safe_bn254_montgomery_input(), + )| { + + // a[3] = a[3] & (2_u64.pow(63) - 1); + // b[3] = b[3] & (2_u64.pow(63) - 1); + // c[3] = c[3] & (2_u64.pow(63) - 1); + let (ab, bc) = simd_mul(a, b, b,c); + let ab_ref = ark_ff_reference(a, b); + let bc_ref = ark_ff_reference(b, c); + let ab = Fr::new(BigInt(ab)); + let bc = Fr::new(BigInt(bc)); + prop_assert_eq!(ab_ref, ab, "mismatch: l = {:#x}, b = {:#x}", ab_ref.0.0[0], ab.0.0[0]); + prop_assert_eq!(bc_ref, bc, "mismatch: l = {:#x}, b = {:#x}", bc_ref.0.0[0], bc.0.0[0]); + }); + } } diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index e13646f9..9cb62bc1 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -1,5 +1,5 @@ use { - crate::constants_wasm::{C1, C2, MASK51, U51_P}, + crate::constants_wasm::{C1, C2, C3, MASK51, U51_P}, core::{ array, ops::BitAnd, @@ -18,6 +18,9 @@ use { /// On WASSM there is no single specialised instruction to cast an integer to a /// float. Since we are only interested in 52 bits, we can emulate it with fewer /// instructions. +/// +/// Warning: due to Rust's limitations this can not be a const function. +/// Therefore check your dependency path as this will not be optimised out. pub fn i2f(a: Simd) -> Simd { // This function has not target gating as we want to verify this function with // kani and proptest on a different platform than wasm @@ -48,9 +51,11 @@ pub fn fma(a: Simd, b: Simd, c: Simd) -> Simd { } #[inline(always)] -pub const fn make_initial(low_count: usize, high_count: usize) -> u64 { - let val = high_count * 0x467 + low_count * 0x433; - -((val as i64) << 52) as u64 +pub const fn make_initial(low_count: u64, high_count: u64) -> i64 { + let val = high_count + .wrapping_mul(C1.to_bits()) + .wrapping_add(low_count.wrapping_mul(C3.to_bits())); + -(val as i64) } #[inline(always)] @@ -85,7 +90,6 @@ where let [l0, l1, l2, l3] = limbs; // Check whether the remainder of l3 fits in 51 bits -> does the input fit in // 255 bits. - debug_assert_eq!(l3 >> 12 & Simd::splat(MASK51), l3 >> 12); [ (l0) & Simd::splat(MASK51), ((l0 >> 51) | (l1 << 13)) & Simd::splat(MASK51), @@ -110,34 +114,50 @@ where } #[inline(always)] -pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { +pub fn u255_to_u256_shr_1_simd(limbs: [Simd; 5]) -> [Simd; 4] +where + LaneCount: SupportedLaneCount, +{ + let [l0, l1, l2, l3, l4] = limbs; + [ + (l0 >> 1) | (l1 << 50), + (l1 >> 14) | (l2 << 37), + (l2 >> 27) | (l3 << 24), + (l3 >> 40) | (l4 << 11), + ] +} + +#[inline(always)] +// TODO check whether as f64 get's properly optimised away +// won't be able to tell using just assembly view +pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { let mut t = [Simd::splat(0); 6]; let s: Simd = i2f(s); let p_hi_0 = fma(s, Simd::splat(v[0] as f64), Simd::splat(C1)); let p_lo_0 = fma(s, Simd::splat(v[0] as f64), Simd::splat(C2) - p_hi_0); - t[1] += p_hi_0.to_bits(); - t[0] += p_lo_0.to_bits(); + t[1] += p_hi_0.to_bits().cast(); + t[0] += p_lo_0.to_bits().cast(); let p_hi_1 = fma(s, Simd::splat(v[1] as f64), Simd::splat(C1)); let p_lo_1 = fma(s, Simd::splat(v[1] as f64), Simd::splat(C2) - p_hi_1); - t[2] += p_hi_1.to_bits(); - t[1] += p_lo_1.to_bits(); + t[2] += p_hi_1.to_bits().cast(); + t[1] += p_lo_1.to_bits().cast(); let p_hi_2 = fma(s, Simd::splat(v[2] as f64), Simd::splat(C1)); let p_lo_2 = fma(s, Simd::splat(v[2] as f64), Simd::splat(C2) - p_hi_2); - t[3] += p_hi_2.to_bits(); - t[2] += p_lo_2.to_bits(); + t[3] += p_hi_2.to_bits().cast(); + t[2] += p_lo_2.to_bits().cast(); let p_hi_3 = fma(s, Simd::splat(v[3] as f64), Simd::splat(C1)); let p_lo_3 = fma(s, Simd::splat(v[3] as f64), Simd::splat(C2) - p_hi_3); - t[4] += p_hi_3.to_bits(); - t[3] += p_lo_3.to_bits(); + t[4] += p_hi_3.to_bits().cast(); + t[3] += p_lo_3.to_bits().cast(); let p_hi_4 = fma(s, Simd::splat(v[4] as f64), Simd::splat(C1)); let p_lo_4 = fma(s, Simd::splat(v[4] as f64), Simd::splat(C2) - p_hi_4); - t[5] += p_hi_4.to_bits(); - t[4] += p_lo_4.to_bits(); + t[5] += p_hi_4.to_bits().cast(); + t[4] += p_lo_4.to_bits().cast(); t } @@ -170,20 +190,24 @@ pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { c[0] = tmp.bitand(Simd::splat(MASK51 as i64)).cast(); borrow = tmp >> 51; - for i in 0..c.len() { + for i in 1..c.len() { let tmp: Simd = a[i] + b[i].cast() + borrow; c[i] = tmp.bitand(Simd::splat(MASK51 as i64)).cast(); borrow = tmp >> 51 } + // Check that final result is even + debug_assert!(c[0][0] & 1 == 0); + debug_assert!(c[0][1] & 1 == 0); + c } #[inline(always)] pub fn addv_simd( - mut va: [Simd; N], - vb: [Simd; N], -) -> [Simd; N] { + mut va: [Simd; N], + vb: [Simd; N], +) -> [Simd; N] { for i in 0..va.len() { va[i] += vb[i]; } From 6f11480e26c619bb13e611b6f7584ea5ef92fe57 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Wed, 21 Jan 2026 13:51:43 +0800 Subject: [PATCH 18/37] i2f: safe conversion Removes use of unsafe transmute --- skyscraper/block-multiplier/src/simd_utils_wasm.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index 9cb62bc1..7a3eb6ec 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -30,8 +30,8 @@ pub fn i2f(a: Simd) -> Simd { // to convert a to it's floating point number we subtract this again. This way // we only pay for the conversion of the lower bits and not the full 64 bits. let exponent = Simd::splat(0x433 << 52); - let a: Simd = unsafe { core::mem::transmute(a | exponent) }; - let b: Simd = unsafe { core::mem::transmute(exponent) }; + let a: Simd = Simd::::from_bits(a | exponent); + let b: Simd = Simd::::from_bits(exponent); a - b } From 68d64876fdbf938f049ef83dbe1f48f092855833 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Wed, 21 Jan 2026 13:52:52 +0800 Subject: [PATCH 19/37] b51 checkpoint: working b51 multipliers --- .../src/portable_simd_wasm.rs | 148 +++++++++++++++--- .../block-multiplier/src/simd_utils_wasm.rs | 7 +- 2 files changed, 132 insertions(+), 23 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index dfe2b293..907032a9 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -11,9 +11,21 @@ use { ops::BitAnd, simd::{num::SimdFloat, Simd}, }, - std::simd::num::{SimdInt, SimdUint}, + std::simd::{ + num::{SimdInt, SimdUint}, + LaneCount, SupportedLaneCount, + }, }; +#[inline(always)] +pub fn single_mul(a: u64, b: u64) -> (i64, i64) { + let avi: Simd = i2f(Simd::splat(a)); + let bvj: Simd = i2f(Simd::splat(b)); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + (p_lo.to_bits().cast()[0], p_hi.to_bits().cast()[0]) +} + #[inline(always)] /// i64 signifies redundant carry form /// t initialise with right for multiplication test @@ -220,28 +232,126 @@ mod tests { super::*, crate::test_utils::{ark_ff_reference, safe_bn254_montgomery_input}, ark_bn254::Fr, - ark_ff::BigInt, - proptest::{prop_assert_eq, proptest}, + ark_ff::{BigInt, PrimeField}, + proptest::{ + prelude::{prop, Strategy}, + prop_assert_eq, proptest, + }, }; #[test] fn test_simd_mul() { proptest!(|( - mut a in safe_bn254_montgomery_input(), - mut b in safe_bn254_montgomery_input(), - mut c in safe_bn254_montgomery_input(), - )| { - - // a[3] = a[3] & (2_u64.pow(63) - 1); - // b[3] = b[3] & (2_u64.pow(63) - 1); - // c[3] = c[3] & (2_u64.pow(63) - 1); - let (ab, bc) = simd_mul(a, b, b,c); - let ab_ref = ark_ff_reference(a, b); - let bc_ref = ark_ff_reference(b, c); - let ab = Fr::new(BigInt(ab)); - let bc = Fr::new(BigInt(bc)); - prop_assert_eq!(ab_ref, ab, "mismatch: l = {:#x}, b = {:#x}", ab_ref.0.0[0], ab.0.0[0]); - prop_assert_eq!(bc_ref, bc, "mismatch: l = {:#x}, b = {:#x}", bc_ref.0.0[0], bc.0.0[0]); - }); + a in limbs5_51(), + b in limbs5_51(), + // c in limbs5_51(), + )| { + + let a: [Simd;_] = a.map(Simd::splat); + let b: [Simd;_] = b.map(Simd::splat); + let a = u255_to_u256_simd(a).map(|x|x[0]); + let b = u255_to_u256_simd(b).map(|x|x[0]); + let (ab, _bc) = simd_mul(a, b, b,a); + let ab_ref = ark_ff_reference(a, b); + // let bc_ref = ark_ff_reference(b, c); + let ab = Fr::new(BigInt(ab)); + // let bc = Fr::new(BigInt(bc)); + prop_assert_eq!(ab_ref, ab, "mismatch: l = {:X}, b = {:X}", ab_ref.into_bigint(), ab.into_bigint()); + }) + } + + fn limb51() -> impl Strategy { + // Either of these is fine: + // 1) Range + 0u64..(1u64 << 51) + + // 2) Or mask (sometimes faster) + // any::().prop_map(|x| x & LIMB_MASK) + } + + fn limbs5_51() -> impl Strategy { + prop::array::uniform5(limb51()) + } + + fn school_mul(ax: [u64; 5], bx: [u64; 5]) -> [u64; 10] { + let mut t = [0; 10]; + for (ai, a) in ax.into_iter().enumerate() { + for (bi, b) in bx.into_iter().enumerate() { + let (lo, hi) = a.widening_mul(b); + let hi = hi << 13 | lo >> 51; + let lo = lo & MASK51; + t[ai + bi] += lo; + t[ai + bi + 1] += hi; + } + } + + let mut carry = 0; + let mut res = [0; 10]; + + for (i, r) in t.into_iter().enumerate() { + let tmp = r + carry; + res[i] = tmp & MASK51; + carry = tmp >> 51; + } + res + } + + fn init_t() -> [i64; 10] { + let mut count: [(u64, u64); _] = [(0, 0); 10]; + for ai in 0..5 { + for bi in 0..5 { + count[ai + bi].0 += 1; + count[ai + bi + 1].1 += 1; + } + } + + let res = count.map(|(lo, hi)| make_initial(lo, hi)); + + res + } + + fn redundant_carry(t: [i64; 10]) -> [u64; 10] { + let mut borrow = 0; + let mut res = [0; 10]; + for (i, x) in t.into_iter().enumerate() { + res[i] = ((x & MASK51 as i64) + borrow) as u64; + borrow = x >> 51; + } + res + } + + #[test] + fn redundant_form_multi_mul() { + proptest!(|(a in limbs5_51(), b in limbs5_51())|{ + let v0_a = a.map(Simd::splat); + let v0_b = b.map(Simd::splat); + let mut t = init_t().map(Simd::splat); + multimul(&mut t, v0_a, v0_b); + let school = school_mul(a,b); + let fp = redundant_carry(t.map(|x| x[0])); + + prop_assert_eq!(school, fp) + + }) + } + + #[test] + fn single_mul_test() { + proptest!(|(a in limb51(), b in limb51())|{ + let (lo,hi) = single_mul(a, b); + let hi = hi.wrapping_add(-(C1.to_bits() as i64)); + let lo = lo.wrapping_add(-(C3.to_bits() as i64)); + let lo_carry = lo >> 51; + let hi = (hi + lo_carry) as u64; + let lo = lo as u64 & 2_u64.pow(51) - 1; + let fp = (lo,hi); + + let (lo, hi) = a.widening_mul(b); + let hi = hi << 13 | lo >> 51; + let lo = lo & 2_u64.pow(51) - 1; + let school = (lo, hi); + + prop_assert_eq!(school, fp) + }) } } diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index 7a3eb6ec..625d8ae8 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -171,11 +171,10 @@ pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { /// carries as it is in redundant form pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { // The lowest limb contains carries that still need to be applied. - let mut borrow = red[0] >> 51; - let a = [red[1], red[2], red[3], red[4], red[5]]; + let a = [red[1] + (red[0] >> 51), red[2], red[3], red[4], red[5]]; let mut c = [Simd::splat(0); 5]; - let tmp = a[0] + borrow; + let tmp = a[0]; // To reduce Check whether the least significant bit is set let mask = (tmp).bitand(Simd::splat(1)).simd_eq(Simd::splat(1)); @@ -188,7 +187,7 @@ pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { let tmp: Simd = tmp + b[0].cast(); c[0] = tmp.bitand(Simd::splat(MASK51 as i64)).cast(); - borrow = tmp >> 51; + let mut borrow = tmp >> 51; for i in 1..c.len() { let tmp: Simd = a[i] + b[i].cast() + borrow; From df3ad67f4c5d72793a5cc917d6c354b8a0b21d20 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Wed, 21 Jan 2026 17:21:21 +0800 Subject: [PATCH 20/37] b51: working montgomery multiplier Lacks optimisations for anchors and carries --- .../src/portable_simd_wasm.rs | 199 +++++++++++------- .../block-multiplier/src/simd_utils_wasm.rs | 40 ++-- 2 files changed, 136 insertions(+), 103 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 907032a9..efd7546c 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -36,136 +36,161 @@ fn multimul(t: &mut [Simd; 10], v0_a: [Simd; 5], v0_b: [Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1] += p_hi.to_bits().cast(); - t[0] += p_lo.to_bits().cast(); + t[1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[0] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += p_hi.to_bits().cast(); - t[1] += p_lo.to_bits().cast(); + t[1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += p_hi.to_bits().cast(); - t[2] += p_lo.to_bits().cast(); + t[2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += p_hi.to_bits().cast(); - t[3] += p_lo.to_bits().cast(); + t[3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += p_hi.to_bits().cast(); - t[4] += p_lo.to_bits().cast(); + t[4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let avi: Simd = i2f(v0_a[1]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += p_hi.to_bits().cast(); - t[1] += p_lo.to_bits().cast(); + t[1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1 + 1] += p_hi.to_bits().cast(); - t[1 + 1] += p_lo.to_bits().cast(); + t[1 + 1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[1 + 1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 2 + 1] += p_hi.to_bits().cast(); - t[1 + 2] += p_lo.to_bits().cast(); + t[1 + 2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[1 + 2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 3 + 1] += p_hi.to_bits().cast(); - t[1 + 3] += p_lo.to_bits().cast(); + t[1 + 3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[1 + 3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 4 + 1] += p_hi.to_bits().cast(); - t[1 + 4] += p_lo.to_bits().cast(); + t[1 + 4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[1 + 4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let avi: Simd = i2f(v0_a[2]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += p_hi.to_bits().cast(); - t[2] += p_lo.to_bits().cast(); + t[2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1 + 1] += p_hi.to_bits().cast(); - t[2 + 1] += p_lo.to_bits().cast(); + t[2 + 1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[2 + 1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 2 + 1] += p_hi.to_bits().cast(); - t[2 + 2] += p_lo.to_bits().cast(); + t[2 + 2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[2 + 2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 3 + 1] += p_hi.to_bits().cast(); - t[2 + 3] += p_lo.to_bits().cast(); + t[2 + 3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[2 + 3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 4 + 1] += p_hi.to_bits().cast(); - t[2 + 4] += p_lo.to_bits().cast(); + t[2 + 4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[2 + 4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let avi: Simd = i2f(v0_a[3]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += p_hi.to_bits().cast(); - t[3] += p_lo.to_bits().cast(); + t[3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1 + 1] += p_hi.to_bits().cast(); - t[3 + 1] += p_lo.to_bits().cast(); + t[3 + 1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[3 + 1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 2 + 1] += p_hi.to_bits().cast(); - t[3 + 2] += p_lo.to_bits().cast(); + t[3 + 2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[3 + 2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 3 + 1] += p_hi.to_bits().cast(); - t[3 + 3] += p_lo.to_bits().cast(); + t[3 + 3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[3 + 3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 4 + 1] += p_hi.to_bits().cast(); - t[3 + 4] += p_lo.to_bits().cast(); + t[3 + 4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[3 + 4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let avi: Simd = i2f(v0_a[4]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += p_hi.to_bits().cast(); - t[4] += p_lo.to_bits().cast(); + t[4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1 + 1] += p_hi.to_bits().cast(); - t[4 + 1] += p_lo.to_bits().cast(); + t[4 + 1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[4 + 1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 2 + 1] += p_hi.to_bits().cast(); - t[4 + 2] += p_lo.to_bits().cast(); + t[4 + 2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[4 + 2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 3 + 1] += p_hi.to_bits().cast(); - t[4 + 3] += p_lo.to_bits().cast(); + t[4 + 3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[4 + 3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 4 + 1] += p_hi.to_bits().cast(); - t[4 + 4] += p_lo.to_bits().cast(); + t[4 + 4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); + t[4 + 4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); +} + +fn redundant_carry(t: [Simd; N]) -> [Simd; N] { + let mut borrow = Simd::splat(0); + let mut res = [Simd::splat(0); N]; + for (i, x) in t.into_iter().enumerate() { + let tmp = x + borrow; + res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); + borrow = x >> 51; + } + debug_assert!(borrow == Simd::splat(0)); + res +} + +fn redundant_carry_u64(t: [Simd; N]) -> [Simd; N] { + let mut carry = Simd::splat(0); + let mut res = [Simd::splat(0); N]; + for (i, x) in t.into_iter().enumerate() { + let tmp = x + carry; + res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); + carry = x >> 51; + } + res[N - 1] = (carry << 51) | res[N - 1]; + // debug_assert!(carry == Simd::splat(0)); + res } #[inline(always)] @@ -179,31 +204,36 @@ pub fn simd_mul( let v0_b = u256_to_u255_simd(transpose_u256_to_simd([v0_b, v1_b])); let mut t: [Simd<_, 2>; 10] = [Simd::splat(0); 10]; - t[0] = Simd::splat(make_initial(1, 0)); - t[9] = Simd::splat(make_initial(0, 6)); - t[1] = Simd::splat(make_initial(2, 1)); - t[8] = Simd::splat(make_initial(6, 7)); - t[2] = Simd::splat(make_initial(3, 2)); - t[7] = Simd::splat(make_initial(7, 8)); - t[3] = Simd::splat(make_initial(4, 3)); - t[6] = Simd::splat(make_initial(8, 9)); - t[4] = Simd::splat(make_initial(10, 4)); - t[5] = Simd::splat(make_initial(9, 10)); + // t[0] = Simd::splat(make_initial(1, 0)); + // t[9] = Simd::splat(make_initial(0, 6)); + // t[1] = Simd::splat(make_initial(2, 1)); + // t[8] = Simd::splat(make_initial(6, 7)); + // t[2] = Simd::splat(make_initial(3, 2)); + // t[7] = Simd::splat(make_initial(7, 8)); + // t[3] = Simd::splat(make_initial(4, 3)); + // t[6] = Simd::splat(make_initial(8, 9)); + // t[4] = Simd::splat(make_initial(10, 4)); + // t[5] = Simd::splat(make_initial(9, 10)); multimul(&mut t, v0_a, v0_b); // sign extend redundant carries - t[1] += t[0] >> 51; - t[2] += t[1] >> 51; - t[3] += t[2] >> 51; - t[4] += t[3] >> 51; + // t[1] += t[0] >> 51; + // t[2] += t[1] >> 51; + // t[3] += t[2] >> 51; + // t[4] += t[3] >> 51; + let t = redundant_carry(t); // lower 51 bits will have the right value as the carry part is either 0 or a // multiple of -2^51 -> which prevents carry bits to leak into the lower part. - let r0 = smult_noinit_simd(t[0].cast().bitand(Simd::splat(MASK51)), RHO_4); - let r1 = smult_noinit_simd(t[1].cast().bitand(Simd::splat(MASK51)), RHO_3); - let r2 = smult_noinit_simd(t[2].cast().bitand(Simd::splat(MASK51)), RHO_2); - let r3 = smult_noinit_simd(t[3].cast().bitand(Simd::splat(MASK51)), RHO_1); + let r0 = smult_noinit_simd(t[0], RHO_4); + let r0 = redundant_carry(r0); + let r1 = smult_noinit_simd(t[1], RHO_3); + let r1 = redundant_carry(r1); + let r2 = smult_noinit_simd(t[2], RHO_2); + let r2 = redundant_carry(r2); + let r3 = smult_noinit_simd(t[3], RHO_1); + let r3 = redundant_carry(r3); let s = [ r0[0] + r1[0] + r2[0] + r3[0] + t[4], @@ -214,12 +244,19 @@ pub fn simd_mul( r0[5] + r1[5] + r2[5] + r3[5] + t[9], ]; + let s = redundant_carry_u64(s); + // The upper bits of s will not affect the lower 51 bits of the product so we // defer the and'ing. - let m = s[0] * Simd::splat(U51_NP0 as i64); - let mp = smult_noinit_simd(m.cast().bitand(Simd::splat(MASK51)), U51_P); - - let reduced = reduce_ct_simd(addv_simd(s, mp)); + let m = (s[0] * Simd::splat(U51_NP0)) + .cast() + .bitand(Simd::splat(MASK51)); + let mp = smult_noinit_simd(m, U51_P); + let mp = redundant_carry(mp); + + let addi = redundant_carry_u64(addv_simd(s, mp)); + let reduced = reduce_ct_simd(addi); + let reduced = redundant_carry_u64(reduced); let u256_result = u255_to_u256_shr_1_simd(reduced); let v = transpose_simd_to_u256(u256_result); (v[0], v[1]) @@ -242,16 +279,15 @@ mod tests { #[test] fn test_simd_mul() { proptest!(|( - a in limbs5_51(), - b in limbs5_51(), + mut a in limbs5_51(), + mut b in limbs5_51(), // c in limbs5_51(), )| { - let a: [Simd;_] = a.map(Simd::splat); let b: [Simd;_] = b.map(Simd::splat); let a = u255_to_u256_simd(a).map(|x|x[0]); let b = u255_to_u256_simd(b).map(|x|x[0]); - let (ab, _bc) = simd_mul(a, b, b,a); + let (ab, _bc) = simd_mul(a, b,a,b); let ab_ref = ark_ff_reference(a, b); // let bc_ref = ark_ff_reference(b, c); let ab = Fr::new(BigInt(ab)); @@ -311,12 +347,14 @@ mod tests { } fn redundant_carry(t: [i64; 10]) -> [u64; 10] { - let mut borrow = 0; + let mut borrow: i64 = 0; let mut res = [0; 10]; for (i, x) in t.into_iter().enumerate() { - res[i] = ((x & MASK51 as i64) + borrow) as u64; - borrow = x >> 51; + let tmp = x + borrow; + res[i] = tmp as u64 & MASK51; + borrow = tmp >> 51; } + debug_assert!(borrow == 0); res } @@ -325,7 +363,8 @@ mod tests { proptest!(|(a in limbs5_51(), b in limbs5_51())|{ let v0_a = a.map(Simd::splat); let v0_b = b.map(Simd::splat); - let mut t = init_t().map(Simd::splat); + let mut t: [Simd<_,_>;_] = [Simd::splat(0);10]; + // let mut t = init_t().map(Simd::splat); multimul(&mut t, v0_a, v0_b); let school = school_mul(a,b); let fp = redundant_carry(t.map(|x| x[0])); diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index 625d8ae8..da0f97be 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -136,28 +136,28 @@ pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { let p_hi_0 = fma(s, Simd::splat(v[0] as f64), Simd::splat(C1)); let p_lo_0 = fma(s, Simd::splat(v[0] as f64), Simd::splat(C2) - p_hi_0); - t[1] += p_hi_0.to_bits().cast(); - t[0] += p_lo_0.to_bits().cast(); + t[1] += (p_hi_0.to_bits() - Simd::splat(C1.to_bits())).cast(); + t[0] += (p_lo_0.to_bits() - Simd::splat(C3.to_bits())).cast(); let p_hi_1 = fma(s, Simd::splat(v[1] as f64), Simd::splat(C1)); let p_lo_1 = fma(s, Simd::splat(v[1] as f64), Simd::splat(C2) - p_hi_1); - t[2] += p_hi_1.to_bits().cast(); - t[1] += p_lo_1.to_bits().cast(); + t[2] += (p_hi_1.to_bits() - Simd::splat(C1.to_bits())).cast(); + t[1] += (p_lo_1.to_bits() - Simd::splat(C3.to_bits())).cast(); let p_hi_2 = fma(s, Simd::splat(v[2] as f64), Simd::splat(C1)); let p_lo_2 = fma(s, Simd::splat(v[2] as f64), Simd::splat(C2) - p_hi_2); - t[3] += p_hi_2.to_bits().cast(); - t[2] += p_lo_2.to_bits().cast(); + t[3] += (p_hi_2.to_bits() - Simd::splat(C1.to_bits())).cast(); + t[2] += (p_lo_2.to_bits() - Simd::splat(C3.to_bits())).cast(); let p_hi_3 = fma(s, Simd::splat(v[3] as f64), Simd::splat(C1)); let p_lo_3 = fma(s, Simd::splat(v[3] as f64), Simd::splat(C2) - p_hi_3); - t[4] += p_hi_3.to_bits().cast(); - t[3] += p_lo_3.to_bits().cast(); + t[4] += (p_hi_3.to_bits() - Simd::splat(C1.to_bits())).cast(); + t[3] += (p_lo_3.to_bits() - Simd::splat(C3.to_bits())).cast(); let p_hi_4 = fma(s, Simd::splat(v[4] as f64), Simd::splat(C1)); let p_lo_4 = fma(s, Simd::splat(v[4] as f64), Simd::splat(C2) - p_hi_4); - t[5] += p_hi_4.to_bits().cast(); - t[4] += p_lo_4.to_bits().cast(); + t[5] += (p_hi_4.to_bits() - Simd::splat(C1.to_bits())).cast(); + t[4] += (p_lo_4.to_bits() - Simd::splat(C3.to_bits())).cast(); t } @@ -169,9 +169,9 @@ pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { /// technically converts from a i64 representation to a u64 representation /// drops off the lowest limb which got zerood out, but it still contains /// carries as it is in redundant form -pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { +pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { // The lowest limb contains carries that still need to be applied. - let a = [red[1] + (red[0] >> 51), red[2], red[3], red[4], red[5]]; + let a = [red[1], red[2], red[3], red[4], red[5]]; let mut c = [Simd::splat(0); 5]; let tmp = a[0]; @@ -185,14 +185,8 @@ pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { let p = U51_P.map(Simd::splat); let b: [_; 5] = array::from_fn(|i| mask.select(p[i], zeros[i])); - let tmp: Simd = tmp + b[0].cast(); - c[0] = tmp.bitand(Simd::splat(MASK51 as i64)).cast(); - let mut borrow = tmp >> 51; - - for i in 1..c.len() { - let tmp: Simd = a[i] + b[i].cast() + borrow; - c[i] = tmp.bitand(Simd::splat(MASK51 as i64)).cast(); - borrow = tmp >> 51 + for i in 0..c.len() { + c[i] = a[i] + b[i]; } // Check that final result is even @@ -204,9 +198,9 @@ pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { #[inline(always)] pub fn addv_simd( - mut va: [Simd; N], - vb: [Simd; N], -) -> [Simd; N] { + mut va: [Simd; N], + vb: [Simd; N], +) -> [Simd; N] { for i in 0..va.len() { va[i] += vb[i]; } From c0fdd6afb89dd0ad74ce8e5b207ea68072c5c4d1 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Wed, 21 Jan 2026 17:41:55 +0800 Subject: [PATCH 21/37] b51: optimise carry handling --- .../block-multiplier/src/portable_simd_wasm.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index efd7546c..0a8e5591 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -174,19 +174,19 @@ fn redundant_carry(t: [Simd; N]) -> [Simd; N] { for (i, x) in t.into_iter().enumerate() { let tmp = x + borrow; res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); - borrow = x >> 51; + borrow = tmp >> 51; } debug_assert!(borrow == Simd::splat(0)); res } -fn redundant_carry_u64(t: [Simd; N]) -> [Simd; N] { +fn redundant_carry_u64_exess(t: [Simd; N]) -> [Simd; N] { let mut carry = Simd::splat(0); let mut res = [Simd::splat(0); N]; for (i, x) in t.into_iter().enumerate() { let tmp = x + carry; res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); - carry = x >> 51; + carry = tmp >> 51; } res[N - 1] = (carry << 51) | res[N - 1]; // debug_assert!(carry == Simd::splat(0)); @@ -244,7 +244,7 @@ pub fn simd_mul( r0[5] + r1[5] + r2[5] + r3[5] + t[9], ]; - let s = redundant_carry_u64(s); + let s = redundant_carry_u64_exess(s); // The upper bits of s will not affect the lower 51 bits of the product so we // defer the and'ing. @@ -254,9 +254,9 @@ pub fn simd_mul( let mp = smult_noinit_simd(m, U51_P); let mp = redundant_carry(mp); - let addi = redundant_carry_u64(addv_simd(s, mp)); + let addi = redundant_carry_u64_exess(addv_simd(s, mp)); let reduced = reduce_ct_simd(addi); - let reduced = redundant_carry_u64(reduced); + let reduced = redundant_carry_u64_exess(reduced); let u256_result = u255_to_u256_shr_1_simd(reduced); let v = transpose_simd_to_u256(u256_result); (v[0], v[1]) From 805894c9bdba0565da07257e0833d60bc6762b2c Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Thu, 22 Jan 2026 12:25:07 +0800 Subject: [PATCH 22/37] b51: further optimise redundant carry mp variable --- .../block-multiplier/src/portable_simd_wasm.rs | 17 +++++++++++++++-- .../block-multiplier/src/simd_utils_wasm.rs | 11 ++++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 0a8e5591..d6b47485 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -180,6 +180,20 @@ fn redundant_carry(t: [Simd; N]) -> [Simd; N] { res } +fn redundant_carry_excess(t: [Simd; N]) -> [Simd; N] { + let mut borrow = Simd::splat(0); + let mut res = [Simd::splat(0); N]; + for (i, x) in t.into_iter().enumerate() { + let tmp = x + borrow; + res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); + borrow = tmp >> 51; + } + // Check whether borrow is not negative. + debug_assert!(borrow >= Simd::splat(0)); + res[N - 1] = (borrow << 51).cast() | res[N - 1]; + res +} + fn redundant_carry_u64_exess(t: [Simd; N]) -> [Simd; N] { let mut carry = Simd::splat(0); let mut res = [Simd::splat(0); N]; @@ -252,9 +266,8 @@ pub fn simd_mul( .cast() .bitand(Simd::splat(MASK51)); let mp = smult_noinit_simd(m, U51_P); - let mp = redundant_carry(mp); - let addi = redundant_carry_u64_exess(addv_simd(s, mp)); + let addi = redundant_carry_excess(addv_simd(s, mp)); let reduced = reduce_ct_simd(addi); let reduced = redundant_carry_u64_exess(reduced); let u256_result = u255_to_u256_shr_1_simd(reduced); diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index da0f97be..6cb60dfb 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -198,13 +198,14 @@ pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { #[inline(always)] pub fn addv_simd( - mut va: [Simd; N], - vb: [Simd; N], -) -> [Simd; N] { + va: [Simd; N], + vb: [Simd; N], +) -> [Simd; N] { + let mut vc = [Simd::splat(0); N]; for i in 0..va.len() { - va[i] += vb[i]; + vc[i] = va[i].cast() + vb[i]; } - va + vc } #[cfg(kani)] From d45f87ee13b861d5228d47e8ee7162c17b9033ab Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Thu, 22 Jan 2026 12:25:33 +0800 Subject: [PATCH 23/37] b51: optimise redundant carry for s --- skyscraper/block-multiplier/src/portable_simd_wasm.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index d6b47485..0c7f68a7 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -258,8 +258,6 @@ pub fn simd_mul( r0[5] + r1[5] + r2[5] + r3[5] + t[9], ]; - let s = redundant_carry_u64_exess(s); - // The upper bits of s will not affect the lower 51 bits of the product so we // defer the and'ing. let m = (s[0] * Simd::splat(U51_NP0)) From 55829ba8b9bef456e21222758ba9cb5d265abe7f Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Thu, 22 Jan 2026 12:34:43 +0800 Subject: [PATCH 24/37] b51: optimise carry for addi --- skyscraper/block-multiplier/src/portable_simd_wasm.rs | 9 +++++++-- skyscraper/block-multiplier/src/simd_utils_wasm.rs | 7 ++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 0c7f68a7..36562546 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -265,9 +265,14 @@ pub fn simd_mul( .bitand(Simd::splat(MASK51)); let mp = smult_noinit_simd(m, U51_P); - let addi = redundant_carry_excess(addv_simd(s, mp)); + let mut addi = addv_simd(s, mp); + // Move over carries before dropping last limb + addi[1] += addi[0] >> 51; + let addi = [addi[1], addi[2], addi[3], addi[4], addi[5]]; + + // 1 bit reduction to go from R^-255 to R^-256 let reduced = reduce_ct_simd(addi); - let reduced = redundant_carry_u64_exess(reduced); + let reduced = redundant_carry_excess(reduced); let u256_result = u255_to_u256_shr_1_simd(reduced); let v = transpose_simd_to_u256(u256_result); (v[0], v[1]) diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index 6cb60dfb..6fb7e945 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -169,10 +169,7 @@ pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { /// technically converts from a i64 representation to a u64 representation /// drops off the lowest limb which got zerood out, but it still contains /// carries as it is in redundant form -pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { - // The lowest limb contains carries that still need to be applied. - let a = [red[1], red[2], red[3], red[4], red[5]]; - +pub fn reduce_ct_simd(a: [Simd; 5]) -> [Simd; 5] { let mut c = [Simd::splat(0); 5]; let tmp = a[0]; @@ -182,7 +179,7 @@ pub fn reduce_ct_simd(red: [Simd; 6]) -> [Simd; 5] { // Select values based on the mask: if mask lane is true, add p, else add // zero let zeros = [Simd::splat(0); 5]; - let p = U51_P.map(Simd::splat); + let p = U51_P.map(|x| Simd::splat(x as i64)); let b: [_; 5] = array::from_fn(|i| mask.select(p[i], zeros[i])); for i in 0..c.len() { From 0fb170a2fdfa03507eb873f3e9c18e6b2126d029 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Thu, 22 Jan 2026 12:46:05 +0800 Subject: [PATCH 25/37] b51: optimises carries on t and r --- .../src/portable_simd_wasm.rs | 64 +++++-------------- .../block-multiplier/src/simd_utils_wasm.rs | 2 +- 2 files changed, 18 insertions(+), 48 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 36562546..3ecc152e 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -168,42 +168,18 @@ fn multimul(t: &mut [Simd; 10], v0_a: [Simd; 5], v0_b: [Simd(t: [Simd; N]) -> [Simd; N] { let mut borrow = Simd::splat(0); let mut res = [Simd::splat(0); N]; - for (i, x) in t.into_iter().enumerate() { - let tmp = x + borrow; + for i in 0..t.len() - 1 { + let tmp = t[i] + borrow; res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); borrow = tmp >> 51; } - debug_assert!(borrow == Simd::splat(0)); - res -} - -fn redundant_carry_excess(t: [Simd; N]) -> [Simd; N] { - let mut borrow = Simd::splat(0); - let mut res = [Simd::splat(0); N]; - for (i, x) in t.into_iter().enumerate() { - let tmp = x + borrow; - res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); - borrow = tmp >> 51; - } - // Check whether borrow is not negative. - debug_assert!(borrow >= Simd::splat(0)); - res[N - 1] = (borrow << 51).cast() | res[N - 1]; - res -} - -fn redundant_carry_u64_exess(t: [Simd; N]) -> [Simd; N] { - let mut carry = Simd::splat(0); - let mut res = [Simd::splat(0); N]; - for (i, x) in t.into_iter().enumerate() { - let tmp = x + carry; - res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); - carry = tmp >> 51; - } - res[N - 1] = (carry << 51) | res[N - 1]; - // debug_assert!(carry == Simd::splat(0)); + // Last limb should not be truncated to 51 bits. As the input value can be + // bigger than 2^255 bits. In that sense the upper limb has no redundant carry. + res[N - 1] = (t[N - 1] + borrow).cast(); res } @@ -232,22 +208,17 @@ pub fn simd_mul( multimul(&mut t, v0_a, v0_b); // sign extend redundant carries - // t[1] += t[0] >> 51; - // t[2] += t[1] >> 51; - // t[3] += t[2] >> 51; - // t[4] += t[3] >> 51; - let t = redundant_carry(t); + t[1] += t[0] >> 51; + t[2] += t[1] >> 51; + t[3] += t[2] >> 51; + t[4] += t[3] >> 51; // lower 51 bits will have the right value as the carry part is either 0 or a // multiple of -2^51 -> which prevents carry bits to leak into the lower part. - let r0 = smult_noinit_simd(t[0], RHO_4); - let r0 = redundant_carry(r0); - let r1 = smult_noinit_simd(t[1], RHO_3); - let r1 = redundant_carry(r1); - let r2 = smult_noinit_simd(t[2], RHO_2); - let r2 = redundant_carry(r2); - let r3 = smult_noinit_simd(t[3], RHO_1); - let r3 = redundant_carry(r3); + let r0 = smult_noinit_simd(t[0].cast().bitand(Simd::splat(MASK51)), RHO_4); + let r1 = smult_noinit_simd(t[1].cast().bitand(Simd::splat(MASK51)), RHO_3); + let r2 = smult_noinit_simd(t[2].cast().bitand(Simd::splat(MASK51)), RHO_2); + let r3 = smult_noinit_simd(t[3].cast().bitand(Simd::splat(MASK51)), RHO_1); let s = [ r0[0] + r1[0] + r2[0] + r3[0] + t[4], @@ -260,9 +231,7 @@ pub fn simd_mul( // The upper bits of s will not affect the lower 51 bits of the product so we // defer the and'ing. - let m = (s[0] * Simd::splat(U51_NP0)) - .cast() - .bitand(Simd::splat(MASK51)); + let m = (s[0].cast() * Simd::splat(U51_NP0)).bitand(Simd::splat(MASK51)); let mp = smult_noinit_simd(m, U51_P); let mut addi = addv_simd(s, mp); @@ -272,7 +241,8 @@ pub fn simd_mul( // 1 bit reduction to go from R^-255 to R^-256 let reduced = reduce_ct_simd(addi); - let reduced = redundant_carry_excess(reduced); + // Are the following two shifts fused? + let reduced = redundant_carry(reduced); let u256_result = u255_to_u256_shr_1_simd(reduced); let v = transpose_simd_to_u256(u256_result); (v[0], v[1]) diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index 6fb7e945..95aa0872 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -195,7 +195,7 @@ pub fn reduce_ct_simd(a: [Simd; 5]) -> [Simd; 5] { #[inline(always)] pub fn addv_simd( - va: [Simd; N], + va: [Simd; N], vb: [Simd; N], ) -> [Simd; N] { let mut vc = [Simd::splat(0); N]; From 08a055b6cfdc5b72873f598d38df39aed7ba0dbf Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Thu, 22 Jan 2026 12:54:55 +0800 Subject: [PATCH 26/37] b51: aggregrate anchor subtractions --- .../src/portable_simd_wasm.rs | 215 ++++++------------ .../block-multiplier/src/simd_utils_wasm.rs | 20 +- 2 files changed, 76 insertions(+), 159 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 3ecc152e..b09a56f8 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -6,6 +6,7 @@ use { transpose_simd_to_u256, transpose_u256_to_simd, u255_to_u256_shr_1_simd, u255_to_u256_simd, u256_to_u255_simd, }, + subarray, }, core::{ ops::BitAnd, @@ -36,136 +37,136 @@ fn multimul(t: &mut [Simd; 10], v0_a: [Simd; 5], v0_b: [Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[0] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[1] += p_hi.to_bits().cast(); + t[0] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[1 + 1] += p_hi.to_bits().cast(); + t[1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[2 + 1] += p_hi.to_bits().cast(); + t[2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[3 + 1] += p_hi.to_bits().cast(); + t[3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[4 + 1] += p_hi.to_bits().cast(); + t[4] += p_lo.to_bits().cast(); let avi: Simd = i2f(v0_a[1]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[1 + 1] += p_hi.to_bits().cast(); + t[1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[1 + 1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[1 + 1 + 1] += p_hi.to_bits().cast(); + t[1 + 1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[1 + 2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[1 + 2 + 1] += p_hi.to_bits().cast(); + t[1 + 2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[1 + 3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[1 + 3 + 1] += p_hi.to_bits().cast(); + t[1 + 3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[1 + 4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[1 + 4 + 1] += p_hi.to_bits().cast(); + t[1 + 4] += p_lo.to_bits().cast(); let avi: Simd = i2f(v0_a[2]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[2 + 1] += p_hi.to_bits().cast(); + t[2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[2 + 1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[2 + 1 + 1] += p_hi.to_bits().cast(); + t[2 + 1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[2 + 2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[2 + 2 + 1] += p_hi.to_bits().cast(); + t[2 + 2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[2 + 3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[2 + 3 + 1] += p_hi.to_bits().cast(); + t[2 + 3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[2 + 4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[2 + 4 + 1] += p_hi.to_bits().cast(); + t[2 + 4] += p_lo.to_bits().cast(); let avi: Simd = i2f(v0_a[3]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[3 + 1] += p_hi.to_bits().cast(); + t[3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[3 + 1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[3 + 1 + 1] += p_hi.to_bits().cast(); + t[3 + 1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[3 + 2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[3 + 2 + 1] += p_hi.to_bits().cast(); + t[3 + 2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[3 + 3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[3 + 3 + 1] += p_hi.to_bits().cast(); + t[3 + 3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[3 + 4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[3 + 4 + 1] += p_hi.to_bits().cast(); + t[3 + 4] += p_lo.to_bits().cast(); let avi: Simd = i2f(v0_a[4]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[4 + 1] += p_hi.to_bits().cast(); + t[4] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[4 + 1] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[4 + 1 + 1] += p_hi.to_bits().cast(); + t[4 + 1] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[2]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 2 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[4 + 2] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[4 + 2 + 1] += p_hi.to_bits().cast(); + t[4 + 2] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[3]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 3 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[4 + 3] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[4 + 3 + 1] += p_hi.to_bits().cast(); + t[4 + 3] += p_lo.to_bits().cast(); let bvj: Simd = i2f(v0_b[4]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 4 + 1] += (p_hi.to_bits() - Simd::splat(C1).to_bits()).cast(); - t[4 + 4] += (p_lo.to_bits() - Simd::splat(C3).to_bits()).cast(); + t[4 + 4 + 1] += p_hi.to_bits().cast(); + t[4 + 4] += p_lo.to_bits().cast(); } /// Deal with the redundant carries @@ -194,16 +195,16 @@ pub fn simd_mul( let v0_b = u256_to_u255_simd(transpose_u256_to_simd([v0_b, v1_b])); let mut t: [Simd<_, 2>; 10] = [Simd::splat(0); 10]; - // t[0] = Simd::splat(make_initial(1, 0)); - // t[9] = Simd::splat(make_initial(0, 6)); - // t[1] = Simd::splat(make_initial(2, 1)); - // t[8] = Simd::splat(make_initial(6, 7)); - // t[2] = Simd::splat(make_initial(3, 2)); - // t[7] = Simd::splat(make_initial(7, 8)); - // t[3] = Simd::splat(make_initial(4, 3)); - // t[6] = Simd::splat(make_initial(8, 9)); - // t[4] = Simd::splat(make_initial(10, 4)); - // t[5] = Simd::splat(make_initial(9, 10)); + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); multimul(&mut t, v0_a, v0_b); @@ -239,7 +240,8 @@ pub fn simd_mul( addi[1] += addi[0] >> 51; let addi = [addi[1], addi[2], addi[3], addi[4], addi[5]]; - // 1 bit reduction to go from R^-255 to R^-256 + // 1 bit reduction to go from R^-255 to R^-256. reduce_ct does the preparation + // and the final shift is done as part of the conversion back to u256 let reduced = reduce_ct_simd(addi); // Are the following two shifts fused? let reduced = redundant_carry(reduced); @@ -253,7 +255,7 @@ pub fn simd_mul( mod tests { use { super::*, - crate::test_utils::{ark_ff_reference, safe_bn254_montgomery_input}, + crate::test_utils::ark_ff_reference, ark_bn254::Fr, ark_ff::{BigInt, PrimeField}, proptest::{ @@ -265,8 +267,8 @@ mod tests { #[test] fn test_simd_mul() { proptest!(|( - mut a in limbs5_51(), - mut b in limbs5_51(), + a in limbs5_51(), + b in limbs5_51(), // c in limbs5_51(), )| { let a: [Simd;_] = a.map(Simd::splat); @@ -294,89 +296,4 @@ mod tests { fn limbs5_51() -> impl Strategy { prop::array::uniform5(limb51()) } - - fn school_mul(ax: [u64; 5], bx: [u64; 5]) -> [u64; 10] { - let mut t = [0; 10]; - for (ai, a) in ax.into_iter().enumerate() { - for (bi, b) in bx.into_iter().enumerate() { - let (lo, hi) = a.widening_mul(b); - let hi = hi << 13 | lo >> 51; - let lo = lo & MASK51; - t[ai + bi] += lo; - t[ai + bi + 1] += hi; - } - } - - let mut carry = 0; - let mut res = [0; 10]; - - for (i, r) in t.into_iter().enumerate() { - let tmp = r + carry; - res[i] = tmp & MASK51; - carry = tmp >> 51; - } - res - } - - fn init_t() -> [i64; 10] { - let mut count: [(u64, u64); _] = [(0, 0); 10]; - for ai in 0..5 { - for bi in 0..5 { - count[ai + bi].0 += 1; - count[ai + bi + 1].1 += 1; - } - } - - let res = count.map(|(lo, hi)| make_initial(lo, hi)); - - res - } - - fn redundant_carry(t: [i64; 10]) -> [u64; 10] { - let mut borrow: i64 = 0; - let mut res = [0; 10]; - for (i, x) in t.into_iter().enumerate() { - let tmp = x + borrow; - res[i] = tmp as u64 & MASK51; - borrow = tmp >> 51; - } - debug_assert!(borrow == 0); - res - } - - #[test] - fn redundant_form_multi_mul() { - proptest!(|(a in limbs5_51(), b in limbs5_51())|{ - let v0_a = a.map(Simd::splat); - let v0_b = b.map(Simd::splat); - let mut t: [Simd<_,_>;_] = [Simd::splat(0);10]; - // let mut t = init_t().map(Simd::splat); - multimul(&mut t, v0_a, v0_b); - let school = school_mul(a,b); - let fp = redundant_carry(t.map(|x| x[0])); - - prop_assert_eq!(school, fp) - - }) - } - - #[test] - fn single_mul_test() { - proptest!(|(a in limb51(), b in limb51())|{ - let (lo,hi) = single_mul(a, b); - let hi = hi.wrapping_add(-(C1.to_bits() as i64)); - let lo = lo.wrapping_add(-(C3.to_bits() as i64)); - let lo_carry = lo >> 51; - let hi = (hi + lo_carry) as u64; - let lo = lo as u64 & 2_u64.pow(51) - 1; - let fp = (lo,hi); - - let (lo, hi) = a.widening_mul(b); - let hi = hi << 13 | lo >> 51; - let lo = lo & 2_u64.pow(51) - 1; - let school = (lo, hi); - - prop_assert_eq!(school, fp) - }) - } } diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_utils_wasm.rs index 95aa0872..b15674e8 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_utils_wasm.rs @@ -136,28 +136,28 @@ pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { let p_hi_0 = fma(s, Simd::splat(v[0] as f64), Simd::splat(C1)); let p_lo_0 = fma(s, Simd::splat(v[0] as f64), Simd::splat(C2) - p_hi_0); - t[1] += (p_hi_0.to_bits() - Simd::splat(C1.to_bits())).cast(); - t[0] += (p_lo_0.to_bits() - Simd::splat(C3.to_bits())).cast(); + t[1] += p_hi_0.to_bits().cast(); + t[0] += p_lo_0.to_bits().cast(); let p_hi_1 = fma(s, Simd::splat(v[1] as f64), Simd::splat(C1)); let p_lo_1 = fma(s, Simd::splat(v[1] as f64), Simd::splat(C2) - p_hi_1); - t[2] += (p_hi_1.to_bits() - Simd::splat(C1.to_bits())).cast(); - t[1] += (p_lo_1.to_bits() - Simd::splat(C3.to_bits())).cast(); + t[2] += p_hi_1.to_bits().cast(); + t[1] += p_lo_1.to_bits().cast(); let p_hi_2 = fma(s, Simd::splat(v[2] as f64), Simd::splat(C1)); let p_lo_2 = fma(s, Simd::splat(v[2] as f64), Simd::splat(C2) - p_hi_2); - t[3] += (p_hi_2.to_bits() - Simd::splat(C1.to_bits())).cast(); - t[2] += (p_lo_2.to_bits() - Simd::splat(C3.to_bits())).cast(); + t[3] += p_hi_2.to_bits().cast(); + t[2] += p_lo_2.to_bits().cast(); let p_hi_3 = fma(s, Simd::splat(v[3] as f64), Simd::splat(C1)); let p_lo_3 = fma(s, Simd::splat(v[3] as f64), Simd::splat(C2) - p_hi_3); - t[4] += (p_hi_3.to_bits() - Simd::splat(C1.to_bits())).cast(); - t[3] += (p_lo_3.to_bits() - Simd::splat(C3.to_bits())).cast(); + t[4] += p_hi_3.to_bits().cast(); + t[3] += p_lo_3.to_bits().cast(); let p_hi_4 = fma(s, Simd::splat(v[4] as f64), Simd::splat(C1)); let p_lo_4 = fma(s, Simd::splat(v[4] as f64), Simd::splat(C2) - p_hi_4); - t[5] += (p_hi_4.to_bits() - Simd::splat(C1.to_bits())).cast(); - t[4] += (p_lo_4.to_bits() - Simd::splat(C3.to_bits())).cast(); + t[5] += p_hi_4.to_bits().cast(); + t[4] += p_lo_4.to_bits().cast(); t } From d97fe8769d46a22ee4ae03e3330d661d6f200400 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Thu, 22 Jan 2026 15:01:29 +0800 Subject: [PATCH 27/37] b51: sqr reduce number of multiplications --- skyscraper/block-multiplier/benches/bench.rs | 14 +- .../src/portable_simd_wasm.rs | 188 ++++++++++++++++-- 2 files changed, 187 insertions(+), 15 deletions(-) diff --git a/skyscraper/block-multiplier/benches/bench.rs b/skyscraper/block-multiplier/benches/bench.rs index 338a9446..859ae4dc 100644 --- a/skyscraper/block-multiplier/benches/bench.rs +++ b/skyscraper/block-multiplier/benches/bench.rs @@ -32,7 +32,7 @@ mod mul { } #[divan::bench] - fn simd_mul(bencher: Bencher) { + fn simd_mul_51b(bencher: Bencher) { bencher //.counter(ItemsCount::new(2usize)) .with_inputs(|| rng().random()) @@ -50,7 +50,7 @@ mod mul { }; #[divan::bench] - fn simd_mul(bencher: Bencher) { + fn simd_mul_52b(bencher: Bencher) { bencher //.counter(ItemsCount::new(2usize)) .with_inputs(|| rng().random()) @@ -119,7 +119,7 @@ mod mul { // #[divan::bench_group] mod sqr { - use {super::*, ark_ff::Field}; + use {super::*, ark_ff::Field, block_multiplier::portable_simd_wasm}; #[divan::bench] fn scalar_sqr(bencher: Bencher) { @@ -129,6 +129,14 @@ mod sqr { .bench_local_values(block_multiplier::scalar_sqr); } + #[divan::bench] + fn simd_sqr_b51(bencher: Bencher) { + bencher + //.counter(ItemsCount::new(1usize)) + .with_inputs(|| rng().random()) + .bench_local_values(|(a, b)| portable_simd_wasm::simd_sqr(a, b)); + } + #[divan::bench] fn ark_ff(bencher: Bencher) { use {ark_bn254::Fr, ark_ff::BigInt}; diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index b09a56f8..6f5d29c7 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -4,27 +4,174 @@ use { simd_utils_wasm::{ addv_simd, fma, i2f, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, transpose_u256_to_simd, u255_to_u256_shr_1_simd, - u255_to_u256_simd, u256_to_u255_simd, + u256_to_u255_simd, }, - subarray, }, core::{ ops::BitAnd, simd::{num::SimdFloat, Simd}, }, - std::simd::{ - num::{SimdInt, SimdUint}, - LaneCount, SupportedLaneCount, - }, + std::simd::num::{SimdInt, SimdUint}, }; -#[inline(always)] -pub fn single_mul(a: u64, b: u64) -> (i64, i64) { - let avi: Simd = i2f(Simd::splat(a)); - let bvj: Simd = i2f(Simd::splat(b)); +#[inline] +pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { + let v0_a = u256_to_u255_simd(transpose_u256_to_simd([v0_a, v1_a])); + + let mut t: [Simd; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + + let avi: Simd = i2f(v0_a[0]); + let bvj: Simd = i2f(v0_a[0]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[1] += p_hi.to_bits().cast(); + t[0] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[1]); let p_hi = fma(avi, bvj, Simd::splat(C1)); let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - (p_lo.to_bits().cast()[0], p_hi.to_bits().cast()[0]) + t[1 + 1] += p_hi.to_bits().cast(); + t[1] += p_lo.to_bits().cast(); + t[1 + 1] += p_hi.to_bits().cast(); + t[1] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[2]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[2 + 1] += p_hi.to_bits().cast(); + t[2] += p_lo.to_bits().cast(); + t[2 + 1] += p_hi.to_bits().cast(); + t[2] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[3 + 1] += p_hi.to_bits().cast(); + t[3] += p_lo.to_bits().cast(); + t[3 + 1] += p_hi.to_bits().cast(); + t[3] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[4 + 1] += p_hi.to_bits().cast(); + t[4] += p_lo.to_bits().cast(); + t[4 + 1] += p_hi.to_bits().cast(); + t[4] += p_lo.to_bits().cast(); + + let avi: Simd = i2f(v0_a[1]); + let bvj: Simd = i2f(v0_a[1]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[1 + 1 + 1] += p_hi.to_bits().cast(); + t[1 + 1] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[2]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[1 + 2 + 1] += p_hi.to_bits().cast(); + t[1 + 2] += p_lo.to_bits().cast(); + t[1 + 2 + 1] += p_hi.to_bits().cast(); + t[1 + 2] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[1 + 3 + 1] += p_hi.to_bits().cast(); + t[1 + 3] += p_lo.to_bits().cast(); + t[1 + 3 + 1] += p_hi.to_bits().cast(); + t[1 + 3] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[1 + 4 + 1] += p_hi.to_bits().cast(); + t[1 + 4] += p_lo.to_bits().cast(); + t[1 + 4 + 1] += p_hi.to_bits().cast(); + t[1 + 4] += p_lo.to_bits().cast(); + + let avi: Simd = i2f(v0_a[2]); + let bvj: Simd = i2f(v0_a[2]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[2 + 2 + 1] += p_hi.to_bits().cast(); + t[2 + 2] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[2 + 3 + 1] += p_hi.to_bits().cast(); + t[2 + 3] += p_lo.to_bits().cast(); + t[2 + 3 + 1] += p_hi.to_bits().cast(); + t[2 + 3] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[2 + 4 + 1] += p_hi.to_bits().cast(); + t[2 + 4] += p_lo.to_bits().cast(); + t[2 + 4 + 1] += p_hi.to_bits().cast(); + t[2 + 4] += p_lo.to_bits().cast(); + + let avi: Simd = i2f(v0_a[3]); + let bvj: Simd = i2f(v0_a[3]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[3 + 3 + 1] += p_hi.to_bits().cast(); + t[3 + 3] += p_lo.to_bits().cast(); + let bvj: Simd = i2f(v0_a[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[3 + 4 + 1] += p_hi.to_bits().cast(); + t[3 + 4] += p_lo.to_bits().cast(); + t[3 + 4 + 1] += p_hi.to_bits().cast(); + t[3 + 4] += p_lo.to_bits().cast(); + + let avi: Simd = i2f(v0_a[4]); + let bvj: Simd = i2f(v0_a[4]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[4 + 4 + 1] += p_hi.to_bits().cast(); + t[4 + 4] += p_lo.to_bits().cast(); + + t[1] += t[0] >> 51; + t[2] += t[1] >> 51; + t[3] += t[2] >> 51; + t[4] += t[3] >> 51; + + let r0 = smult_noinit_simd(t[0].cast().bitand(Simd::splat(MASK51)), RHO_4); + let r1 = smult_noinit_simd(t[1].cast().bitand(Simd::splat(MASK51)), RHO_3); + let r2 = smult_noinit_simd(t[2].cast().bitand(Simd::splat(MASK51)), RHO_2); + let r3 = smult_noinit_simd(t[3].cast().bitand(Simd::splat(MASK51)), RHO_1); + + let s = [ + r0[0] + r1[0] + r2[0] + r3[0] + t[4], + r0[1] + r1[1] + r2[1] + r3[1] + t[5], + r0[2] + r1[2] + r2[2] + r3[2] + t[6], + r0[3] + r1[3] + r2[3] + r3[3] + t[7], + r0[4] + r1[4] + r2[4] + r3[4] + t[8], + r0[5] + r1[5] + r2[5] + r3[5] + t[9], + ]; + + // The upper bits of s will not affect the lower 51 bits of the product so we + // defer the and'ing. + let m = (s[0].cast() * Simd::splat(U51_NP0)).bitand(Simd::splat(MASK51)); + let mp = smult_noinit_simd(m, U51_P); + + let mut addi = addv_simd(s, mp); + // Move over carries before dropping last limb + addi[1] += addi[0] >> 51; + let addi = [addi[1], addi[2], addi[3], addi[4], addi[5]]; + + // 1 bit reduction to go from R^-255 to R^-256. reduce_ct does the preparation + // and the final shift is done as part of the conversion back to u256 + let reduced = reduce_ct_simd(addi); + // Are the following two shifts fused? + let reduced = redundant_carry(reduced); + let u256_result = u255_to_u256_shr_1_simd(reduced); + let v = transpose_simd_to_u256(u256_result); + (v[0], v[1]) } #[inline(always)] @@ -255,7 +402,7 @@ pub fn simd_mul( mod tests { use { super::*, - crate::test_utils::ark_ff_reference, + crate::{simd_utils_wasm::u255_to_u256_simd, test_utils::ark_ff_reference}, ark_bn254::Fr, ark_ff::{BigInt, PrimeField}, proptest::{ @@ -284,6 +431,23 @@ mod tests { }) } + #[test] + fn test_simd_sqr() { + proptest!(|( + a in limbs5_51(), + b in limbs5_51(), + // c in limbs5_51(), + )| { + let a: [Simd;_] = a.map(Simd::splat); + let b: [Simd;_] = b.map(Simd::splat); + let a = u255_to_u256_simd(a).map(|x|x[0]); + let b = u255_to_u256_simd(b).map(|x|x[0]); + let (a2, _b2) = simd_mul(a, a, b, b); + let (a2s, _b2s) = simd_sqr(a, b); + prop_assert_eq!(a2, a2s); + }) + } + fn limb51() -> impl Strategy { // Either of these is fine: // 1) Range From 7a53b63da911651ad76e6264eb200cf32327a368 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Fri, 23 Jan 2026 12:25:40 +0800 Subject: [PATCH 28/37] b51: sqr reduce additions --- .../src/portable_simd_wasm.rs | 143 ++++-------------- 1 file changed, 32 insertions(+), 111 deletions(-) diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_wasm.rs index 6f5d29c7..baa78202 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_wasm.rs @@ -19,121 +19,42 @@ pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { let v0_a = u256_to_u255_simd(transpose_u256_to_simd([v0_a, v1_a])); let mut t: [Simd; 10] = [Simd::splat(0); 10]; - t[0] = Simd::splat(make_initial(1, 0)); - t[9] = Simd::splat(make_initial(0, 6)); - t[1] = Simd::splat(make_initial(2, 1)); - t[8] = Simd::splat(make_initial(6, 7)); - t[2] = Simd::splat(make_initial(3, 2)); - t[7] = Simd::splat(make_initial(7, 8)); - t[3] = Simd::splat(make_initial(4, 3)); - t[6] = Simd::splat(make_initial(8, 9)); - t[4] = Simd::splat(make_initial(10, 4)); - t[5] = Simd::splat(make_initial(9, 10)); - - let avi: Simd = i2f(v0_a[0]); - let bvj: Simd = i2f(v0_a[0]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1] += p_hi.to_bits().cast(); - t[0] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[1]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1] += p_hi.to_bits().cast(); - t[1] += p_lo.to_bits().cast(); - t[1 + 1] += p_hi.to_bits().cast(); - t[1] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[2]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 1] += p_hi.to_bits().cast(); - t[2] += p_lo.to_bits().cast(); - t[2 + 1] += p_hi.to_bits().cast(); - t[2] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[3]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 1] += p_hi.to_bits().cast(); - t[3] += p_lo.to_bits().cast(); - t[3 + 1] += p_hi.to_bits().cast(); - t[3] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[4]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 1] += p_hi.to_bits().cast(); - t[4] += p_lo.to_bits().cast(); - t[4 + 1] += p_hi.to_bits().cast(); - t[4] += p_lo.to_bits().cast(); - let avi: Simd = i2f(v0_a[1]); - let bvj: Simd = i2f(v0_a[1]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 1 + 1] += p_hi.to_bits().cast(); - t[1 + 1] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[2]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 2 + 1] += p_hi.to_bits().cast(); - t[1 + 2] += p_lo.to_bits().cast(); - t[1 + 2 + 1] += p_hi.to_bits().cast(); - t[1 + 2] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[3]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 3 + 1] += p_hi.to_bits().cast(); - t[1 + 3] += p_lo.to_bits().cast(); - t[1 + 3 + 1] += p_hi.to_bits().cast(); - t[1 + 3] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[4]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[1 + 4 + 1] += p_hi.to_bits().cast(); - t[1 + 4] += p_lo.to_bits().cast(); - t[1 + 4 + 1] += p_hi.to_bits().cast(); - t[1 + 4] += p_lo.to_bits().cast(); + for i in 0..5 { + let avi: Simd = i2f(v0_a[i]); + for j in (i + 1)..5 { + let bvj: Simd = i2f(v0_a[j]); + let p_hi = fma(avi, bvj, Simd::splat(C1)); + let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); + t[i + j + 1] += p_hi.to_bits().cast(); + t[i + j] += p_lo.to_bits().cast(); + } + } - let avi: Simd = i2f(v0_a[2]); - let bvj: Simd = i2f(v0_a[2]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 2 + 1] += p_hi.to_bits().cast(); - t[2 + 2] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[3]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 3 + 1] += p_hi.to_bits().cast(); - t[2 + 3] += p_lo.to_bits().cast(); - t[2 + 3 + 1] += p_hi.to_bits().cast(); - t[2 + 3] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[4]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[2 + 4 + 1] += p_hi.to_bits().cast(); - t[2 + 4] += p_lo.to_bits().cast(); - t[2 + 4 + 1] += p_hi.to_bits().cast(); - t[2 + 4] += p_lo.to_bits().cast(); + // On most instruction sets SIMD shift left is more expensive than SIMD + // addition. While for scalar they tend to cost the same. + for i in 1..=8 { + t[i] += t[i]; + } - let avi: Simd = i2f(v0_a[3]); - let bvj: Simd = i2f(v0_a[3]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 3 + 1] += p_hi.to_bits().cast(); - t[3 + 3] += p_lo.to_bits().cast(); - let bvj: Simd = i2f(v0_a[4]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[3 + 4 + 1] += p_hi.to_bits().cast(); - t[3 + 4] += p_lo.to_bits().cast(); - t[3 + 4 + 1] += p_hi.to_bits().cast(); - t[3 + 4] += p_lo.to_bits().cast(); + for i in 0..5 { + let avi: Simd = i2f(v0_a[i]); + let p_hi = fma(avi, avi, Simd::splat(C1)); + let p_lo = fma(avi, avi, Simd::splat(C2) - p_hi); + t[i + i + 1] += p_hi.to_bits().cast(); + t[i + i] += p_lo.to_bits().cast(); + } - let avi: Simd = i2f(v0_a[4]); - let bvj: Simd = i2f(v0_a[4]); - let p_hi = fma(avi, bvj, Simd::splat(C1)); - let p_lo = fma(avi, bvj, Simd::splat(C2) - p_hi); - t[4 + 4 + 1] += p_hi.to_bits().cast(); - t[4 + 4] += p_lo.to_bits().cast(); + t[0] += Simd::splat(make_initial(1, 0)); + t[9] += Simd::splat(make_initial(0, 6)); + t[1] += Simd::splat(make_initial(2, 1)); + t[8] += Simd::splat(make_initial(6, 7)); + t[2] += Simd::splat(make_initial(3, 2)); + t[7] += Simd::splat(make_initial(7, 8)); + t[3] += Simd::splat(make_initial(4, 3)); + t[6] += Simd::splat(make_initial(8, 9)); + t[4] += Simd::splat(make_initial(10, 4)); + t[5] += Simd::splat(make_initial(9, 10)); t[1] += t[0] >> 51; t[2] += t[1] >> 51; From 613072483f819e2364677754b5ed8d75b03b8f77 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Fri, 23 Jan 2026 12:37:13 +0800 Subject: [PATCH 29/37] kani: silence unexpected_cfg --- Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 9c51196c..0d130371 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,9 @@ license = "MIT" homepage = "https://github.com/worldfnd/ProveKit" repository = "https://github.com/worldfnd/ProveKit" +[workspace.lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(kani)'] } + [workspace.lints.clippy] cargo = "warn" perf = "warn" From c1161fffaecdf43558c41b729e6227a7b3fe0051 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Fri, 23 Jan 2026 13:54:03 +0800 Subject: [PATCH 30/37] block multiplier: reorganizing --- skyscraper/block-multiplier/benches/bench.rs | 29 ++++--- skyscraper/block-multiplier/src/block_simd.rs | 3 +- skyscraper/block-multiplier/src/constants.rs | 84 ------------------- .../{constants_wasm.rs => constants_rne.rs} | 31 +------ .../block-multiplier/src/constants_rtz.rs | 71 ++++++++++++++++ skyscraper/block-multiplier/src/lib.rs | 21 +++-- ...able_simd_wasm.rs => portable_simd_rne.rs} | 6 +- ...{portable_simd.rs => portable_simd_rtz.rs} | 16 +++- .../{simd_utils_wasm.rs => simd_rne_utils.rs} | 6 +- .../src/{simd_utils.rs => simd_rtz_utils.rs} | 2 +- 10 files changed, 127 insertions(+), 142 deletions(-) rename skyscraper/block-multiplier/src/{constants_wasm.rs => constants_rne.rs} (54%) create mode 100644 skyscraper/block-multiplier/src/constants_rtz.rs rename skyscraper/block-multiplier/src/{portable_simd_wasm.rs => portable_simd_rne.rs} (99%) rename skyscraper/block-multiplier/src/{portable_simd.rs => portable_simd_rtz.rs} (98%) rename skyscraper/block-multiplier/src/{simd_utils_wasm.rs => simd_rne_utils.rs} (96%) rename skyscraper/block-multiplier/src/{simd_utils.rs => simd_rtz_utils.rs} (98%) diff --git a/skyscraper/block-multiplier/benches/bench.rs b/skyscraper/block-multiplier/benches/bench.rs index 859ae4dc..0a8d3173 100644 --- a/skyscraper/block-multiplier/benches/bench.rs +++ b/skyscraper/block-multiplier/benches/bench.rs @@ -37,7 +37,7 @@ mod mul { //.counter(ItemsCount::new(2usize)) .with_inputs(|| rng().random()) .bench_local_values(|(a, b, c, d)| { - block_multiplier::portable_simd_wasm::simd_mul(a, b, c, d) + block_multiplier::portable_simd_rne::simd_mul(a, b, c, d) }); } @@ -51,10 +51,14 @@ mod mul { #[divan::bench] fn simd_mul_52b(bencher: Bencher) { - bencher - //.counter(ItemsCount::new(2usize)) - .with_inputs(|| rng().random()) - .bench_local_values(|(a, b, c, d)| block_multiplier::simd_mul(a, b, c, d)); + let bencher = bencher.with_inputs(|| rng().random()); + unsafe { + with_rounding_mode((), |mode_guard, _| { + bencher.bench_local_values(|(a, b, c, d)| { + block_multiplier::simd_mul(mode_guard, a, b, c, d) + }); + }); + } } #[divan::bench] @@ -119,7 +123,7 @@ mod mul { // #[divan::bench_group] mod sqr { - use {super::*, ark_ff::Field, block_multiplier::portable_simd_wasm}; + use {super::*, ark_ff::Field, block_multiplier::portable_simd_rne}; #[divan::bench] fn scalar_sqr(bencher: Bencher) { @@ -134,7 +138,7 @@ mod sqr { bencher //.counter(ItemsCount::new(1usize)) .with_inputs(|| rng().random()) - .bench_local_values(|(a, b)| portable_simd_wasm::simd_sqr(a, b)); + .bench_local_values(|(a, b)| portable_simd_rne::simd_sqr(a, b)); } #[divan::bench] @@ -226,10 +230,13 @@ mod sqr { #[divan::bench] fn simd_sqr(bencher: Bencher) { - bencher - //.counter(ItemsCount::new(2usize)) - .with_inputs(|| rng().random()) - .bench_local_values(|(a, b)| block_multiplier::simd_sqr(a, b)); + let bencher = bencher.with_inputs(|| rng().random()); + unsafe { + with_rounding_mode((), |mode_guard, _| { + bencher + .bench_local_values(|(a, b)| block_multiplier::simd_sqr(mode_guard, a, b)); + }); + } } #[divan::bench] diff --git a/skyscraper/block-multiplier/src/block_simd.rs b/skyscraper/block-multiplier/src/block_simd.rs index e770f557..2364cc11 100644 --- a/skyscraper/block-multiplier/src/block_simd.rs +++ b/skyscraper/block-multiplier/src/block_simd.rs @@ -1,7 +1,8 @@ use { crate::{ constants::*, - simd_utils::{ + constants_rtz::*, + simd_rtz_utils::{ addv_simd, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, transpose_u256_to_simd, u256_to_u260_shl2_simd, u260_to_u256_simd, }, diff --git a/skyscraper/block-multiplier/src/constants.rs b/skyscraper/block-multiplier/src/constants.rs index f9b8d82b..b4997113 100644 --- a/skyscraper/block-multiplier/src/constants.rs +++ b/skyscraper/block-multiplier/src/constants.rs @@ -38,42 +38,6 @@ pub const U64_R_INV: [u64; 4] = [ 0x15ebf95182c5551c, ]; -pub const U52_NP0: u64 = 0x1f593efffffff; -pub const U52_R2: [u64; 5] = [ - 0x0b852d16da6f5, - 0xc621620cddce3, - 0xaf1b95343ffb6, - 0xc3c15e103e7c2, - 0x00281528fa122, -]; - -pub const U52_P: [u64; 5] = [ - 0x1f593f0000001, - 0x4879b9709143e, - 0x181585d2833e8, - 0xa029b85045b68, - 0x030644e72e131, -]; - -pub const U52_2P: [u64; 5] = [ - 0x3eb27e0000002, - 0x90f372e12287c, - 0x302b0ba5067d0, - 0x405370a08b6d0, - 0x060c89ce5c263, -]; - -pub const F52_P: [f64; 5] = [ - 0x1f593f0000001_u64 as f64, - 0x4879b9709143e_u64 as f64, - 0x181585d2833e8_u64 as f64, - 0xa029b85045b68_u64 as f64, - 0x030644e72e131_u64 as f64, -]; - -pub const MASK52: u64 = 2_u64.pow(52) - 1; -pub const MASK48: u64 = 2_u64.pow(48) - 1; - pub const U64_I1: [u64; 4] = [ 0x2d3e8053e396ee4d, 0xca478dbeab3c92cd, @@ -95,54 +59,6 @@ pub const U64_I3: [u64; 4] = [ ]; pub const U64_MU0: u64 = 0xc2e1f593efffffff; -// -- [FP SIMD CONSTANTS] -// -------------------------------------------------------------------------- -pub const RHO_1: [u64; 5] = [ - 0x82e644ee4c3d2, - 0xf93893c98b1de, - 0xd46fe04d0a4c7, - 0x8f0aad55e2a1f, - 0x005ed0447de83, -]; - -pub const RHO_2: [u64; 5] = [ - 0x74eccce9a797a, - 0x16ddcc30bd8a4, - 0x49ecd3539499e, - 0xb23a6fcc592b8, - 0x00e3bd49f6ee5, -]; - -pub const RHO_3: [u64; 5] = [ - 0x0e8c656567d77, - 0x430d05713ae61, - 0xea3ba6b167128, - 0xa7dae55c5a296, - 0x01b4afd513572, -]; - -pub const RHO_4: [u64; 5] = [ - 0x22e2400e2f27d, - 0x323b46ea19686, - 0xe6c43f0df672d, - 0x7824014c39e8b, - 0x00c6b48afe1b8, -]; - -pub const C1: f64 = pow_2(104); // 2.0^104 -pub const C2: f64 = pow_2(104) + pow_2(52); // 2.0^104 + 2.0^52 - // const C3: f64 = pow_2(52); // 2.0^52 - // ------------------------------------------------------------------------------------------------- -pub const C1F51: f64 = pow_2(103); -pub const C2F51: f64 = pow_2(103) + pow_2(52) + pow_2(51); - -const fn pow_2(n: u32) -> f64 { - // Unfortunately we can't use f64::powi in const fn yet - // This is a workaround that creates the bit pattern directly - let exp = ((n as u64 + 1023) & 0x7ff) << 52; - f64::from_bits(exp) -} - // BOUNDS /// Upper bound of 2**256-2p pub const OUTPUT_MAX: [u64; 4] = [ diff --git a/skyscraper/block-multiplier/src/constants_wasm.rs b/skyscraper/block-multiplier/src/constants_rne.rs similarity index 54% rename from skyscraper/block-multiplier/src/constants_wasm.rs rename to skyscraper/block-multiplier/src/constants_rne.rs index d9677662..47ade0b3 100644 --- a/skyscraper/block-multiplier/src/constants_wasm.rs +++ b/skyscraper/block-multiplier/src/constants_rne.rs @@ -1,4 +1,5 @@ -// Double check if this is still correct +use crate::pow_2; + pub const U51_NP0: u64 = 0x1f593efffffff; pub const U51_P: [u64; 5] = [ @@ -9,19 +10,8 @@ pub const U51_P: [u64; 5] = [ 0x30644e72e131a, ]; -pub const F52_P: [f64; 5] = [ - 0x1f593f0000001_u64 as f64, - 0x4879b9709143e_u64 as f64, - 0x181585d2833e8_u64 as f64, - 0xa029b85045b68_u64 as f64, - 0x030644e72e131_u64 as f64, -]; - pub const MASK51: u64 = 2_u64.pow(51) - 1; -// -- [FP SIMD CONSTANTS] -// -------------------------------------------------------------------------- - pub const RHO_1: [u64; 5] = [ 0x05cc89dc987a4, 0x64e24f262c77a, @@ -57,20 +47,3 @@ pub const RHO_4: [u64; 5] = [ pub const C1: f64 = pow_2(103); pub const C2: f64 = pow_2(103) + pow_2(52) + pow_2(51); pub const C3: f64 = pow_2(52) + pow_2(51); - -const fn pow_2(n: u32) -> f64 { - assert!(n <= 1023); - // Unfortunately we can't use f64::powi in const fn yet - // This is a workaround that creates the bit pattern directly - let exp = (n as u64 + 1023) << 52; - f64::from_bits(exp) -} - -// BOUNDS -/// Upper bound of 2**256-2p -pub const OUTPUT_MAX: [u64; 4] = [ - 0x783c14d81ffffffe, - 0xaf982f6f0c8d1edd, - 0x8f5f7492fcfd4f45, - 0x9f37631a3d9cbfac, -]; diff --git a/skyscraper/block-multiplier/src/constants_rtz.rs b/skyscraper/block-multiplier/src/constants_rtz.rs new file mode 100644 index 00000000..2d8cbe29 --- /dev/null +++ b/skyscraper/block-multiplier/src/constants_rtz.rs @@ -0,0 +1,71 @@ +use crate::pow_2; + +pub const U52_NP0: u64 = 0x1f593efffffff; +pub const U52_R2: [u64; 5] = [ + 0x0b852d16da6f5, + 0xc621620cddce3, + 0xaf1b95343ffb6, + 0xc3c15e103e7c2, + 0x00281528fa122, +]; + +pub const U52_P: [u64; 5] = [ + 0x1f593f0000001, + 0x4879b9709143e, + 0x181585d2833e8, + 0xa029b85045b68, + 0x030644e72e131, +]; + +pub const U52_2P: [u64; 5] = [ + 0x3eb27e0000002, + 0x90f372e12287c, + 0x302b0ba5067d0, + 0x405370a08b6d0, + 0x060c89ce5c263, +]; + +pub const F52_P: [f64; 5] = [ + 0x1f593f0000001_u64 as f64, + 0x4879b9709143e_u64 as f64, + 0x181585d2833e8_u64 as f64, + 0xa029b85045b68_u64 as f64, + 0x030644e72e131_u64 as f64, +]; + +pub const MASK52: u64 = 2_u64.pow(52) - 1; + +pub const RHO_1: [u64; 5] = [ + 0x82e644ee4c3d2, + 0xf93893c98b1de, + 0xd46fe04d0a4c7, + 0x8f0aad55e2a1f, + 0x005ed0447de83, +]; + +pub const RHO_2: [u64; 5] = [ + 0x74eccce9a797a, + 0x16ddcc30bd8a4, + 0x49ecd3539499e, + 0xb23a6fcc592b8, + 0x00e3bd49f6ee5, +]; + +pub const RHO_3: [u64; 5] = [ + 0x0e8c656567d77, + 0x430d05713ae61, + 0xea3ba6b167128, + 0xa7dae55c5a296, + 0x01b4afd513572, +]; + +pub const RHO_4: [u64; 5] = [ + 0x22e2400e2f27d, + 0x323b46ea19686, + 0xe6c43f0df672d, + 0x7824014c39e8b, + 0x00c6b48afe1b8, +]; + +pub const C1: f64 = pow_2(104); // 2.0^104 +pub const C2: f64 = pow_2(104) + pow_2(52); // 2.0^104 + 2.0^52 diff --git a/skyscraper/block-multiplier/src/lib.rs b/skyscraper/block-multiplier/src/lib.rs index b1a19da3..0e858619 100644 --- a/skyscraper/block-multiplier/src/lib.rs +++ b/skyscraper/block-multiplier/src/lib.rs @@ -11,16 +11,17 @@ mod aarch64; #[cfg(target_arch = "aarch64")] mod block_simd; #[cfg(target_arch = "aarch64")] -mod portable_simd; +mod portable_simd_rtz; #[cfg(target_arch = "aarch64")] -mod simd_utils; +mod simd_rtz_utils; // pub mod block_simd_wasm; pub mod constants; -pub mod constants_wasm; -pub mod portable_simd_wasm; +pub mod constants_rne; +pub mod constants_rtz; +pub mod portable_simd_rne; mod scalar; -pub mod simd_utils_wasm; +pub mod simd_rne_utils; #[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI mod test_utils; mod utils; @@ -34,5 +35,13 @@ pub use crate::{ montgomery_square_log_interleaved_4, }, block_simd::{block_mul, block_sqr}, - portable_simd::{simd_mul, simd_sqr}, + portable_simd_rtz::{simd_mul, simd_sqr}, }; + +const fn pow_2(n: u32) -> f64 { + assert!(n <= 1023); + // Unfortunately we can't use f64::powi in const fn yet + // This is a workaround that creates the bit pattern directly + let exp = (n as u64 + 1023) << 52; + f64::from_bits(exp) +} diff --git a/skyscraper/block-multiplier/src/portable_simd_wasm.rs b/skyscraper/block-multiplier/src/portable_simd_rne.rs similarity index 99% rename from skyscraper/block-multiplier/src/portable_simd_wasm.rs rename to skyscraper/block-multiplier/src/portable_simd_rne.rs index baa78202..2e804e66 100644 --- a/skyscraper/block-multiplier/src/portable_simd_wasm.rs +++ b/skyscraper/block-multiplier/src/portable_simd_rne.rs @@ -1,7 +1,7 @@ use { crate::{ - constants_wasm::*, - simd_utils_wasm::{ + constants_rne::*, + simd_rne_utils::{ addv_simd, fma, i2f, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, transpose_u256_to_simd, u255_to_u256_shr_1_simd, u256_to_u255_simd, @@ -323,7 +323,7 @@ pub fn simd_mul( mod tests { use { super::*, - crate::{simd_utils_wasm::u255_to_u256_simd, test_utils::ark_ff_reference}, + crate::{simd_rne_utils::u255_to_u256_simd, test_utils::ark_ff_reference}, ark_bn254::Fr, ark_ff::{BigInt, PrimeField}, proptest::{ diff --git a/skyscraper/block-multiplier/src/portable_simd.rs b/skyscraper/block-multiplier/src/portable_simd_rtz.rs similarity index 98% rename from skyscraper/block-multiplier/src/portable_simd.rs rename to skyscraper/block-multiplier/src/portable_simd_rtz.rs index 5881d8bf..af5d156b 100644 --- a/skyscraper/block-multiplier/src/portable_simd.rs +++ b/skyscraper/block-multiplier/src/portable_simd_rtz.rs @@ -1,7 +1,9 @@ +// Montgomery multiplier +// Requires RTZ use { crate::{ - constants::*, - simd_utils::{ + constants_rtz::*, + simd_rtz_utils::{ addv_simd, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, transpose_u256_to_simd, u256_to_u260_shl2_simd, u260_to_u256_simd, }, @@ -11,11 +13,16 @@ use { ops::BitAnd, simd::{num::SimdFloat, Simd}, }, + fp_rounding::{RoundingGuard, Zero}, std::simd::StdFloat, }; #[inline] -pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { +pub fn simd_sqr( + _rtz: &RoundingGuard, + v0_a: [u64; 4], + v1_a: [u64; 4], +) -> ([u64; 4], [u64; 4]) { let v0_a = u256_to_u260_shl2_simd(transpose_u256_to_simd([v0_a, v1_a])); let mut t: [Simd; 10] = [Simd::splat(0); 10]; @@ -195,6 +202,7 @@ pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { #[inline] pub fn simd_mul( + _rtz: &RoundingGuard, v0_a: [u64; 4], v0_b: [u64; 4], v1_a: [u64; 4], @@ -399,7 +407,7 @@ mod tests { unsafe { with_rounding_mode((), |rtz : &fp_rounding::RoundingGuard, _| { - let (ab, bc) = simd_mul(a, b, b,c); + let (ab, bc) = simd_mul(&rtz, a, b, b,c); let ab_ref = ark_ff_reference(a, b); let bc_ref = ark_ff_reference(b, c); let ab = Fr::new(BigInt(ab)); diff --git a/skyscraper/block-multiplier/src/simd_utils_wasm.rs b/skyscraper/block-multiplier/src/simd_rne_utils.rs similarity index 96% rename from skyscraper/block-multiplier/src/simd_utils_wasm.rs rename to skyscraper/block-multiplier/src/simd_rne_utils.rs index b15674e8..adc4cd39 100644 --- a/skyscraper/block-multiplier/src/simd_utils_wasm.rs +++ b/skyscraper/block-multiplier/src/simd_rne_utils.rs @@ -1,5 +1,5 @@ use { - crate::constants_wasm::{C1, C2, C3, MASK51, U51_P}, + crate::constants_rne::{C1, C2, C3, MASK51, U51_P}, core::{ array, ops::BitAnd, @@ -210,10 +210,10 @@ mod tests { use std::simd::Simd; fn u255_to_u256(u: [u64; 5]) -> [u64; 4] { - crate::simd_utils_wasm::u255_to_u256_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) + crate::simd_rne_utils::u255_to_u256_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) } fn u256_to_u255(u: [u64; 4]) -> [u64; 5] { - crate::simd_utils_wasm::u256_to_u255_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) + crate::simd_rne_utils::u256_to_u255_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) } #[kani::proof] diff --git a/skyscraper/block-multiplier/src/simd_utils.rs b/skyscraper/block-multiplier/src/simd_rtz_utils.rs similarity index 98% rename from skyscraper/block-multiplier/src/simd_utils.rs rename to skyscraper/block-multiplier/src/simd_rtz_utils.rs index 9ce3b4f6..21fb6f04 100644 --- a/skyscraper/block-multiplier/src/simd_utils.rs +++ b/skyscraper/block-multiplier/src/simd_rtz_utils.rs @@ -1,5 +1,5 @@ use { - crate::constants::{C1, C2, MASK52, U52_2P}, + crate::constants_rtz::{C1, C2, MASK52, U52_2P}, core::{ arch::aarch64::vcvtq_f64_u64, array, From fabda22a94c487bf7eb7f7cf37940d06373f7d07 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Fri, 23 Jan 2026 14:04:02 +0800 Subject: [PATCH 31/37] block-multiplier: rne organisation --- skyscraper/block-multiplier/benches/bench.rs | 6 +++--- skyscraper/block-multiplier/src/lib.rs | 4 +--- .../src/{constants_rne.rs => rne/constants.rs} | 0 skyscraper/block-multiplier/src/rne/mod.rs | 5 +++++ .../src/{portable_simd_rne.rs => rne/portable_simd.rs} | 8 ++++---- .../src/{simd_rne_utils.rs => rne/simd_utils.rs} | 2 +- 6 files changed, 14 insertions(+), 11 deletions(-) rename skyscraper/block-multiplier/src/{constants_rne.rs => rne/constants.rs} (100%) create mode 100644 skyscraper/block-multiplier/src/rne/mod.rs rename skyscraper/block-multiplier/src/{portable_simd_rne.rs => rne/portable_simd.rs} (99%) rename skyscraper/block-multiplier/src/{simd_rne_utils.rs => rne/simd_utils.rs} (99%) diff --git a/skyscraper/block-multiplier/benches/bench.rs b/skyscraper/block-multiplier/benches/bench.rs index 0a8d3173..25020d6e 100644 --- a/skyscraper/block-multiplier/benches/bench.rs +++ b/skyscraper/block-multiplier/benches/bench.rs @@ -37,7 +37,7 @@ mod mul { //.counter(ItemsCount::new(2usize)) .with_inputs(|| rng().random()) .bench_local_values(|(a, b, c, d)| { - block_multiplier::portable_simd_rne::simd_mul(a, b, c, d) + block_multiplier::rne::portable_simd::simd_mul(a, b, c, d) }); } @@ -123,7 +123,7 @@ mod mul { // #[divan::bench_group] mod sqr { - use {super::*, ark_ff::Field, block_multiplier::portable_simd_rne}; + use {super::*, ark_ff::Field, block_multiplier::rne}; #[divan::bench] fn scalar_sqr(bencher: Bencher) { @@ -138,7 +138,7 @@ mod sqr { bencher //.counter(ItemsCount::new(1usize)) .with_inputs(|| rng().random()) - .bench_local_values(|(a, b)| portable_simd_rne::simd_sqr(a, b)); + .bench_local_values(|(a, b)| rne::simd_sqr(a, b)); } #[divan::bench] diff --git a/skyscraper/block-multiplier/src/lib.rs b/skyscraper/block-multiplier/src/lib.rs index 0e858619..f63d8489 100644 --- a/skyscraper/block-multiplier/src/lib.rs +++ b/skyscraper/block-multiplier/src/lib.rs @@ -17,11 +17,9 @@ mod simd_rtz_utils; // pub mod block_simd_wasm; pub mod constants; -pub mod constants_rne; pub mod constants_rtz; -pub mod portable_simd_rne; +pub mod rne; mod scalar; -pub mod simd_rne_utils; #[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI mod test_utils; mod utils; diff --git a/skyscraper/block-multiplier/src/constants_rne.rs b/skyscraper/block-multiplier/src/rne/constants.rs similarity index 100% rename from skyscraper/block-multiplier/src/constants_rne.rs rename to skyscraper/block-multiplier/src/rne/constants.rs diff --git a/skyscraper/block-multiplier/src/rne/mod.rs b/skyscraper/block-multiplier/src/rne/mod.rs new file mode 100644 index 00000000..b66b1b03 --- /dev/null +++ b/skyscraper/block-multiplier/src/rne/mod.rs @@ -0,0 +1,5 @@ +pub mod constants; +pub mod portable_simd; +pub mod simd_utils; + +pub use {constants::*, portable_simd::*, simd_utils::*}; diff --git a/skyscraper/block-multiplier/src/portable_simd_rne.rs b/skyscraper/block-multiplier/src/rne/portable_simd.rs similarity index 99% rename from skyscraper/block-multiplier/src/portable_simd_rne.rs rename to skyscraper/block-multiplier/src/rne/portable_simd.rs index 2e804e66..0586c9b7 100644 --- a/skyscraper/block-multiplier/src/portable_simd_rne.rs +++ b/skyscraper/block-multiplier/src/rne/portable_simd.rs @@ -1,7 +1,7 @@ use { - crate::{ - constants_rne::*, - simd_rne_utils::{ + crate::rne::{ + constants::*, + simd_utils::{ addv_simd, fma, i2f, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, transpose_u256_to_simd, u255_to_u256_shr_1_simd, u256_to_u255_simd, @@ -323,7 +323,7 @@ pub fn simd_mul( mod tests { use { super::*, - crate::{simd_rne_utils::u255_to_u256_simd, test_utils::ark_ff_reference}, + crate::{rne::simd_utils::u255_to_u256_simd, test_utils::ark_ff_reference}, ark_bn254::Fr, ark_ff::{BigInt, PrimeField}, proptest::{ diff --git a/skyscraper/block-multiplier/src/simd_rne_utils.rs b/skyscraper/block-multiplier/src/rne/simd_utils.rs similarity index 99% rename from skyscraper/block-multiplier/src/simd_rne_utils.rs rename to skyscraper/block-multiplier/src/rne/simd_utils.rs index adc4cd39..44d32d20 100644 --- a/skyscraper/block-multiplier/src/simd_rne_utils.rs +++ b/skyscraper/block-multiplier/src/rne/simd_utils.rs @@ -1,5 +1,5 @@ use { - crate::constants_rne::{C1, C2, C3, MASK51, U51_P}, + crate::rne::constants::{C1, C2, C3, MASK51, U51_P}, core::{ array, ops::BitAnd, From d1479f7f8eea0ab4e9f669b75cc9f7507276b040 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Fri, 23 Jan 2026 14:14:36 +0800 Subject: [PATCH 32/37] block-multiplier: rtz organisation --- skyscraper/block-multiplier/benches/bench.rs | 13 +++++----- skyscraper/block-multiplier/src/lib.rs | 25 ++++++------------- .../src/{ => rtz}/block_simd.rs | 10 +++++--- .../{constants_rtz.rs => rtz/constants.rs} | 0 skyscraper/block-multiplier/src/rtz/mod.rs | 6 +++++ .../portable_simd.rs} | 6 ++--- .../{simd_rtz_utils.rs => rtz/simd_utils.rs} | 2 +- 7 files changed, 31 insertions(+), 31 deletions(-) rename skyscraper/block-multiplier/src/{ => rtz}/block_simd.rs (98%) rename skyscraper/block-multiplier/src/{constants_rtz.rs => rtz/constants.rs} (100%) create mode 100644 skyscraper/block-multiplier/src/rtz/mod.rs rename skyscraper/block-multiplier/src/{portable_simd_rtz.rs => rtz/portable_simd.rs} (99%) rename skyscraper/block-multiplier/src/{simd_rtz_utils.rs => rtz/simd_utils.rs} (98%) diff --git a/skyscraper/block-multiplier/benches/bench.rs b/skyscraper/block-multiplier/benches/bench.rs index 25020d6e..fd1268f7 100644 --- a/skyscraper/block-multiplier/benches/bench.rs +++ b/skyscraper/block-multiplier/benches/bench.rs @@ -50,12 +50,12 @@ mod mul { }; #[divan::bench] - fn simd_mul_52b(bencher: Bencher) { + fn simd_mul_rtz(bencher: Bencher) { let bencher = bencher.with_inputs(|| rng().random()); unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b, c, d)| { - block_multiplier::simd_mul(mode_guard, a, b, c, d) + block_multiplier::rtz::simd_mul(mode_guard, a, b, c, d) }); }); } @@ -69,7 +69,7 @@ mod mul { unsafe { with_rounding_mode((), |guard, _| { bencher.bench_local_values(|(a, b, c, d, e, f)| { - block_multiplier::block_mul(guard, a, b, c, d, e, f) + block_multiplier::rtz::block_mul(guard, a, b, c, d, e, f) }); }); } @@ -233,8 +233,9 @@ mod sqr { let bencher = bencher.with_inputs(|| rng().random()); unsafe { with_rounding_mode((), |mode_guard, _| { - bencher - .bench_local_values(|(a, b)| block_multiplier::simd_sqr(mode_guard, a, b)); + bencher.bench_local_values(|(a, b)| { + block_multiplier::rtz::simd_sqr(mode_guard, a, b) + }); }); } } @@ -247,7 +248,7 @@ mod sqr { unsafe { with_rounding_mode((), |guard, _| { bencher.bench_local_values(|(a, b, c)| { - block_multiplier::block_sqr(guard, a, b, c) + block_multiplier::rtz::block_sqr(guard, a, b, c) }); }); } diff --git a/skyscraper/block-multiplier/src/lib.rs b/skyscraper/block-multiplier/src/lib.rs index f63d8489..b8c33b08 100644 --- a/skyscraper/block-multiplier/src/lib.rs +++ b/skyscraper/block-multiplier/src/lib.rs @@ -9,32 +9,23 @@ mod aarch64; // These can be made to work on x86, // but for now it uses an ARM NEON intrinsic. #[cfg(target_arch = "aarch64")] -mod block_simd; -#[cfg(target_arch = "aarch64")] -mod portable_simd_rtz; -#[cfg(target_arch = "aarch64")] -mod simd_rtz_utils; +pub mod rtz; -// pub mod block_simd_wasm; pub mod constants; -pub mod constants_rtz; pub mod rne; mod scalar; +mod utils; + #[cfg(not(target_arch = "wasm32"))] // Proptest not supported on WASI mod test_utils; -mod utils; -pub use crate::scalar::{scalar_mul, scalar_sqr}; #[cfg(target_arch = "aarch64")] -pub use crate::{ - aarch64::{ - montgomery_interleaved_3, montgomery_interleaved_4, montgomery_square_interleaved_3, - montgomery_square_interleaved_4, montgomery_square_log_interleaved_3, - montgomery_square_log_interleaved_4, - }, - block_simd::{block_mul, block_sqr}, - portable_simd_rtz::{simd_mul, simd_sqr}, +pub use crate::aarch64::{ + montgomery_interleaved_3, montgomery_interleaved_4, montgomery_square_interleaved_3, + montgomery_square_interleaved_4, montgomery_square_log_interleaved_3, + montgomery_square_log_interleaved_4, }; +pub use crate::scalar::{scalar_mul, scalar_sqr}; const fn pow_2(n: u32) -> f64 { assert!(n <= 1023); diff --git a/skyscraper/block-multiplier/src/block_simd.rs b/skyscraper/block-multiplier/src/rtz/block_simd.rs similarity index 98% rename from skyscraper/block-multiplier/src/block_simd.rs rename to skyscraper/block-multiplier/src/rtz/block_simd.rs index 2364cc11..b261cb45 100644 --- a/skyscraper/block-multiplier/src/block_simd.rs +++ b/skyscraper/block-multiplier/src/rtz/block_simd.rs @@ -1,10 +1,12 @@ use { crate::{ constants::*, - constants_rtz::*, - simd_rtz_utils::{ - addv_simd, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, - transpose_u256_to_simd, u256_to_u260_shl2_simd, u260_to_u256_simd, + rtz::{ + constants::*, + simd_utils::{ + addv_simd, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, + transpose_u256_to_simd, u256_to_u260_shl2_simd, u260_to_u256_simd, + }, }, subarray, utils::{addv, carrying_mul_add, reduce_ct}, diff --git a/skyscraper/block-multiplier/src/constants_rtz.rs b/skyscraper/block-multiplier/src/rtz/constants.rs similarity index 100% rename from skyscraper/block-multiplier/src/constants_rtz.rs rename to skyscraper/block-multiplier/src/rtz/constants.rs diff --git a/skyscraper/block-multiplier/src/rtz/mod.rs b/skyscraper/block-multiplier/src/rtz/mod.rs new file mode 100644 index 00000000..8f8dc1a0 --- /dev/null +++ b/skyscraper/block-multiplier/src/rtz/mod.rs @@ -0,0 +1,6 @@ +pub mod block_simd; +pub mod constants; +pub mod portable_simd; +pub mod simd_utils; + +pub use {block_simd::*, constants::*, portable_simd::*, simd_utils::*}; diff --git a/skyscraper/block-multiplier/src/portable_simd_rtz.rs b/skyscraper/block-multiplier/src/rtz/portable_simd.rs similarity index 99% rename from skyscraper/block-multiplier/src/portable_simd_rtz.rs rename to skyscraper/block-multiplier/src/rtz/portable_simd.rs index af5d156b..1907a2b0 100644 --- a/skyscraper/block-multiplier/src/portable_simd_rtz.rs +++ b/skyscraper/block-multiplier/src/rtz/portable_simd.rs @@ -1,9 +1,9 @@ // Montgomery multiplier // Requires RTZ use { - crate::{ - constants_rtz::*, - simd_rtz_utils::{ + crate::rtz::{ + constants::*, + simd_utils::{ addv_simd, make_initial, reduce_ct_simd, smult_noinit_simd, transpose_simd_to_u256, transpose_u256_to_simd, u256_to_u260_shl2_simd, u260_to_u256_simd, }, diff --git a/skyscraper/block-multiplier/src/simd_rtz_utils.rs b/skyscraper/block-multiplier/src/rtz/simd_utils.rs similarity index 98% rename from skyscraper/block-multiplier/src/simd_rtz_utils.rs rename to skyscraper/block-multiplier/src/rtz/simd_utils.rs index 21fb6f04..144951ff 100644 --- a/skyscraper/block-multiplier/src/simd_rtz_utils.rs +++ b/skyscraper/block-multiplier/src/rtz/simd_utils.rs @@ -1,5 +1,5 @@ use { - crate::constants_rtz::{C1, C2, MASK52, U52_2P}, + crate::rtz::constants::{C1, C2, MASK52, U52_2P}, core::{ arch::aarch64::vcvtq_f64_u64, array, From ebc5d7849c6882c46bf3de483d94453368167055 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Fri, 23 Jan 2026 14:22:04 +0800 Subject: [PATCH 33/37] block-multiplier -> bn254-multiplier --- .gitignore | 2 +- Cargo.toml | 8 +++--- .../proptest-regressions/scalar.txt | 8 ------ .../.gitignore | 4 +-- .../Cargo.toml | 2 +- .../README.md | 6 ++-- .../src/constants.rs | 0 .../src/lib.rs | 0 .../src/load_store.rs | 0 .../src/main.rs | 2 +- .../src/scalar.rs | 0 .../src/simd.rs | 0 .../Cargo.toml | 4 +-- .../benches/bench.rs | 28 +++++++++---------- .../build.rs | 2 +- .../src/aarch64/generate_montgomery_table.py | 0 .../src/aarch64/mod.rs | 0 .../src/aarch64/montgomery_interleaved_3.s | 0 .../src/aarch64/montgomery_interleaved_4.s | 0 .../aarch64/montgomery_square_interleaved_3.s | 0 .../aarch64/montgomery_square_interleaved_4.s | 0 .../montgomery_square_log_interleaved_3.s | 0 .../montgomery_square_log_interleaved_4.s | 0 .../src/constants.rs | 0 .../src/lib.rs | 0 .../src/rne/constants.rs | 0 .../src/rne/mod.rs | 0 .../src/rne/portable_simd.rs | 0 .../src/rne/simd_utils.rs | 0 .../src/rtz/block_simd.rs | 0 .../src/rtz/constants.rs | 0 .../src/rtz/mod.rs | 0 .../src/rtz/portable_simd.rs | 0 .../src/rtz/simd_utils.rs | 0 .../src/scalar.rs | 0 .../src/test_utils.rs | 0 .../src/utils.rs | 2 +- skyscraper/core/Cargo.toml | 2 +- skyscraper/core/benches/bench.rs | 2 +- skyscraper/core/src/block3.rs | 2 +- skyscraper/core/src/block4.rs | 2 +- skyscraper/core/src/simple.rs | 2 +- skyscraper/core/src/v1.rs | 2 +- 43 files changed, 36 insertions(+), 44 deletions(-) delete mode 100644 skyscraper/block-multiplier/proptest-regressions/scalar.txt rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/.gitignore (63%) rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/Cargo.toml (88%) rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/README.md (71%) rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/src/constants.rs (100%) rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/src/lib.rs (100%) rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/src/load_store.rs (100%) rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/src/main.rs (97%) rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/src/scalar.rs (100%) rename skyscraper/{block-multiplier-codegen => bn254-multiplier-codegen}/src/simd.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/Cargo.toml (91%) rename skyscraper/{block-multiplier => bn254-multiplier}/benches/bench.rs (89%) rename skyscraper/{block-multiplier => bn254-multiplier}/build.rs (97%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/aarch64/generate_montgomery_table.py (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/aarch64/mod.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/aarch64/montgomery_interleaved_3.s (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/aarch64/montgomery_interleaved_4.s (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/aarch64/montgomery_square_interleaved_3.s (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/aarch64/montgomery_square_interleaved_4.s (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/aarch64/montgomery_square_log_interleaved_3.s (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/aarch64/montgomery_square_log_interleaved_4.s (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/constants.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/lib.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rne/constants.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rne/mod.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rne/portable_simd.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rne/simd_utils.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rtz/block_simd.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rtz/constants.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rtz/mod.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rtz/portable_simd.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/rtz/simd_utils.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/scalar.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/test_utils.rs (100%) rename skyscraper/{block-multiplier => bn254-multiplier}/src/utils.rs (98%) diff --git a/.gitignore b/.gitignore index f770c0ae..165e92b5 100644 --- a/.gitignore +++ b/.gitignore @@ -43,4 +43,4 @@ Cargo.lock # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -circuit_stats_examples/ \ No newline at end of file +circuit_stats_examples/ diff --git a/Cargo.toml b/Cargo.toml index 0d130371..e7b31656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,8 +3,8 @@ resolver = "2" members = [ "skyscraper/fp-rounding", "skyscraper/hla", - "skyscraper/block-multiplier", - "skyscraper/block-multiplier-codegen", + "skyscraper/bn254-multiplier", + "skyscraper/bn254-multiplier-codegen", "skyscraper/core", "provekit/common", "provekit/r1cs-compiler", @@ -73,8 +73,8 @@ opt-level = 3 [workspace.dependencies] # Workspace members - Skyscraper -block-multiplier = { path = "skyscraper/block-multiplier" } -block-multiplier-codegen = { path = "skyscraper/block-multiplier-codegen" } +bn254-multiplier = { path = "skyscraper/bn254-multiplier" } +bn254-multiplier-codegen = { path = "skyscraper/bn254-multiplier-codegen" } fp-rounding = { path = "skyscraper/fp-rounding" } hla = { path = "skyscraper/hla" } skyscraper = { path = "skyscraper/core" } diff --git a/skyscraper/block-multiplier/proptest-regressions/scalar.txt b/skyscraper/block-multiplier/proptest-regressions/scalar.txt deleted file mode 100644 index 4715d78f..00000000 --- a/skyscraper/block-multiplier/proptest-regressions/scalar.txt +++ /dev/null @@ -1,8 +0,0 @@ -# Seeds for failure cases proptest has generated in the past. It is -# automatically read and these particular cases re-run before any -# novel cases are generated. -# -# It is recommended to check this file in to source control so that -# everyone who runs the test benefits from these saved cases. -cc 46acc9f3c07fefb126b59a0edec37c56f92c16c1468989ed132bf42ef54ffe86 # shrinks to l = [0, 0, 0, 1], r = [0, 0, 0, 1] -cc e629632cdf5eb4aefd4fdb2da29bdbd7b2a177a69dd74f99f70683f11c942da7 # shrinks to l = [0, 887, 0, 15778841185528309819], r = [458854615557053794, 8784556235901218364, 1751211468174275388, 16873806747226852460] diff --git a/skyscraper/block-multiplier-codegen/.gitignore b/skyscraper/bn254-multiplier-codegen/.gitignore similarity index 63% rename from skyscraper/block-multiplier-codegen/.gitignore rename to skyscraper/bn254-multiplier-codegen/.gitignore index ab9cdb40..8e3e5af3 100644 --- a/skyscraper/block-multiplier-codegen/.gitignore +++ b/skyscraper/bn254-multiplier-codegen/.gitignore @@ -1,2 +1,2 @@ -# We don't include the inline rust generated files as they will be part of block-multiplier-sys -asm/ \ No newline at end of file +# We don't include the inline rust generated files as they will be part of bn254-multiplier-sys +asm/ diff --git a/skyscraper/block-multiplier-codegen/Cargo.toml b/skyscraper/bn254-multiplier-codegen/Cargo.toml similarity index 88% rename from skyscraper/block-multiplier-codegen/Cargo.toml rename to skyscraper/bn254-multiplier-codegen/Cargo.toml index 946f023d..d8a7b8f1 100644 --- a/skyscraper/block-multiplier-codegen/Cargo.toml +++ b/skyscraper/bn254-multiplier-codegen/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "block-multiplier-codegen" +name = "bn254-multiplier-codegen" version = "0.1.0" edition.workspace = true rust-version.workspace = true diff --git a/skyscraper/block-multiplier-codegen/README.md b/skyscraper/bn254-multiplier-codegen/README.md similarity index 71% rename from skyscraper/block-multiplier-codegen/README.md rename to skyscraper/bn254-multiplier-codegen/README.md index f929636d..270d99d1 100644 --- a/skyscraper/block-multiplier-codegen/README.md +++ b/skyscraper/bn254-multiplier-codegen/README.md @@ -6,12 +6,12 @@ This crate contains a binary that generates optimized assembly code for block mu 1. **Run the binary:** ```bash - cargo run --package block-multiplier-codegen + cargo run --package bn254-multiplier-codegen ``` This will execute the `main` function in `src/main.rs`. 2. **Generated File:** The binary will generate an assembly file named `asm/montgomery_interleaved.s` within this crate's directory. -3. **Integrate into `block-multiplier-sys`:** - Copy the contents of the generated `asm/montgomery_interleaved.s` file. Paste this assembly code into the appropriate location within the `block-multiplier-sys` crate, likely inside a specific function designed to use this inline assembly. \ No newline at end of file +3. **Integrate into `bn254-multiplier-sys`:** + Copy the contents of the generated `asm/montgomery_interleaved.s` file. Paste this assembly code into the appropriate location within the `bn254-multiplier-sys` crate, likely inside a specific function designed to use this inline assembly. diff --git a/skyscraper/block-multiplier-codegen/src/constants.rs b/skyscraper/bn254-multiplier-codegen/src/constants.rs similarity index 100% rename from skyscraper/block-multiplier-codegen/src/constants.rs rename to skyscraper/bn254-multiplier-codegen/src/constants.rs diff --git a/skyscraper/block-multiplier-codegen/src/lib.rs b/skyscraper/bn254-multiplier-codegen/src/lib.rs similarity index 100% rename from skyscraper/block-multiplier-codegen/src/lib.rs rename to skyscraper/bn254-multiplier-codegen/src/lib.rs diff --git a/skyscraper/block-multiplier-codegen/src/load_store.rs b/skyscraper/bn254-multiplier-codegen/src/load_store.rs similarity index 100% rename from skyscraper/block-multiplier-codegen/src/load_store.rs rename to skyscraper/bn254-multiplier-codegen/src/load_store.rs diff --git a/skyscraper/block-multiplier-codegen/src/main.rs b/skyscraper/bn254-multiplier-codegen/src/main.rs similarity index 97% rename from skyscraper/block-multiplier-codegen/src/main.rs rename to skyscraper/bn254-multiplier-codegen/src/main.rs index 7437e321..b467bbfa 100644 --- a/skyscraper/block-multiplier-codegen/src/main.rs +++ b/skyscraper/bn254-multiplier-codegen/src/main.rs @@ -1,5 +1,5 @@ use { - block_multiplier_codegen::{scalar, simd}, + bn254_multiplier_codegen::{scalar, simd}, hla::builder::{build_includable, Interleaving}, }; diff --git a/skyscraper/block-multiplier-codegen/src/scalar.rs b/skyscraper/bn254-multiplier-codegen/src/scalar.rs similarity index 100% rename from skyscraper/block-multiplier-codegen/src/scalar.rs rename to skyscraper/bn254-multiplier-codegen/src/scalar.rs diff --git a/skyscraper/block-multiplier-codegen/src/simd.rs b/skyscraper/bn254-multiplier-codegen/src/simd.rs similarity index 100% rename from skyscraper/block-multiplier-codegen/src/simd.rs rename to skyscraper/bn254-multiplier-codegen/src/simd.rs diff --git a/skyscraper/block-multiplier/Cargo.toml b/skyscraper/bn254-multiplier/Cargo.toml similarity index 91% rename from skyscraper/block-multiplier/Cargo.toml rename to skyscraper/bn254-multiplier/Cargo.toml index 3960da90..ddd49133 100644 --- a/skyscraper/block-multiplier/Cargo.toml +++ b/skyscraper/bn254-multiplier/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "block-multiplier" +name = "bn254-multiplier" version = "0.1.0" edition.workspace = true rust-version.workspace = true @@ -31,7 +31,7 @@ proptest.workspace = true [build-dependencies] # Workspace crates -block-multiplier-codegen.workspace = true +bn254-multiplier-codegen.workspace = true hla.workspace = true [lints] diff --git a/skyscraper/block-multiplier/benches/bench.rs b/skyscraper/bn254-multiplier/benches/bench.rs similarity index 89% rename from skyscraper/block-multiplier/benches/bench.rs rename to skyscraper/bn254-multiplier/benches/bench.rs index fd1268f7..7d27d256 100644 --- a/skyscraper/block-multiplier/benches/bench.rs +++ b/skyscraper/bn254-multiplier/benches/bench.rs @@ -14,7 +14,7 @@ mod mul { bencher //.counter(ItemsCount::new(1usize)) .with_inputs(|| rng().random()) - .bench_local_values(|(a, b)| block_multiplier::scalar_mul(a, b)); + .bench_local_values(|(a, b)| bn254_multiplier::scalar_mul(a, b)); } #[divan::bench] @@ -37,7 +37,7 @@ mod mul { //.counter(ItemsCount::new(2usize)) .with_inputs(|| rng().random()) .bench_local_values(|(a, b, c, d)| { - block_multiplier::rne::portable_simd::simd_mul(a, b, c, d) + bn254_multiplier::rne::portable_simd::simd_mul(a, b, c, d) }); } @@ -55,7 +55,7 @@ mod mul { unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b, c, d)| { - block_multiplier::rtz::simd_mul(mode_guard, a, b, c, d) + bn254_multiplier::rtz::simd_mul(mode_guard, a, b, c, d) }); }); } @@ -69,7 +69,7 @@ mod mul { unsafe { with_rounding_mode((), |guard, _| { bencher.bench_local_values(|(a, b, c, d, e, f)| { - block_multiplier::rtz::block_mul(guard, a, b, c, d, e, f) + bn254_multiplier::rtz::block_mul(guard, a, b, c, d, e, f) }); }); } @@ -90,7 +90,7 @@ mod mul { unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b, c, d)| { - block_multiplier::montgomery_interleaved_3(mode_guard, a, b, c, d) + bn254_multiplier::montgomery_interleaved_3(mode_guard, a, b, c, d) }); }); } @@ -113,7 +113,7 @@ mod mul { unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b, c, d, e, f)| { - block_multiplier::montgomery_interleaved_4(mode_guard, a, b, c, d, e, f) + bn254_multiplier::montgomery_interleaved_4(mode_guard, a, b, c, d, e, f) }); }); } @@ -123,14 +123,14 @@ mod mul { // #[divan::bench_group] mod sqr { - use {super::*, ark_ff::Field, block_multiplier::rne}; + use {super::*, ark_ff::Field, bn254_multiplier::rne}; #[divan::bench] fn scalar_sqr(bencher: Bencher) { bencher //.counter(ItemsCount::new(1usize)) .with_inputs(|| rng().random()) - .bench_local_values(block_multiplier::scalar_sqr); + .bench_local_values(bn254_multiplier::scalar_sqr); } #[divan::bench] @@ -169,7 +169,7 @@ mod sqr { unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b)| { - block_multiplier::montgomery_square_log_interleaved_3(mode_guard, a, b) + bn254_multiplier::montgomery_square_log_interleaved_3(mode_guard, a, b) }); }); } @@ -187,7 +187,7 @@ mod sqr { unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b, c)| { - block_multiplier::montgomery_square_log_interleaved_4(mode_guard, a, b, c) + bn254_multiplier::montgomery_square_log_interleaved_4(mode_guard, a, b, c) }); }); } @@ -204,7 +204,7 @@ mod sqr { unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b)| { - block_multiplier::montgomery_square_interleaved_3(mode_guard, a, b) + bn254_multiplier::montgomery_square_interleaved_3(mode_guard, a, b) }); }); } @@ -222,7 +222,7 @@ mod sqr { unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b, c)| { - block_multiplier::montgomery_square_interleaved_4(mode_guard, a, b, c) + bn254_multiplier::montgomery_square_interleaved_4(mode_guard, a, b, c) }); }); } @@ -234,7 +234,7 @@ mod sqr { unsafe { with_rounding_mode((), |mode_guard, _| { bencher.bench_local_values(|(a, b)| { - block_multiplier::rtz::simd_sqr(mode_guard, a, b) + bn254_multiplier::rtz::simd_sqr(mode_guard, a, b) }); }); } @@ -248,7 +248,7 @@ mod sqr { unsafe { with_rounding_mode((), |guard, _| { bencher.bench_local_values(|(a, b, c)| { - block_multiplier::rtz::block_sqr(guard, a, b, c) + bn254_multiplier::rtz::block_sqr(guard, a, b, c) }); }); } diff --git a/skyscraper/block-multiplier/build.rs b/skyscraper/bn254-multiplier/build.rs similarity index 97% rename from skyscraper/block-multiplier/build.rs rename to skyscraper/bn254-multiplier/build.rs index 7623a247..8d2137a5 100644 --- a/skyscraper/block-multiplier/build.rs +++ b/skyscraper/bn254-multiplier/build.rs @@ -1,5 +1,5 @@ use { - block_multiplier_codegen::{scalar, simd}, + bn254_multiplier_codegen::{scalar, simd}, hla::builder::{build_includable, Interleaving}, std::path::Path, }; diff --git a/skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py b/skyscraper/bn254-multiplier/src/aarch64/generate_montgomery_table.py similarity index 100% rename from skyscraper/block-multiplier/src/aarch64/generate_montgomery_table.py rename to skyscraper/bn254-multiplier/src/aarch64/generate_montgomery_table.py diff --git a/skyscraper/block-multiplier/src/aarch64/mod.rs b/skyscraper/bn254-multiplier/src/aarch64/mod.rs similarity index 100% rename from skyscraper/block-multiplier/src/aarch64/mod.rs rename to skyscraper/bn254-multiplier/src/aarch64/mod.rs diff --git a/skyscraper/block-multiplier/src/aarch64/montgomery_interleaved_3.s b/skyscraper/bn254-multiplier/src/aarch64/montgomery_interleaved_3.s similarity index 100% rename from skyscraper/block-multiplier/src/aarch64/montgomery_interleaved_3.s rename to skyscraper/bn254-multiplier/src/aarch64/montgomery_interleaved_3.s diff --git a/skyscraper/block-multiplier/src/aarch64/montgomery_interleaved_4.s b/skyscraper/bn254-multiplier/src/aarch64/montgomery_interleaved_4.s similarity index 100% rename from skyscraper/block-multiplier/src/aarch64/montgomery_interleaved_4.s rename to skyscraper/bn254-multiplier/src/aarch64/montgomery_interleaved_4.s diff --git a/skyscraper/block-multiplier/src/aarch64/montgomery_square_interleaved_3.s b/skyscraper/bn254-multiplier/src/aarch64/montgomery_square_interleaved_3.s similarity index 100% rename from skyscraper/block-multiplier/src/aarch64/montgomery_square_interleaved_3.s rename to skyscraper/bn254-multiplier/src/aarch64/montgomery_square_interleaved_3.s diff --git a/skyscraper/block-multiplier/src/aarch64/montgomery_square_interleaved_4.s b/skyscraper/bn254-multiplier/src/aarch64/montgomery_square_interleaved_4.s similarity index 100% rename from skyscraper/block-multiplier/src/aarch64/montgomery_square_interleaved_4.s rename to skyscraper/bn254-multiplier/src/aarch64/montgomery_square_interleaved_4.s diff --git a/skyscraper/block-multiplier/src/aarch64/montgomery_square_log_interleaved_3.s b/skyscraper/bn254-multiplier/src/aarch64/montgomery_square_log_interleaved_3.s similarity index 100% rename from skyscraper/block-multiplier/src/aarch64/montgomery_square_log_interleaved_3.s rename to skyscraper/bn254-multiplier/src/aarch64/montgomery_square_log_interleaved_3.s diff --git a/skyscraper/block-multiplier/src/aarch64/montgomery_square_log_interleaved_4.s b/skyscraper/bn254-multiplier/src/aarch64/montgomery_square_log_interleaved_4.s similarity index 100% rename from skyscraper/block-multiplier/src/aarch64/montgomery_square_log_interleaved_4.s rename to skyscraper/bn254-multiplier/src/aarch64/montgomery_square_log_interleaved_4.s diff --git a/skyscraper/block-multiplier/src/constants.rs b/skyscraper/bn254-multiplier/src/constants.rs similarity index 100% rename from skyscraper/block-multiplier/src/constants.rs rename to skyscraper/bn254-multiplier/src/constants.rs diff --git a/skyscraper/block-multiplier/src/lib.rs b/skyscraper/bn254-multiplier/src/lib.rs similarity index 100% rename from skyscraper/block-multiplier/src/lib.rs rename to skyscraper/bn254-multiplier/src/lib.rs diff --git a/skyscraper/block-multiplier/src/rne/constants.rs b/skyscraper/bn254-multiplier/src/rne/constants.rs similarity index 100% rename from skyscraper/block-multiplier/src/rne/constants.rs rename to skyscraper/bn254-multiplier/src/rne/constants.rs diff --git a/skyscraper/block-multiplier/src/rne/mod.rs b/skyscraper/bn254-multiplier/src/rne/mod.rs similarity index 100% rename from skyscraper/block-multiplier/src/rne/mod.rs rename to skyscraper/bn254-multiplier/src/rne/mod.rs diff --git a/skyscraper/block-multiplier/src/rne/portable_simd.rs b/skyscraper/bn254-multiplier/src/rne/portable_simd.rs similarity index 100% rename from skyscraper/block-multiplier/src/rne/portable_simd.rs rename to skyscraper/bn254-multiplier/src/rne/portable_simd.rs diff --git a/skyscraper/block-multiplier/src/rne/simd_utils.rs b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs similarity index 100% rename from skyscraper/block-multiplier/src/rne/simd_utils.rs rename to skyscraper/bn254-multiplier/src/rne/simd_utils.rs diff --git a/skyscraper/block-multiplier/src/rtz/block_simd.rs b/skyscraper/bn254-multiplier/src/rtz/block_simd.rs similarity index 100% rename from skyscraper/block-multiplier/src/rtz/block_simd.rs rename to skyscraper/bn254-multiplier/src/rtz/block_simd.rs diff --git a/skyscraper/block-multiplier/src/rtz/constants.rs b/skyscraper/bn254-multiplier/src/rtz/constants.rs similarity index 100% rename from skyscraper/block-multiplier/src/rtz/constants.rs rename to skyscraper/bn254-multiplier/src/rtz/constants.rs diff --git a/skyscraper/block-multiplier/src/rtz/mod.rs b/skyscraper/bn254-multiplier/src/rtz/mod.rs similarity index 100% rename from skyscraper/block-multiplier/src/rtz/mod.rs rename to skyscraper/bn254-multiplier/src/rtz/mod.rs diff --git a/skyscraper/block-multiplier/src/rtz/portable_simd.rs b/skyscraper/bn254-multiplier/src/rtz/portable_simd.rs similarity index 100% rename from skyscraper/block-multiplier/src/rtz/portable_simd.rs rename to skyscraper/bn254-multiplier/src/rtz/portable_simd.rs diff --git a/skyscraper/block-multiplier/src/rtz/simd_utils.rs b/skyscraper/bn254-multiplier/src/rtz/simd_utils.rs similarity index 100% rename from skyscraper/block-multiplier/src/rtz/simd_utils.rs rename to skyscraper/bn254-multiplier/src/rtz/simd_utils.rs diff --git a/skyscraper/block-multiplier/src/scalar.rs b/skyscraper/bn254-multiplier/src/scalar.rs similarity index 100% rename from skyscraper/block-multiplier/src/scalar.rs rename to skyscraper/bn254-multiplier/src/scalar.rs diff --git a/skyscraper/block-multiplier/src/test_utils.rs b/skyscraper/bn254-multiplier/src/test_utils.rs similarity index 100% rename from skyscraper/block-multiplier/src/test_utils.rs rename to skyscraper/bn254-multiplier/src/test_utils.rs diff --git a/skyscraper/block-multiplier/src/utils.rs b/skyscraper/bn254-multiplier/src/utils.rs similarity index 98% rename from skyscraper/block-multiplier/src/utils.rs rename to skyscraper/bn254-multiplier/src/utils.rs index 88a14022..ee3ac57b 100644 --- a/skyscraper/block-multiplier/src/utils.rs +++ b/skyscraper/bn254-multiplier/src/utils.rs @@ -14,7 +14,7 @@ use crate::constants::U64_2P; /// # Example /// /// ``` -/// use block_multiplier::subarray; +/// use bn254_multiplier::subarray; /// let array = [1, 2, 3, 4, 5]; /// let sub = subarray!(array, 1, 3); // Creates [2, 3, 4] /// ``` diff --git a/skyscraper/core/Cargo.toml b/skyscraper/core/Cargo.toml index aa14dee4..cbbc5f92 100644 --- a/skyscraper/core/Cargo.toml +++ b/skyscraper/core/Cargo.toml @@ -10,7 +10,7 @@ repository.workspace = true [dependencies] # Workspace crates -block-multiplier.workspace = true +bn254-multiplier.workspace = true # Cryptography and proof systems ark-bn254.workspace = true diff --git a/skyscraper/core/benches/bench.rs b/skyscraper/core/benches/bench.rs index a5537148..bf37a2de 100644 --- a/skyscraper/core/benches/bench.rs +++ b/skyscraper/core/benches/bench.rs @@ -185,7 +185,7 @@ mod parts { use skyscraper::reduce::reduce_partial; bencher .with_inputs(|| reduce_partial(array::from_fn(|_| rng().random()))) - .bench_values(block_multiplier::scalar_sqr) + .bench_values(bn254_multiplier::scalar_sqr) } } diff --git a/skyscraper/core/src/block3.rs b/skyscraper/core/src/block3.rs index 285dd521..81974244 100644 --- a/skyscraper/core/src/block3.rs +++ b/skyscraper/core/src/block3.rs @@ -21,7 +21,7 @@ fn compress(guard: &RoundingGuard, input: [[[u64; 4]; 2]; 3]) -> [[u64; 4] fn square(guard: &RoundingGuard, n: [[u64; 4]; 3]) -> [[u64; 4]; 3] { let [a, b, c] = n; let v = array::from_fn(|i| std::simd::u64x2::from_array([b[i], c[i]])); - let (a, v) = block_multiplier::montgomery_square_log_interleaved_3(guard, a, v); + let (a, v) = bn254_multiplier::montgomery_square_log_interleaved_3(guard, a, v); let b = v.map(|e| e[0]); let c = v.map(|e| e[1]); [a, b, c] diff --git a/skyscraper/core/src/block4.rs b/skyscraper/core/src/block4.rs index 5ac239b1..24a388d5 100644 --- a/skyscraper/core/src/block4.rs +++ b/skyscraper/core/src/block4.rs @@ -21,7 +21,7 @@ fn compress(guard: &RoundingGuard, input: [[[u64; 4]; 2]; 4]) -> [[u64; 4] fn square(guard: &RoundingGuard, n: [[u64; 4]; 4]) -> [[u64; 4]; 4] { let [a, b, c, d] = n; let v = array::from_fn(|i| std::simd::u64x2::from_array([c[i], d[i]])); - let (a, b, v) = block_multiplier::montgomery_square_log_interleaved_4(guard, a, b, v); + let (a, b, v) = bn254_multiplier::montgomery_square_log_interleaved_4(guard, a, b, v); let c = v.map(|e| e[0]); let d = v.map(|e| e[1]); [a, b, c, d] diff --git a/skyscraper/core/src/simple.rs b/skyscraper/core/src/simple.rs index c1e530bb..f822c6ad 100644 --- a/skyscraper/core/src/simple.rs +++ b/skyscraper/core/src/simple.rs @@ -1,4 +1,4 @@ -use {crate::generic, block_multiplier::scalar_sqr as square}; +use {crate::generic, bn254_multiplier::scalar_sqr as square}; pub fn compress_many(messages: &[u8], hashes: &mut [u8]) { generic::compress_many( diff --git a/skyscraper/core/src/v1.rs b/skyscraper/core/src/v1.rs index 7f31f1cc..512d2bd1 100644 --- a/skyscraper/core/src/v1.rs +++ b/skyscraper/core/src/v1.rs @@ -5,7 +5,7 @@ use { generic, reduce::{reduce, reduce_partial, reduce_partial_add_rc}, }, - block_multiplier::scalar_sqr as square, + bn254_multiplier::scalar_sqr as square, }; pub fn compress_many(messages: &[u8], hashes: &mut [u8]) { From 586d8971c3912c48b8cb8aa4f6d712d41683a07a Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Mon, 26 Jan 2026 12:33:28 +0800 Subject: [PATCH 34/37] b51: inline multimul, fix kani paths, make i2f generic --- .../bn254-multiplier/src/rne/portable_simd.rs | 101 ++++++++---------- .../bn254-multiplier/src/rne/simd_utils.rs | 15 +-- .../bn254-multiplier/src/rtz/portable_simd.rs | 2 - 3 files changed, 55 insertions(+), 63 deletions(-) diff --git a/skyscraper/bn254-multiplier/src/rne/portable_simd.rs b/skyscraper/bn254-multiplier/src/rne/portable_simd.rs index 0586c9b7..94aeb03b 100644 --- a/skyscraper/bn254-multiplier/src/rne/portable_simd.rs +++ b/skyscraper/bn254-multiplier/src/rne/portable_simd.rs @@ -95,12 +95,46 @@ pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { (v[0], v[1]) } +/// Move redundant carries from lower limbs to the higher limbs such that all +/// limbs except the last one is 51 bits. The most significant limb can be +/// larger than 51 bits as the input can be bigger 2^255-1. #[inline(always)] -/// i64 signifies redundant carry form -/// t initialise with right for multiplication test -/// compare with school multiplication on 51 bits. This does not require having -/// to move over carries -fn multimul(t: &mut [Simd; 10], v0_a: [Simd; 5], v0_b: [Simd; 5]) { +fn redundant_carry(t: [Simd; N]) -> [Simd; N] { + let mut borrow = Simd::splat(0); + let mut res = [Simd::splat(0); N]; + for i in 0..t.len() - 1 { + let tmp = t[i] + borrow; + res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); + borrow = tmp >> 51; + } + + res[N - 1] = (t[N - 1] + borrow).cast(); + res +} + +#[inline(always)] +/// Montgomery multiplier +pub fn simd_mul( + v0_a: [u64; 4], + v0_b: [u64; 4], + v1_a: [u64; 4], + v1_b: [u64; 4], +) -> ([u64; 4], [u64; 4]) { + let v0_a = u256_to_u255_simd(transpose_u256_to_simd([v0_a, v1_a])); + let v0_b = u256_to_u255_simd(transpose_u256_to_simd([v0_b, v1_b])); + + let mut t: [Simd<_, 2>; 10] = [Simd::splat(0); 10]; + t[0] = Simd::splat(make_initial(1, 0)); + t[9] = Simd::splat(make_initial(0, 6)); + t[1] = Simd::splat(make_initial(2, 1)); + t[8] = Simd::splat(make_initial(6, 7)); + t[2] = Simd::splat(make_initial(3, 2)); + t[7] = Simd::splat(make_initial(7, 8)); + t[3] = Simd::splat(make_initial(4, 3)); + t[6] = Simd::splat(make_initial(8, 9)); + t[4] = Simd::splat(make_initial(10, 4)); + t[5] = Simd::splat(make_initial(9, 10)); + let avi: Simd = i2f(v0_a[0]); let bvj: Simd = i2f(v0_b[0]); let p_hi = fma(avi, bvj, Simd::splat(C1)); @@ -235,46 +269,6 @@ fn multimul(t: &mut [Simd; 10], v0_a: [Simd; 5], v0_b: [Simd(t: [Simd; N]) -> [Simd; N] { - let mut borrow = Simd::splat(0); - let mut res = [Simd::splat(0); N]; - for i in 0..t.len() - 1 { - let tmp = t[i] + borrow; - res[i] = (tmp.cast()).bitand(Simd::splat(MASK51)); - borrow = tmp >> 51; - } - // Last limb should not be truncated to 51 bits. As the input value can be - // bigger than 2^255 bits. In that sense the upper limb has no redundant carry. - res[N - 1] = (t[N - 1] + borrow).cast(); - res -} - -#[inline(always)] -pub fn simd_mul( - v0_a: [u64; 4], - v0_b: [u64; 4], - v1_a: [u64; 4], - v1_b: [u64; 4], -) -> ([u64; 4], [u64; 4]) { - let v0_a = u256_to_u255_simd(transpose_u256_to_simd([v0_a, v1_a])); - let v0_b = u256_to_u255_simd(transpose_u256_to_simd([v0_b, v1_b])); - - let mut t: [Simd<_, 2>; 10] = [Simd::splat(0); 10]; - t[0] = Simd::splat(make_initial(1, 0)); - t[9] = Simd::splat(make_initial(0, 6)); - t[1] = Simd::splat(make_initial(2, 1)); - t[8] = Simd::splat(make_initial(6, 7)); - t[2] = Simd::splat(make_initial(3, 2)); - t[7] = Simd::splat(make_initial(7, 8)); - t[3] = Simd::splat(make_initial(4, 3)); - t[6] = Simd::splat(make_initial(8, 9)); - t[4] = Simd::splat(make_initial(10, 4)); - t[5] = Simd::splat(make_initial(9, 10)); - - multimul(&mut t, v0_a, v0_b); // sign extend redundant carries t[1] += t[0] >> 51; @@ -337,18 +331,21 @@ mod tests { proptest!(|( a in limbs5_51(), b in limbs5_51(), - // c in limbs5_51(), + c in limbs5_51(), )| { let a: [Simd;_] = a.map(Simd::splat); let b: [Simd;_] = b.map(Simd::splat); + let c: [Simd;_] = c.map(Simd::splat); let a = u255_to_u256_simd(a).map(|x|x[0]); let b = u255_to_u256_simd(b).map(|x|x[0]); - let (ab, _bc) = simd_mul(a, b,a,b); + let c = u255_to_u256_simd(c).map(|x|x[0]); + let (ab, bc) = simd_mul(a, b,b,c); let ab_ref = ark_ff_reference(a, b); - // let bc_ref = ark_ff_reference(b, c); + let bc_ref = ark_ff_reference(b, c); let ab = Fr::new(BigInt(ab)); - // let bc = Fr::new(BigInt(bc)); + let bc = Fr::new(BigInt(bc)); prop_assert_eq!(ab_ref, ab, "mismatch: l = {:X}, b = {:X}", ab_ref.into_bigint(), ab.into_bigint()); + prop_assert_eq!(bc_ref, bc, "mismatch: l = {:X}, b = {:X}", bc_ref.into_bigint(), bc.into_bigint()); }) } @@ -357,7 +354,6 @@ mod tests { proptest!(|( a in limbs5_51(), b in limbs5_51(), - // c in limbs5_51(), )| { let a: [Simd;_] = a.map(Simd::splat); let b: [Simd;_] = b.map(Simd::splat); @@ -370,12 +366,7 @@ mod tests { } fn limb51() -> impl Strategy { - // Either of these is fine: - // 1) Range 0u64..(1u64 << 51) - - // 2) Or mask (sometimes faster) - // any::().prop_map(|x| x & LIMB_MASK) } fn limbs5_51() -> impl Strategy { diff --git a/skyscraper/bn254-multiplier/src/rne/simd_utils.rs b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs index 44d32d20..b8a2b3c7 100644 --- a/skyscraper/bn254-multiplier/src/rne/simd_utils.rs +++ b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs @@ -15,13 +15,16 @@ use { // -- [SIMD UTILS] // --------------------------------------------------------------------------------- #[inline(always)] -/// On WASSM there is no single specialised instruction to cast an integer to a +/// On WASM there is no single specialised instruction to cast an integer to a /// float. Since we are only interested in 52 bits, we can emulate it with fewer /// instructions. /// /// Warning: due to Rust's limitations this can not be a const function. /// Therefore check your dependency path as this will not be optimised out. -pub fn i2f(a: Simd) -> Simd { +pub fn i2f(a: Simd) -> Simd +where + LaneCount: SupportedLaneCount, +{ // This function has not target gating as we want to verify this function with // kani and proptest on a different platform than wasm @@ -30,8 +33,8 @@ pub fn i2f(a: Simd) -> Simd { // to convert a to it's floating point number we subtract this again. This way // we only pay for the conversion of the lower bits and not the full 64 bits. let exponent = Simd::splat(0x433 << 52); - let a: Simd = Simd::::from_bits(a | exponent); - let b: Simd = Simd::::from_bits(exponent); + let a: Simd = Simd::::from_bits(a | exponent); + let b: Simd = Simd::::from_bits(exponent); a - b } @@ -210,10 +213,10 @@ mod tests { use std::simd::Simd; fn u255_to_u256(u: [u64; 5]) -> [u64; 4] { - crate::simd_rne_utils::u255_to_u256_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) + crate::rne::simd_utils::u255_to_u256_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) } fn u256_to_u255(u: [u64; 4]) -> [u64; 5] { - crate::simd_rne_utils::u256_to_u255_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) + crate::rne::simd_utils::u256_to_u255_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) } #[kani::proof] diff --git a/skyscraper/bn254-multiplier/src/rtz/portable_simd.rs b/skyscraper/bn254-multiplier/src/rtz/portable_simd.rs index 1907a2b0..a41c77de 100644 --- a/skyscraper/bn254-multiplier/src/rtz/portable_simd.rs +++ b/skyscraper/bn254-multiplier/src/rtz/portable_simd.rs @@ -1,5 +1,3 @@ -// Montgomery multiplier -// Requires RTZ use { crate::rtz::{ constants::*, From fee0d5ea63b5ee189ffb9d32dc587536e3f36d73 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Mon, 26 Jan 2026 12:57:16 +0800 Subject: [PATCH 35/37] b51: documentation --- .../bn254-multiplier/src/rne/constants.rs | 6 ++++ skyscraper/bn254-multiplier/src/rne/mod.rs | 24 +++++++++++++ .../bn254-multiplier/src/rne/portable_simd.rs | 27 +++++++------- .../bn254-multiplier/src/rne/simd_utils.rs | 36 ++++++++++--------- 4 files changed, 63 insertions(+), 30 deletions(-) diff --git a/skyscraper/bn254-multiplier/src/rne/constants.rs b/skyscraper/bn254-multiplier/src/rne/constants.rs index 47ade0b3..6f320cf5 100644 --- a/skyscraper/bn254-multiplier/src/rne/constants.rs +++ b/skyscraper/bn254-multiplier/src/rne/constants.rs @@ -1,7 +1,11 @@ +//! Constants for RNE Montgomery multiplication over the BN254 scalar field. + use crate::pow_2; +/// Montgomery reduction constant: `-p⁻¹ mod 2⁵¹` pub const U51_NP0: u64 = 0x1f593efffffff; +/// The BN254 scalar field prime in 51-bit limb representation. pub const U51_P: [u64; 5] = [ 0x1f593f0000001, 0x10f372e12287c, @@ -10,8 +14,10 @@ pub const U51_P: [u64; 5] = [ 0x30644e72e131a, ]; +/// Bit mask for 51-bit limbs. pub const MASK51: u64 = 2_u64.pow(51) - 1; +/// Reduction constants: `RHO_i = 2^(51*i) * 2^255 mod p` in 51-bit limbs. pub const RHO_1: [u64; 5] = [ 0x05cc89dc987a4, 0x64e24f262c77a, diff --git a/skyscraper/bn254-multiplier/src/rne/mod.rs b/skyscraper/bn254-multiplier/src/rne/mod.rs index b66b1b03..415090bd 100644 --- a/skyscraper/bn254-multiplier/src/rne/mod.rs +++ b/skyscraper/bn254-multiplier/src/rne/mod.rs @@ -1,3 +1,27 @@ +//! # RNE - Round-to-Nearest-Even Montgomery Multiplication +//! +//! This module implements Montgomery multiplication over the BN254 scalar field +//! using floating-point arithmetic with round-to-nearest-even (RNE) rounding +//! mode. +//! +//! ## Why Floating-Point? +//! +//! On WASM and ARM Cortex, integer multiplication has lower throughput +//! than floating-point FMA (fused multiply-add). By encoding +//! 51-bit limbs into the mantissa of f64 values we can perform integer +//! multiplication using FMA. +//! +//! ## Representation +//! +//! Field elements are stored in a 5-limb redundant form with 51 bits per limb +//! (5 × 51 = 255 bits), allowing representation of values up to 2²⁵⁵ - 1. +//! +//! ## References +//! +//! Variation of "Faster Modular Exponentiation using Double Precision Floating +//! Point Arithmetic on the GPU, 2018 IEEE 25th Symposium on Computer Arithmetic +//! (ARITH) by Emmart, Zheng and Weems; which uses RTZ. + pub mod constants; pub mod portable_simd; pub mod simd_utils; diff --git a/skyscraper/bn254-multiplier/src/rne/portable_simd.rs b/skyscraper/bn254-multiplier/src/rne/portable_simd.rs index 94aeb03b..4aa7fd9f 100644 --- a/skyscraper/bn254-multiplier/src/rne/portable_simd.rs +++ b/skyscraper/bn254-multiplier/src/rne/portable_simd.rs @@ -1,3 +1,8 @@ +//! Portable SIMD Montgomery multiplication and squaring. +//! +//! Processes two independent field multiplications in parallel using 2-lane +//! SIMD. + use { crate::rne::{ constants::*, @@ -14,6 +19,8 @@ use { std::simd::num::{SimdInt, SimdUint}, }; +/// Two parallel Montgomery squarings: `(v0², v1²)`. +/// input must fit in 2^255-1; no runtime checking #[inline] pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { let v0_a = u256_to_u255_simd(transpose_u256_to_simd([v0_a, v1_a])); @@ -31,8 +38,8 @@ pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { } } - // On most instruction sets SIMD shift left is more expensive than SIMD - // addition. While for scalar they tend to cost the same. + // Most shifting operations are more expensive addition thus for multiplying by + // 2 we use addition. for i in 1..=8 { t[i] += t[i]; } @@ -75,20 +82,19 @@ pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { r0[5] + r1[5] + r2[5] + r3[5] + t[9], ]; - // The upper bits of s will not affect the lower 51 bits of the product so we - // defer the and'ing. + // The upper bits of s will not affect the lower 51 bits of the product and + // therefore we only have to bitmask once. let m = (s[0].cast() * Simd::splat(U51_NP0)).bitand(Simd::splat(MASK51)); let mp = smult_noinit_simd(m, U51_P); let mut addi = addv_simd(s, mp); - // Move over carries before dropping last limb + // Apply carries before dropping the last limb addi[1] += addi[0] >> 51; let addi = [addi[1], addi[2], addi[3], addi[4], addi[5]]; // 1 bit reduction to go from R^-255 to R^-256. reduce_ct does the preparation // and the final shift is done as part of the conversion back to u256 let reduced = reduce_ct_simd(addi); - // Are the following two shifts fused? let reduced = redundant_carry(reduced); let u256_result = u255_to_u256_shr_1_simd(reduced); let v = transpose_simd_to_u256(u256_result); @@ -112,8 +118,9 @@ fn redundant_carry(t: [Simd; N]) -> [Simd; N] { res } +/// Two parallel Montgomery multiplications: `(v0_a*v0_b, v1_a*v1_b)`. +/// input must fit in 2^255-1; no runtime checking #[inline(always)] -/// Montgomery multiplier pub fn simd_mul( v0_a: [u64; 4], v0_b: [u64; 4], @@ -276,8 +283,6 @@ pub fn simd_mul( t[3] += t[2] >> 51; t[4] += t[3] >> 51; - // lower 51 bits will have the right value as the carry part is either 0 or a - // multiple of -2^51 -> which prevents carry bits to leak into the lower part. let r0 = smult_noinit_simd(t[0].cast().bitand(Simd::splat(MASK51)), RHO_4); let r1 = smult_noinit_simd(t[1].cast().bitand(Simd::splat(MASK51)), RHO_3); let r2 = smult_noinit_simd(t[2].cast().bitand(Simd::splat(MASK51)), RHO_2); @@ -292,20 +297,16 @@ pub fn simd_mul( r0[5] + r1[5] + r2[5] + r3[5] + t[9], ]; - // The upper bits of s will not affect the lower 51 bits of the product so we - // defer the and'ing. let m = (s[0].cast() * Simd::splat(U51_NP0)).bitand(Simd::splat(MASK51)); let mp = smult_noinit_simd(m, U51_P); let mut addi = addv_simd(s, mp); - // Move over carries before dropping last limb addi[1] += addi[0] >> 51; let addi = [addi[1], addi[2], addi[3], addi[4], addi[5]]; // 1 bit reduction to go from R^-255 to R^-256. reduce_ct does the preparation // and the final shift is done as part of the conversion back to u256 let reduced = reduce_ct_simd(addi); - // Are the following two shifts fused? let reduced = redundant_carry(reduced); let u256_result = u255_to_u256_shr_1_simd(reduced); let v = transpose_simd_to_u256(u256_result); diff --git a/skyscraper/bn254-multiplier/src/rne/simd_utils.rs b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs index b8a2b3c7..c66786be 100644 --- a/skyscraper/bn254-multiplier/src/rne/simd_utils.rs +++ b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs @@ -1,3 +1,5 @@ +//! SIMD utilities for RNE Montgomery multiplication. + use { crate::rne::constants::{C1, C2, C3, MASK51, U51_P}, core::{ @@ -11,9 +13,6 @@ use { }, std::simd::{LaneCount, SupportedLaneCount}, }; - -// -- [SIMD UTILS] -// --------------------------------------------------------------------------------- #[inline(always)] /// On WASM there is no single specialised instruction to cast an integer to a /// float. Since we are only interested in 52 bits, we can emulate it with fewer @@ -25,7 +24,7 @@ pub fn i2f(a: Simd) -> Simd where LaneCount: SupportedLaneCount, { - // This function has not target gating as we want to verify this function with + // This function has no target gating as we want to verify this function with // kani and proptest on a different platform than wasm // By adding 2^52 represented as float (0x1p52) -> 0x433 << 52, we align the @@ -38,6 +37,7 @@ where a - b } +/// Fused multiply-add: `a * b + c`. #[inline(always)] pub fn fma(a: Simd, b: Simd, c: Simd) -> Simd { #[cfg(not(target_arch = "wasm32"))] @@ -53,6 +53,10 @@ pub fn fma(a: Simd, b: Simd, c: Simd) -> Simd { } } +/// Computes bias compensation for accumulator limbs. +/// +/// - `low_count`: number of p_lo contributions +/// - `high_count`: number of p_hi contributions #[inline(always)] pub const fn make_initial(low_count: u64, high_count: u64) -> i64 { let val = high_count @@ -61,9 +65,9 @@ pub const fn make_initial(low_count: u64, high_count: u64) -> i64 { -(val as i64) } +/// Transpose two 4-limb values into 4 SIMD vectors. #[inline(always)] pub fn transpose_u256_to_simd(limbs: [[u64; 4]; 2]) -> [Simd; 4] { - // This does not issue multiple ldp and zip which might be marginally faster. [ Simd::from_array([limbs[0][0], limbs[1][0]]), Simd::from_array([limbs[0][1], limbs[1][1]]), @@ -72,6 +76,7 @@ pub fn transpose_u256_to_simd(limbs: [[u64; 4]; 2]) -> [Simd; 4] { ] } +/// Transpose 4 SIMD vectors back to two 4-limb values. #[inline(always)] pub fn transpose_simd_to_u256(limbs: [Simd; 4]) -> [[u64; 4]; 2] { let tmp0 = limbs[0].to_array(); @@ -83,16 +88,14 @@ pub fn transpose_simd_to_u256(limbs: [Simd; 4]) -> [[u64; 4]; 2] { ]] } +/// Convert 4×64-bit to 5×51-bit limb representation. +/// Input must fit in 255 bits; no runtime checking. #[inline(always)] -/// Safety: If the input is too large for the conversion the top bit will be -/// discarded. In debug mode it will throw an error. pub fn u256_to_u255_simd(limbs: [Simd; 4]) -> [Simd; 5] where LaneCount: SupportedLaneCount, { let [l0, l1, l2, l3] = limbs; - // Check whether the remainder of l3 fits in 51 bits -> does the input fit in - // 255 bits. [ (l0) & Simd::splat(MASK51), ((l0 >> 51) | (l1 << 13)) & Simd::splat(MASK51), @@ -102,6 +105,7 @@ where ] } +/// Convert 5×51-bit back to 4×64-bit limb representation. #[inline(always)] pub fn u255_to_u256_simd(limbs: [Simd; 5]) -> [Simd; 4] where @@ -116,6 +120,7 @@ where ] } +/// Convert 5×51-bit to 4×64-bit with simultaneous division by 2. #[inline(always)] pub fn u255_to_u256_shr_1_simd(limbs: [Simd; 5]) -> [Simd; 4] where @@ -130,9 +135,9 @@ where ] } +/// Multiply SIMD scalar by 5-limb constant using FMA splitting. +/// Returns 6-limb result in redundant signed form. #[inline(always)] -// TODO check whether as f64 get's properly optimised away -// won't be able to tell using just assembly view pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { let mut t = [Simd::splat(0); 6]; let s: Simd = i2f(s); @@ -165,13 +170,9 @@ pub fn smult_noinit_simd(s: Simd, v: [u64; 5]) -> [Simd; 6] { t } +/// Constant-time conditional add of p to prepare for final bit reduction by +/// making the result even. #[inline(always)] -/// Resolve the carry bits in the upper parts 13b and prepare result for final -/// shift by adding p if the result is odd. -/// The final division will be taken care off by the bit packing -/// technically converts from a i64 representation to a u64 representation -/// drops off the lowest limb which got zerood out, but it still contains -/// carries as it is in redundant form pub fn reduce_ct_simd(a: [Simd; 5]) -> [Simd; 5] { let mut c = [Simd::splat(0); 5]; let tmp = a[0]; @@ -196,6 +197,7 @@ pub fn reduce_ct_simd(a: [Simd; 5]) -> [Simd; 5] { c } +/// Element-wise vector addition in redundant form. #[inline(always)] pub fn addv_simd( va: [Simd; N], From 70c18ff85f5b57453ef6a67c698e7b1cfb86930f Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Mon, 26 Jan 2026 16:23:41 +0800 Subject: [PATCH 36/37] b51: i2f kani --- .../bn254-multiplier/src/rne/portable_simd.rs | 5 +++- .../bn254-multiplier/src/rne/simd_utils.rs | 27 ++++++++++++------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/skyscraper/bn254-multiplier/src/rne/portable_simd.rs b/skyscraper/bn254-multiplier/src/rne/portable_simd.rs index 4aa7fd9f..dcaeaa52 100644 --- a/skyscraper/bn254-multiplier/src/rne/portable_simd.rs +++ b/skyscraper/bn254-multiplier/src/rne/portable_simd.rs @@ -105,7 +105,10 @@ pub fn simd_sqr(v0_a: [u64; 4], v1_a: [u64; 4]) -> ([u64; 4], [u64; 4]) { /// limbs except the last one is 51 bits. The most significant limb can be /// larger than 51 bits as the input can be bigger 2^255-1. #[inline(always)] -fn redundant_carry(t: [Simd; N]) -> [Simd; N] { +fn redundant_carry(t: [Simd; N]) -> [Simd; N] +where + std::simd::LaneCount: std::simd::SupportedLaneCount, +{ let mut borrow = Simd::splat(0); let mut res = [Simd::splat(0); N]; for i in 0..t.len() - 1 { diff --git a/skyscraper/bn254-multiplier/src/rne/simd_utils.rs b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs index c66786be..e637cd55 100644 --- a/skyscraper/bn254-multiplier/src/rne/simd_utils.rs +++ b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs @@ -212,14 +212,10 @@ pub fn addv_simd( #[cfg(kani)] mod tests { - use std::simd::Simd; - - fn u255_to_u256(u: [u64; 5]) -> [u64; 4] { - crate::rne::simd_utils::u255_to_u256_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) - } - fn u256_to_u255(u: [u64; 4]) -> [u64; 5] { - crate::rne::simd_utils::u256_to_u255_simd::<1>(u.map(Simd::splat)).map(|v| v[0]) - } + use { + crate::rne::simd_utils::{i2f, u255_to_u256_simd, u256_to_u255_simd}, + std::simd::Simd, + }; #[kani::proof] fn u256_to_u255_kani_roundtrip() { @@ -229,6 +225,19 @@ mod tests { kani::any(), kani::any::() & 0x7fffffffffffffff, ]; - assert_eq!(u, u255_to_u256(u256_to_u255(u))) + let u255 = u256_to_u255_simd::<1>(u.map(Simd::splat)); + let roundtrip = u255_to_u256_simd::<1>(u255).map(|v| v[0]); + assert_eq!(u, roundtrip) + } + + /// Verify that i2f correctly converts integers in the valid range [0, 2^52). + #[kani::proof] + fn i2f_kani_correctness() { + let val: u64 = kani::any(); + kani::assume(val < (1u64 << 52)); + + let result = i2f(Simd::from_array([val])); + + assert_eq!(result[0], val as f64); } } From 62f391d2dcb65eab4dfd5894e4beadd05ec38384 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Mon, 26 Jan 2026 16:41:39 +0800 Subject: [PATCH 37/37] fixup! b51: i2f kani --- skyscraper/bn254-multiplier/src/rne/simd_utils.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skyscraper/bn254-multiplier/src/rne/simd_utils.rs b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs index e637cd55..b0054b08 100644 --- a/skyscraper/bn254-multiplier/src/rne/simd_utils.rs +++ b/skyscraper/bn254-multiplier/src/rne/simd_utils.rs @@ -230,7 +230,8 @@ mod tests { assert_eq!(u, roundtrip) } - /// Verify that i2f correctly converts integers in the valid range [0, 2^52). + /// Verify that i2f correctly converts integers in the valid range [0, + /// 2^52). #[kani::proof] fn i2f_kani_correctness() { let val: u64 = kani::any();