Skip to content

Commit 0a580ed

Browse files
authored
slh-dsa: SHA-2: PK.Seed state caching (#1116)
This PR is meant to address issue #1035. It introduces an optimisation mentioned in section 8.1.6. of SPHINCS+ specification: caching the intermediate state during SHA-2 computations. On my laptop, I observe an improvement in signing/verification time in the range of ~15-25% (using the criterion benchmark provided).
1 parent a8df317 commit 0a580ed

File tree

11 files changed

+305
-238
lines changed

11 files changed

+305
-238
lines changed

slh-dsa/src/fors.rs

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,25 @@ use core::fmt::Debug;
33
use hybrid_array::{Array, ArraySize};
44
use typenum::Unsigned;
55

6-
use crate::{PkSeed, SkSeed, address};
6+
use crate::{SkSeed, address};
77

88
use crate::hypertree::HypertreeParams;
99
use crate::util::base_2b;
1010

11-
#[derive(Clone, Debug, PartialEq, Eq)]
11+
#[derive(Clone, Debug)]
1212
pub(crate) struct ForsMTSig<P: ForsParams> {
1313
sk: Array<u8, P::N>,
1414
auth: Array<Array<u8, P::N>, P::A>,
1515
}
1616

17+
impl<P: ForsParams> PartialEq for ForsMTSig<P> {
18+
fn eq(&self, other: &Self) -> bool {
19+
self.sk == other.sk && self.auth == other.auth
20+
}
21+
}
22+
23+
impl<P: ForsParams> Eq for ForsMTSig<P> {}
24+
1725
impl<P: ForsParams> ForsMTSig<P> {
1826
const SIZE: usize = P::N::USIZE + P::A::USIZE * P::N::USIZE;
1927

@@ -64,9 +72,17 @@ impl<P: ForsParams> TryFrom<&[u8]> for ForsMTSig<P> {
6472
}
6573
}
6674

67-
#[derive(Clone, Debug, PartialEq, Eq)]
75+
#[derive(Clone, Debug)]
6876
pub(crate) struct ForsSignature<P: ForsParams>(Array<ForsMTSig<P>, P::K>);
6977

78+
impl<P: ForsParams> PartialEq for ForsSignature<P> {
79+
fn eq(&self, other: &Self) -> bool {
80+
self.0 == other.0
81+
}
82+
}
83+
84+
impl<P: ForsParams> Eq for ForsSignature<P> {}
85+
7086
impl<P: ForsParams> TryFrom<&[u8]> for ForsSignature<P> {
7187
// TODO - real error type
7288
type Error = ();
@@ -118,73 +134,67 @@ pub(crate) trait ForsParams: HypertreeParams {
118134
type MD: ArraySize; // ceil(K*A/8)
119135

120136
fn fors_sk_gen(
137+
&self,
121138
sk_seed: &SkSeed<Self::N>,
122-
pk_seed: &PkSeed<Self::N>,
123139
adrs: &address::ForsTree,
124140
idx: u32,
125141
) -> Array<u8, Self::N> {
126142
let mut adrs = adrs.prf_adrs();
127143
adrs.tree_index.set(idx);
128-
Self::prf_sk(pk_seed, sk_seed, &adrs)
144+
self.prf_sk(sk_seed, &adrs)
129145
}
130146

131147
fn fors_node(
148+
&self,
132149
sk_seed: &SkSeed<Self::N>,
133150
i: u32,
134151
z: u32,
135-
pk_seed: &PkSeed<Self::N>,
136152
adrs: &address::ForsTree,
137153
) -> Array<u8, Self::N> {
138154
debug_assert!(z <= Self::A::U32);
139155
debug_assert!(i < (Self::K::U32 << (Self::A::U32 - z)));
140156
let mut adrs = adrs.clone(); // TODO: do we really need clone or should we take mut ref?
141157
if z == 0 {
142-
let sk = Self::fors_sk_gen(sk_seed, pk_seed, &adrs, i);
158+
let sk = self.fors_sk_gen(sk_seed, &adrs, i);
143159
adrs.tree_height.set(0);
144160
adrs.tree_index.set(i);
145-
Self::f(pk_seed, &adrs, &sk)
161+
self.f(&adrs, &sk)
146162
} else {
147-
let lnode = Self::fors_node(sk_seed, 2 * i, z - 1, pk_seed, &adrs);
148-
let rnode = Self::fors_node(sk_seed, 2 * i + 1, z - 1, pk_seed, &adrs);
163+
let lnode = self.fors_node(sk_seed, 2 * i, z - 1, &adrs);
164+
let rnode = self.fors_node(sk_seed, 2 * i + 1, z - 1, &adrs);
149165
adrs.tree_height.set(z);
150166
adrs.tree_index.set(i);
151-
Self::h(pk_seed, &adrs, &lnode, &rnode)
167+
self.h(&adrs, &lnode, &rnode)
152168
}
153169
}
154170

155171
fn fors_sign(
172+
&self,
156173
md: &Array<u8, Self::MD>,
157174
sk_seed: &SkSeed<Self::N>,
158-
pk_seed: &PkSeed<Self::N>,
159175
adrs: &address::ForsTree,
160176
) -> ForsSignature<Self> {
161177
let mut sig = ForsSignature::<Self>::default();
162178
let indices = base_2b::<Self::K, Self::A>(md);
163179
for i in 0..Self::K::U32 {
164-
sig.0[i as usize].sk = Self::fors_sk_gen(
180+
sig.0[i as usize].sk = self.fors_sk_gen(
165181
sk_seed,
166-
pk_seed,
167182
adrs,
168183
(i << Self::A::U32) + u32::from(indices[i as usize]),
169184
);
170185
for j in 0..Self::A::U32 {
171186
let s = (indices[i as usize] >> j) ^ 1;
172-
sig.0[i as usize].auth[j as usize] = Self::fors_node(
173-
sk_seed,
174-
(i << (Self::A::U32 - j)) + u32::from(s),
175-
j,
176-
pk_seed,
177-
adrs,
178-
);
187+
sig.0[i as usize].auth[j as usize] =
188+
self.fors_node(sk_seed, (i << (Self::A::U32 - j)) + u32::from(s), j, adrs);
179189
}
180190
}
181191
sig
182192
}
183193

184194
fn fors_pk_from_sig(
195+
&self,
185196
sig: &ForsSignature<Self>,
186197
md: &Array<u8, Self::MD>,
187-
pk_seed: &PkSeed<Self::N>,
188198
adrs: &address::ForsTree,
189199
) -> Array<u8, Self::N> {
190200
let mut adrs = adrs.clone();
@@ -195,27 +205,27 @@ pub(crate) trait ForsParams: HypertreeParams {
195205
adrs.tree_height.set(0);
196206
adrs.tree_index
197207
.set((i << Self::A::U32) + u32::from(indices[i as usize]));
198-
let mut node = Self::f(pk_seed, &adrs, sk);
208+
let mut node = self.f(&adrs, sk);
199209
for j in 0..Self::A::U32 {
200210
adrs.tree_height.set(j + 1);
201211
adrs.tree_index.set(adrs.tree_index.get() >> 1);
202212
if (indices[i as usize] >> j) & 1 == 0 {
203-
node = Self::h(pk_seed, &adrs, &node, &sig.0[i as usize].auth[j as usize]);
213+
node = self.h(&adrs, &node, &sig.0[i as usize].auth[j as usize]);
204214
} else {
205-
node = Self::h(pk_seed, &adrs, &sig.0[i as usize].auth[j as usize], &node);
215+
node = self.h(&adrs, &sig.0[i as usize].auth[j as usize], &node);
206216
}
207217
}
208218
roots[i as usize] = node;
209219
}
210-
Self::t(pk_seed, &adrs.fors_roots(), &roots)
220+
self.t(&adrs.fors_roots(), &roots)
211221
}
212222
}
213223

214224
#[cfg(test)]
215225
mod tests {
216226
use self::address::ForsTree;
217-
use crate::Shake128f;
218227
use crate::util::macros::test_parameter_sets;
228+
use crate::{PkSeed, Shake128f};
219229

220230
use rand::{Rng, RngCore, rng};
221231

@@ -229,9 +239,10 @@ mod tests {
229239

230240
let sk_seed = SkSeed(Array([1; 16]));
231241
let pk_seed = PkSeed(Array([2; 16]));
242+
let fors = Shake128f::new_from_pk_seed(&pk_seed);
232243
let adrs = ForsTree::new(3, 5);
233244
let md = Array([3; 25]);
234-
let sig = <Shake128f as ForsParams>::fors_sign(&md, &sk_seed, &pk_seed, &adrs);
245+
let sig = fors.fors_sign(&md, &sk_seed, &&adrs);
235246

236247
let expected = hex!(
237248
"2cac88fad4eeae791048fe07aa3544a9
@@ -478,6 +489,8 @@ mod tests {
478489

479490
let pk_seed = PkSeed::new(&mut rng);
480491

492+
let fors = Fors::new_from_pk_seed(&pk_seed);
493+
481494
let mut msg = Array::<u8, Fors::MD>::default();
482495
rng.fill_bytes(msg.as_mut_slice());
483496

@@ -492,12 +505,12 @@ mod tests {
492505
let mut pks = Array::<Array<u8, Fors::N>, Fors::K>::default();
493506
for i in 0..Fors::K::U32 {
494507
adrs.tree_index.set(i);
495-
pks[i as usize] = Fors::fors_node(&sk_seed, i, Fors::A::U32, &pk_seed, &adrs);
508+
pks[i as usize] = fors.fors_node(&sk_seed, i, Fors::A::U32, &adrs);
496509
}
497-
let pk = Fors::t(&pk_seed, &adrs.fors_roots(), &pks);
510+
let pk = fors.t(&adrs.fors_roots(), &pks);
498511

499-
let sig = Fors::fors_sign(&msg, &sk_seed, &pk_seed, &adrs);
500-
let pk_recovered = Fors::fors_pk_from_sig(&sig, &msg, &pk_seed, &adrs);
512+
let sig = fors.fors_sign(&msg, &sk_seed, &adrs);
513+
let pk_recovered = fors.fors_pk_from_sig(&sig, &msg, &adrs);
501514
assert_eq!(pk, pk_recovered);
502515
}
503516

@@ -511,6 +524,8 @@ mod tests {
511524

512525
let pk_seed = PkSeed::new(&mut rng);
513526

527+
let fors = Fors::new_from_pk_seed(&pk_seed);
528+
514529
let mut msg = Array::<u8, Fors::MD>::default();
515530
rng.fill_bytes(msg.as_mut_slice());
516531

@@ -525,16 +540,16 @@ mod tests {
525540
let mut pks = Array::<Array<u8, Fors::N>, Fors::K>::default();
526541
for i in 0..Fors::K::U32 {
527542
adrs.tree_index.set(i);
528-
pks[i as usize] = Fors::fors_node(&sk_seed, i, Fors::A::U32, &pk_seed, &adrs);
543+
pks[i as usize] = fors.fors_node(&sk_seed, i, Fors::A::U32, &adrs);
529544
}
530-
let pk = Fors::t(&pk_seed, &adrs.fors_roots(), &pks);
545+
let pk = fors.t(&adrs.fors_roots(), &pks);
531546

532-
let sig = Fors::fors_sign(&msg, &sk_seed, &pk_seed, &adrs);
547+
let sig = fors.fors_sign(&msg, &sk_seed, &adrs);
533548

534549
// Modify the message
535550
msg[0] ^= 0xff; // Invert the first byte of the message
536551

537-
let pk_recovered = Fors::fors_pk_from_sig(&sig, &msg, &pk_seed, &adrs);
552+
let pk_recovered = fors.fors_pk_from_sig(&sig, &msg, &adrs);
538553
assert_ne!(
539554
pk, pk_recovered,
540555
"Signature verification should fail with a modified message"

slh-dsa/src/hashes.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ pub use shake::*;
1515
use crate::{PkSeed, SkPrf, SkSeed, address::Address};
1616

1717
/// A trait specifying the hash functions described in FIPS-205 section 10
18-
pub(crate) trait HashSuite: Sized + Clone + Debug + PartialEq + Eq {
18+
pub(crate) trait HashSuite: Sized + Clone + Debug {
1919
type N: ArraySize + Debug + Clone + PartialEq + Eq;
2020
type M: ArraySize + Debug + Clone + PartialEq + Eq;
2121

22+
/// Instantiates the hash suite.
23+
fn new_from_pk_seed(pk_seed: &PkSeed<Self::N>) -> Self;
24+
2225
/// Pseudorandom function that generates the randomizer for the randomized hashing of the message to be signed.
2326
fn prf_msg(
2427
sk_prf: &SkPrf<Self::N>,
@@ -35,36 +38,28 @@ pub(crate) trait HashSuite: Sized + Clone + Debug + PartialEq + Eq {
3538
) -> Array<u8, Self::M>;
3639

3740
/// PRF that is used to generate the secret values in WOTS+ and FORS private keys.
38-
fn prf_sk(
39-
pk_seed: &PkSeed<Self::N>,
40-
sk_seed: &SkSeed<Self::N>,
41-
adrs: &impl Address,
42-
) -> Array<u8, Self::N>;
41+
fn prf_sk(&self, sk_seed: &SkSeed<Self::N>, adrs: &impl Address) -> Array<u8, Self::N>;
4342

4443
/// A hash function that maps an L*N-byte string to an N-byte string. Used for the chain function in WOTS+.
4544
/// Message length must be a multiple of `N`. Panics otherwise.
4645
fn t<L: ArraySize>(
47-
pk_seed: &PkSeed<Self::N>,
46+
&self,
4847
adrs: &impl Address,
4948
m: &Array<Array<u8, Self::N>, L>,
5049
) -> Array<u8, Self::N>;
5150

5251
/// Specialization of `t` for 2*chunk messages. Used to compute Merkle tree nodes.
5352
/// May be reimplemented for better performance.
5453
fn h(
55-
pk_seed: &PkSeed<Self::N>,
54+
&self,
5655
adrs: &impl Address,
5756
m1: &Array<u8, Self::N>,
5857
m2: &Array<u8, Self::N>,
5958
) -> Array<u8, Self::N>;
6059

6160
/// Hash function that takes an N-byte input to an N-byte output
6261
/// Used for the WOTS+ chain function
63-
fn f(
64-
pk_seed: &PkSeed<Self::N>,
65-
adrs: &impl Address,
66-
m: &Array<u8, Self::N>,
67-
) -> Array<u8, Self::N>;
62+
fn f(&self, adrs: &impl Address, m: &Array<u8, Self::N>) -> Array<u8, Self::N>;
6863
}
6964

7065
#[cfg(test)]

0 commit comments

Comments
 (0)