From a420f547cdc2f0dbf4fdeaf92dbb262fd5dc727e Mon Sep 17 00:00:00 2001 From: Manh Dinh Date: Fri, 5 Dec 2025 18:09:11 -0500 Subject: [PATCH 1/7] draft macros --- Cargo.lock | 10 ++ Cargo.toml | 11 +- extensions/ecc/curve-macros/Cargo.toml | 16 ++ extensions/ecc/curve-macros/src/lib.rs | 235 +++++++++++++++++++++++++ extensions/ecc/guest/Cargo.toml | 1 + extensions/ecc/guest/src/lib.rs | 3 + extensions/ecc/transpiler/src/lib.rs | 21 +++ 7 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 extensions/ecc/curve-macros/Cargo.toml create mode 100644 extensions/ecc/curve-macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 79efdb11af..9953d1e35a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6499,6 +6499,15 @@ dependencies = [ "strum 0.26.3", ] +[[package]] +name = "openvm-ecc-curve-macros" +version = "1.4.1" +dependencies = [ + "openvm-macros-common", + "quote", + "syn 2.0.106", +] + [[package]] name = "openvm-ecc-guest" version = "1.4.2-rc.1" @@ -6511,6 +6520,7 @@ dependencies = [ "openvm", "openvm-algebra-guest", "openvm-custom-insn", + "openvm-ecc-curve-macros", "openvm-ecc-sw-macros", "openvm-rv32im-guest", "serde", diff --git a/Cargo.toml b/Cargo.toml index aad568e77c..5ce606b898 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,7 @@ members = [ "extensions/ecc/transpiler", "extensions/ecc/guest", "extensions/ecc/sw-macros", + "extensions/ecc/curve-macros", "extensions/ecc/tests", "extensions/pairing/circuit", "extensions/pairing/guest", @@ -168,6 +169,7 @@ openvm-ecc-circuit = { path = "extensions/ecc/circuit", default-features = false openvm-ecc-transpiler = { path = "extensions/ecc/transpiler", default-features = false } openvm-ecc-guest = { path = "extensions/ecc/guest", default-features = false } openvm-ecc-sw-macros = { path = "extensions/ecc/sw-macros", default-features = false } +openvm-ecc-curve-macros = { path = "extensions/ecc/curve-macros", default-features = false } openvm-pairing-circuit = { path = "extensions/pairing/circuit", default-features = false } openvm-pairing-transpiler = { path = "extensions/pairing/transpiler", default-features = false } openvm-pairing-guest = { path = "extensions/pairing/guest", default-features = false } @@ -189,7 +191,10 @@ p3-poseidon2-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539b p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } zkhash = { git = "https://github.com/HorizenLabs/poseidon2.git", rev = "bb476b9" } -snark-verifier-sdk = { version = "0.2.0", default-features = false, features = ["loader_halo2", "halo2-axiom"] } +snark-verifier-sdk = { version = "0.2.0", default-features = false, features = [ + "loader_halo2", + "halo2-axiom", +] } snark-verifier = { version = "0.2.0", default-features = false } halo2curves-axiom = { git = "https://github.com/axiom-crypto/halo2curves.git", tag = "v0.7.2" } @@ -204,7 +209,9 @@ clap = "4.5.23" toml = "0.8.14" lazy_static = "1.5.0" derive-new = "0.6.0" -derive_more = { version = "1.0.0", features = ["display"], default-features = false } +derive_more = { version = "1.0.0", features = [ + "display", +], default-features = false } derivative = "2.2.0" strum_macros = "0.26.4" strum = { version = "0.26.3", features = ["derive"] } diff --git a/extensions/ecc/curve-macros/Cargo.toml b/extensions/ecc/curve-macros/Cargo.toml new file mode 100644 index 0000000000..898a77e00c --- /dev/null +++ b/extensions/ecc/curve-macros/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "openvm-ecc-curve-macros" +description = "OpenVM elliptic curve macros for intrinsic curves" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +syn = { version = "2.0", features = ["full"] } +quote = "1.0" +openvm-macros-common = { workspace = true, default-features = false } + +[lib] +proc-macro = true diff --git a/extensions/ecc/curve-macros/src/lib.rs b/extensions/ecc/curve-macros/src/lib.rs new file mode 100644 index 0000000000..313494e135 --- /dev/null +++ b/extensions/ecc/curve-macros/src/lib.rs @@ -0,0 +1,235 @@ +extern crate proc_macro; + +use openvm_macros_common::MacroArgs; +use proc_macro::TokenStream; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, ExprPath, LitStr, Token, +}; + +/// This macro generates the code to setup the intrinsic curve for a given point type and scalar +/// type. Also it places the curve parameters into a special static variable to be later extracted +/// from the ELF and used by the VM. Usage: +/// ``` +/// curve_declare! { +/// [TODO] +/// } +/// ``` +/// +/// For this macro to work, you must import the `openvm_ecc_guest` crate. +#[proc_macro] +pub fn curve_declare(input: TokenStream) -> TokenStream { + let MacroArgs { items } = parse_macro_input!(input as MacroArgs); + + let mut output = Vec::new(); + + let span = proc_macro::Span::call_site(); + + for item in items.into_iter() { + let struct_name_str = item.name.to_string(); + let struct_name = syn::Ident::new(&struct_name_str, span.into()); + let mut point_type: Option = None; + let mut scalar_type: Option = None; + for param in item.params { + match param.name.to_string().as_str() { + // Note that point_type and scalar_type must be valid types + "point_type" => { + if let syn::Expr::Path(ExprPath { path, .. }) = param.value { + point_type = Some(path) + } else { + return syn::Error::new_spanned(param.value, "Expected a type") + .to_compile_error() + .into(); + } + } + "scalar_type" => { + if let syn::Expr::Path(ExprPath { path, .. }) = param.value { + scalar_type = Some(path) + } else { + return syn::Error::new_spanned(param.value, "Expected a type") + .to_compile_error() + .into(); + } + } + _ => { + panic!("Unknown parameter {}", param.name); + } + } + } + + let point_type = point_type.expect("point_type parameter is required"); + let scalar_type = scalar_type.expect("scalar_type parameter is required"); + + macro_rules! create_extern_func { + ($name:ident) => { + let $name = syn::Ident::new( + &format!("{}_{}", stringify!($name), struct_name_str), + span.into(), + ); + }; + } + create_extern_func!(curve_ec_mul_extern_func); + create_extern_func!(curve_setup_extern_func); + + let result = TokenStream::from(quote::quote_spanned! { span.into() => + extern "C" { + fn #curve_ec_mul_extern_func(rd: usize, rs1: usize, rs2: usize); + fn #curve_setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8); + } + + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] + #[repr(C)] + pub struct #struct_name; + #[allow(non_upper_case_globals)] + + impl ::openvm_ecc_guest::weierstrass::IntrinsicCurve for #struct_name { + type Scalar = #scalar_type; + type Point = #point_type; + + #[inline(always)] + fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point + where + for<'a> &'a Self::Point: core::ops::Add<&'a Self::Point, Output = Self::Point>, + { + #[cfg(not(target_os = "zkvm"))] + { + // heuristic + if coeffs.len() < 25 { + let table = ::openvm_ecc_guest::weierstrass::CachedMulTable::::new_with_prime_order(bases, 4); + table.windowed_mul(coeffs) + } else { + ::openvm_ecc_guest::msm(coeffs, bases) + } + } + #[cfg(target_os = "zkvm")] + { + if CHECK_SETUP { + Self::set_up_once(); + } + let mut acc = ::IDENTITY; + for (coeff, base) in coeffs.iter().zip(bases.iter()) { + unsafe { + let mut uninit: core::mem::MaybeUninit = + core::mem::MaybeUninit::uninit(); + #curve_ec_mul_extern_func( + uninit.as_mut_ptr() as usize, + coeff as *const Self::Scalar as usize, + base as *const Self::Point as usize, + ); + acc.add_assign(&uninit.assume_init()); + } + } + acc + } + } + + // Helper function to call the setup instruction on first use + #[inline(always)] + #[cfg(target_os = "zkvm")] + fn set_up_once() { + static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new(); + + is_setup.get_or_init(|| { + let scalar_modulus_bytes = ::MODULUS; + let point_modulus_bytes = <::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; + let p1 = scalar_modulus_bytes.as_ref(); + let p2 = [point_modulus_bytes.as_ref(), point_modulus_bytes.as_ref()].concat(); + let mut uninit: core::mem::MaybeUninit<[Self::Scalar, Self::Point]> = core::mem::MaybeUninit::uninit(); + + unsafe { #curve_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); } + ::set_up_once(); + ::set_up_once(); + true + }); + } + + #[inline(always)] + #[cfg(not(target_os = "zkvm"))] + fn set_up_once() { + // No-op for non-ZKVM targets + } + } + }); + output.push(result); + } + + TokenStream::from_iter(output) +} + +struct CurveDefine { + items: Vec, +} + +impl Parse for CurveDefine { + fn parse(input: ParseStream) -> syn::Result { + let items = input.parse_terminated(::parse, Token![,])?; + Ok(Self { + items: items.into_iter().map(|e| e.value()).collect(), + }) + } +} + +#[proc_macro] +pub fn curve_init(input: TokenStream) -> TokenStream { + let CurveDefine { items } = parse_macro_input!(input as CurveDefine); + + let mut externs = Vec::new(); + + let span = proc_macro::Span::call_site(); + + for (curve_idx, struct_id) in items.into_iter().enumerate() { + // Unique identifier shared by curve_declare! and curve_init! used for naming the extern + // funcs. Currently it's just the struct type name. + let ec_mul_extern_func = syn::Ident::new( + &format!("curve_ec_mul_extern_func_{}", struct_id), + span.into(), + ); + let setup_extern_func = syn::Ident::new( + &format!("curve_setup_extern_func_{}", struct_id), + span.into(), + ); + + externs.push(quote::quote_spanned! { span.into() => + #[no_mangle] + extern "C" fn #ec_mul_extern_func(rd: usize, rs1: usize, rs2: usize) { + openvm::platform::custom_insn_r!( + opcode = openvm_ecc_guest::OPCODE, + funct3 = openvm_ecc_guest::SW_FUNCT3 as usize, + funct7 = openvm_ecc_guest::SwBaseFunct7::SwEcMul as usize + + #curve_idx + * (openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize), + rd = In rd, + rs1 = In rs1, + rs2 = In rs2 + ); + } + + #[no_mangle] + extern "C" fn #setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8) { + #[cfg(target_os = "zkvm")] + { + openvm::platform::custom_insn_r!( + opcode = ::openvm_ecc_guest::OPCODE, + funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize, + funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetupMul as usize + + #curve_idx + * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize), + rd = In uninit, + rs1 = In p1, + rs2 = In p2 + ); + } + } + }); + } + + TokenStream::from(quote::quote_spanned! { span.into() => + #[allow(non_snake_case)] + #[cfg(target_os = "zkvm")] + mod openvm_intrinsics_ffi_2 { + use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7}; + + #(#externs)* + } + }) +} diff --git a/extensions/ecc/guest/Cargo.toml b/extensions/ecc/guest/Cargo.toml index e5251eb366..c4ca2ee9f8 100644 --- a/extensions/ecc/guest/Cargo.toml +++ b/extensions/ecc/guest/Cargo.toml @@ -17,6 +17,7 @@ openvm-custom-insn = { workspace = true } openvm-rv32im-guest = { workspace = true } openvm-algebra-guest = { workspace = true } openvm-ecc-sw-macros = { workspace = true } +openvm-ecc-curve-macros = { workspace = true } once_cell = { workspace = true, features = ["race", "alloc"] } # Used for `halo2curves` feature diff --git a/extensions/ecc/guest/src/lib.rs b/extensions/ecc/guest/src/lib.rs index c7a9851cfd..8173659355 100644 --- a/extensions/ecc/guest/src/lib.rs +++ b/extensions/ecc/guest/src/lib.rs @@ -5,6 +5,7 @@ extern crate alloc; pub use once_cell; pub use openvm_algebra_guest as algebra; +pub use openvm_ecc_curve_macros as curve_macros; pub use openvm_ecc_sw_macros as sw_macros; use strum_macros::FromRepr; @@ -32,6 +33,8 @@ pub enum SwBaseFunct7 { SwAddNe = 0, SwDouble, SwSetup, + SwEcMul, + SwSetupMul, } impl SwBaseFunct7 { diff --git a/extensions/ecc/transpiler/src/lib.rs b/extensions/ecc/transpiler/src/lib.rs index 462e95dbdd..577162b93f 100644 --- a/extensions/ecc/transpiler/src/lib.rs +++ b/extensions/ecc/transpiler/src/lib.rs @@ -19,6 +19,8 @@ pub enum Rv32WeierstrassOpcode { SETUP_EC_ADD_NE, EC_DOUBLE, SETUP_EC_DOUBLE, + EC_MUL, + SETUP_EC_MUL, } #[derive(Default)] @@ -65,6 +67,22 @@ impl TranspilerExtension for EccTranspilerExtension { F::ZERO, F::ZERO, )) + } else if base_funct7 == SwBaseFunct7::SwEcMul as u8 { + Some(Instruction::new( + VmOpcode::from_usize( + Rv32WeierstrassOpcode::SETUP_EC_MUL + .global_opcode() + .as_usize() + + curve_idx_shift, + ), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2), + F::ONE, // d_as = 1 + F::TWO, // e_as = 2 + F::ZERO, + F::ZERO, + )) } else { let global_opcode = match SwBaseFunct7::from_repr(base_funct7) { Some(SwBaseFunct7::SwAddNe) => { @@ -76,6 +94,9 @@ impl TranspilerExtension for EccTranspilerExtension { Rv32WeierstrassOpcode::EC_DOUBLE as usize + Rv32WeierstrassOpcode::CLASS_OFFSET } + Some(SwBaseFunct7::SwEcMul) => { + Rv32WeierstrassOpcode::EC_MUL as usize + Rv32WeierstrassOpcode::CLASS_OFFSET + } _ => unimplemented!(), }; let global_opcode = global_opcode + curve_idx_shift; From ef1ca7c3a7bb4c8adac520d380cc76e69715fd15 Mon Sep 17 00:00:00 2001 From: Manh Dinh Date: Fri, 5 Dec 2025 18:46:13 -0500 Subject: [PATCH 2/7] impl guest --- Cargo.lock | 3 +++ extensions/ecc/curve-macros/src/lib.rs | 2 +- extensions/ecc/guest/src/ecdsa.rs | 4 ++-- extensions/ecc/guest/src/weierstrass.rs | 5 +++- guest-libs/k256/Cargo.toml | 1 + guest-libs/k256/src/internal.rs | 27 ++------------------- guest-libs/k256/src/lib.rs | 6 ++--- guest-libs/k256/src/point.rs | 12 +++++----- guest-libs/p256/Cargo.toml | 1 + guest-libs/p256/src/internal.rs | 26 ++------------------- guest-libs/p256/src/lib.rs | 6 ++--- guest-libs/p256/src/point.rs | 12 +++++----- guest-libs/pairing/Cargo.toml | 1 + guest-libs/pairing/src/bls12_381/mod.rs | 13 +++-------- guest-libs/pairing/src/bn254/mod.rs | 31 ++++--------------------- 15 files changed, 43 insertions(+), 107 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9953d1e35a..389eb39826 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5175,6 +5175,7 @@ dependencies = [ "openvm-algebra-transpiler", "openvm-circuit", "openvm-ecc-circuit", + "openvm-ecc-curve-macros", "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", @@ -6845,6 +6846,7 @@ dependencies = [ "openvm-circuit", "openvm-custom-insn", "openvm-ecc-circuit", + "openvm-ecc-curve-macros", "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", @@ -7372,6 +7374,7 @@ dependencies = [ "openvm-algebra-transpiler", "openvm-circuit", "openvm-ecc-circuit", + "openvm-ecc-curve-macros", "openvm-ecc-guest", "openvm-ecc-sw-macros", "openvm-ecc-transpiler", diff --git a/extensions/ecc/curve-macros/src/lib.rs b/extensions/ecc/curve-macros/src/lib.rs index 313494e135..daf9a7161f 100644 --- a/extensions/ecc/curve-macros/src/lib.rs +++ b/extensions/ecc/curve-macros/src/lib.rs @@ -134,7 +134,7 @@ pub fn curve_declare(input: TokenStream) -> TokenStream { let point_modulus_bytes = <::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; let p1 = scalar_modulus_bytes.as_ref(); let p2 = [point_modulus_bytes.as_ref(), point_modulus_bytes.as_ref()].concat(); - let mut uninit: core::mem::MaybeUninit<[Self::Scalar, Self::Point]> = core::mem::MaybeUninit::uninit(); + let mut uninit: core::mem::MaybeUninit<(Self::Scalar, Self::Point)> = core::mem::MaybeUninit::uninit(); unsafe { #curve_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); } ::set_up_once(); diff --git a/extensions/ecc/guest/src/ecdsa.rs b/extensions/ecc/guest/src/ecdsa.rs index 07fc6d44fc..c05250fb9a 100644 --- a/extensions/ecc/guest/src/ecdsa.rs +++ b/extensions/ecc/guest/src/ecdsa.rs @@ -487,7 +487,7 @@ where let neg_u1 = z.div_unsafe(&r); let u2 = s.div_unsafe(&r); let NEG_G = C::Point::NEG_GENERATOR; - let point = ::msm(&[neg_u1, u2], &[NEG_G, R]); + let point = ::msm::(&[neg_u1, u2], &[NEG_G, R]); let vk = VerifyingKey::from_affine(point)?; Ok(vk) @@ -533,7 +533,7 @@ where let G = C::Point::GENERATOR; // public key let Q = pubkey; - let R = ::msm(&[u1, u2], &[G, Q]); + let R = ::msm::(&[u1, u2], &[G, Q]); // For Coordinate: IntMod, the internal implementation of is_identity will assert x, y // coordinates of R are both reduced. if R.is_identity() { diff --git a/extensions/ecc/guest/src/weierstrass.rs b/extensions/ecc/guest/src/weierstrass.rs index 82d5468b04..0a9748558a 100644 --- a/extensions/ecc/guest/src/weierstrass.rs +++ b/extensions/ecc/guest/src/weierstrass.rs @@ -136,7 +136,10 @@ pub trait IntrinsicCurve { /// Multi-scalar multiplication. /// The implementation may be specialized to use properties of the curve /// (e.g., if the curve order is prime). - fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point; + fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point; + + /// Setup the curve. + fn set_up_once(); } // MSM using preprocessed table (windowed method) diff --git a/guest-libs/k256/Cargo.toml b/guest-libs/k256/Cargo.toml index 11b2abccfa..0a5c6eb23c 100644 --- a/guest-libs/k256/Cargo.toml +++ b/guest-libs/k256/Cargo.toml @@ -17,6 +17,7 @@ openvm-algebra-guest = { workspace = true } openvm-algebra-moduli-macros = { workspace = true } openvm-ecc-guest = { workspace = true } openvm-ecc-sw-macros = { workspace = true } +openvm-ecc-curve-macros = { workspace = true } once_cell = { workspace = true, optional = true } elliptic-curve = { workspace = true } diff --git a/guest-libs/k256/src/internal.rs b/guest-libs/k256/src/internal.rs index b8f8857dc9..ed40273dbe 100644 --- a/guest-libs/k256/src/internal.rs +++ b/guest-libs/k256/src/internal.rs @@ -1,16 +1,11 @@ -use core::ops::{Add, Neg}; +use core::ops::Neg; use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; -use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, Group, -}; +use openvm_ecc_guest::{weierstrass::WeierstrassPoint, CyclicGroup, Group}; use openvm_ecc_sw_macros::sw_declare; -use crate::Secp256k1; - // --- Define the OpenVM modular arithmetic and ecc types --- const CURVE_B: Secp256k1Coord = Secp256k1Coord::from_const_bytes(seven_le()); @@ -52,24 +47,6 @@ impl CyclicGroup for Secp256k1Point { }; } -impl IntrinsicCurve for Secp256k1 { - type Scalar = Secp256k1Scalar; - type Point = Secp256k1Point; - - fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point - where - for<'a> &'a Self::Point: Add<&'a Self::Point, Output = Self::Point>, - { - // heuristic - if coeffs.len() < 25 { - let table = CachedMulTable::::new_with_prime_order(bases, 4); - table.windowed_mul(coeffs) - } else { - openvm_ecc_guest::msm(coeffs, bases) - } - } -} - // --- Implement helpful methods mimicking the structs in k256 --- impl Secp256k1Point { diff --git a/guest-libs/k256/src/lib.rs b/guest-libs/k256/src/lib.rs index 992fd802cb..85abc19f46 100644 --- a/guest-libs/k256/src/lib.rs +++ b/guest-libs/k256/src/lib.rs @@ -21,9 +21,9 @@ pub use internal::{ Secp256k1Point as ProjectivePoint, Secp256k1Scalar as Scalar, Secp256k1Scalar, }; -// -- Define the ZST for implementing the elliptic curve traits -- -#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] -pub struct Secp256k1; +openvm_ecc_curve_macros::curve_declare! { + Secp256k1 { point_type = Secp256k1Point, scalar_type = Secp256k1Scalar }, +} // --- Implement the Curve trait on Secp256k1 --- diff --git a/guest-libs/k256/src/point.rs b/guest-libs/k256/src/point.rs index 7ab76c1126..69f7d8e2fc 100644 --- a/guest-libs/k256/src/point.rs +++ b/guest-libs/k256/src/point.rs @@ -103,7 +103,7 @@ impl Mul for Secp256k1Point { type Output = Secp256k1Point; fn mul(self, other: Secp256k1Scalar) -> Secp256k1Point { - Secp256k1::msm(&[other], &[self]) + Secp256k1::msm::(&[other], &[self]) } } @@ -111,7 +111,7 @@ impl Mul<&Secp256k1Scalar> for &Secp256k1Point { type Output = Secp256k1Point; fn mul(self, other: &Secp256k1Scalar) -> Secp256k1Point { - Secp256k1::msm(&[*other], &[*self]) + Secp256k1::msm::(&[*other], &[*self]) } } @@ -119,19 +119,19 @@ impl Mul<&Secp256k1Scalar> for Secp256k1Point { type Output = Secp256k1Point; fn mul(self, other: &Secp256k1Scalar) -> Secp256k1Point { - Secp256k1::msm(&[*other], &[self]) + Secp256k1::msm::(&[*other], &[self]) } } impl MulAssign for Secp256k1Point { fn mul_assign(&mut self, rhs: Secp256k1Scalar) { - *self = Secp256k1::msm(&[rhs], &[*self]); + *self = Secp256k1::msm::(&[rhs], &[*self]); } } impl MulAssign<&Secp256k1Scalar> for Secp256k1Point { fn mul_assign(&mut self, rhs: &Secp256k1Scalar) { - *self = Secp256k1::msm(&[*rhs], &[*self]); + *self = Secp256k1::msm::(&[*rhs], &[*self]); } } @@ -170,7 +170,7 @@ impl elliptic_curve::group::Curve for Secp256k1Point { impl LinearCombination for Secp256k1Point { fn lincomb(x: &Self, k: &Self::Scalar, y: &Self, l: &Self::Scalar) -> Self { - Secp256k1::msm(&[*k, *l], &[*x, *y]) + Secp256k1::msm::(&[*k, *l], &[*x, *y]) } } diff --git a/guest-libs/p256/Cargo.toml b/guest-libs/p256/Cargo.toml index 5bcb547846..12916e0a05 100644 --- a/guest-libs/p256/Cargo.toml +++ b/guest-libs/p256/Cargo.toml @@ -17,6 +17,7 @@ openvm-algebra-guest = { workspace = true } openvm-algebra-moduli-macros = { workspace = true } openvm-ecc-guest = { workspace = true } openvm-ecc-sw-macros = { workspace = true } +openvm-ecc-curve-macros = { workspace = true } elliptic-curve = { workspace = true, features = ["hazmat", "sec1"] } ecdsa-core = { version = "0.16.9", package = "ecdsa", optional = true, default-features = false, features = [ diff --git a/guest-libs/p256/src/internal.rs b/guest-libs/p256/src/internal.rs index b98c401c8c..75d19d55d2 100644 --- a/guest-libs/p256/src/internal.rs +++ b/guest-libs/p256/src/internal.rs @@ -1,16 +1,11 @@ -use core::ops::{Add, Neg}; +use core::ops::Neg; use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; -use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, Group, -}; +use openvm_ecc_guest::{weierstrass::WeierstrassPoint, CyclicGroup, Group}; use openvm_ecc_sw_macros::sw_declare; -use crate::NistP256; - // --- Define the OpenVM modular arithmetic and ecc types --- moduli_declare! { @@ -53,23 +48,6 @@ impl CyclicGroup for P256Point { }; } -impl IntrinsicCurve for NistP256 { - type Scalar = P256Scalar; - type Point = P256Point; - - fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point - where - for<'a> &'a Self::Point: Add<&'a Self::Point, Output = Self::Point>, - { - if coeffs.len() < 25 { - let table = CachedMulTable::::new_with_prime_order(bases, 4); - table.windowed_mul(coeffs) - } else { - openvm_ecc_guest::msm(coeffs, bases) - } - } -} - // --- Implement helpful methods mimicking the structs in p256 --- impl P256Point { diff --git a/guest-libs/p256/src/lib.rs b/guest-libs/p256/src/lib.rs index a3492a12f7..7f09f1d3a6 100644 --- a/guest-libs/p256/src/lib.rs +++ b/guest-libs/p256/src/lib.rs @@ -19,9 +19,9 @@ pub mod ecdsa; // Needs to be public so that the `sw_init` macro can access it pub use internal::{P256Coord, P256Point, P256Scalar}; -// -- Define the ZST for implementing the elliptic curve traits -- -#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] -pub struct NistP256; +openvm_ecc_curve_macros::curve_declare! { + NistP256 { point_type = P256Point, scalar_type = P256Scalar }, +} // --- Implement the Curve trait on P256 --- diff --git a/guest-libs/p256/src/point.rs b/guest-libs/p256/src/point.rs index bddd0ba924..e10955e389 100644 --- a/guest-libs/p256/src/point.rs +++ b/guest-libs/p256/src/point.rs @@ -99,7 +99,7 @@ impl Mul for P256Point { type Output = P256Point; fn mul(self, other: P256Scalar) -> P256Point { - NistP256::msm(&[other], &[self]) + NistP256::msm::(&[other], &[self]) } } @@ -107,7 +107,7 @@ impl Mul<&P256Scalar> for &P256Point { type Output = P256Point; fn mul(self, other: &P256Scalar) -> P256Point { - NistP256::msm(&[*other], &[*self]) + NistP256::msm::(&[*other], &[*self]) } } @@ -115,19 +115,19 @@ impl Mul<&P256Scalar> for P256Point { type Output = P256Point; fn mul(self, other: &P256Scalar) -> P256Point { - NistP256::msm(&[*other], &[self]) + NistP256::msm::(&[*other], &[self]) } } impl MulAssign for P256Point { fn mul_assign(&mut self, rhs: P256Scalar) { - *self = NistP256::msm(&[rhs], &[*self]); + *self = NistP256::msm::(&[rhs], &[*self]); } } impl MulAssign<&P256Scalar> for P256Point { fn mul_assign(&mut self, rhs: &P256Scalar) { - *self = NistP256::msm(&[*rhs], &[*self]); + *self = NistP256::msm::(&[*rhs], &[*self]); } } @@ -166,7 +166,7 @@ impl elliptic_curve::group::Curve for P256Point { impl LinearCombination for P256Point { fn lincomb(x: &Self, k: &Self::Scalar, y: &Self, l: &Self::Scalar) -> Self { - NistP256::msm(&[*k, *l], &[*x, *y]) + NistP256::msm::(&[*k, *l], &[*x, *y]) } } diff --git a/guest-libs/pairing/Cargo.toml b/guest-libs/pairing/Cargo.toml index 774c61c968..e5e89cd079 100644 --- a/guest-libs/pairing/Cargo.toml +++ b/guest-libs/pairing/Cargo.toml @@ -18,6 +18,7 @@ openvm-algebra-guest = { workspace = true } openvm-algebra-moduli-macros = { workspace = true } openvm-ecc-guest = { workspace = true } openvm-ecc-sw-macros = { workspace = true } +openvm-ecc-curve-macros = { workspace = true } openvm-algebra-complex-macros = { workspace = true } openvm-custom-insn = { workspace = true } openvm-rv32im-guest = { workspace = true } diff --git a/guest-libs/pairing/src/bls12_381/mod.rs b/guest-libs/pairing/src/bls12_381/mod.rs index 0a7c150e1c..f6b9f4ebab 100644 --- a/guest-libs/pairing/src/bls12_381/mod.rs +++ b/guest-libs/pairing/src/bls12_381/mod.rs @@ -4,7 +4,7 @@ use core::ops::Neg; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; -use openvm_ecc_guest::{weierstrass::IntrinsicCurve, CyclicGroup, Group}; +use openvm_ecc_guest::{CyclicGroup, Group}; mod fp12; mod fp2; @@ -64,15 +64,8 @@ impl CyclicGroup for G1Affine { }; } -pub struct Bls12_381; - -impl IntrinsicCurve for Bls12_381 { - type Scalar = Scalar; - type Point = G1Affine; - - fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point { - openvm_ecc_guest::msm(coeffs, bases) - } +openvm_ecc_curve_macros::curve_declare! { + Bls12_381 { point_type = G1Affine, scalar_type = Scalar }, } // Define a G2Affine struct that implements curve operations using `Fp2` intrinsics diff --git a/guest-libs/pairing/src/bn254/mod.rs b/guest-libs/pairing/src/bn254/mod.rs index 8384b8b3e8..cdb09a9860 100644 --- a/guest-libs/pairing/src/bn254/mod.rs +++ b/guest-libs/pairing/src/bn254/mod.rs @@ -1,14 +1,11 @@ extern crate alloc; -use core::ops::{Add, Neg}; +use core::ops::Neg; use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; -use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve}, - CyclicGroup, Group, -}; +use openvm_ecc_guest::{CyclicGroup, Group}; use openvm_ecc_sw_macros::sw_declare; use openvm_pairing_guest::pairing::PairingIntrinsics; @@ -90,7 +87,9 @@ mod g2 { } } -pub struct Bn254; +openvm_ecc_curve_macros::curve_declare! { + Bn254 { point_type = G1Affine, scalar_type = Scalar }, +} impl Bn254 { // Same as the values from halo2curves_shims @@ -140,26 +139,6 @@ impl Bn254 { ); } -impl IntrinsicCurve for Bn254 { - type Scalar = Scalar; - type Point = G1Affine; - - fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point - where - for<'a> &'a Self::Point: Add<&'a Self::Point, Output = Self::Point>, - { - // heuristic - if coeffs.len() < 25 { - // BN254(Fp) is of prime order by Weil conjecture: - // - let table = CachedMulTable::::new_with_prime_order(bases, 4); - table.windowed_mul(coeffs) - } else { - openvm_ecc_guest::msm(coeffs, bases) - } - } -} - impl PairingIntrinsics for Bn254 { type Fp = Fp; type Fp2 = Fp2; From e25ddbd08e002c43fa711446f2dbc9bc56684f4e Mon Sep 17 00:00:00 2001 From: Manh Dinh Date: Sat, 6 Dec 2025 03:30:47 -0500 Subject: [PATCH 3/7] impl chip --- Cargo.lock | 3 +- Cargo.toml | 1 + benchmarks/execute/src/execute-verifier.rs | 2 +- benchmarks/prove/src/util.rs | 2 +- crates/sdk/src/config/global.rs | 2 + crates/vm/src/arch/integration_api.rs | 120 +++- extensions/ecc/circuit/Cargo.toml | 7 +- extensions/ecc/circuit/src/extension/mod.rs | 5 +- .../ecc/circuit/src/extension/weierstrass.rs | 102 +++- .../circuit/src/weierstrass_chip/curves.rs | 124 +++- .../ecc/circuit/src/weierstrass_chip/mod.rs | 2 + .../src/weierstrass_chip/mul/execution.rs | 466 +++++++++++++++ .../circuit/src/weierstrass_chip/mul/mod.rs | 170 ++++++ extensions/pairing/circuit/src/config.rs | 5 +- .../pairing/circuit/src/pairing_extension.rs | 7 +- extensions/pairing/guest/src/bls12_381/mod.rs | 1 + extensions/pairing/guest/src/bn254/mod.rs | 1 + extensions/rv32-adapters/src/ec_mul.rs | 563 ++++++++++++++++++ extensions/rv32-adapters/src/lib.rs | 2 + guest-libs/k256/tests/lib.rs | 5 +- guest-libs/p256/tests/lib.rs | 5 +- 21 files changed, 1562 insertions(+), 33 deletions(-) create mode 100644 extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs create mode 100644 extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs create mode 100644 extensions/rv32-adapters/src/ec_mul.rs diff --git a/Cargo.lock b/Cargo.lock index 389eb39826..5108c52eed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6494,6 +6494,7 @@ dependencies = [ "openvm-rv32-adapters", "openvm-stark-backend", "openvm-stark-sdk", + "pasta_curves 0.5.1", "rand 0.8.5", "serde", "serde_with", @@ -6502,7 +6503,7 @@ dependencies = [ [[package]] name = "openvm-ecc-curve-macros" -version = "1.4.1" +version = "1.4.2-rc.1" dependencies = [ "openvm-macros-common", "quote", diff --git a/Cargo.toml b/Cargo.toml index 5ce606b898..2ef587c331 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -258,6 +258,7 @@ num-traits = { version = "0.2.19", default-features = false } ff = { version = "0.13.1", default-features = false } sha2 = { version = "0.10", default-features = false } blstrs = { version = "0.7.1", default-features = true } +pasta_curves = { version = "0.5.1", default-features = true } # specific to CUDA and GPU cuda-runtime-sys = "0.3.0-alpha.1" diff --git a/benchmarks/execute/src/execute-verifier.rs b/benchmarks/execute/src/execute-verifier.rs index ce0472ed49..089d7a0771 100644 --- a/benchmarks/execute/src/execute-verifier.rs +++ b/benchmarks/execute/src/execute-verifier.rs @@ -17,7 +17,7 @@ use std::fs; -use clap::{arg, Parser, ValueEnum}; +use clap::{Parser, ValueEnum}; use eyre::Result; use openvm_benchmarks_utils::get_fixtures_dir; use openvm_circuit::arch::{ diff --git a/benchmarks/prove/src/util.rs b/benchmarks/prove/src/util.rs index ac231ba7c9..7d2b900959 100644 --- a/benchmarks/prove/src/util.rs +++ b/benchmarks/prove/src/util.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use clap::{command, Parser}; +use clap::Parser; use eyre::Result; use openvm_benchmarks_utils::{build_elf, get_programs_dir}; use openvm_circuit::{ diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index 9699b1ed34..95afd0b4da 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -520,6 +520,8 @@ impl InitFileGenerator for SdkVmConfigInner { if let Some(ecc_config) = &self.ecc { contents.push_str(&ecc_config.generate_sw_init()); contents.push('\n'); + contents.push_str(&ecc_config.generate_curve_init()); + contents.push('\n'); } Some(contents) diff --git a/crates/vm/src/arch/integration_api.rs b/crates/vm/src/arch/integration_api.rs index 1105cb40a8..9521ebb37a 100644 --- a/crates/vm/src/arch/integration_api.rs +++ b/crates/vm/src/arch/integration_api.rs @@ -324,6 +324,31 @@ impl< type ProcessedInstruction = MinimalInstruction; } +pub struct Rv32EcMulAdapterInterface< + T, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +>(PhantomData); + +impl< + T, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > VmAdapterInterface + for Rv32EcMulAdapterInterface +{ + type Reads = ( + [[T; SCALAR_SIZE]; BLOCKS_PER_SCALAR], + [[T; POINT_SIZE]; BLOCKS_PER_POINT], + ); + type Writes = [[T; POINT_SIZE]; BLOCKS_PER_POINT]; + type ProcessedInstruction = MinimalInstruction; +} + /// Similar to `BasicAdapterInterface`, but it flattens the reads and writes into a single flat /// array for each pub struct FlatInterface( @@ -628,6 +653,76 @@ mod conversions { } } + // AdapterAirContext: Rv32EcMulAdapterInterface -> DynAdapterInterface + impl< + T, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > + From< + AdapterAirContext< + T, + Rv32EcMulAdapterInterface< + T, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >, + >, + > for AdapterAirContext> + { + fn from( + ctx: AdapterAirContext< + T, + Rv32EcMulAdapterInterface< + T, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >, + >, + ) -> Self { + AdapterAirContext { + to_pc: ctx.to_pc, + reads: ctx.reads.into(), + writes: ctx.writes.into(), + instruction: ctx.instruction.into(), + } + } + } + + impl< + T, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > From>> + for AdapterAirContext< + T, + Rv32EcMulAdapterInterface< + T, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >, + > + { + fn from(ctx: AdapterAirContext>) -> Self { + AdapterAirContext { + to_pc: ctx.to_pc, + reads: ctx.reads.into(), + writes: ctx.writes.into(), + instruction: ctx.instruction.into(), + } + } + } + impl From> for DynArray { fn from(v: Vec) -> Self { Self(v) @@ -677,26 +772,29 @@ mod conversions { } } - impl From<([[T; N]; M1], [[T; N]; M2])> - for DynArray + // From implementation for tuples with potentially different inner array sizes + impl + From<([[T; N1]; M1], [[T; N2]; M2])> for DynArray { - fn from(v: ([[T; N]; M1], [[T; N]; M2])) -> Self { - let vec = - v.0.into_iter() - .flatten() - .chain(v.1.into_iter().flatten()) - .collect(); + fn from(v: ([[T; N1]; M1], [[T; N2]; M2])) -> Self { + let mut vec = Vec::new(); + for block in v.0 { + vec.extend(block); + } + for block in v.1 { + vec.extend(block); + } Self(vec) } } - impl From> - for ([[T; N]; M1], [[T; N]; M2]) + impl From> + for ([[T; N1]; M1], [[T; N2]; M2]) { fn from(v: DynArray) -> Self { assert_eq!( v.0.len(), - N * (M1 + M2), + N1 * M1 + N2 * M2, "Incorrect vector length {}", v.0.len() ); diff --git a/extensions/ecc/circuit/Cargo.toml b/extensions/ecc/circuit/Cargo.toml index ec60b81901..e0d60d4338 100644 --- a/extensions/ecc/circuit/Cargo.toml +++ b/extensions/ecc/circuit/Cargo.toml @@ -36,6 +36,7 @@ cfg-if = { workspace = true } halo2curves-axiom = { workspace = true } blstrs = { workspace = true } +pasta_curves = { workspace = true } [dev-dependencies] openvm-pairing-guest = { workspace = true, features = ["halo2curves"] } @@ -48,7 +49,11 @@ lazy_static = { workspace = true } [features] default = [] tco = ["openvm-algebra-circuit/tco"] -aot = ["openvm-circuit/aot", "openvm-algebra-circuit/aot", "halo2curves-axiom/asm"] +aot = [ + "openvm-circuit/aot", + "openvm-algebra-circuit/aot", + "halo2curves-axiom/asm", +] cuda = [ "dep:openvm-cuda-common", "dep:openvm-cuda-backend", diff --git a/extensions/ecc/circuit/src/extension/mod.rs b/extensions/ecc/circuit/src/extension/mod.rs index 03fa1a4c76..b7534eb4f6 100644 --- a/extensions/ecc/circuit/src/extension/mod.rs +++ b/extensions/ecc/circuit/src/extension/mod.rs @@ -63,9 +63,10 @@ impl Rv32WeierstrassConfig { impl InitFileGenerator for Rv32WeierstrassConfig { fn generate_init_file_contents(&self) -> Option { Some(format!( - "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", + "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n{}\n", self.modular.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() + self.weierstrass.generate_sw_init(), + self.weierstrass.generate_curve_init() )) } } diff --git a/extensions/ecc/circuit/src/extension/weierstrass.rs b/extensions/ecc/circuit/src/extension/weierstrass.rs index 435b5184ef..9b8dcd71ec 100644 --- a/extensions/ecc/circuit/src/extension/weierstrass.rs +++ b/extensions/ecc/circuit/src/extension/weierstrass.rs @@ -36,7 +36,8 @@ use strum::EnumCount; use crate::{ get_ec_addne_air, get_ec_addne_chip, get_ec_addne_step, get_ec_double_air, get_ec_double_chip, - get_ec_double_step, EcAddNeExecutor, EcDoubleExecutor, EccCpuProverExt, WeierstrassAir, + get_ec_double_step, get_ec_mul_air, get_ec_mul_chip, get_ec_mul_step, EcAddNeExecutor, + EcDoubleExecutor, EcMulExecutor, EccCpuProverExt, WeierstrassAir, WeierstrassEcMulAir, }; #[serde_as] @@ -44,6 +45,8 @@ use crate::{ pub struct CurveConfig { /// The name of the curve struct as defined by moduli_declare. pub struct_name: String, + /// The name of the curve as defined by curve_declare. + pub curve_name: String, /// The coordinate modulus of the curve. #[serde_as(as = "DisplayFromStr")] pub modulus: BigUint, @@ -60,6 +63,7 @@ pub struct CurveConfig { pub static SECP256K1_CONFIG: Lazy = Lazy::new(|| CurveConfig { struct_name: SECP256K1_ECC_STRUCT_NAME.to_string(), + curve_name: SECP256K1_CURVE_NAME.to_string(), modulus: SECP256K1_MODULUS.clone(), scalar: SECP256K1_ORDER.clone(), a: BigUint::zero(), @@ -68,6 +72,7 @@ pub static SECP256K1_CONFIG: Lazy = Lazy::new(|| CurveConfig { pub static P256_CONFIG: Lazy = Lazy::new(|| CurveConfig { struct_name: P256_ECC_STRUCT_NAME.to_string(), + curve_name: P256_CURVE_NAME.to_string(), modulus: P256_MODULUS.clone(), scalar: P256_ORDER.clone(), a: BigUint::from_bytes_le(&P256_A), @@ -90,6 +95,15 @@ impl WeierstrassExtension { format!("openvm_ecc_guest::sw_macros::sw_init! {{ {supported_curves} }}") } + pub fn generate_curve_init(&self) -> String { + let supported_curves = self + .supported_curves + .iter() + .map(|curve_config| format!("\"{}\"", curve_config.curve_name)) + .collect::>() + .join(", "); + format!("openvm_ecc_guest::curve_macros::curve_init! {{ {supported_curves} }}") + } } #[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] @@ -104,9 +118,11 @@ pub enum WeierstrassExtensionExecutor { // 32 limbs prime EcAddNeRv32_32(EcAddNeExecutor<2, 32>), EcDoubleRv32_32(EcDoubleExecutor<2, 32>), + EcMulRv32_32(EcMulExecutor<1, 2, 32, 32>), // 48 limbs prime EcAddNeRv32_48(EcAddNeExecutor<6, 16>), EcDoubleRv32_48(EcDoubleExecutor<6, 16>), + EcMulRv32_48(EcMulExecutor<1, 6, 32, 16>), } impl VmExecutionExtension for WeierstrassExtension { @@ -145,7 +161,7 @@ impl VmExecutionExtension for WeierstrassExtension { )?; let double = get_ec_double_step( - config, + config.clone(), dummy_range_checker_bus, pointer_max_bits, start_offset, @@ -158,6 +174,20 @@ impl VmExecutionExtension for WeierstrassExtension { ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; + + let mul = get_ec_mul_step( + config, + dummy_range_checker_bus, + pointer_max_bits, + start_offset, + ); + + inventory.add_executor( + WeierstrassExtensionExecutor::EcMulRv32_32(mul), + ((Rv32WeierstrassOpcode::EC_MUL as usize) + ..=(Rv32WeierstrassOpcode::SETUP_EC_MUL as usize)) + .map(|x| VmOpcode::from_usize(x + start_offset)), + )?; } else if bytes <= 48 { let config = ExprBuilderConfig { modulus: curve.modulus.clone(), @@ -179,7 +209,7 @@ impl VmExecutionExtension for WeierstrassExtension { )?; let double = get_ec_double_step( - config, + config.clone(), dummy_range_checker_bus, pointer_max_bits, start_offset, @@ -192,6 +222,20 @@ impl VmExecutionExtension for WeierstrassExtension { ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize)) .map(|x| VmOpcode::from_usize(x + start_offset)), )?; + + let mul = get_ec_mul_step( + config, + dummy_range_checker_bus, + pointer_max_bits, + start_offset, + ); + + inventory.add_executor( + WeierstrassExtensionExecutor::EcMulRv32_48(mul), + ((Rv32WeierstrassOpcode::EC_MUL as usize) + ..=(Rv32WeierstrassOpcode::SETUP_EC_MUL as usize)) + .map(|x| VmOpcode::from_usize(x + start_offset)), + )?; } else { panic!("Modulus too large"); } @@ -251,7 +295,7 @@ impl VmCircuitExtension for WeierstrassExtension { let double = get_ec_double_air::<2, 32>( exec_bridge, memory_bridge, - config, + config.clone(), range_checker_bus, bitwise_lu, pointer_max_bits, @@ -259,6 +303,17 @@ impl VmCircuitExtension for WeierstrassExtension { curve.a.clone(), ); inventory.add_air(double); + + let mul = get_ec_mul_air::<1, 2, 32, 32>( + exec_bridge, + memory_bridge, + config, + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(mul); } else if bytes <= 48 { let config = ExprBuilderConfig { modulus: curve.modulus.clone(), @@ -280,7 +335,7 @@ impl VmCircuitExtension for WeierstrassExtension { let double = get_ec_double_air::<6, 16>( exec_bridge, memory_bridge, - config, + config.clone(), range_checker_bus, bitwise_lu, pointer_max_bits, @@ -288,6 +343,17 @@ impl VmCircuitExtension for WeierstrassExtension { curve.a.clone(), ); inventory.add_air(double); + + let mul = get_ec_mul_air::<1, 6, 32, 16>( + exec_bridge, + memory_bridge, + config, + range_checker_bus, + bitwise_lu, + pointer_max_bits, + start_offset, + ); + inventory.add_air(mul); } else { panic!("Modulus too large"); } @@ -350,7 +416,7 @@ where inventory.next_air::>()?; let double = get_ec_double_chip::, 2, 32>( - config, + config.clone(), mem_helper.clone(), range_checker.clone(), bitwise_lu.clone(), @@ -358,6 +424,16 @@ where curve.a.clone(), ); inventory.add_executor_chip(double); + + inventory.next_air::>()?; + let mul = get_ec_mul_chip::, 1, 2, 32, 32>( + config, + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(mul); } else if bytes <= 48 { let config = ExprBuilderConfig { modulus: curve.modulus.clone(), @@ -377,7 +453,7 @@ where inventory.next_air::>()?; let double = get_ec_double_chip::, 6, 16>( - config, + config.clone(), mem_helper.clone(), range_checker.clone(), bitwise_lu.clone(), @@ -385,6 +461,16 @@ where curve.a.clone(), ); inventory.add_executor_chip(double); + + inventory.next_air::>()?; + let mul = get_ec_mul_chip::, 1, 6, 32, 16>( + config, + mem_helper.clone(), + range_checker.clone(), + bitwise_lu.clone(), + pointer_max_bits, + ); + inventory.add_executor_chip(mul); } else { panic!("Modulus too large"); } @@ -421,3 +507,5 @@ const P256_B: [u8; 32] = hex!("4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e pub const SECP256K1_ECC_STRUCT_NAME: &str = "Secp256k1Point"; pub const P256_ECC_STRUCT_NAME: &str = "P256Point"; +pub const SECP256K1_CURVE_NAME: &str = "Secp256k1"; +pub const P256_CURVE_NAME: &str = "NistP256"; diff --git a/extensions/ecc/circuit/src/weierstrass_chip/curves.rs b/extensions/ecc/circuit/src/weierstrass_chip/curves.rs index b00eee8ce9..9a9c220a04 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/curves.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/curves.rs @@ -1,10 +1,16 @@ -use halo2curves_axiom::ff::{Field, PrimeField}; +use core::ops::Mul; + +use halo2curves_axiom::{ + ff::{Field, PrimeField}, + CurveAffineExt, +}; use num_bigint::BigUint; use num_traits::Num; use openvm_algebra_circuit::fields::{ blocks_to_field_element, blocks_to_field_element_bls12_381_coordinate, field_element_to_blocks, field_element_to_blocks_bls12_381_coordinate, FieldType, }; +use pasta_curves::arithmetic::CurveAffine; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CurveType { @@ -94,6 +100,58 @@ pub fn ec_double( + scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], + point_data: [[u8; POINT_SIZE]; BLOCKS_PER_POINT], +) -> [[u8; POINT_SIZE]; BLOCKS_PER_POINT] { + match CURVE_TYPE { + x if x == CurveType::K256 as u8 => ec_mul_256bit::< + halo2curves_axiom::secq256k1::Fq, + halo2curves_axiom::secq256k1::Fp, + halo2curves_axiom::secq256k1::Secq256k1, + halo2curves_axiom::secq256k1::Secq256k1Affine, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >(scalar_data, point_data), + x if x == CurveType::P256 as u8 => ec_mul_256bit::< + halo2curves_axiom::secp256r1::Fq, + halo2curves_axiom::secp256r1::Fp, + halo2curves_axiom::secp256r1::Secp256r1, + halo2curves_axiom::secp256r1::Secp256r1Affine, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >(scalar_data, point_data), + x if x == CurveType::BN254 as u8 => ec_mul_256bit::< + halo2curves_axiom::bn256::Fr, + halo2curves_axiom::bn256::Fq, + halo2curves_axiom::bn256::G1, + halo2curves_axiom::bn256::G1Affine, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >(scalar_data, point_data), + x if x == CurveType::BLS12_381 as u8 => { + ec_mul_bls12_381::( + scalar_data, + point_data, + ) + } + _ => panic!("Unsupported curve type: {}", CURVE_TYPE), + } +} + #[inline(always)] fn ec_add_ne_256bit< F: PrimeField, @@ -212,3 +270,67 @@ pub fn ec_double_impl, const NEG_A: u64>(x1: F, y1: F) -> ( (x3, y3) } + +#[inline(always)] +fn ec_mul_256bit< + Fr: PrimeField, + Fq: PrimeField, + CJ: for<'a> Mul<&'a Fr, Output = CJ> + From, + CA: CurveAffine + CurveAffineExt + From, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +>( + scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], + point_data: [[u8; POINT_SIZE]; BLOCKS_PER_POINT], +) -> [[u8; POINT_SIZE]; BLOCKS_PER_POINT] { + // read scalar and point data + let scalar = blocks_to_field_element::(scalar_data.as_flattened()); + let x1 = blocks_to_field_element::(point_data[..BLOCKS_PER_POINT / 2].as_flattened()); + let y1 = blocks_to_field_element::(point_data[BLOCKS_PER_POINT / 2..].as_flattened()); + + let point_jacobian: CJ = CA::from_xy(x1, y1).unwrap().into(); + + let output_affine: CA = (point_jacobian * &scalar).into(); + let (x3, y3) = output_affine.into_coordinates(); + + // write output data to memory + let mut output = [[0u8; POINT_SIZE]; BLOCKS_PER_POINT]; + field_element_to_blocks::(&x3, &mut output[..BLOCKS_PER_POINT / 2]); + field_element_to_blocks::(&y3, &mut output[BLOCKS_PER_POINT / 2..]); + output +} + +#[inline(always)] +fn ec_mul_bls12_381< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +>( + scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], + point_data: [[u8; POINT_SIZE]; BLOCKS_PER_POINT], +) -> [[u8; POINT_SIZE]; BLOCKS_PER_POINT] { + // read scalar and point data + let scalar = blocks_to_field_element::(scalar_data.as_flattened()); + let x1 = blocks_to_field_element_bls12_381_coordinate( + point_data[..BLOCKS_PER_POINT / 2].as_flattened(), + ); + let y1 = blocks_to_field_element_bls12_381_coordinate( + point_data[BLOCKS_PER_POINT / 2..].as_flattened(), + ); + + let point_jacobian: blstrs::G1Projective = + blstrs::G1Affine::from_raw_unchecked(x1, y1, false).into(); + + let output_affine: blstrs::G1Affine = (point_jacobian * scalar).into(); + let x3 = output_affine.x(); + let y3 = output_affine.y(); + + // write output data to memory + let mut output = [[0u8; POINT_SIZE]; BLOCKS_PER_POINT]; + field_element_to_blocks_bls12_381_coordinate(&x3, &mut output[..BLOCKS_PER_POINT / 2]); + field_element_to_blocks_bls12_381_coordinate(&y3, &mut output[BLOCKS_PER_POINT / 2..]); + output +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs index cc8e97841e..633f017950 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs @@ -1,9 +1,11 @@ mod add_ne; mod curves; mod double; +mod mul; pub use add_ne::*; pub use double::*; +pub use mul::*; #[cfg(test)] mod tests; diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs b/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs new file mode 100644 index 0000000000..3d8378ba9b --- /dev/null +++ b/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs @@ -0,0 +1,466 @@ +use std::{ + array::from_fn, + borrow::{Borrow, BorrowMut}, +}; + +use openvm_circuit::{ + arch::*, + system::memory::{online::GuestMemory, POINTER_MAX_BITS}, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_ecc_transpiler::Rv32WeierstrassOpcode; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, +}; +use openvm_mod_circuit_builder::FieldExpr; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::EcMulExecutor; +use crate::weierstrass_chip::curves::{ec_mul, get_curve_type, CurveType}; + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct EcMulPreCompute<'a> { + expr: &'a FieldExpr, + rs_addrs: [u8; 2], + a: u8, + flag_idx: u8, +} + +impl< + 'a, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > EcMulExecutor +{ + fn pre_compute_impl( + &'a self, + pc: u32, + inst: &Instruction, + data: &mut EcMulPreCompute<'a>, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + // Validate instruction format + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + + // TODO: Pre-compute flag_idx + // let needs_setup = self.expr.needs_setup(); + let flag_idx = self.expr.num_flags() as u8; + // if needs_setup { + // // Find which opcode this is in our local_opcode_idx list + // if let Some(opcode_position) = self + // .local_opcode_idx + // .iter() + // .position(|&idx| idx == local_opcode) + // { + // // If this is NOT the last opcode (setup), get the corresponding flag_idx + // if opcode_position < self.opcode_flag_idx.len() { + // flag_idx = self.opcode_flag_idx[opcode_position] as u8; + // } + // } + // } + + let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8); + *data = EcMulPreCompute { + expr: &self.expr, + rs_addrs, + a: a as u8, + flag_idx, + }; + + let local_opcode = opcode.local_opcode_idx(self.offset); + let is_setup = local_opcode == Rv32WeierstrassOpcode::SETUP_EC_MUL as usize; + + Ok(is_setup) + } +} + +macro_rules! dispatch { + ($execute_impl:ident, $pre_compute:ident, $is_setup:ident) => { + if let Some(curve_type) = { + let modulus = &$pre_compute.expr.builder.prime; + let a_coeff = &$pre_compute.expr.setup_values[0]; + get_curve_type(modulus, a_coeff) + } { + match ($is_setup, curve_type) { + (true, CurveType::K256) => Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { CurveType::K256 as u8 }, + true, + >), + (true, CurveType::P256) => Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { CurveType::P256 as u8 }, + true, + >), + (true, CurveType::BN254) => Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { CurveType::BN254 as u8 }, + true, + >), + (true, CurveType::BLS12_381) => Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { CurveType::BLS12_381 as u8 }, + true, + >), + (false, CurveType::K256) => Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { CurveType::K256 as u8 }, + false, + >), + (false, CurveType::P256) => Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { CurveType::P256 as u8 }, + false, + >), + (false, CurveType::BN254) => Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { CurveType::BN254 as u8 }, + false, + >), + (false, CurveType::BLS12_381) => Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { CurveType::BLS12_381 as u8 }, + false, + >), + } + } else if $is_setup { + Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { u8::MAX }, + true, + >) + } else { + Ok($execute_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + { u8::MAX }, + false, + >) + } + }; +} + +impl< + F: PrimeField32, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > InterpreterExecutor + for EcMulExecutor +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::() + } + + #[cfg(not(feature = "tco"))] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut EcMulPreCompute = data.borrow_mut(); + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_handler, pre_compute, is_setup) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let pre_compute: &mut EcMulPreCompute = data.borrow_mut(); + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + + dispatch!(execute_e1_handler, pre_compute, is_setup) + } +} + +impl< + F: PrimeField32, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > InterpreterMeteredExecutor + for EcMulExecutor +{ + #[inline(always)] + fn metered_pre_compute_size(&self) -> usize { + std::mem::size_of::>() + } + + #[cfg(not(feature = "tco"))] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let pre_compute_pure = &mut pre_compute.data; + let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?; + dispatch!(execute_e2_handler, pre_compute_pure, is_setup) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let pre_compute_pure = &mut pre_compute.data; + let is_setup = self.pre_compute_impl(pc, inst, pre_compute_pure)?; + dispatch!(execute_e2_handler, pre_compute_pure, is_setup) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: ExecutionCtxTrait, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + const CURVE_TYPE: u8, + const IS_SETUP: bool, +>( + pre_compute: &EcMulPreCompute, + exec_state: &mut VmExecState, +) -> Result<(), ExecutionError> { + let pc = exec_state.pc(); + // Read register values + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + // Read memory values for the scalar and point + let scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR] = + from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, rs_vals[0] + (i * SCALAR_SIZE) as u32)); + let point_data: [[u8; POINT_SIZE]; BLOCKS_PER_POINT] = + from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, rs_vals[1] + (i * POINT_SIZE) as u32)); + + // TODO: Check this later + // if IS_SETUP { + // let input_scalar_prime = BigUint::from_bytes_le(scalar_data.as_flattened()); + // if input_scalar_prime != pre_compute.expr.builder.prime { + // let err = ExecutionError::Fail { + // pc, + // msg: "EcMul: mismatched scalar prime", + // }; + // return Err(err); + // } + + // let input_point_prime = + // BigUint::from_bytes_le(point_data[..BLOCKS_PER_POINT / 2].as_flattened()); + // if input_point_prime != pre_compute.expr.builder.prime { + // let err = ExecutionError::Fail { + // pc, + // msg: "EcMul: mismatched point prime", + // }; + // return Err(err); + // } + + // let input_a = BigUint::from_bytes_le(point_data[BLOCKS_PER_POINT / 2..].as_flattened()); + // let coeff_a = &pre_compute.expr.setup_values[0]; + // if input_a != *coeff_a { + // let err = ExecutionError::Fail { + // pc, + // msg: "EcMul: mismatched coeff_a", + // }; + // return Err(err); + // } + // } + + // TODO: Check this later + // let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP { + // let scalar_data: DynArray = scalar_data.into(); + // let point_data: DynArray = point_data.into(); + // run_field_expression_precomputed::( + // pre_compute.expr, + // pre_compute.flag_idx as usize, + // &scalar_data.concat(&point_data).0, + // ) + // .into() + // } else { + // ec_mul::( + // scalar_data, + // point_data, + // ) + // }; + let output_data = + ec_mul::( + scalar_data, + point_data, + ); + + let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); + debug_assert!(rd_val as usize + POINT_SIZE * BLOCKS_PER_POINT - 1 < (1 << POINTER_MAX_BITS)); + + // Write output data to memory + for (i, block) in output_data.into_iter().enumerate() { + exec_state.vm_write(RV32_MEMORY_AS, rd_val + (i * POINT_SIZE) as u32, &block); + } + + exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP)); + + Ok(()) +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: ExecutionCtxTrait, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + const CURVE_TYPE: u8, + const IS_SETUP: bool, +>( + pre_compute: *const u8, + exec_state: &mut VmExecState, +) -> Result<(), ExecutionError> { + let pre_compute: &EcMulPreCompute = + std::slice::from_raw_parts(pre_compute, size_of::()).borrow(); + execute_e12_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + CURVE_TYPE, + IS_SETUP, + >(pre_compute, exec_state) +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: MeteredExecutionCtxTrait, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + const CURVE_TYPE: u8, + const IS_SETUP: bool, +>( + pre_compute: *const u8, + exec_state: &mut VmExecState, +) -> Result<(), ExecutionError> { + let e2_pre_compute: &E2PreCompute = + std::slice::from_raw_parts(pre_compute, size_of::>()) + .borrow(); + exec_state + .ctx + .on_height_change(e2_pre_compute.chip_idx as usize, 1); + execute_e12_impl::< + _, + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + CURVE_TYPE, + IS_SETUP, + >(&e2_pre_compute.data, exec_state) +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs new file mode 100644 index 0000000000..21b2126889 --- /dev/null +++ b/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs @@ -0,0 +1,170 @@ +use std::{cell::RefCell, rc::Rc}; + +use derive_more::derive::{Deref, DerefMut}; +use openvm_circuit::{ + arch::*, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::PreflightExecutor; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, +}; +use openvm_ecc_transpiler::Rv32WeierstrassOpcode; +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_mod_circuit_builder::{ + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldExpressionExecutor, + FieldExpressionFiller, +}; +use openvm_rv32_adapters::{Rv32EcMulAdapterAir, Rv32EcMulAdapterExecutor, Rv32EcMulAdapterFiller}; + +mod execution; + +/// dummy implementation for now +pub fn ec_mul_expr( + config: ExprBuilderConfig, // The coordinate field. + range_bus: VariableRangeCheckerBus, +) -> FieldExpr { + config.check_valid(); + let builder = ExprBuilder::new(config, range_bus.range_max_bits); + let builder = Rc::new(RefCell::new(builder)); + + // Create inputs + let _scalar = ExprBuilder::new_input(builder.clone()); + let x1 = ExprBuilder::new_input(builder.clone()); + let y1 = ExprBuilder::new_input(builder.clone()); + + // Create dummy outputs: result x and result y + // Note: The actual computation is done natively, but we need these for the AIR structure + let mut x_out = x1.clone(); // Dummy - actual computation happens in native code + x_out.save_output(); + let mut y_out = y1.clone(); // Dummy - actual computation happens in native code + y_out.save_output(); + + let builder = (*builder).borrow().clone(); + FieldExpr::new(builder, range_bus, false) +} + +#[derive(Clone, PreflightExecutor, Deref, DerefMut)] +pub struct EcMulExecutor< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +>( + FieldExpressionExecutor< + Rv32EcMulAdapterExecutor, + >, +); + +pub type WeierstrassEcMulAir< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +> = VmAirWrapper< + Rv32EcMulAdapterAir, + FieldExpressionCoreAir, +>; + +pub type WeierstrassEcMulChip< + F, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +> = VmChipWrapper< + F, + FieldExpressionFiller< + Rv32EcMulAdapterFiller, + >, +>; + +fn gen_base_expr( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, +) -> (FieldExpr, Vec) { + let expr = ec_mul_expr(config, range_checker_bus); + + let local_opcode_idx = vec![ + Rv32WeierstrassOpcode::EC_MUL as usize, + Rv32WeierstrassOpcode::SETUP_EC_MUL as usize, + ]; + + (expr, local_opcode_idx) +} + +pub fn get_ec_mul_air< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +>( + exec_bridge: ExecutionBridge, + mem_bridge: MemoryBridge, + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + bitwise_lookup_bus: BitwiseOperationLookupBus, + pointer_max_bits: usize, + offset: usize, +) -> WeierstrassEcMulAir { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus); + WeierstrassEcMulAir::new( + Rv32EcMulAdapterAir::new( + exec_bridge, + mem_bridge, + bitwise_lookup_bus, + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ) +} + +pub fn get_ec_mul_step< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +>( + config: ExprBuilderConfig, + range_checker_bus: VariableRangeCheckerBus, + pointer_max_bits: usize, + offset: usize, +) -> EcMulExecutor { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus); + EcMulExecutor(FieldExpressionExecutor::new( + Rv32EcMulAdapterExecutor::new(pointer_max_bits), + expr, + offset, + local_opcode_idx, + vec![], + "EcMul", + )) +} + +pub fn get_ec_mul_chip< + F, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +>( + config: ExprBuilderConfig, + mem_helper: SharedMemoryHelper, + range_checker: SharedVariableRangeCheckerChip, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, +) -> WeierstrassEcMulChip { + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker.bus()); + WeierstrassEcMulChip::new( + FieldExpressionFiller::new( + Rv32EcMulAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), + expr, + local_opcode_idx, + vec![], + range_checker, + false, + ), + mem_helper, + ) +} diff --git a/extensions/pairing/circuit/src/config.rs b/extensions/pairing/circuit/src/config.rs index 20ea07186a..87f3f0998e 100644 --- a/extensions/pairing/circuit/src/config.rs +++ b/extensions/pairing/circuit/src/config.rs @@ -62,10 +62,11 @@ impl Rv32PairingConfig { impl InitFileGenerator for Rv32PairingConfig { fn generate_init_file_contents(&self) -> Option { Some(format!( - "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n{}\n", + "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n{}\n{}\n", self.modular.modular.generate_moduli_init(), self.fp2.generate_complex_init(&self.modular.modular), - self.weierstrass.generate_sw_init() + self.weierstrass.generate_sw_init(), + self.weierstrass.generate_curve_init() )) } } diff --git a/extensions/pairing/circuit/src/pairing_extension.rs b/extensions/pairing/circuit/src/pairing_extension.rs index 5f5bcc74bf..a7a8a2ac33 100644 --- a/extensions/pairing/circuit/src/pairing_extension.rs +++ b/extensions/pairing/circuit/src/pairing_extension.rs @@ -14,9 +14,10 @@ use openvm_ecc_circuit::CurveConfig; use openvm_instructions::PhantomDiscriminant; use openvm_pairing_guest::{ bls12_381::{ - BLS12_381_ECC_STRUCT_NAME, BLS12_381_MODULUS, BLS12_381_ORDER, BLS12_381_XI_ISIZE, + BLS12_381_CURVE_NAME, BLS12_381_ECC_STRUCT_NAME, BLS12_381_MODULUS, BLS12_381_ORDER, + BLS12_381_XI_ISIZE, }, - bn254::{BN254_ECC_STRUCT_NAME, BN254_MODULUS, BN254_ORDER, BN254_XI_ISIZE}, + bn254::{BN254_CURVE_NAME, BN254_ECC_STRUCT_NAME, BN254_MODULUS, BN254_ORDER, BN254_XI_ISIZE}, }; use openvm_pairing_transpiler::PairingPhantom; use openvm_stark_backend::{config::StarkGenericConfig, engine::StarkEngine, p3_field::Field}; @@ -36,6 +37,7 @@ impl PairingCurve { match self { PairingCurve::Bn254 => CurveConfig::new( BN254_ECC_STRUCT_NAME.to_string(), + BN254_CURVE_NAME.to_string(), BN254_MODULUS.clone(), BN254_ORDER.clone(), BigUint::zero(), @@ -43,6 +45,7 @@ impl PairingCurve { ), PairingCurve::Bls12_381 => CurveConfig::new( BLS12_381_ECC_STRUCT_NAME.to_string(), + BLS12_381_CURVE_NAME.to_string(), BLS12_381_MODULUS.clone(), BLS12_381_ORDER.clone(), BigUint::zero(), diff --git a/extensions/pairing/guest/src/bls12_381/mod.rs b/extensions/pairing/guest/src/bls12_381/mod.rs index 08808e10da..445da4fb22 100644 --- a/extensions/pairing/guest/src/bls12_381/mod.rs +++ b/extensions/pairing/guest/src/bls12_381/mod.rs @@ -35,6 +35,7 @@ pub const BLS12_381_PSEUDO_BINARY_ENCODING: [i8; 64] = [ #[cfg(not(target_os = "zkvm"))] // Used in WeierstrassExtension config pub const BLS12_381_ECC_STRUCT_NAME: &str = "Bls12_381G1Affine"; +pub const BLS12_381_CURVE_NAME: &str = "Bls12_381"; #[cfg(not(target_os = "zkvm"))] // Used in Fp2Extension config diff --git a/extensions/pairing/guest/src/bn254/mod.rs b/extensions/pairing/guest/src/bn254/mod.rs index 1e30a51945..2e75fbe933 100644 --- a/extensions/pairing/guest/src/bn254/mod.rs +++ b/extensions/pairing/guest/src/bn254/mod.rs @@ -37,6 +37,7 @@ pub const BN254_PSEUDO_BINARY_ENCODING: [i8; 66] = [ #[cfg(not(target_os = "zkvm"))] // Used in WeierstrassExtension config pub const BN254_ECC_STRUCT_NAME: &str = "Bn254G1Affine"; +pub const BN254_CURVE_NAME: &str = "Bn254"; #[cfg(not(target_os = "zkvm"))] // Used in Fp2Extension config diff --git a/extensions/rv32-adapters/src/ec_mul.rs b/extensions/rv32-adapters/src/ec_mul.rs new file mode 100644 index 0000000000..201e0cb8d3 --- /dev/null +++ b/extensions/rv32-adapters/src/ec_mul.rs @@ -0,0 +1,563 @@ +use std::{ + array::from_fn, + borrow::{Borrow, BorrowMut}, + iter::{once, zip}, +}; + +use itertools::izip; +use openvm_circuit::{ + arch::{ + get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller, + ExecutionBridge, ExecutionState, Rv32EcMulAdapterInterface, VmAdapterAir, + }, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, +}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, +}; +use openvm_rv32im_circuit::adapters::{ + abstract_compose, tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; +use openvm_stark_backend::{ + interaction::InteractionBuilder, + p3_air::BaseAir, + p3_field::{Field, FieldAlgebra, PrimeField32}, +}; + +/// This adapter reads from 2 pointers and writes to 1 pointer. +/// * The data is read from the heap (address space 2), and the pointers are read from registers +/// (address space 1). +#[repr(C)] +#[derive(AlignedBorrow, Debug)] +pub struct Rv32EcMulAdapterCols< + T, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +> { + pub from_state: ExecutionState, + + pub rs_ptr: [T; 2], + pub rd_ptr: T, + + pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; 2], + pub rd_val: [T; RV32_REGISTER_NUM_LIMBS], + + pub rs_read_aux: [MemoryReadAuxCols; 2], + pub rd_read_aux: MemoryReadAuxCols, + + pub reads_scalar_aux: [MemoryReadAuxCols; BLOCKS_PER_SCALAR], + pub reads_point_aux: [MemoryReadAuxCols; BLOCKS_PER_POINT], + pub writes_aux: [MemoryWriteAuxCols; BLOCKS_PER_POINT], +} + +#[allow(dead_code)] +#[derive(Clone, Copy, Debug, derive_new::new)] +pub struct Rv32EcMulAdapterAir< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +> { + pub(super) execution_bridge: ExecutionBridge, + pub(super) memory_bridge: MemoryBridge, + pub bus: BitwiseOperationLookupBus, + /// The max number of bits for an address in memory + address_bits: usize, +} + +impl< + F: Field, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > BaseAir + for Rv32EcMulAdapterAir +{ + fn width(&self) -> usize { + Rv32EcMulAdapterCols::::width() + } +} + +impl< + AB: InteractionBuilder, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > VmAdapterAir + for Rv32EcMulAdapterAir +{ + type Interface = Rv32EcMulAdapterInterface< + AB::Expr, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >; + + fn eval( + &self, + builder: &mut AB, + local: &[AB::Var], + ctx: AdapterAirContext, + ) { + let cols: &Rv32EcMulAdapterCols< + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + > = local.borrow(); + let timestamp = cols.from_state.timestamp; + let mut timestamp_delta: usize = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + timestamp + AB::F::from_canonical_usize(timestamp_delta - 1) + }; + + // Read register values for rs, rd + for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux).chain(once(( + cols.rd_ptr, + cols.rd_val, + &cols.rd_read_aux, + ))) { + self.memory_bridge + .read( + MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr), + val, + timestamp_pp(), + aux, + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } + + // We constrain the highest limbs of heap pointers to be less than 2^(addr_bits - + // (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))). This ensures that no overflow + // occurs when computing memory pointers. Since the number of cells accessed with each + // address will be small enough, and combined with the memory argument, it ensures + // that all the cells accessed in the memory are less than 2^addr_bits. + let need_range_check: Vec = cols + .rs_val + .iter() + .chain(std::iter::repeat_n(&cols.rd_val, 2)) + .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1]) + .collect(); + + // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain + // the correct amount of bits + let limb_shift = AB::F::from_canonical_usize( + 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits), + ); + + // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS + // thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that + // limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))) + for pair in need_range_check.chunks_exact(2) { + self.bus + .send_range(pair[0] * limb_shift, pair[1] * limb_shift) + .eval(builder, ctx.instruction.is_valid.clone()); + } + + // Compose the u32 register value into single field element, with `abstract_compose` + let rd_val_f: AB::Expr = abstract_compose(cols.rd_val); + let rs_val_f: [AB::Expr; 2] = cols.rs_val.map(abstract_compose); + + let e = AB::F::from_canonical_u32(RV32_MEMORY_AS); + // Reads from heap + // Scalar reads + let scalar_address = &rs_val_f[0]; + for (i, (read, aux)) in zip(ctx.reads.0, &cols.reads_scalar_aux).enumerate() { + self.memory_bridge + .read( + MemoryAddress::new( + e, + scalar_address.clone() + AB::Expr::from_canonical_usize(i * SCALAR_SIZE), + ), + read, + timestamp_pp(), + aux, + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } + // Point reads + let point_address = &rs_val_f[1]; + for (i, (read, aux)) in zip(ctx.reads.1, &cols.reads_point_aux).enumerate() { + self.memory_bridge + .read( + MemoryAddress::new( + e, + point_address.clone() + AB::Expr::from_canonical_usize(i * POINT_SIZE), + ), + read, + timestamp_pp(), + aux, + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } + + // Writes to heap + for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() { + self.memory_bridge + .write( + MemoryAddress::new( + e, + rd_val_f.clone() + AB::Expr::from_canonical_usize(i * POINT_SIZE), + ), + write, + timestamp_pp(), + aux, + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } + + self.execution_bridge + .execute_and_increment_or_set_pc( + ctx.instruction.opcode, + [ + cols.rd_ptr.into(), + cols.rs_ptr + .first() + .map(|&x| x.into()) + .unwrap_or(AB::Expr::ZERO), + cols.rs_ptr + .get(1) + .map(|&x| x.into()) + .unwrap_or(AB::Expr::ZERO), + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + e.into(), + ], + cols.from_state, + AB::F::from_canonical_usize(timestamp_delta), + (DEFAULT_PC_STEP, ctx.to_pc), + ) + .eval(builder, ctx.instruction.is_valid.clone()); + } + + fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { + let cols: &Rv32EcMulAdapterCols< + _, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + > = local.borrow(); + cols.from_state.pc + } +} + +// Intermediate type that should not be copied or cloned and should be directly written to +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32EcMulAdapterRecord< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +> { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs_ptrs: [u32; 2], + pub rd_ptr: u32, + + pub rs_vals: [u32; 2], + pub rd_val: u32, + + pub rs_read_aux: [MemoryReadAuxRecord; 2], + pub rd_read_aux: MemoryReadAuxRecord, + + pub reads_scalar_aux: [MemoryReadAuxRecord; BLOCKS_PER_SCALAR], + pub reads_point_aux: [MemoryReadAuxRecord; BLOCKS_PER_POINT], + pub writes_aux: [MemoryWriteBytesAuxRecord; BLOCKS_PER_POINT], +} + +#[derive(derive_new::new, Clone, Copy)] +pub struct Rv32EcMulAdapterExecutor< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +> { + pointer_max_bits: usize, +} + +#[derive(derive_new::new)] +pub struct Rv32EcMulAdapterFiller< + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, +> { + pointer_max_bits: usize, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + +impl< + F: PrimeField32, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > AdapterTraceExecutor + for Rv32EcMulAdapterExecutor +{ + const WIDTH: usize = Rv32EcMulAdapterCols::< + F, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >::width(); + type ReadData = ( + [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], + [[u8; POINT_SIZE]; BLOCKS_PER_POINT], + ); + type WriteData = [[u8; POINT_SIZE]; BLOCKS_PER_POINT]; + type RecordMut<'a> = &'a mut Rv32EcMulAdapterRecord< + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + record: &mut &mut Rv32EcMulAdapterRecord< + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >, + ) -> Self::ReadData { + let &Instruction { a, b, c, d, e, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + // Read register values + record.rs_vals = from_fn(|i| { + record.rs_ptrs[i] = if i == 0 { b } else { c }.as_canonical_u32(); + u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs_ptrs[i], + &mut record.rs_read_aux[i].prev_timestamp, + )) + }); + + record.rd_ptr = a.as_canonical_u32(); + record.rd_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut record.rd_read_aux.prev_timestamp, + )); + + // Read memory values + debug_assert!( + (record.rs_vals[0] + (SCALAR_SIZE * BLOCKS_PER_SCALAR - 1) as u32) + < (1 << self.pointer_max_bits) as u32 + ); + let reads_scalar = from_fn(|i| { + tracing_read( + memory, + RV32_MEMORY_AS, + record.rs_vals[0] + (i * SCALAR_SIZE) as u32, + &mut record.reads_scalar_aux[i].prev_timestamp, + ) + }); + debug_assert!( + (record.rs_vals[1] + (POINT_SIZE * BLOCKS_PER_POINT - 1) as u32) + < (1 << self.pointer_max_bits) as u32 + ); + let reads_point = from_fn(|i| { + tracing_read( + memory, + RV32_MEMORY_AS, + record.rs_vals[1] + (i * POINT_SIZE) as u32, + &mut record.reads_point_aux[i].prev_timestamp, + ) + }); + (reads_scalar, reads_point) + } + + fn write( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut &mut Rv32EcMulAdapterRecord< + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >, + ) { + debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS); + + debug_assert!( + record.rd_val as usize + POINT_SIZE * BLOCKS_PER_POINT - 1 + < (1 << self.pointer_max_bits) + ); + + #[allow(clippy::needless_range_loop)] + for i in 0..BLOCKS_PER_POINT { + tracing_write( + memory, + RV32_MEMORY_AS, + record.rd_val + (i * POINT_SIZE) as u32, + data[i], + &mut record.writes_aux[i].prev_timestamp, + &mut record.writes_aux[i].prev_data, + ); + } + } +} + +impl< + F: PrimeField32, + const BLOCKS_PER_SCALAR: usize, + const BLOCKS_PER_POINT: usize, + const SCALAR_SIZE: usize, + const POINT_SIZE: usize, + > AdapterTraceFiller + for Rv32EcMulAdapterFiller +{ + const WIDTH: usize = Rv32EcMulAdapterCols::< + F, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + >::width(); + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + // SAFETY: + // - caller ensures `adapter_row` contains a valid record representation that was previously + // written by the executor + let record: &Rv32EcMulAdapterRecord< + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + > = unsafe { get_record_from_slice(&mut adapter_row, ()) }; + + let cols: &mut Rv32EcMulAdapterCols< + F, + BLOCKS_PER_SCALAR, + BLOCKS_PER_POINT, + SCALAR_SIZE, + POINT_SIZE, + > = adapter_row.borrow_mut(); + + // Range checks: + // **NOTE**: Must do the range checks before overwriting the records + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits, + ); + self.bitwise_lookup_chip.request_range( + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + ); + + let timestamp_delta = 3 + BLOCKS_PER_SCALAR + 2 * BLOCKS_PER_POINT; + let mut timestamp = record.from_timestamp + timestamp_delta as u32; + let mut timestamp_mm = || { + timestamp -= 1; + timestamp + }; + + // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records + record + .writes_aux + .iter() + .rev() + .zip(cols.writes_aux.iter_mut().rev()) + .for_each(|(write, cols_write)| { + cols_write.set_prev_data(write.prev_data.map(F::from_canonical_u8)); + mem_helper.fill(write.prev_timestamp, timestamp_mm(), cols_write.as_mut()); + }); + + record + .reads_point_aux + .iter() + .zip(cols.reads_point_aux.iter_mut()) + .rev() + .for_each(|(read, cols_read)| { + mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut()); + }); + + record + .reads_scalar_aux + .iter() + .zip(cols.reads_scalar_aux.iter_mut()) + .rev() + .for_each(|(read, cols_read)| { + mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut()); + }); + + mem_helper.fill( + record.rd_read_aux.prev_timestamp, + timestamp_mm(), + cols.rd_read_aux.as_mut(), + ); + + record + .rs_read_aux + .iter() + .zip(cols.rs_read_aux.iter_mut()) + .rev() + .for_each(|(aux, cols_aux)| { + mem_helper.fill(aux.prev_timestamp, timestamp_mm(), cols_aux.as_mut()); + }); + + cols.rd_val = record.rd_val.to_le_bytes().map(F::from_canonical_u8); + cols.rs_val + .iter_mut() + .rev() + .zip(record.rs_vals.iter().rev()) + .for_each(|(cols_val, val)| { + *cols_val = val.to_le_bytes().map(F::from_canonical_u8); + }); + cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + cols.rs_ptr + .iter_mut() + .rev() + .zip(record.rs_ptrs.iter().rev()) + .for_each(|(cols_ptr, ptr)| { + *cols_ptr = F::from_canonical_u32(*ptr); + }); + cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); + } +} diff --git a/extensions/rv32-adapters/src/lib.rs b/extensions/rv32-adapters/src/lib.rs index 6d884daedf..780b44284e 100644 --- a/extensions/rv32-adapters/src/lib.rs +++ b/extensions/rv32-adapters/src/lib.rs @@ -1,8 +1,10 @@ +mod ec_mul; mod eq_mod; mod heap; mod heap_branch; mod vec_heap; +pub use ec_mul::*; pub use eq_mod::*; pub use heap::*; pub use heap_branch::*; diff --git a/guest-libs/k256/tests/lib.rs b/guest-libs/k256/tests/lib.rs index 59eb42dde3..32490b2a67 100644 --- a/guest-libs/k256/tests/lib.rs +++ b/guest-libs/k256/tests/lib.rs @@ -143,9 +143,10 @@ mod guest_tests { impl InitFileGenerator for EcdsaConfig { fn generate_init_file_contents(&self) -> Option { Some(format!( - "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", + "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n{}\n", self.weierstrass.modular.modular.generate_moduli_init(), - self.weierstrass.weierstrass.generate_sw_init() + self.weierstrass.weierstrass.generate_sw_init(), + self.weierstrass.weierstrass.generate_curve_init() )) } } diff --git a/guest-libs/p256/tests/lib.rs b/guest-libs/p256/tests/lib.rs index 9eaf2b2c74..70527dd651 100644 --- a/guest-libs/p256/tests/lib.rs +++ b/guest-libs/p256/tests/lib.rs @@ -143,9 +143,10 @@ mod guest_tests { impl InitFileGenerator for EcdsaConfig { fn generate_init_file_contents(&self) -> Option { Some(format!( - "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", + "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n{}\n", self.weierstrass.modular.modular.generate_moduli_init(), - self.weierstrass.weierstrass.generate_sw_init() + self.weierstrass.weierstrass.generate_sw_init(), + self.weierstrass.weierstrass.generate_curve_init() )) } } From 255dc1a2d6d65bd53b433c627dc083acabe3ebfc Mon Sep 17 00:00:00 2001 From: Manh Dinh Date: Sat, 6 Dec 2025 05:52:54 -0500 Subject: [PATCH 4/7] run ecrecover --- benchmarks/guest/ecrecover/openvm.toml | 1 + benchmarks/guest/ecrecover/openvm_init.rs | 1 + benchmarks/guest/kitchen-sink/openvm.toml | 7 ++++-- benchmarks/guest/pairing/openvm.toml | 1 + crates/sdk/guest/little/openvm.toml | 4 ++++ crates/sdk/guest/p256/openvm.toml | 4 ++++ crates/sdk/src/config/openvm_standard.toml | 4 ++++ .../ecc/circuit/src/extension/weierstrass.rs | 6 +++++ .../circuit/src/weierstrass_chip/curves.rs | 8 +++---- .../src/weierstrass_chip/mul/execution.rs | 22 +++++-------------- .../circuit/src/weierstrass_chip/mul/mod.rs | 16 +++++++++----- extensions/ecc/curve-macros/src/lib.rs | 12 ++++++---- .../ecc/tests/programs/openvm_k256.toml | 1 + .../tests/programs/openvm_k256_keccak.toml | 1 + .../ecc/tests/programs/openvm_p256.toml | 1 + extensions/ecc/transpiler/src/lib.rs | 2 +- 16 files changed, 58 insertions(+), 33 deletions(-) diff --git a/benchmarks/guest/ecrecover/openvm.toml b/benchmarks/guest/ecrecover/openvm.toml index c1261ee458..750950df8f 100644 --- a/benchmarks/guest/ecrecover/openvm.toml +++ b/benchmarks/guest/ecrecover/openvm.toml @@ -11,6 +11,7 @@ supported_moduli = [ [[app_vm_config.ecc.supported_curves]] struct_name = "Secp256k1Point" +curve_name = "Secp256k1" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" a = "0" diff --git a/benchmarks/guest/ecrecover/openvm_init.rs b/benchmarks/guest/ecrecover/openvm_init.rs index dc6d4917dd..fb50966aa9 100644 --- a/benchmarks/guest/ecrecover/openvm_init.rs +++ b/benchmarks/guest/ecrecover/openvm_init.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" } +openvm_ecc_guest::curve_macros::curve_init! { "Secp256k1" } diff --git a/benchmarks/guest/kitchen-sink/openvm.toml b/benchmarks/guest/kitchen-sink/openvm.toml index 2d1b307eef..bc3f31ba5b 100644 --- a/benchmarks/guest/kitchen-sink/openvm.toml +++ b/benchmarks/guest/kitchen-sink/openvm.toml @@ -19,7 +19,7 @@ supported_moduli = [ "21888242871839275222246405745257275088548364400416034343698204186575808495617", # scalar # bls12_381 "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787", # coordinate - "52435875175126190479447740508185965837690552500527637822603658699938581184513", # scalar + "52435875175126190479447740508185965837690552500527637822603658699938581184513", # scalar # 2^61 - 1 "2305843009213693951", "7", @@ -41,6 +41,7 @@ supported_moduli = [ [[app_vm_config.ecc.supported_curves]] struct_name = "Secp256k1Point" +curve_name = "Secp256k1" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" a = "0" @@ -48,6 +49,7 @@ b = "7" [[app_vm_config.ecc.supported_curves]] struct_name = "P256Point" +curve_name = "NistP256" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" @@ -55,6 +57,7 @@ b = "410583637251521421293261297800472684091144410159937255548352563140394674012 [[app_vm_config.ecc.supported_curves]] struct_name = "Bn254G1Affine" +curve_name = "Bn254" modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" scalar = "21888242871839275222246405745257275088548364400416034343698204186575808495617" a = "0" @@ -62,6 +65,7 @@ b = "3" [[app_vm_config.ecc.supported_curves]] struct_name = "Bls12_381G1Affine" +curve_name = "Bls12_381" modulus = "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" scalar = "52435875175126190479447740508185965837690552500527637822603658699938581184513" a = "0" @@ -69,4 +73,3 @@ b = "4" [app_vm_config.pairing] supported_curves = ["Bn254", "Bls12_381"] - diff --git a/benchmarks/guest/pairing/openvm.toml b/benchmarks/guest/pairing/openvm.toml index 321383b8eb..aaee414495 100644 --- a/benchmarks/guest/pairing/openvm.toml +++ b/benchmarks/guest/pairing/openvm.toml @@ -23,6 +23,7 @@ supported_curves = ["Bn254"] # bn254 (alt bn128) [[app_vm_config.ecc.supported_curves]] struct_name = "Bn254G1Affine" +curve_name = "Bn254" modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" scalar = "21888242871839275222246405745257275088548364400416034343698204186575808495617" a = "0" diff --git a/crates/sdk/guest/little/openvm.toml b/crates/sdk/guest/little/openvm.toml index f1f9267191..9d743b8d91 100644 --- a/crates/sdk/guest/little/openvm.toml +++ b/crates/sdk/guest/little/openvm.toml @@ -39,6 +39,7 @@ supported_moduli = [ # bn254 (alt bn128) [[app_vm_config.ecc.supported_curves]] struct_name = "Bn254G1Affine" +curve_name = "Bn254" modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" scalar = "21888242871839275222246405745257275088548364400416034343698204186575808495617" a = "0" @@ -47,6 +48,7 @@ b = "3" # secp256k1 (k256) [[app_vm_config.ecc.supported_curves]] struct_name = "Secp256k1Point" +curve_name = "Secp256k1" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" a = "0" @@ -55,6 +57,7 @@ b = "7" # secp256r1 (p256) [[app_vm_config.ecc.supported_curves]] struct_name = "P256Point" +curve_name = "NistP256" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" @@ -63,6 +66,7 @@ b = "410583637251521421293261297800472684091144410159937255548352563140394674012 # bls12_381 [[app_vm_config.ecc.supported_curves]] struct_name = "Bls12_381G1Affine" +curve_name = "Bls12_381" modulus = "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" scalar = "52435875175126190479447740508185965837690552500527637822603658699938581184513" a = "0" diff --git a/crates/sdk/guest/p256/openvm.toml b/crates/sdk/guest/p256/openvm.toml index f1f9267191..9d743b8d91 100644 --- a/crates/sdk/guest/p256/openvm.toml +++ b/crates/sdk/guest/p256/openvm.toml @@ -39,6 +39,7 @@ supported_moduli = [ # bn254 (alt bn128) [[app_vm_config.ecc.supported_curves]] struct_name = "Bn254G1Affine" +curve_name = "Bn254" modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" scalar = "21888242871839275222246405745257275088548364400416034343698204186575808495617" a = "0" @@ -47,6 +48,7 @@ b = "3" # secp256k1 (k256) [[app_vm_config.ecc.supported_curves]] struct_name = "Secp256k1Point" +curve_name = "Secp256k1" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" a = "0" @@ -55,6 +57,7 @@ b = "7" # secp256r1 (p256) [[app_vm_config.ecc.supported_curves]] struct_name = "P256Point" +curve_name = "NistP256" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" @@ -63,6 +66,7 @@ b = "410583637251521421293261297800472684091144410159937255548352563140394674012 # bls12_381 [[app_vm_config.ecc.supported_curves]] struct_name = "Bls12_381G1Affine" +curve_name = "Bls12_381" modulus = "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" scalar = "52435875175126190479447740508185965837690552500527637822603658699938581184513" a = "0" diff --git a/crates/sdk/src/config/openvm_standard.toml b/crates/sdk/src/config/openvm_standard.toml index f1f9267191..9d743b8d91 100644 --- a/crates/sdk/src/config/openvm_standard.toml +++ b/crates/sdk/src/config/openvm_standard.toml @@ -39,6 +39,7 @@ supported_moduli = [ # bn254 (alt bn128) [[app_vm_config.ecc.supported_curves]] struct_name = "Bn254G1Affine" +curve_name = "Bn254" modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" scalar = "21888242871839275222246405745257275088548364400416034343698204186575808495617" a = "0" @@ -47,6 +48,7 @@ b = "3" # secp256k1 (k256) [[app_vm_config.ecc.supported_curves]] struct_name = "Secp256k1Point" +curve_name = "Secp256k1" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" a = "0" @@ -55,6 +57,7 @@ b = "7" # secp256r1 (p256) [[app_vm_config.ecc.supported_curves]] struct_name = "P256Point" +curve_name = "NistP256" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" @@ -63,6 +66,7 @@ b = "410583637251521421293261297800472684091144410159937255548352563140394674012 # bls12_381 [[app_vm_config.ecc.supported_curves]] struct_name = "Bls12_381G1Affine" +curve_name = "Bls12_381" modulus = "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787" scalar = "52435875175126190479447740508185965837690552500527637822603658699938581184513" a = "0" diff --git a/extensions/ecc/circuit/src/extension/weierstrass.rs b/extensions/ecc/circuit/src/extension/weierstrass.rs index 9b8dcd71ec..b6e392ec52 100644 --- a/extensions/ecc/circuit/src/extension/weierstrass.rs +++ b/extensions/ecc/circuit/src/extension/weierstrass.rs @@ -180,6 +180,7 @@ impl VmExecutionExtension for WeierstrassExtension { dummy_range_checker_bus, pointer_max_bits, start_offset, + curve.a.clone(), ); inventory.add_executor( @@ -228,6 +229,7 @@ impl VmExecutionExtension for WeierstrassExtension { dummy_range_checker_bus, pointer_max_bits, start_offset, + curve.a.clone(), ); inventory.add_executor( @@ -312,6 +314,7 @@ impl VmCircuitExtension for WeierstrassExtension { bitwise_lu, pointer_max_bits, start_offset, + curve.a.clone(), ); inventory.add_air(mul); } else if bytes <= 48 { @@ -352,6 +355,7 @@ impl VmCircuitExtension for WeierstrassExtension { bitwise_lu, pointer_max_bits, start_offset, + curve.a.clone(), ); inventory.add_air(mul); } else { @@ -432,6 +436,7 @@ where range_checker.clone(), bitwise_lu.clone(), pointer_max_bits, + curve.a.clone(), ); inventory.add_executor_chip(mul); } else if bytes <= 48 { @@ -469,6 +474,7 @@ where range_checker.clone(), bitwise_lu.clone(), pointer_max_bits, + curve.a.clone(), ); inventory.add_executor_chip(mul); } else { diff --git a/extensions/ecc/circuit/src/weierstrass_chip/curves.rs b/extensions/ecc/circuit/src/weierstrass_chip/curves.rs index 9a9c220a04..545a69a807 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/curves.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/curves.rs @@ -113,10 +113,10 @@ pub fn ec_mul< ) -> [[u8; POINT_SIZE]; BLOCKS_PER_POINT] { match CURVE_TYPE { x if x == CurveType::K256 as u8 => ec_mul_256bit::< - halo2curves_axiom::secq256k1::Fq, - halo2curves_axiom::secq256k1::Fp, - halo2curves_axiom::secq256k1::Secq256k1, - halo2curves_axiom::secq256k1::Secq256k1Affine, + halo2curves_axiom::secp256k1::Fq, + halo2curves_axiom::secp256k1::Fp, + halo2curves_axiom::secp256k1::Secp256k1, + halo2curves_axiom::secp256k1::Secp256k1Affine, BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, SCALAR_SIZE, diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs b/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs index 3d8378ba9b..6984865229 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs @@ -369,26 +369,14 @@ unsafe fn execute_e12_impl< // } // TODO: Check this later - // let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP { - // let scalar_data: DynArray = scalar_data.into(); - // let point_data: DynArray = point_data.into(); - // run_field_expression_precomputed::( - // pre_compute.expr, - // pre_compute.flag_idx as usize, - // &scalar_data.concat(&point_data).0, - // ) - // .into() - // } else { - // ec_mul::( - // scalar_data, - // point_data, - // ) - // }; - let output_data = + let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP { + point_data + } else { ec_mul::( scalar_data, point_data, - ); + ) + }; let rd_val = u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); debug_assert!(rd_val as usize + POINT_SIZE * BLOCKS_PER_POINT - 1 < (1 << POINTER_MAX_BITS)); diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs index 21b2126889..843921752b 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs @@ -1,6 +1,7 @@ use std::{cell::RefCell, rc::Rc}; use derive_more::derive::{Deref, DerefMut}; +use num_bigint::BigUint; use openvm_circuit::{ arch::*, system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, @@ -24,6 +25,7 @@ mod execution; pub fn ec_mul_expr( config: ExprBuilderConfig, // The coordinate field. range_bus: VariableRangeCheckerBus, + a_biguint: BigUint, ) -> FieldExpr { config.check_valid(); let builder = ExprBuilder::new(config, range_bus.range_max_bits); @@ -42,7 +44,7 @@ pub fn ec_mul_expr( y_out.save_output(); let builder = (*builder).borrow().clone(); - FieldExpr::new(builder, range_bus, false) + FieldExpr::new_with_setup_values(builder, range_bus, true, vec![a_biguint]) } #[derive(Clone, PreflightExecutor, Deref, DerefMut)] @@ -83,8 +85,9 @@ pub type WeierstrassEcMulChip< fn gen_base_expr( config: ExprBuilderConfig, range_checker_bus: VariableRangeCheckerBus, + a_biguint: BigUint, ) -> (FieldExpr, Vec) { - let expr = ec_mul_expr(config, range_checker_bus); + let expr = ec_mul_expr(config, range_checker_bus, a_biguint); let local_opcode_idx = vec![ Rv32WeierstrassOpcode::EC_MUL as usize, @@ -107,8 +110,9 @@ pub fn get_ec_mul_air< bitwise_lookup_bus: BitwiseOperationLookupBus, pointer_max_bits: usize, offset: usize, + a_biguint: BigUint, ) -> WeierstrassEcMulAir { - let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus); + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint); WeierstrassEcMulAir::new( Rv32EcMulAdapterAir::new( exec_bridge, @@ -130,8 +134,9 @@ pub fn get_ec_mul_step< range_checker_bus: VariableRangeCheckerBus, pointer_max_bits: usize, offset: usize, + a_biguint: BigUint, ) -> EcMulExecutor { - let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus); + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint); EcMulExecutor(FieldExpressionExecutor::new( Rv32EcMulAdapterExecutor::new(pointer_max_bits), expr, @@ -154,8 +159,9 @@ pub fn get_ec_mul_chip< range_checker: SharedVariableRangeCheckerChip, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pointer_max_bits: usize, + a_biguint: BigUint, ) -> WeierstrassEcMulChip { - let (expr, local_opcode_idx) = gen_base_expr(config, range_checker.bus()); + let (expr, local_opcode_idx) = gen_base_expr(config, range_checker.bus(), a_biguint); WeierstrassEcMulChip::new( FieldExpressionFiller::new( Rv32EcMulAdapterFiller::new(pointer_max_bits, bitwise_lookup_chip), diff --git a/extensions/ecc/curve-macros/src/lib.rs b/extensions/ecc/curve-macros/src/lib.rs index daf9a7161f..0fa41310d2 100644 --- a/extensions/ecc/curve-macros/src/lib.rs +++ b/extensions/ecc/curve-macros/src/lib.rs @@ -103,9 +103,12 @@ pub fn curve_declare(input: TokenStream) -> TokenStream { } #[cfg(target_os = "zkvm")] { + use core::ops::AddAssign; + if CHECK_SETUP { Self::set_up_once(); } + let mut acc = ::IDENTITY; for (coeff, base) in coeffs.iter().zip(bases.iter()) { unsafe { @@ -127,13 +130,16 @@ pub fn curve_declare(input: TokenStream) -> TokenStream { #[inline(always)] #[cfg(target_os = "zkvm")] fn set_up_once() { + use openvm_algebra_guest::IntMod; + static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new(); is_setup.get_or_init(|| { let scalar_modulus_bytes = ::MODULUS; let point_modulus_bytes = <::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; let p1 = scalar_modulus_bytes.as_ref(); - let p2 = [point_modulus_bytes.as_ref(), point_modulus_bytes.as_ref()].concat(); + let curve_a = ::CURVE_A; + let p2 = [point_modulus_bytes.as_ref(), curve_a.as_le_bytes()].concat(); let mut uninit: core::mem::MaybeUninit<(Self::Scalar, Self::Point)> = core::mem::MaybeUninit::uninit(); unsafe { #curve_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); } @@ -226,9 +232,7 @@ pub fn curve_init(input: TokenStream) -> TokenStream { TokenStream::from(quote::quote_spanned! { span.into() => #[allow(non_snake_case)] #[cfg(target_os = "zkvm")] - mod openvm_intrinsics_ffi_2 { - use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7}; - + mod openvm_intrinsics_ffi_3 { #(#externs)* } }) diff --git a/extensions/ecc/tests/programs/openvm_k256.toml b/extensions/ecc/tests/programs/openvm_k256.toml index 571fdb895c..f77a773348 100644 --- a/extensions/ecc/tests/programs/openvm_k256.toml +++ b/extensions/ecc/tests/programs/openvm_k256.toml @@ -10,6 +10,7 @@ supported_moduli = [ [[app_vm_config.ecc.supported_curves]] struct_name = "Secp256k1Point" +curve_name = "Secp256k1" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" a = "0" diff --git a/extensions/ecc/tests/programs/openvm_k256_keccak.toml b/extensions/ecc/tests/programs/openvm_k256_keccak.toml index c1261ee458..750950df8f 100644 --- a/extensions/ecc/tests/programs/openvm_k256_keccak.toml +++ b/extensions/ecc/tests/programs/openvm_k256_keccak.toml @@ -11,6 +11,7 @@ supported_moduli = [ [[app_vm_config.ecc.supported_curves]] struct_name = "Secp256k1Point" +curve_name = "Secp256k1" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" a = "0" diff --git a/extensions/ecc/tests/programs/openvm_p256.toml b/extensions/ecc/tests/programs/openvm_p256.toml index 0035cd83da..a249edcffb 100644 --- a/extensions/ecc/tests/programs/openvm_p256.toml +++ b/extensions/ecc/tests/programs/openvm_p256.toml @@ -9,6 +9,7 @@ supported_moduli = [ [[app_vm_config.ecc.supported_curves]] struct_name = "P256Point" +curve_name = "NistP256" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" diff --git a/extensions/ecc/transpiler/src/lib.rs b/extensions/ecc/transpiler/src/lib.rs index 577162b93f..397272a6a9 100644 --- a/extensions/ecc/transpiler/src/lib.rs +++ b/extensions/ecc/transpiler/src/lib.rs @@ -67,7 +67,7 @@ impl TranspilerExtension for EccTranspilerExtension { F::ZERO, F::ZERO, )) - } else if base_funct7 == SwBaseFunct7::SwEcMul as u8 { + } else if base_funct7 == SwBaseFunct7::SwSetupMul as u8 { Some(Instruction::new( VmOpcode::from_usize( Rv32WeierstrassOpcode::SETUP_EC_MUL From e21ccf75dfac3c9a821c211c27a5831ef35d36b7 Mon Sep 17 00:00:00 2001 From: Manh Dinh Date: Sat, 6 Dec 2025 07:13:15 -0500 Subject: [PATCH 5/7] swap params order --- crates/vm/src/arch/integration_api.rs | 32 ++--- .../ecc/circuit/src/extension/weierstrass.rs | 16 +-- .../circuit/src/weierstrass_chip/curves.rs | 42 +++--- .../src/weierstrass_chip/mul/execution.rs | 88 ++++++------ .../circuit/src/weierstrass_chip/mul/mod.rs | 38 ++--- extensions/ecc/curve-macros/src/lib.rs | 10 +- extensions/rv32-adapters/src/ec_mul.rs | 134 +++++++++--------- 7 files changed, 180 insertions(+), 180 deletions(-) diff --git a/crates/vm/src/arch/integration_api.rs b/crates/vm/src/arch/integration_api.rs index 9521ebb37a..5ea338bec9 100644 --- a/crates/vm/src/arch/integration_api.rs +++ b/crates/vm/src/arch/integration_api.rs @@ -326,24 +326,24 @@ impl< pub struct Rv32EcMulAdapterInterface< T, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, >(PhantomData); impl< T, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > VmAdapterInterface - for Rv32EcMulAdapterInterface + for Rv32EcMulAdapterInterface { type Reads = ( - [[T; SCALAR_SIZE]; BLOCKS_PER_SCALAR], [[T; POINT_SIZE]; BLOCKS_PER_POINT], + [[T; SCALAR_SIZE]; BLOCKS_PER_SCALAR], ); type Writes = [[T; POINT_SIZE]; BLOCKS_PER_POINT]; type ProcessedInstruction = MinimalInstruction; @@ -656,20 +656,20 @@ mod conversions { // AdapterAirContext: Rv32EcMulAdapterInterface -> DynAdapterInterface impl< T, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > From< AdapterAirContext< T, Rv32EcMulAdapterInterface< T, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >, >, > for AdapterAirContext> @@ -679,10 +679,10 @@ mod conversions { T, Rv32EcMulAdapterInterface< T, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >, >, ) -> Self { @@ -697,19 +697,19 @@ mod conversions { impl< T, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > From>> for AdapterAirContext< T, Rv32EcMulAdapterInterface< T, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >, > { diff --git a/extensions/ecc/circuit/src/extension/weierstrass.rs b/extensions/ecc/circuit/src/extension/weierstrass.rs index b6e392ec52..20193b0810 100644 --- a/extensions/ecc/circuit/src/extension/weierstrass.rs +++ b/extensions/ecc/circuit/src/extension/weierstrass.rs @@ -118,11 +118,11 @@ pub enum WeierstrassExtensionExecutor { // 32 limbs prime EcAddNeRv32_32(EcAddNeExecutor<2, 32>), EcDoubleRv32_32(EcDoubleExecutor<2, 32>), - EcMulRv32_32(EcMulExecutor<1, 2, 32, 32>), + EcMulRv32_32(EcMulExecutor<2, 1, 32, 32>), // 48 limbs prime EcAddNeRv32_48(EcAddNeExecutor<6, 16>), EcDoubleRv32_48(EcDoubleExecutor<6, 16>), - EcMulRv32_48(EcMulExecutor<1, 6, 32, 16>), + EcMulRv32_48(EcMulExecutor<6, 3, 16, 16>), } impl VmExecutionExtension for WeierstrassExtension { @@ -306,7 +306,7 @@ impl VmCircuitExtension for WeierstrassExtension { ); inventory.add_air(double); - let mul = get_ec_mul_air::<1, 2, 32, 32>( + let mul = get_ec_mul_air::<2, 1, 32, 32>( exec_bridge, memory_bridge, config, @@ -347,7 +347,7 @@ impl VmCircuitExtension for WeierstrassExtension { ); inventory.add_air(double); - let mul = get_ec_mul_air::<1, 6, 32, 16>( + let mul = get_ec_mul_air::<6, 3, 16, 16>( exec_bridge, memory_bridge, config, @@ -429,8 +429,8 @@ where ); inventory.add_executor_chip(double); - inventory.next_air::>()?; - let mul = get_ec_mul_chip::, 1, 2, 32, 32>( + inventory.next_air::>()?; + let mul = get_ec_mul_chip::, 2, 1, 32, 32>( config, mem_helper.clone(), range_checker.clone(), @@ -467,8 +467,8 @@ where ); inventory.add_executor_chip(double); - inventory.next_air::>()?; - let mul = get_ec_mul_chip::, 1, 6, 32, 16>( + inventory.next_air::>()?; + let mul = get_ec_mul_chip::, 6, 3, 16, 16>( config, mem_helper.clone(), range_checker.clone(), diff --git a/extensions/ecc/circuit/src/weierstrass_chip/curves.rs b/extensions/ecc/circuit/src/weierstrass_chip/curves.rs index 545a69a807..5dfbd6e948 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/curves.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/curves.rs @@ -103,13 +103,13 @@ pub fn ec_double( - scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], point_data: [[u8; POINT_SIZE]; BLOCKS_PER_POINT], + scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], ) -> [[u8; POINT_SIZE]; BLOCKS_PER_POINT] { match CURVE_TYPE { x if x == CurveType::K256 as u8 => ec_mul_256bit::< @@ -117,35 +117,35 @@ pub fn ec_mul< halo2curves_axiom::secp256k1::Fp, halo2curves_axiom::secp256k1::Secp256k1, halo2curves_axiom::secp256k1::Secp256k1Affine, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, - >(scalar_data, point_data), + SCALAR_SIZE, + >(point_data, scalar_data), x if x == CurveType::P256 as u8 => ec_mul_256bit::< halo2curves_axiom::secp256r1::Fq, halo2curves_axiom::secp256r1::Fp, halo2curves_axiom::secp256r1::Secp256r1, halo2curves_axiom::secp256r1::Secp256r1Affine, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, - >(scalar_data, point_data), + SCALAR_SIZE, + >(point_data, scalar_data), x if x == CurveType::BN254 as u8 => ec_mul_256bit::< halo2curves_axiom::bn256::Fr, halo2curves_axiom::bn256::Fq, halo2curves_axiom::bn256::G1, halo2curves_axiom::bn256::G1Affine, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, - >(scalar_data, point_data), + SCALAR_SIZE, + >(point_data, scalar_data), x if x == CurveType::BLS12_381 as u8 => { - ec_mul_bls12_381::( - scalar_data, + ec_mul_bls12_381::( point_data, + scalar_data, ) } _ => panic!("Unsupported curve type: {}", CURVE_TYPE), @@ -277,13 +277,13 @@ fn ec_mul_256bit< Fq: PrimeField, CJ: for<'a> Mul<&'a Fr, Output = CJ> + From, CA: CurveAffine + CurveAffineExt + From, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, >( - scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], point_data: [[u8; POINT_SIZE]; BLOCKS_PER_POINT], + scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], ) -> [[u8; POINT_SIZE]; BLOCKS_PER_POINT] { // read scalar and point data let scalar = blocks_to_field_element::(scalar_data.as_flattened()); @@ -304,16 +304,16 @@ fn ec_mul_256bit< #[inline(always)] fn ec_mul_bls12_381< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, >( - scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], point_data: [[u8; POINT_SIZE]; BLOCKS_PER_POINT], + scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], ) -> [[u8; POINT_SIZE]; BLOCKS_PER_POINT] { // read scalar and point data - let scalar = blocks_to_field_element::(scalar_data.as_flattened()); + let scalar = blocks_to_field_element::(&scalar_data.as_flattened()[..32]); let x1 = blocks_to_field_element_bls12_381_coordinate( point_data[..BLOCKS_PER_POINT / 2].as_flattened(), ); diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs b/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs index 6984865229..6622f1aa50 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs @@ -31,11 +31,11 @@ struct EcMulPreCompute<'a> { impl< 'a, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, - > EcMulExecutor + const SCALAR_SIZE: usize, + > EcMulExecutor { fn pre_compute_impl( &'a self, @@ -106,80 +106,80 @@ macro_rules! dispatch { (true, CurveType::K256) => Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { CurveType::K256 as u8 }, true, >), (true, CurveType::P256) => Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { CurveType::P256 as u8 }, true, >), (true, CurveType::BN254) => Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { CurveType::BN254 as u8 }, true, >), (true, CurveType::BLS12_381) => Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { CurveType::BLS12_381 as u8 }, true, >), (false, CurveType::K256) => Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { CurveType::K256 as u8 }, false, >), (false, CurveType::P256) => Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { CurveType::P256 as u8 }, false, >), (false, CurveType::BN254) => Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { CurveType::BN254 as u8 }, false, >), (false, CurveType::BLS12_381) => Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { CurveType::BLS12_381 as u8 }, false, >), @@ -188,10 +188,10 @@ macro_rules! dispatch { Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { u8::MAX }, true, >) @@ -199,10 +199,10 @@ macro_rules! dispatch { Ok($execute_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, { u8::MAX }, false, >) @@ -212,12 +212,12 @@ macro_rules! dispatch { impl< F: PrimeField32, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > InterpreterExecutor - for EcMulExecutor + for EcMulExecutor { #[inline(always)] fn pre_compute_size(&self) -> usize { @@ -259,12 +259,12 @@ impl< impl< F: PrimeField32, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > InterpreterMeteredExecutor - for EcMulExecutor + for EcMulExecutor { #[inline(always)] fn metered_pre_compute_size(&self) -> usize { @@ -314,10 +314,10 @@ impl< unsafe fn execute_e12_impl< F: PrimeField32, CTX: ExecutionCtxTrait, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, const CURVE_TYPE: u8, const IS_SETUP: bool, >( @@ -331,10 +331,10 @@ unsafe fn execute_e12_impl< .map(|addr| u32::from_le_bytes(exec_state.vm_read(RV32_REGISTER_AS, addr as u32))); // Read memory values for the scalar and point - let scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR] = - from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, rs_vals[0] + (i * SCALAR_SIZE) as u32)); let point_data: [[u8; POINT_SIZE]; BLOCKS_PER_POINT] = - from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, rs_vals[1] + (i * POINT_SIZE) as u32)); + from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, rs_vals[0] + (i * POINT_SIZE) as u32)); + let scalar_data: [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR] = + from_fn(|i| exec_state.vm_read(RV32_MEMORY_AS, rs_vals[1] + (i * SCALAR_SIZE) as u32)); // TODO: Check this later // if IS_SETUP { @@ -372,9 +372,9 @@ unsafe fn execute_e12_impl< let output_data = if CURVE_TYPE == u8::MAX || IS_SETUP { point_data } else { - ec_mul::( - scalar_data, + ec_mul::( point_data, + scalar_data, ) }; @@ -396,10 +396,10 @@ unsafe fn execute_e12_impl< unsafe fn execute_e1_impl< F: PrimeField32, CTX: ExecutionCtxTrait, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, const CURVE_TYPE: u8, const IS_SETUP: bool, >( @@ -411,10 +411,10 @@ unsafe fn execute_e1_impl< execute_e12_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, CURVE_TYPE, IS_SETUP, >(pre_compute, exec_state) @@ -425,10 +425,10 @@ unsafe fn execute_e1_impl< unsafe fn execute_e2_impl< F: PrimeField32, CTX: MeteredExecutionCtxTrait, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, const CURVE_TYPE: u8, const IS_SETUP: bool, >( @@ -444,10 +444,10 @@ unsafe fn execute_e2_impl< execute_e12_impl::< _, _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, CURVE_TYPE, IS_SETUP, >(&e2_pre_compute.data, exec_state) diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs index 843921752b..3041bd10c7 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs @@ -32,9 +32,9 @@ pub fn ec_mul_expr( let builder = Rc::new(RefCell::new(builder)); // Create inputs - let _scalar = ExprBuilder::new_input(builder.clone()); let x1 = ExprBuilder::new_input(builder.clone()); let y1 = ExprBuilder::new_input(builder.clone()); + let _scalar = ExprBuilder::new_input(builder.clone()); // Create dummy outputs: result x and result y // Note: The actual computation is done natively, but we need these for the AIR structure @@ -49,36 +49,36 @@ pub fn ec_mul_expr( #[derive(Clone, PreflightExecutor, Deref, DerefMut)] pub struct EcMulExecutor< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, >( FieldExpressionExecutor< - Rv32EcMulAdapterExecutor, + Rv32EcMulAdapterExecutor, >, ); pub type WeierstrassEcMulAir< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > = VmAirWrapper< - Rv32EcMulAdapterAir, + Rv32EcMulAdapterAir, FieldExpressionCoreAir, >; pub type WeierstrassEcMulChip< F, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > = VmChipWrapper< F, FieldExpressionFiller< - Rv32EcMulAdapterFiller, + Rv32EcMulAdapterFiller, >, >; @@ -98,10 +98,10 @@ fn gen_base_expr( } pub fn get_ec_mul_air< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, >( exec_bridge: ExecutionBridge, mem_bridge: MemoryBridge, @@ -111,7 +111,7 @@ pub fn get_ec_mul_air< pointer_max_bits: usize, offset: usize, a_biguint: BigUint, -) -> WeierstrassEcMulAir { +) -> WeierstrassEcMulAir { let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint); WeierstrassEcMulAir::new( Rv32EcMulAdapterAir::new( @@ -125,17 +125,17 @@ pub fn get_ec_mul_air< } pub fn get_ec_mul_step< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, >( config: ExprBuilderConfig, range_checker_bus: VariableRangeCheckerBus, pointer_max_bits: usize, offset: usize, a_biguint: BigUint, -) -> EcMulExecutor { +) -> EcMulExecutor { let (expr, local_opcode_idx) = gen_base_expr(config, range_checker_bus, a_biguint); EcMulExecutor(FieldExpressionExecutor::new( Rv32EcMulAdapterExecutor::new(pointer_max_bits), @@ -149,10 +149,10 @@ pub fn get_ec_mul_step< pub fn get_ec_mul_chip< F, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, >( config: ExprBuilderConfig, mem_helper: SharedMemoryHelper, @@ -160,7 +160,7 @@ pub fn get_ec_mul_chip< bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pointer_max_bits: usize, a_biguint: BigUint, -) -> WeierstrassEcMulChip { +) -> WeierstrassEcMulChip { let (expr, local_opcode_idx) = gen_base_expr(config, range_checker.bus(), a_biguint); WeierstrassEcMulChip::new( FieldExpressionFiller::new( diff --git a/extensions/ecc/curve-macros/src/lib.rs b/extensions/ecc/curve-macros/src/lib.rs index 0fa41310d2..f56e3fd30c 100644 --- a/extensions/ecc/curve-macros/src/lib.rs +++ b/extensions/ecc/curve-macros/src/lib.rs @@ -116,8 +116,8 @@ pub fn curve_declare(input: TokenStream) -> TokenStream { core::mem::MaybeUninit::uninit(); #curve_ec_mul_extern_func( uninit.as_mut_ptr() as usize, - coeff as *const Self::Scalar as usize, base as *const Self::Point as usize, + coeff as *const Self::Scalar as usize, ); acc.add_assign(&uninit.assume_init()); } @@ -135,12 +135,12 @@ pub fn curve_declare(input: TokenStream) -> TokenStream { static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new(); is_setup.get_or_init(|| { - let scalar_modulus_bytes = ::MODULUS; let point_modulus_bytes = <::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; - let p1 = scalar_modulus_bytes.as_ref(); let curve_a = ::CURVE_A; - let p2 = [point_modulus_bytes.as_ref(), curve_a.as_le_bytes()].concat(); - let mut uninit: core::mem::MaybeUninit<(Self::Scalar, Self::Point)> = core::mem::MaybeUninit::uninit(); + let p1 = [point_modulus_bytes.as_ref(), curve_a.as_le_bytes()].concat(); + let scalar_modulus_bytes = ::MODULUS; + let p2 = scalar_modulus_bytes.as_ref(); + let mut uninit: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); unsafe { #curve_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); } ::set_up_once(); diff --git a/extensions/rv32-adapters/src/ec_mul.rs b/extensions/rv32-adapters/src/ec_mul.rs index 201e0cb8d3..74d345ac55 100644 --- a/extensions/rv32-adapters/src/ec_mul.rs +++ b/extensions/rv32-adapters/src/ec_mul.rs @@ -45,10 +45,10 @@ use openvm_stark_backend::{ #[derive(AlignedBorrow, Debug)] pub struct Rv32EcMulAdapterCols< T, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > { pub from_state: ExecutionState, @@ -61,18 +61,18 @@ pub struct Rv32EcMulAdapterCols< pub rs_read_aux: [MemoryReadAuxCols; 2], pub rd_read_aux: MemoryReadAuxCols, - pub reads_scalar_aux: [MemoryReadAuxCols; BLOCKS_PER_SCALAR], pub reads_point_aux: [MemoryReadAuxCols; BLOCKS_PER_POINT], + pub reads_scalar_aux: [MemoryReadAuxCols; BLOCKS_PER_SCALAR], pub writes_aux: [MemoryWriteAuxCols; BLOCKS_PER_POINT], } #[allow(dead_code)] #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32EcMulAdapterAir< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > { pub(super) execution_bridge: ExecutionBridge, pub(super) memory_bridge: MemoryBridge, @@ -83,33 +83,33 @@ pub struct Rv32EcMulAdapterAir< impl< F: Field, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > BaseAir - for Rv32EcMulAdapterAir + for Rv32EcMulAdapterAir { fn width(&self) -> usize { - Rv32EcMulAdapterCols::::width() + Rv32EcMulAdapterCols::::width() } } impl< AB: InteractionBuilder, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > VmAdapterAir - for Rv32EcMulAdapterAir + for Rv32EcMulAdapterAir { type Interface = Rv32EcMulAdapterInterface< AB::Expr, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >; fn eval( @@ -120,10 +120,10 @@ impl< ) { let cols: &Rv32EcMulAdapterCols< _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, > = local.borrow(); let timestamp = cols.from_state.timestamp; let mut timestamp_delta: usize = 0; @@ -181,14 +181,14 @@ impl< let e = AB::F::from_canonical_u32(RV32_MEMORY_AS); // Reads from heap - // Scalar reads - let scalar_address = &rs_val_f[0]; - for (i, (read, aux)) in zip(ctx.reads.0, &cols.reads_scalar_aux).enumerate() { + // Point reads + let point_address = &rs_val_f[0]; + for (i, (read, aux)) in zip(ctx.reads.0, &cols.reads_point_aux).enumerate() { self.memory_bridge .read( MemoryAddress::new( e, - scalar_address.clone() + AB::Expr::from_canonical_usize(i * SCALAR_SIZE), + point_address.clone() + AB::Expr::from_canonical_usize(i * POINT_SIZE), ), read, timestamp_pp(), @@ -196,14 +196,14 @@ impl< ) .eval(builder, ctx.instruction.is_valid.clone()); } - // Point reads - let point_address = &rs_val_f[1]; - for (i, (read, aux)) in zip(ctx.reads.1, &cols.reads_point_aux).enumerate() { + // Scalar reads + let scalar_address = &rs_val_f[1]; + for (i, (read, aux)) in zip(ctx.reads.1, &cols.reads_scalar_aux).enumerate() { self.memory_bridge .read( MemoryAddress::new( e, - point_address.clone() + AB::Expr::from_canonical_usize(i * POINT_SIZE), + scalar_address.clone() + AB::Expr::from_canonical_usize(i * SCALAR_SIZE), ), read, timestamp_pp(), @@ -253,10 +253,10 @@ impl< fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { let cols: &Rv32EcMulAdapterCols< _, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, > = local.borrow(); cols.from_state.pc } @@ -266,10 +266,10 @@ impl< #[repr(C)] #[derive(AlignedBytesBorrow, Debug)] pub struct Rv32EcMulAdapterRecord< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > { pub from_pc: u32, pub from_timestamp: u32, @@ -283,27 +283,27 @@ pub struct Rv32EcMulAdapterRecord< pub rs_read_aux: [MemoryReadAuxRecord; 2], pub rd_read_aux: MemoryReadAuxRecord, - pub reads_scalar_aux: [MemoryReadAuxRecord; BLOCKS_PER_SCALAR], pub reads_point_aux: [MemoryReadAuxRecord; BLOCKS_PER_POINT], + pub reads_scalar_aux: [MemoryReadAuxRecord; BLOCKS_PER_SCALAR], pub writes_aux: [MemoryWriteBytesAuxRecord; BLOCKS_PER_POINT], } #[derive(derive_new::new, Clone, Copy)] pub struct Rv32EcMulAdapterExecutor< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > { pointer_max_bits: usize, } #[derive(derive_new::new)] pub struct Rv32EcMulAdapterFiller< - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > { pointer_max_bits: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, @@ -311,30 +311,30 @@ pub struct Rv32EcMulAdapterFiller< impl< F: PrimeField32, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > AdapterTraceExecutor - for Rv32EcMulAdapterExecutor + for Rv32EcMulAdapterExecutor { const WIDTH: usize = Rv32EcMulAdapterCols::< F, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >::width(); type ReadData = ( - [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], [[u8; POINT_SIZE]; BLOCKS_PER_POINT], + [[u8; SCALAR_SIZE]; BLOCKS_PER_SCALAR], ); type WriteData = [[u8; POINT_SIZE]; BLOCKS_PER_POINT]; type RecordMut<'a> = &'a mut Rv32EcMulAdapterRecord< - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >; #[inline(always)] @@ -348,10 +348,10 @@ impl< memory: &mut TracingMemory, instruction: &Instruction, record: &mut &mut Rv32EcMulAdapterRecord< - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >, ) -> Self::ReadData { let &Instruction { a, b, c, d, e, .. } = instruction; @@ -380,30 +380,30 @@ impl< // Read memory values debug_assert!( - (record.rs_vals[0] + (SCALAR_SIZE * BLOCKS_PER_SCALAR - 1) as u32) + (record.rs_vals[0] + (POINT_SIZE * BLOCKS_PER_POINT - 1) as u32) < (1 << self.pointer_max_bits) as u32 ); - let reads_scalar = from_fn(|i| { + let reads_point = from_fn(|i| { tracing_read( memory, RV32_MEMORY_AS, - record.rs_vals[0] + (i * SCALAR_SIZE) as u32, - &mut record.reads_scalar_aux[i].prev_timestamp, + record.rs_vals[0] + (i * POINT_SIZE) as u32, + &mut record.reads_point_aux[i].prev_timestamp, ) }); debug_assert!( - (record.rs_vals[1] + (POINT_SIZE * BLOCKS_PER_POINT - 1) as u32) + (record.rs_vals[1] + (SCALAR_SIZE * BLOCKS_PER_SCALAR - 1) as u32) < (1 << self.pointer_max_bits) as u32 ); - let reads_point = from_fn(|i| { + let reads_scalar = from_fn(|i| { tracing_read( memory, RV32_MEMORY_AS, - record.rs_vals[1] + (i * POINT_SIZE) as u32, - &mut record.reads_point_aux[i].prev_timestamp, + record.rs_vals[1] + (i * SCALAR_SIZE) as u32, + &mut record.reads_scalar_aux[i].prev_timestamp, ) }); - (reads_scalar, reads_point) + (reads_point, reads_scalar) } fn write( @@ -412,10 +412,10 @@ impl< instruction: &Instruction, data: Self::WriteData, record: &mut &mut Rv32EcMulAdapterRecord< - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >, ) { debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS); @@ -441,19 +441,19 @@ impl< impl< F: PrimeField32, - const BLOCKS_PER_SCALAR: usize, const BLOCKS_PER_POINT: usize, - const SCALAR_SIZE: usize, + const BLOCKS_PER_SCALAR: usize, const POINT_SIZE: usize, + const SCALAR_SIZE: usize, > AdapterTraceFiller - for Rv32EcMulAdapterFiller + for Rv32EcMulAdapterFiller { const WIDTH: usize = Rv32EcMulAdapterCols::< F, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, >::width(); fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { @@ -461,18 +461,18 @@ impl< // - caller ensures `adapter_row` contains a valid record representation that was previously // written by the executor let record: &Rv32EcMulAdapterRecord< - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, > = unsafe { get_record_from_slice(&mut adapter_row, ()) }; let cols: &mut Rv32EcMulAdapterCols< F, - BLOCKS_PER_SCALAR, BLOCKS_PER_POINT, - SCALAR_SIZE, + BLOCKS_PER_SCALAR, POINT_SIZE, + SCALAR_SIZE, > = adapter_row.borrow_mut(); // Range checks: @@ -509,18 +509,18 @@ impl< }); record - .reads_point_aux + .reads_scalar_aux .iter() - .zip(cols.reads_point_aux.iter_mut()) + .zip(cols.reads_scalar_aux.iter_mut()) .rev() .for_each(|(read, cols_read)| { mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut()); }); record - .reads_scalar_aux + .reads_point_aux .iter() - .zip(cols.reads_scalar_aux.iter_mut()) + .zip(cols.reads_point_aux.iter_mut()) .rev() .for_each(|(read, cols_read)| { mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut()); From 915f15d5b0981e46e8680ab83e4c13a4cdceed19 Mon Sep 17 00:00:00 2001 From: Manh Dinh Date: Sat, 6 Dec 2025 13:33:45 -0500 Subject: [PATCH 6/7] fix tests --- benchmarks/execute/src/execute-verifier.rs | 2 +- benchmarks/prove/src/util.rs | 2 +- .../src/weierstrass_chip/mul/execution.rs | 24 +++++++++++++++++++ .../circuit/src/weierstrass_chip/mul/mod.rs | 1 + extensions/ecc/tests/src/lib.rs | 2 ++ guest-libs/pairing/tests/lib.rs | 5 ++-- 6 files changed, 32 insertions(+), 4 deletions(-) diff --git a/benchmarks/execute/src/execute-verifier.rs b/benchmarks/execute/src/execute-verifier.rs index 089d7a0771..ce0472ed49 100644 --- a/benchmarks/execute/src/execute-verifier.rs +++ b/benchmarks/execute/src/execute-verifier.rs @@ -17,7 +17,7 @@ use std::fs; -use clap::{Parser, ValueEnum}; +use clap::{arg, Parser, ValueEnum}; use eyre::Result; use openvm_benchmarks_utils::get_fixtures_dir; use openvm_circuit::arch::{ diff --git a/benchmarks/prove/src/util.rs b/benchmarks/prove/src/util.rs index 7d2b900959..ac231ba7c9 100644 --- a/benchmarks/prove/src/util.rs +++ b/benchmarks/prove/src/util.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use clap::Parser; +use clap::{command, Parser}; use eyre::Result; use openvm_benchmarks_utils::{build_elf, get_programs_dir}; use openvm_circuit::{ diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs b/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs index 6622f1aa50..12d8116b3f 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mul/execution.rs @@ -257,6 +257,18 @@ impl< } } +#[cfg(feature = "aot")] +impl< + F: PrimeField32, + const BLOCKS_PER_POINT: usize, + const BLOCKS_PER_SCALAR: usize, + const POINT_SIZE: usize, + const SCALAR_SIZE: usize, + > AotExecutor + for EcMulExecutor +{ +} + impl< F: PrimeField32, const BLOCKS_PER_POINT: usize, @@ -310,6 +322,18 @@ impl< } } +#[cfg(feature = "aot")] +impl< + F: PrimeField32, + const BLOCKS_PER_POINT: usize, + const BLOCKS_PER_SCALAR: usize, + const POINT_SIZE: usize, + const SCALAR_SIZE: usize, + > AotMeteredExecutor + for EcMulExecutor +{ +} + #[inline(always)] unsafe fn execute_e12_impl< F: PrimeField32, diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs index 3041bd10c7..3a56b7a07a 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mul/mod.rs @@ -97,6 +97,7 @@ fn gen_base_expr( (expr, local_opcode_idx) } +#[allow(clippy::too_many_arguments)] pub fn get_ec_mul_air< const BLOCKS_PER_POINT: usize, const BLOCKS_PER_SCALAR: usize, diff --git a/extensions/ecc/tests/src/lib.rs b/extensions/ecc/tests/src/lib.rs index 14d38e6230..0ecebac989 100644 --- a/extensions/ecc/tests/src/lib.rs +++ b/extensions/ecc/tests/src/lib.rs @@ -118,6 +118,7 @@ mod tests { let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone(), CurveConfig { struct_name: "CurvePoint5mod8".to_string(), + curve_name: "Curve5mod8".to_string(), modulus: BigUint::from_str("115792089237316195423570985008687907853269984665640564039457584007913129639501") .unwrap(), // unused, set to 10e9 + 7 @@ -128,6 +129,7 @@ mod tests { }, CurveConfig { struct_name: "CurvePoint1mod4".to_string(), + curve_name: "Curve1mod4".to_string(), modulus: BigUint::from_radix_be(&hex!("ffffffffffffffffffffffffffffffff000000000000000000000001"), 256) .unwrap(), scalar: BigUint::from_radix_be(&hex!("ffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d"), 256) diff --git a/guest-libs/pairing/tests/lib.rs b/guest-libs/pairing/tests/lib.rs index 68150d536a..b101eaf1ab 100644 --- a/guest-libs/pairing/tests/lib.rs +++ b/guest-libs/pairing/tests/lib.rs @@ -495,8 +495,8 @@ mod bls12_381 { }; use openvm_pairing_guest::{ bls12_381::{ - BLS12_381_COMPLEX_STRUCT_NAME, BLS12_381_ECC_STRUCT_NAME, BLS12_381_MODULUS, - BLS12_381_ORDER, + BLS12_381_COMPLEX_STRUCT_NAME, BLS12_381_CURVE_NAME, BLS12_381_ECC_STRUCT_NAME, + BLS12_381_MODULUS, BLS12_381_ORDER, }, halo2curves_shims::bls12_381::Bls12_381, pairing::{EvaluatedLine, FinalExp, LineMulMType, MillerStep, MultiMillerLoop}, @@ -541,6 +541,7 @@ mod bls12_381 { fn test_bls_ec() -> Result<()> { let curve = CurveConfig { struct_name: BLS12_381_ECC_STRUCT_NAME.to_string(), + curve_name: BLS12_381_CURVE_NAME.to_string(), modulus: BLS12_381_MODULUS.clone(), scalar: BLS12_381_ORDER.clone(), a: BigUint::ZERO, From 0bb03d999ab0690e8a81790fc60ae7eb5184b4b0 Mon Sep 17 00:00:00 2001 From: Manh Dinh Date: Sat, 6 Dec 2025 13:48:01 -0500 Subject: [PATCH 7/7] fix tests --- examples/ecc/openvm.toml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/ecc/openvm.toml b/examples/ecc/openvm.toml index 1dc6cf25f2..b9d7572a7a 100644 --- a/examples/ecc/openvm.toml +++ b/examples/ecc/openvm.toml @@ -2,11 +2,15 @@ [app_vm_config.rv32m] [app_vm_config.io] [app_vm_config.modular] -supported_moduli = ["115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337"] +supported_moduli = [ + "115792089237316195423570985008687907853269984665640564039457584007908834671663", + "115792089237316195423570985008687907852837564279074904382605163141518161494337", +] [[app_vm_config.ecc.supported_curves]] struct_name = "Secp256k1Point" +curve_name = "Secp256k1" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" a = "0" -b = "7" \ No newline at end of file +b = "7"