From 27cc8e7b1dfb43a2bc7ecc3726c13f049a4ec096 Mon Sep 17 00:00:00 2001 From: tamirhemo Date: Fri, 24 Jan 2025 21:23:33 +0000 Subject: [PATCH] optimize patch --- Cargo.lock | 233 ++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 38 +++++-- src/algorithms/rsa.rs | 59 +++++++---- 3 files changed, 292 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 941d174..5a229f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -188,6 +188,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + [[package]] name = "errno" version = "0.3.7" @@ -210,6 +216,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "gcd" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a" + [[package]] name = "generic-array" version = "0.14.7" @@ -231,6 +243,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "hex-literal" version = "0.4.1" @@ -256,6 +274,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "keccak" version = "0.1.4" @@ -267,9 +294,9 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ "spin", ] @@ -292,6 +319,16 @@ version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -312,11 +349,10 @@ dependencies = [ [[package]] name = "num-integer" -version = "0.1.45" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "autocfg", "num-traits", ] @@ -333,14 +369,132 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", "libm", ] +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "p3-baby-bear" +version = "0.2.0-succinct" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "080896e9d09e9761982febafe3b3da5cbf320e32f0c89b6e2e01e875129f4c2d" +dependencies = [ + "num-bigint", + "p3-field", + "p3-mds", + "p3-poseidon2", + "p3-symmetric", + "rand", + "serde", +] + +[[package]] +name = "p3-dft" +version = "0.2.0-succinct" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "292e97d02d4c38d8b306c2b8c0428bf15f4d32a11a40bcf80018f675bf33267e" +dependencies = [ + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", + "tracing", +] + +[[package]] +name = "p3-field" +version = "0.2.0-succinct" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91d8e5f9ede1171adafdb0b6a0df1827fbd4eb6a6217bfa36374e5d86248757" +dependencies = [ + "itertools", + "num-bigint", + "num-traits", + "p3-util", + "rand", + "serde", +] + +[[package]] +name = "p3-matrix" +version = "0.2.0-succinct" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98bf2c7680b8e906a5e147fe4ceb05a11cc9fa35678aa724333bcb35c72483c1" +dependencies = [ + "itertools", + "p3-field", + "p3-maybe-rayon", + "p3-util", + "rand", + "serde", + "tracing", +] + +[[package]] +name = "p3-maybe-rayon" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3925562a4c03183eafc92fd07b19f65ac6cb4b48d68c3920ce58d9bee6efe362" + +[[package]] +name = "p3-mds" +version = "0.2.0-succinct" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "706cea48976f54702dc68dffa512684c1304d1a3606cadea423cfe0b1ee25134" +dependencies = [ + "itertools", + "p3-dft", + "p3-field", + "p3-matrix", + "p3-symmetric", + "p3-util", + "rand", +] + +[[package]] +name = "p3-poseidon2" +version = "0.2.0-succinct" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2ce5f5ec7f1ba3a233a671621029def7bd416e7c51218c9d1167d21602cf312" +dependencies = [ + "gcd", + "p3-field", + "p3-mds", + "p3-symmetric", + "rand", + "serde", +] + +[[package]] +name = "p3-symmetric" +version = "0.2.0-succinct" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f29dc5bb6c99d3de75869d5c086874b64890280eeb7d3e068955f939e219253" +dependencies = [ + "itertools", + "p3-field", + "serde", +] + +[[package]] +name = "p3-util" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88dd5ca3eb6ff33cb20084778c32a6d68064a1913b4632437408c5a1098408b3" +dependencies = [ + "serde", +] + [[package]] name = "pbkdf2" version = "0.12.2" @@ -360,6 +514,12 @@ dependencies = [ "base64ct", ] +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + [[package]] name = "pkcs1" version = "0.7.5" @@ -659,18 +819,38 @@ checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "sp1-lib" -version = "4.0.0-rc.2" -source = "git+https://github.com/succinctlabs/sp1.git?branch=n/rsa-hook#a8789a6c62a11bb435c965db774b90075f7e0735" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aac3d3deeed25e9cad80e4275faf5954aa63f213ed3422f0e098dd2d0c1b0c0e" dependencies = [ "bincode", "serde", + "sp1-primitives", +] + +[[package]] +name = "sp1-primitives" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09b25a09b455dfae9c688da05718b205e8bd4afab36cd912d54639f5d4035815" +dependencies = [ + "bincode", + "hex", + "lazy_static", + "num-bigint", + "p3-baby-bear", + "p3-field", + "p3-poseidon2", + "p3-symmetric", + "serde", + "sha2", ] [[package]] name = "spin" -version = "0.5.2" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "spki" @@ -712,6 +892,37 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +dependencies = [ + "once_cell", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index 964c4d9..46ab60c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,16 +13,32 @@ readme = "README.md" rust-version = "1.65" [dependencies] -num-bigint = { version = "0.8.2", features = ["i128", "prime", "zeroize"], default-features = false, package = "num-bigint-dig" } -num-traits = { version= "0.2.9", default-features = false, features = ["libm"] } +num-bigint = { version = "0.8.2", features = [ + "i128", + "prime", + "zeroize", +], default-features = false, package = "num-bigint-dig" } +num-traits = { version = "0.2.9", default-features = false, features = [ + "libm", +] } num-integer = { version = "0.1.39", default-features = false } rand_core = { version = "0.6.4", default-features = false } const-oid = { version = "0.9", default-features = false } subtle = { version = "2.1.1", default-features = false } -digest = { version = "0.10.5", default-features = false, features = ["alloc", "oid"] } -pkcs1 = { version = "0.7.5", default-features = false, features = ["alloc", "pkcs8"] } +digest = { version = "0.10.5", default-features = false, features = [ + "alloc", + "oid", +] } +pkcs1 = { version = "0.7.5", default-features = false, features = [ + "alloc", + "pkcs8", +] } pkcs8 = { version = "0.10.2", default-features = false, features = ["alloc"] } -signature = { version = ">2.0, <2.3", default-features = false , features = ["alloc", "digest", "rand_core"] } +signature = { version = ">2.0, <2.3", default-features = false, features = [ + "alloc", + "digest", + "rand_core", +] } spki = { version = "0.7.3", default-features = false, features = ["alloc"] } zeroize = { version = "1.5", features = ["alloc"] } crypto-bigint = "0.5.5" @@ -30,9 +46,15 @@ cfg-if = "1.0.0" bytemuck = { version = "1.16.1", features = ["derive"] } # optional dependencies -sha1 = { version = "0.10.5", optional = true, default-features = false, features = ["oid"] } -sha2 = { version = "0.10.6", optional = true, default-features = false, features = ["oid"] } -serde = { version = "1.0.184", optional = true, default-features = false, features = ["derive"] } +sha1 = { version = "0.10.5", optional = true, default-features = false, features = [ + "oid", +] } +sha2 = { version = "0.10.6", optional = true, default-features = false, features = [ + "oid", +] } +serde = { version = "1.0.184", optional = true, default-features = false, features = [ + "derive", +] } [target.'cfg(all(target_os = "zkvm", target_vendor = "succinct"))'.dependencies] sp1-lib = "4.0.0" diff --git a/src/algorithms/rsa.rs b/src/algorithms/rsa.rs index d5642cd..9a31a8a 100644 --- a/src/algorithms/rsa.rs +++ b/src/algorithms/rsa.rs @@ -36,9 +36,10 @@ pub fn rsa_encrypt(key: &K, m: &BigUint) -> Result { let e_u2048 = from_biguint_to_u2048(key.e()); let n_u2048 = from_biguint_to_u2048(key.n()); - return Ok(custom_modpow_u2048(&m_u2048, &e_u2048, &n_u2048)); - } + let result = custom_modpow_u2048(&m_u2048, &e_u2048, &n_u2048); + return Ok(result); + } Ok(m.modpow(key.e(), key.n())) } @@ -59,18 +60,35 @@ mod zkvm { return BigUint::zero(); } - let mut result = U2048::ONE; - let modulus_nonzero = NonZero::new(*modulus).unwrap(); // Convert modulus to NonZero - let mut base = base.rem(&modulus_nonzero); + // The most common exponent is 65537, so we optimize for that case, otherwise we use the + // generic square and multiply algorithm. + let result = if (exp == &U2048::from_u64(65537u64)) { + let modulus_nonzero = NonZero::new(*modulus).unwrap(); // Convert modulus to NonZero + let mut base = base.rem(&modulus_nonzero); + let mut result = base; - let mut exp = *exp; - while exp > U2048::ZERO { - if exp.is_odd().into() { - result = mul_mod_u2048(&result, &base, &modulus_nonzero); + // Square 16 times + for i in 0..16 { + result = mul_mod_u2048(&result, &result, &modulus_nonzero); } - exp = exp.shr(1); - base = mul_mod_u2048(&base, &base, &modulus_nonzero); - } + // Multiply by the base + mul_mod_u2048(&result, &base, &modulus_nonzero) + } else { + let mut result = U2048::ONE; + let modulus_nonzero = NonZero::new(*modulus).unwrap(); // Convert modulus to NonZero + let mut base = base.rem(&modulus_nonzero); + + let mut exp = *exp; + while exp > U2048::ZERO { + if exp.is_odd().into() { + result = mul_mod_u2048(&result, &base, &modulus_nonzero); + } + exp = exp.shr(1); + base = mul_mod_u2048(&base, &base, &modulus_nonzero); + } + + result + }; let result_biguint = BigUint::from_bytes_le(&result.to_le_bytes()); result_biguint @@ -108,14 +126,18 @@ mod zkvm { /// and returns a U4096. fn mul_u2048(a_array: U2048, b_array: U2048) -> U4096 { let mut sum = U4096::ZERO; - let a_words = a_array.to_words(); - for i in 0..8 { - let chunk = a_words[i * 8..(i + 1) * 8].try_into().unwrap(); - let a_chunk: U256 = U256::from_words(chunk); - let mut prod = mul_array(a_chunk, b_array); + for (i, chunk) in a_array.as_words().chunks(8).enumerate() { let mut shifted_words = [0u32; 128]; - shifted_words[i * 8..].copy_from_slice(&prod.to_words()[..(128 - 8 * i)]); + let prod_result_ptr = shifted_words[i * 8..].as_mut_ptr(); + unsafe { + sp1_lib::syscall_u256x2048_mul( + chunk.as_ptr() as *const [u32; 8], + b_array.as_words().as_ptr() as *const [u32; 64], + prod_result_ptr as *mut [u32; 64], + prod_result_ptr.add(64) as *mut [u32; 8], + ); + } let shifted_prod = U4096::from_words(shifted_words); sum = sum.wrapping_add(&shifted_prod); } @@ -149,7 +171,6 @@ mod zkvm { } padded_bytes[i] = byte; } - U2048::from_le_slice(&padded_bytes) } }