diff --git a/include/keyswitch.hpp b/include/keyswitch.hpp index 961a91ad..cd767d55 100644 --- a/include/keyswitch.hpp +++ b/include/keyswitch.hpp @@ -17,8 +17,8 @@ constexpr typename P::domainP::T iksoffsetgen() typename P::domainP::T offset = 0; for (int i = 1; i <= P::t; i++) offset += - (1ULL << P::basebit) / 2 * - (1ULL << (std::numeric_limits::digits - + (static_cast(1) << P::basebit) / 2 * + (static_cast(1) << (std::numeric_limits::digits - i * P::basebit)); return offset; } @@ -35,7 +35,7 @@ void IdentityKeySwitch(TLWE &res, std::numeric_limits::digits; constexpr typename P::domainP::T roundoffset = (P::basebit * P::t) < domain_digit - ? 1ULL << (domain_digit - (1 + P::basebit * P::t)) + ? static_cast(1) << (domain_digit - (1 + P::basebit * P::t)) : 0; if constexpr (domain_digit == target_digit) res[P::targetP::k * P::targetP::n] = @@ -43,7 +43,7 @@ void IdentityKeySwitch(TLWE &res, else if constexpr (domain_digit > target_digit) res[P::targetP::k * P::targetP::n] = (tlwe[P::domainP::k * P::domainP::n] + - (1ULL << (domain_digit - target_digit - 1))) >> + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); else if constexpr (domain_digit < target_digit) res[P::targetP::k * P::targetP::n] = @@ -86,7 +86,7 @@ void CatIdentityKeySwitch( std::numeric_limits::digits; constexpr typename P::domainP::T roundoffset = (P::basebit * P::t) < domain_digit - ? 1ULL << (domain_digit - (1 + P::basebit * P::t)) + ? static_cast(1) << (domain_digit - (1 + P::basebit * P::t)) : 0; for (int cat = 0; cat < numcat; cat++) { @@ -96,7 +96,7 @@ void CatIdentityKeySwitch( else if constexpr (domain_digit > target_digit) res[cat][P::targetP::k * P::targetP::n] = (tlwe[cat][P::domainP::k * P::domainP::n] + - (1ULL << (domain_digit - target_digit - 1))) >> + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); else if constexpr (domain_digit < target_digit) res[cat][P::targetP::k * P::targetP::n] = @@ -138,7 +138,7 @@ void SubsetIdentityKeySwitch(TLWE &res, const SubsetKeySwitchingKey

&ksk) { constexpr typename P::domainP::T prec_offset = - 1ULL << (std::numeric_limits::digits - + static_cast(1) << (std::numeric_limits::digits - (1 + P::basebit * P::t)); constexpr uint32_t mask = (1U << P::basebit) - 1; res = {}; @@ -154,11 +154,11 @@ void SubsetIdentityKeySwitch(TLWE &res, } else if constexpr (domain_digit > target_digit) { for (int i = 0; i < P::targetP::k * P::targetP::n; i++) - res[i] = (tlwe[i] + (1ULL << (domain_digit - target_digit - 1))) >> + res[i] = (tlwe[i] + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); res[P::targetP::k * P::targetP::n] = (tlwe[P::domainP::k * P::domainP::n] + - (1ULL << (domain_digit - target_digit - 1))) >> + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); } else if constexpr (domain_digit < target_digit) { @@ -189,13 +189,13 @@ void PrivKeySwitch(TRLWE &res, const PrivateKeySwitchingKey

&privksk) { constexpr typename P::domainP::T roundoffset = - 1ULL << (std::numeric_limits::digits - + static_cast(1) << (std::numeric_limits::digits - (1 + P::basebit * P::t)); // Koga's Optimization constexpr typename P::domainP::T offset = iksoffsetgen

(); - constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1; - constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1); + constexpr typename P::domainP::T mask = (static_cast(1) << P::basebit) - 1; + constexpr typename P::domainP::T halfbase = static_cast(1) << (P::basebit - 1); res = {}; for (int i = 0; i <= P::domainP::k * P::domainP::n; i++) { const typename P::domainP::T aibar = tlwe[i] + offset + roundoffset; @@ -258,7 +258,7 @@ void TLWE2TRLWEIKS(TRLWE &res, const TLWE2TRLWEIKSKey

&iksk) { constexpr typename P::domainP::T prec_offset = - 1ULL << (std::numeric_limits::digits - + static_cast(1) << (std::numeric_limits::digits - (1 + P::basebit * P::t)); constexpr uint32_t mask = (1U << P::basebit) - 1; res = {}; @@ -270,7 +270,7 @@ void TLWE2TRLWEIKS(TRLWE &res, res[P::targetP::k][0] = tlwe[P::domainP::n]; else if constexpr (domain_digit > target_digit) res[P::targetP::k][0] = (tlwe[P::domainP::n] + - (1ULL << (domain_digit - target_digit - 1))) >> + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); else if constexpr (domain_digit < target_digit) res[P::targetP::k][0] = tlwe[P::domainP::n] diff --git a/include/mulfft.hpp b/include/mulfft.hpp index c25a9826..7caf5dc3 100644 --- a/include/mulfft.hpp +++ b/include/mulfft.hpp @@ -68,10 +68,9 @@ inline void TwistNTT(Polynomial

&res, PolynomialNTT

&a) cuHEpp::TwistNTT( res, a, (*ntttablelvl1)[0], (*ntttwistlvl1)[0]); #endif - else if constexpr (std::is_same_v) { + else if constexpr (std::is_same_v) cuHEpp::TwistNTT( res, a, (*ntttablelvl2)[0], (*ntttwistlvl2)[0]); - } else static_assert(false_v, "Undefined TwistNTT!"); } @@ -89,6 +88,15 @@ inline void TwistFFT(Polynomial

&res, const PolynomialInFD

&a) else if constexpr (std::is_same_v) fftplvl1.execute_direct_torus64(res.data(), a.data()); } + else if constexpr (std::is_same_v) { + // For 128-bit lvl3param with Double Decomposition: + // Output is intermediate result that will be recombined + // Store in low 64 bits - reconstruction handles proper positioning + alignas(64) std::array temp; + fftplvl3.execute_direct_torus64(temp.data(), a.data()); + for (int i = 0; i < P::n; i++) + res[i] = static_cast<__uint128_t>(temp[i]); + } else if constexpr (std::is_same_v) fftplvl2.execute_direct_torus64(res.data(), a.data()); else @@ -143,6 +151,15 @@ inline void TwistIFFT(PolynomialInFD

&res, const Polynomial

&a) if constexpr (std::is_same_v) fftplvl1.execute_reverse_torus64(res.data(), a.data()); } + else if constexpr (std::is_same_v) { + // For 128-bit lvl3param with Double Decomposition: + // Input is always decomposition digits (small integers in low 64 bits) + // Use low 64 bits directly - no shift needed + alignas(64) std::array temp; + for (int i = 0; i < P::n; i++) + temp[i] = static_cast(a[i]); + fftplvl3.execute_reverse_torus64(res.data(), temp.data()); + } else if constexpr (std::is_same_v) fftplvl2.execute_reverse_torus64(res.data(), a.data()); else @@ -301,8 +318,21 @@ inline void PolyMul(Polynomial

&res, const Polynomial

&a, for (int i = 0; i < P::n; i++) ntta[i] *= nttb[i]; TwistNTT

(res, ntta); } + else if constexpr (std::is_same_v) { + // Naive for 128-bit types (FFT/NTT don't support 128-bit precision) + for (int i = 0; i < P::n; i++) { + __uint128_t ri = 0; + for (int j = 0; j <= i; j++) + ri += static_cast<__int128_t>(a[j]) * + static_cast<__int128_t>(b[i - j]); + for (int j = i + 1; j < P::n; j++) + ri -= static_cast<__int128_t>(a[j]) * + static_cast<__int128_t>(b[P::n + i - j]); + res[i] = ri; + } + } else { - // Naieve + // Naive for other types for (int i = 0; i < P::n; i++) { typename P::T ri = 0; for (int j = 0; j <= i; j++) @@ -339,17 +369,33 @@ template inline void PolyMulNaive(Polynomial

&res, const Polynomial

&a, const Polynomial

&b) { - for (int i = 0; i < P::n; i++) { - typename P::T ri = 0; - for (int j = 0; j <= i; j++) - ri += static_cast::type>( - a[j]) * - b[i - j]; - for (int j = i + 1; j < P::n; j++) - ri -= static_cast::type>( - a[j]) * - b[P::n + i - j]; - res[i] = ri; + if constexpr (std::is_same_v) { + for (int i = 0; i < P::n; i++) { + __uint128_t ri = 0; + for (int j = 0; j <= i; j++) + ri += static_cast<__int128_t>(a[j]) * + static_cast<__int128_t>(b[i - j]); + for (int j = i + 1; j < P::n; j++) + ri -= static_cast<__int128_t>(a[j]) * + static_cast<__int128_t>(b[P::n + i - j]); + res[i] = ri; + } + } + else { + for (int i = 0; i < P::n; i++) { + typename P::T ri = 0; + for (int j = 0; j <= i; j++) + ri += + static_cast::type>( + a[j]) * + b[i - j]; + for (int j = i + 1; j < P::n; j++) + ri -= + static_cast::type>( + a[j]) * + b[P::n + i - j]; + res[i] = ri; + } } } diff --git a/include/params.hpp b/include/params.hpp index 56afe2d7..22d10be6 100644 --- a/include/params.hpp +++ b/include/params.hpp @@ -67,6 +67,16 @@ struct lvl02param { #endif }; +struct lvl03param { + using domainP = lvl0param; + using targetP = lvl3param; +#ifdef USE_KEY_BUNDLE + static constexpr uint32_t Addends = 2; +#else + static constexpr uint32_t Addends = 1; +#endif +}; + struct lvlh2param { using domainP = lvlhalfparam; using targetP = lvl2param; @@ -118,6 +128,10 @@ using DecomposedPolynomial = std::array, P::l>; template using DecomposedNoncePolynomial = std::array, P::lₐ>; template +using DecomposedPolynomialDD = std::array, P::l * P::l̅>; +template +using DecomposedNoncePolynomialDD = std::array, P::lₐ * P::l̅ₐ>; +template using DecomposedPolynomialNTT = std::array, P::l>; template using DecomposedNoncePolynomialNTT = std::array, P::lₐ>; @@ -138,17 +152,17 @@ template using TRLWERAINTT = std::array, P::k + 1>; template -using TRGSW = std::array, P::k * P::lₐ + P::l>; +using TRGSW = std::array, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>; template -using HalfTRGSW = std::array, P::l>; +using HalfTRGSW = std::array, P::l * P::l̅>; template -using TRGSWFFT = aligned_array, P::k * P::lₐ + P::l>; +using TRGSWFFT = aligned_array, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>; template -using HalfTRGSWFFT = aligned_array, P::l>; +using HalfTRGSWFFT = aligned_array, P::l * P::l̅>; template -using TRGSWNTT = std::array, P::k * P::lₐ + P::l>; +using TRGSWNTT = std::array, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>; template -using TRGSWRAINTT = std::array, P::k * P::lₐ + P::l>; +using TRGSWRAINTT = std::array, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>; #ifdef USE_KEY_BUNDLE template diff --git a/include/params/128bit.hpp b/include/params/128bit.hpp index 19ffe8a8..a48e0576 100644 --- a/include/params/128bit.hpp +++ b/include/params/128bit.hpp @@ -62,6 +62,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct AHlvl1param { @@ -83,6 +90,11 @@ struct AHlvl1param { static constexpr std::make_signed_t μ = baseP::μ; static constexpr uint32_t plain_modulus = baseP::plain_modulus; static constexpr double Δ = baseP::Δ; + // Double Decomposition parameters inherited from baseP + static constexpr std::uint32_t l̅ = baseP::l̅; + static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ; + static constexpr std::uint32_t B̅gbit = baseP::B̅gbit; + static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; struct lvl2param { @@ -106,6 +118,13 @@ struct lvl2param { static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = static_cast(1ULL << (std::numeric_limits::digits - 4)); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct AHlvl2param { @@ -127,9 +146,53 @@ struct AHlvl2param { static constexpr std::make_signed_t μ = baseP::μ; static constexpr uint32_t plain_modulus = baseP::plain_modulus; static constexpr double Δ = baseP::Δ; + // Double Decomposition parameters inherited from baseP + static constexpr std::uint32_t l̅ = baseP::l̅; + static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ; + static constexpr std::uint32_t B̅gbit = baseP::B̅gbit; + static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; +// lvl3param with 128-bit Torus and non-trivial Double Decomposition +// Double decomposition structure: +// - Primary decomposition (l, Bgbit): Decomposes plaintext by μ in TRGSW gadget +// - Auxiliary decomposition (l̅, B̅gbit): Decomposes TRLWE ciphertext coefficients +// in the external product. Must cover full 128-bit coefficient. +// Constraint for DD algorithm: l*Bgbit + (l̅-1)*B̅gbit ≤ 128 +// Using l=2, Bgbit=16 for primary (32 bits); l̅=4, B̅gbit=32 for auxiliary (128 bits) +// This gives: 32 + 96 = 128 ✓ struct lvl3param { + static constexpr int32_t key_value_max = 1; + static constexpr int32_t key_value_min = -1; + static const std::uint32_t nbit = 12; // dimension must be a power of 2 for + // ease of polynomial multiplication. + static constexpr std::uint32_t n = 1 << nbit; // dimension = 4096 + static constexpr std::uint32_t k = 1; + static constexpr std::uint32_t lₐ = 2; // reduced to fit DD constraint + static constexpr std::uint32_t l = 2; // reduced to fit DD constraint + static constexpr std::uint32_t Bgbit = 16; + static constexpr std::uint32_t Bgₐbit = 16; + static constexpr uint32_t Bg = 1U << Bgbit; + static constexpr uint32_t Bgₐ = 1U << Bgₐbit; + static constexpr ErrorDistribution errordist = + ErrorDistribution::ModularGaussian; + static const inline double α = std::pow(2.0, -105); // fresh noise + using T = __uint128_t; // Torus representation + static constexpr T μ = static_cast(1) << 125; + static constexpr uint32_t plain_modulusbit = 31; + static constexpr T plain_modulus = static_cast(1) << plain_modulusbit; + static constexpr double Δ = + static_cast(static_cast(1) << (128 - plain_modulusbit - 1)); + // Double Decomposition (bivariate representation) parameters + // Auxiliary decomposition must cover full 128-bit ciphertext coefficients + // l̅ * B̅gbit = 4 * 32 = 128 bits + static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = 4; + static constexpr std::uint32_t B̅gbit = 32; // 2^32 base for auxiliary (covers 128-bit T) + static constexpr std::uint32_t B̅gₐbit = 32; +}; + +struct lvl4param { static constexpr int32_t key_value_max = 1; static constexpr int32_t key_value_min = -1; static const std::uint32_t nbit = 13; // dimension must be a power of 2 for @@ -150,6 +213,13 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // Trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Key Switching parameters @@ -239,3 +309,12 @@ struct lvl31param { using domainP = lvl3param; using targetP = lvl1param; }; + +struct lvl41param { + static constexpr std::uint32_t t = 7; // number of addition in keyswitching + static constexpr std::uint32_t basebit = + 2; // how many bit should be encrypted in keyswitching key + static const inline double α = lvl1param::α; // key noise + using domainP = lvl4param; + using targetP = lvl1param; +}; diff --git a/include/params/CGGI16.hpp b/include/params/CGGI16.hpp index c7936c06..c691a95b 100644 --- a/include/params/CGGI16.hpp +++ b/include/params/CGGI16.hpp @@ -59,6 +59,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -80,6 +87,13 @@ struct lvl2param { static constexpr T μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -104,8 +118,18 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + struct lvl10param { static constexpr std::uint32_t t = 8; static constexpr std::uint32_t basebit = 2; @@ -191,4 +215,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/CGGI19.hpp b/include/params/CGGI19.hpp index 623282d9..dacf1d41 100644 --- a/include/params/CGGI19.hpp +++ b/include/params/CGGI19.hpp @@ -59,6 +59,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -78,6 +85,13 @@ struct lvl2param { static constexpr T μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -102,8 +116,18 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Dummy struct lvl11param { static constexpr std::uint32_t t = 0; // number of addition in keyswitching @@ -190,4 +214,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/compress.hpp b/include/params/compress.hpp index d836ab0f..633889d7 100644 --- a/include/params/compress.hpp +++ b/include/params/compress.hpp @@ -69,6 +69,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -93,6 +100,13 @@ struct lvl2param { static constexpr std::make_signed_t μ = q / 8; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -117,8 +131,18 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Key Switching parameters struct lvl10param { static constexpr std::uint32_t t = 5; // number of addition in keyswitching @@ -192,4 +216,7 @@ struct lvl31param { 2; // how many bit should be encrypted in keyswitching key using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/concrete.hpp b/include/params/concrete.hpp index 39c28f3c..99d65313 100644 --- a/include/params/concrete.hpp +++ b/include/params/concrete.hpp @@ -69,6 +69,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -94,6 +101,13 @@ struct lvl2param { static constexpr T μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; #define USE_DIFFERENT_BR_PARAM @@ -118,6 +132,13 @@ struct cblvl2param { static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = static_cast(1ULL << (std::numeric_limits::digits - 4)); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -143,6 +164,11 @@ struct cbAHlvl2param { static constexpr std::make_signed_t μ = baseP::μ; static constexpr uint32_t plain_modulus = baseP::plain_modulus; static constexpr double Δ = baseP::Δ; + // Double Decomposition parameters inherited from baseP + static constexpr std::uint32_t l̅ = baseP::l̅; + static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ; + static constexpr std::uint32_t B̅gbit = baseP::B̅gbit; + static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; struct lvl3param { @@ -166,8 +192,18 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Key Switching parameters struct lvl10param { static constexpr std::uint32_t t = 5; // number of addition in keyswitching @@ -257,4 +293,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/ternary.hpp b/include/params/ternary.hpp index c976747f..8adbfac8 100644 --- a/include/params/ternary.hpp +++ b/include/params/ternary.hpp @@ -63,6 +63,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -85,6 +92,13 @@ struct lvl2param { static constexpr std::make_signed_t μ = 1LL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl3param { @@ -108,8 +122,18 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Key Switching parameters struct lvl10param { static constexpr std::uint32_t t = 7; // number of addition in keyswitching @@ -198,4 +222,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/tfhe-rs.hpp b/include/params/tfhe-rs.hpp index 09709e87..cd5f9229 100644 --- a/include/params/tfhe-rs.hpp +++ b/include/params/tfhe-rs.hpp @@ -70,6 +70,13 @@ struct lvl1param { static constexpr double Δ = 2 * static_cast(1ULL << (std::numeric_limits::digits - 1)) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -92,6 +99,13 @@ struct lvl2param { static constexpr std::make_signed_t μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -114,8 +128,18 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Key Switching parameters struct lvl10param { static constexpr std::uint32_t t = 3; // number of addition in keyswitching @@ -204,4 +228,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/tlwe.hpp b/include/tlwe.hpp index f95ee471..6b7dbaf8 100644 --- a/include/tlwe.hpp +++ b/include/tlwe.hpp @@ -13,13 +13,11 @@ template void tlweSymEncrypt(TLWE

&res, const typename P::T p, const double α, const Key

&key) { - std::uniform_int_distribution Torusdist( - 0, std::numeric_limits::max()); res = {}; res[P::k * P::n] = ModularGaussian

(p, α); for (int k = 0; k < P::k; k++) for (int i = 0; i < P::n; i++) { - res[k * P::n + i] = Torusdist(generator); + res[k * P::n + i] = UniformTorusRandom

(); res[P::k * P::n] += res[k * P::n + i] * key[k * P::n + i]; } } @@ -122,7 +120,7 @@ typename P::T tlweSymIntDecrypt(const TLWE

&c, const Key

&key) constexpr double Δ = 2 * static_cast( - 1ULL << (std::numeric_limits::digits - 1)) / + static_cast(1) << (std::numeric_limits::digits - 1)) / plain_modulus; const typename P::T phase = tlweSymPhase

(c, key); typename P::T res = static_cast(std::round(phase / Δ)); diff --git a/include/trgsw.hpp b/include/trgsw.hpp index 429d7a4d..257807b9 100644 --- a/include/trgsw.hpp +++ b/include/trgsw.hpp @@ -15,7 +15,7 @@ constexpr typename P::T offsetgen() typename P::T offset = 0; for (int i = 1; i <= P::l; i++) offset += P::Bg / 2 * - (1ULL << (std::numeric_limits::digits - + (static_cast(1) << (std::numeric_limits::digits - i * P::Bgbit)); return offset; } @@ -27,11 +27,11 @@ inline void Decomposition(DecomposedPolynomial

&decpoly, { // https://eprint.iacr.org/2021/1161 constexpr typename P::T roundoffset = - 1ULL << (std::numeric_limits::digits - P::l * P::Bgbit - + static_cast(1) << (std::numeric_limits::digits - P::l * P::Bgbit - 1); constexpr typename P::T mask = - static_cast((1ULL << P::Bgbit) - 1); - constexpr typename P::T Bgl = 1ULL << (P::l * P::Bgbit); + static_cast((static_cast(1) << P::Bgbit) - 1); + constexpr typename P::T Bgl = static_cast(1) << (P::l * P::Bgbit); Polynomial

K; for (int i = 0; i < P::n; i++) { @@ -68,11 +68,11 @@ inline void Decomposition(DecomposedPolynomial

&decpoly, { constexpr typename P::T offset = offsetgen

(); constexpr typename P::T roundoffset = - 1ULL << (std::numeric_limits::digits - P::l * P::Bgbit - + static_cast(1) << (std::numeric_limits::digits - P::l * P::Bgbit - 1); constexpr typename P::T mask = - static_cast((1ULL << P::Bgbit) - 1); - constexpr typename P::T halfBg = (1ULL << (P::Bgbit - 1)); + static_cast((static_cast(1) << P::Bgbit) - 1); + constexpr typename P::T halfBg = (static_cast(1) << (P::Bgbit - 1)); for (int i = 0; i < P::n; i++) { for (int l = 0; l < P::l; l++) @@ -90,7 +90,7 @@ constexpr typename P::T nonceoffsetgen() typename P::T offset = 0; for (int i = 1; i <= P::lₐ; i++) offset += P::Bgₐ / 2 * - (1ULL << (std::numeric_limits::digits - + (static_cast(1) << (std::numeric_limits::digits - i * P::Bgₐbit)); return offset; } @@ -101,11 +101,11 @@ inline void NonceDecomposition(DecomposedNoncePolynomial

&decpoly, { constexpr typename P::T offset = nonceoffsetgen

(); constexpr typename P::T roundoffset = - 1ULL << (std::numeric_limits::digits - + static_cast(1) << (std::numeric_limits::digits - P::lₐ * P::Bgₐbit - 1); constexpr typename P::T mask = - static_cast((1ULL << P::Bgₐbit) - 1); - constexpr typename P::T halfBg = (1ULL << (P::Bgₐbit - 1)); + static_cast((static_cast(1) << P::Bgₐbit) - 1); + constexpr typename P::T halfBg = (static_cast(1) << (P::Bgₐbit - 1)); for (int i = 0; i < P::n; i++) { for (int l = 0; l < P::lₐ; l++) @@ -117,6 +117,185 @@ inline void NonceDecomposition(DecomposedNoncePolynomial

&decpoly, } } +// Double Decomposition (bivariate representation) for external product +// Decomposes each coefficient a into l*l̅ components such that: +// a ≈ Σᵢ Σⱼ aᵢⱼ * Bg^(l-i) * B̅g^(l̅-j) +// When l̅=1 (j=0 only), this reduces to standard decomposition. +template +constexpr typename P::T ddoffsetgen() +{ + typename P::T offset = 0; + for (int i = 1; i <= P::l; i++) + for (int j = 0; j < P::l̅; j++) + offset += (static_cast(P::Bg) / 2) * + (static_cast(1) + << (std::numeric_limits::digits - + i * P::Bgbit - j * P::B̅gbit)); + return offset; +} + +template +inline void DoubleDecomposition(DecomposedPolynomialDD

&decpoly, + const Polynomial

&poly) +{ + constexpr typename P::T offset = ddoffsetgen

(); + // Remaining bits after decomposition + constexpr int remaining_bits = std::numeric_limits::digits - + P::l * P::Bgbit - P::l̅ * P::B̅gbit; + // roundoffset is 0 if no remaining bits, otherwise 2^(remaining_bits-1) + constexpr typename P::T roundoffset = + remaining_bits > 0 + ? (static_cast(1) << (remaining_bits - 1)) + : static_cast(0); + constexpr typename P::T maskBg = + static_cast((static_cast(1) << P::Bgbit) - 1); + constexpr typename P::T halfBg = (static_cast(1) << (P::Bgbit - 1)); + + for (int n = 0; n < P::n; n++) { + typename P::T a = poly[n] + offset + roundoffset; + for (int i = 0; i < P::l; i++) { + for (int j = 0; j < P::l̅; j++) { + // Shift to get the (i,j)-th digit in base Bg (after B̅g grouping) + // When l̅=1 (j=0 only), this reduces to standard decomposition + const int shift = std::numeric_limits::digits - + (i + 1) * P::Bgbit - j * P::B̅gbit; + decpoly[i * P::l̅ + j][n] = ((a >> shift) & maskBg) - halfBg; + } + } + } +} + +template +constexpr typename P::T nonceddoffsetgen() +{ + typename P::T offset = 0; + for (int i = 1; i <= P::lₐ; i++) + for (int j = 0; j < P::l̅ₐ; j++) + offset += (static_cast(P::Bgₐ) / 2) * + (static_cast(1) + << (std::numeric_limits::digits - + i * P::Bgₐbit - j * P::B̅gₐbit)); + return offset; +} + +template +inline void NonceDoubleDecomposition(DecomposedNoncePolynomialDD

&decpoly, + const Polynomial

&poly) +{ + constexpr typename P::T offset = nonceddoffsetgen

(); + // Remaining bits after decomposition + constexpr int remaining_bits = std::numeric_limits::digits - + P::lₐ * P::Bgₐbit - P::l̅ₐ * P::B̅gₐbit; + // roundoffset is 0 if no remaining bits, otherwise 2^(remaining_bits-1) + constexpr typename P::T roundoffset = + remaining_bits > 0 + ? (static_cast(1) << (remaining_bits - 1)) + : static_cast(0); + constexpr typename P::T maskBg = + static_cast((static_cast(1) << P::Bgₐbit) - 1); + constexpr typename P::T halfBg = (static_cast(1) << (P::Bgₐbit - 1)); + + for (int n = 0; n < P::n; n++) { + typename P::T a = poly[n] + offset + roundoffset; + for (int i = 0; i < P::lₐ; i++) { + for (int j = 0; j < P::l̅ₐ; j++) { + // Shift to get the (i,j)-th digit + // When l̅ₐ=1 (j=0 only), this reduces to standard decomposition + const int shift = std::numeric_limits::digits - + (i + 1) * P::Bgₐbit - j * P::B̅gₐbit; + decpoly[i * P::l̅ₐ + j][n] = ((a >> shift) & maskBg) - halfBg; + } + } + } +} + +// TRLWE Decomposition to base B̅g for Double Decomposition +// Decomposes each TRLWE coefficient into l̅ digits in base B̅g +// Used to pre-apply DD to TRGSW rows during encryption +// Returns l̅ TRLWEs where result[j] contains the j-th B̅g digit of each coefficient +template +inline void TRLWEBaseBbarDecompose(std::array, P::l̅> &result, + const TRLWE

&input) +{ + constexpr typename P::T maskB̅g = + static_cast((static_cast(1) << P::B̅gbit) - + 1); + constexpr typename P::T halfB̅g = + static_cast(1) << (P::B̅gbit - 1); + + // Compute offset for signed digit representation + constexpr typename P::T offset = []() { + typename P::T off = 0; + for (int j = 0; j < P::l̅; j++) + off += halfB̅g * (static_cast(1) + << (std::numeric_limits::digits - + (j + 1) * P::B̅gbit)); + return off; + }(); + + // Remaining bits after decomposition + constexpr int remaining_bits = + std::numeric_limits::digits - P::l̅ * P::B̅gbit; + constexpr typename P::T roundoffset = + remaining_bits > 0 + ? (static_cast(1) << (remaining_bits - 1)) + : static_cast(0); + + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + typename P::T a = input[k][n] + offset + roundoffset; + for (int j = 0; j < P::l̅; j++) { + // Extract j-th digit from MSB side + const int shift = std::numeric_limits::digits - + (j + 1) * P::B̅gbit; + result[j][k][n] = ((a >> shift) & maskB̅g) - halfB̅g; + } + } + } +} + +// Nonce version of TRLWE Decomposition to base B̅gₐ +template +inline void TRLWEBaseBbarDecomposeNonce(std::array, P::l̅ₐ> &result, + const TRLWE

&input) +{ + constexpr typename P::T maskB̅g = + static_cast((static_cast(1) << P::B̅gₐbit) - + 1); + constexpr typename P::T halfB̅g = + static_cast(1) << (P::B̅gₐbit - 1); + + // Compute offset for signed digit representation + constexpr typename P::T offset = []() { + typename P::T off = 0; + for (int j = 0; j < P::l̅ₐ; j++) + off += halfB̅g * (static_cast(1) + << (std::numeric_limits::digits - + (j + 1) * P::B̅gₐbit)); + return off; + }(); + + // Remaining bits after decomposition + constexpr int remaining_bits = + std::numeric_limits::digits - P::l̅ₐ * P::B̅gₐbit; + constexpr typename P::T roundoffset = + remaining_bits > 0 + ? (static_cast(1) << (remaining_bits - 1)) + : static_cast(0); + + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + typename P::T a = input[k][n] + offset + roundoffset; + for (int j = 0; j < P::l̅ₐ; j++) { + // Extract j-th digit from MSB side + const int shift = std::numeric_limits::digits - + (j + 1) * P::B̅gₐbit; + result[j][k][n] = ((a >> shift) & maskB̅g) - halfB̅g; + } + } + } +} + template void Decomposition(DecomposedPolynomialNTT

&decpolyntt, const Polynomial

&poly) @@ -157,42 +336,192 @@ void NonceDecomposition(DecomposedNoncePolynomialRAINTT

&decpolyntt, decpolyntt[i], decpoly[i], (*raintttable)[1], (*raintttwist)[1]); } +// Recombine l̅ TRLWEs from Double Decomposition back to single TRLWE +// result[j] contains j-th B̅g digit; recombine as: res = Σⱼ result[j] * 2^(width - (j+1)*B̅gbit) +template +inline void RecombineTRLWEFromDD(TRLWE

&res, + const std::array, P::l̅> &decomposed) +{ + constexpr int width = std::numeric_limits::digits; + + // Initialize result to zero + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + res[k][n] = 0; + } + } + + // Add all components with appropriate shifts + for (int j = 0; j < P::l̅; j++) { + const int shift = width - (j + 1) * P::B̅gbit; + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + res[k][n] += decomposed[j][k][n] << shift; + } + } + } +} + +// Recombine l̅ₐ TRLWEs from Double Decomposition (nonce version) +template +inline void RecombineTRLWEFromDDNonce( + TRLWE

&res, const std::array, P::l̅ₐ> &decomposed) +{ + constexpr int width = std::numeric_limits::digits; + + // Initialize result to zero + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + res[k][n] = 0; + } + } + + // Add all components with appropriate shifts + for (int j = 0; j < P::l̅ₐ; j++) { + const int shift = width - (j + 1) * P::B̅gₐbit; + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + res[k][n] += decomposed[j][k][n] << shift; + } + } + } +} + +// External product with TRGSWFFT +// Automatically uses Double Decomposition when P::l̅ > 1 template void ExternalProduct(TRLWE

&res, const TRLWE

&trlwe, const TRGSWFFT

&trgswfft) { alignas(64) PolynomialInFD

decpolyfft; - alignas(64) TRLWEInFD

restrlwefft; - { - alignas(64) DecomposedNoncePolynomial

decpoly; - NonceDecomposition

(decpoly, trlwe[0]); - TwistIFFT

(decpolyfft, decpoly[0]); - for (int m = 0; m < P::k + 1; m++) - MulInFD(restrlwefft[m], decpolyfft, trgswfft[0][m]); - for (int i = 1; i < P::lₐ; i++) { + + if constexpr (P::l̅ > 1) { + // Double Decomposition: use standard decomposition on input, + // accumulate l̅ separate results, then recombine + // TRGSW rows are organized as: for each ordinary row i, l̅ rows for B̅g digits + + // l̅ separate accumulators in FD domain + alignas(64) std::array, P::l̅> restrlwefft_dd; + + // Initialize all accumulators to zero + for (int j = 0; j < P::l̅; j++) + for (int m = 0; m <= P::k; m++) + for (int n = 0; n < P::n; n++) restrlwefft_dd[j][m][n] = 0.0; + + // Process nonce part with standard decomposition (lₐ levels) + if constexpr (P::l̅ₐ > 1) { + alignas(64) DecomposedNoncePolynomial

decpoly; + NonceDecomposition

(decpoly, trlwe[0]); + for (int i = 0; i < P::lₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + // Each decomposition level i multiplies with l̅ₐ TRGSW rows + for (int j = 0; j < P::l̅ₐ; j++) { + const int row_idx = i * P::l̅ₐ + j; + for (int m = 0; m <= P::k; m++) { + if (i == 0 && j == 0) + MulInFD(restrlwefft_dd[j][m], decpolyfft, + trgswfft[row_idx][m]); + else + FMAInFD(restrlwefft_dd[j][m], decpolyfft, + trgswfft[row_idx][m]); + } + } + } + for (int k_idx = 1; k_idx < P::k; k_idx++) { + NonceDecomposition

(decpoly, trlwe[k_idx]); + for (int i = 0; i < P::lₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int j = 0; j < P::l̅ₐ; j++) { + const int row_idx = + (i * P::l̅ₐ + j) + k_idx * P::lₐ * P::l̅ₐ; + for (int m = 0; m <= P::k; m++) + FMAInFD(restrlwefft_dd[j][m], decpolyfft, + trgswfft[row_idx][m]); + } + } + } + } + else { + // l̅ₐ == 1: nonce part has no DD, just standard decomposition + alignas(64) DecomposedNoncePolynomial

decpoly; + NonceDecomposition

(decpoly, trlwe[0]); + TwistIFFT

(decpolyfft, decpoly[0]); + for (int m = 0; m <= P::k; m++) + MulInFD(restrlwefft_dd[0][m], decpolyfft, trgswfft[0][m]); + for (int i = 1; i < P::lₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m <= P::k; m++) + FMAInFD(restrlwefft_dd[0][m], decpolyfft, + trgswfft[i][m]); + } + for (int k_idx = 1; k_idx < P::k; k_idx++) { + NonceDecomposition

(decpoly, trlwe[k_idx]); + for (int i = 0; i < P::lₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m <= P::k; m++) + FMAInFD(restrlwefft_dd[0][m], decpolyfft, + trgswfft[i + k_idx * P::lₐ][m]); + } + } + } + + // Process main part with standard decomposition (l levels) + alignas(64) DecomposedPolynomial

decpoly; + Decomposition

(decpoly, trlwe[P::k]); + for (int i = 0; i < P::l; i++) { TwistIFFT

(decpolyfft, decpoly[i]); - for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i][m]); + // Each decomposition level i multiplies with l̅ TRGSW rows + for (int j = 0; j < P::l̅; j++) { + const int row_idx = (i * P::l̅ + j) + P::k * P::lₐ * P::l̅ₐ; + for (int m = 0; m <= P::k; m++) + FMAInFD(restrlwefft_dd[j][m], decpolyfft, + trgswfft[row_idx][m]); + } } - for (int k = 1; k < P::k; k++) { - NonceDecomposition

(decpoly, trlwe[k]); - for (int i = 0; i < P::lₐ; i++) { + + // FFT back to coefficient domain for each accumulator and recombine + std::array, P::l̅> results_dd; + for (int j = 0; j < P::l̅; j++) + for (int k = 0; k <= P::k; k++) + TwistFFT

(results_dd[j][k], restrlwefft_dd[j][k]); + + // Recombine the l̅ TRLWEs back to single TRLWE + RecombineTRLWEFromDD

(res, results_dd); + } + else { + // Standard decomposition (l̅ == 1) + alignas(64) TRLWEInFD

restrlwefft; + { + alignas(64) DecomposedNoncePolynomial

decpoly; + NonceDecomposition

(decpoly, trlwe[0]); + TwistIFFT

(decpolyfft, decpoly[0]); + for (int m = 0; m < P::k + 1; m++) + MulInFD(restrlwefft[m], decpolyfft, trgswfft[0][m]); + for (int i = 1; i < P::lₐ; i++) { TwistIFFT

(decpolyfft, decpoly[i]); for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, - trgswfft[i + k * P::lₐ][m]); + FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i][m]); + } + for (int k = 1; k < P::k; k++) { + NonceDecomposition

(decpoly, trlwe[k]); + for (int i = 0; i < P::lₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, + trgswfft[i + k * P::lₐ][m]); + } } } + alignas(64) DecomposedPolynomial

decpoly; + Decomposition

(decpoly, trlwe[P::k]); + for (int i = 0; i < P::l; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, + trgswfft[i + P::k * P::lₐ][m]); + } + for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } - alignas(64) DecomposedPolynomial

decpoly; - Decomposition

(decpoly, trlwe[P::k]); - for (int i = 0; i < P::l; i++) { - TwistIFFT

(decpolyfft, decpoly[i]); - for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, - trgswfft[i + P::k * P::lₐ][m]); - } - for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } template @@ -202,18 +531,52 @@ void ExternalProduct(TRLWE

&res, const Polynomial

&poly, alignas(64) DecomposedPolynomial

decpoly; Decomposition

(decpoly, poly); alignas(64) PolynomialInFD

decpolyfft; - // __builtin_prefetch(trgswfft[0].data()); - TwistIFFT

(decpolyfft, decpoly[0]); - alignas(64) TRLWEInFD

restrlwefft; - for (int m = 0; m < P::k + 1; m++) - MulInFD(restrlwefft[m], decpolyfft, halftrgswfft[0][m]); - for (int i = 1; i < P::l; i++) { - // __builtin_prefetch(trgswfft[i].data()); - TwistIFFT

(decpolyfft, decpoly[i]); + + if constexpr (P::l̅ > 1) { + // DD: use standard decomposition, accumulate l̅ results, recombine + alignas(64) std::array, P::l̅> restrlwefft_dd; + + // Initialize accumulators to zero + for (int j = 0; j < P::l̅; j++) + for (int m = 0; m <= P::k; m++) + for (int n = 0; n < P::n; n++) restrlwefft_dd[j][m][n] = 0.0; + + for (int i = 0; i < P::l; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int j = 0; j < P::l̅; j++) { + const int row_idx = i * P::l̅ + j; + for (int m = 0; m <= P::k; m++) { + if (i == 0 && j == 0) + MulInFD(restrlwefft_dd[j][m], decpolyfft, + halftrgswfft[row_idx][m]); + else + FMAInFD(restrlwefft_dd[j][m], decpolyfft, + halftrgswfft[row_idx][m]); + } + } + } + + // FFT back and recombine + std::array, P::l̅> results_dd; + for (int j = 0; j < P::l̅; j++) + for (int k = 0; k <= P::k; k++) + TwistFFT

(results_dd[j][k], restrlwefft_dd[j][k]); + + RecombineTRLWEFromDD

(res, results_dd); + } + else { + // Standard decomposition (l̅ == 1) + TwistIFFT

(decpolyfft, decpoly[0]); + alignas(64) TRLWEInFD

restrlwefft; for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, halftrgswfft[i][m]); + MulInFD(restrlwefft[m], decpolyfft, halftrgswfft[0][m]); + for (int i = 1; i < P::l; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, halftrgswfft[i][m]); + } + for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } - for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } template @@ -353,7 +716,7 @@ template TRGSWFFT

ApplyFFT2trgsw(const TRGSW

&trgsw) { alignas(64) TRGSWFFT

trgswfft; - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < (P::k + 1); j++) TwistIFFT

(trgswfft[i][j], trgsw[i][j]); return trgswfft; @@ -362,7 +725,7 @@ TRGSWFFT

ApplyFFT2trgsw(const TRGSW

&trgsw) template void ApplyFFT2trgsw(TRGSWFFT

&trgswfft, const TRGSW

&trgsw) { - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < (P::k + 1); j++) TwistIFFT

(trgswfft[i][j], trgsw[i][j]); } @@ -371,7 +734,7 @@ template HalfTRGSWFFT

ApplyFFT2halftrgsw(const HalfTRGSW

&trgsw) { alignas(64) HalfTRGSWFFT

halftrgswfft; - for (int i = 0; i < P::l; i++) + for (int i = 0; i < P::l * P::l̅; i++) for (int j = 0; j < (P::k + 1); j++) TwistIFFT

(halftrgswfft[i][j], trgsw[i][j]); return halftrgswfft; @@ -381,7 +744,7 @@ template TRGSWNTT

ApplyNTT2trgsw(const TRGSW

&trgsw) { TRGSWNTT

trgswntt; - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < P::k + 1; j++) TwistINTT

(trgswntt[i][j], trgsw[i][j]); return trgswntt; @@ -392,7 +755,7 @@ TRGSWRAINTT

ApplyRAINTT2trgsw(const TRGSW

&trgsw) { constexpr uint8_t remainder = ((P::nbit - 1) % 3) + 1; TRGSWRAINTT

trgswntt; - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < P::k + 1; j++) { raintt::TwistINTT( trgswntt[i][j], trgsw[i][j], (*raintttable)[1], @@ -412,7 +775,7 @@ template TRGSWNTT

TRGSW2NTT(const TRGSW

&trgsw) { TRGSWNTT

trgswntt; - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < P::k + 1; j++) { PolynomialNTT

temp; TwistINTT

(temp, trgsw[i][j]); @@ -432,7 +795,7 @@ constexpr std::array hgen() ((i + 1) * P::Bgbit); else for (int i = 0; i < P::l; i++) - h[i] = 1ULL << (std::numeric_limits::digits - + h[i] = static_cast(1) << (std::numeric_limits::digits - (i + 1) * P::Bgbit); return h; } @@ -447,26 +810,67 @@ constexpr std::array noncehgen() ((i + 1) * P::Bgₐbit); else for (int i = 0; i < P::lₐ; i++) - h[i] = 1ULL << (std::numeric_limits::digits - + h[i] = static_cast(1) << (std::numeric_limits::digits - (i + 1) * P::Bgₐbit); return h; } +// Auxiliary h generation for Double Decomposition (bivariate representation) +// h̅[j] values are used to construct gadget values h[i] * h̅[j] = 2^(width - (i+1)*Bgbit - j*B̅gbit) +// For j=0: no auxiliary shift, so h̅[0] = 1 +// For j>0: h̅[j] = 2^(width - j*B̅gbit) which combines with h[i] via modular multiplication +template +constexpr std::array h̅gen() +{ + std::array h̅{}; + h̅[0] = 1; // j=0 means no auxiliary shift + for (int i = 1; i < P::l̅; i++) + h̅[i] = static_cast(1) << (std::numeric_limits::digits - + i * P::B̅gbit); + return h̅; +} + +// Auxiliary h generation for nonce part of TRGSW with Double Decomposition +template +constexpr std::array nonceh̅gen() +{ + std::array h̅{}; + h̅[0] = 1; // j=0 means no auxiliary shift + for (int i = 1; i < P::l̅ₐ; i++) + h̅[i] = static_cast(1) << (std::numeric_limits::digits - + i * P::B̅gₐbit); + return h̅; +} + +// Add gadget values to HalfTRGSW (standard decomposition only) +// For Double Decomposition, use halftrgswSymEncrypt directly template inline void halftrgswhadd(HalfTRGSW

&halftrgsw, const Polynomial

&p) { + static_assert(P::l̅ == 1, + "halftrgswhadd only supports standard decomposition (l̅=1). " + "Use halftrgswSymEncrypt for DD."); constexpr std::array h = hgen

(); for (int i = 0; i < P::l; i++) { for (int j = 0; j < P::n; j++) { - halftrgsw[i][P::k][j] += static_cast(p[j]) * h[i]; + halftrgsw[i][P::k][j] += + static_cast(p[j]) * h[i]; } } } +// Add gadget values to TRGSW (standard decomposition only) +// For Double Decomposition, use trgswSymEncrypt directly template inline void trgswhadd(TRGSW

&trgsw, const Polynomial

&p) { + static_assert(P::l̅ == 1 && P::l̅ₐ == 1, + "trgswhadd only supports standard decomposition (l̅=l̅ₐ=1). " + "Use trgswSymEncrypt for DD."); constexpr std::array nonceh = noncehgen

(); + constexpr std::array h = hgen

(); + + // Nonce part for (int i = 0; i < P::lₐ; i++) { for (int k = 0; k < P::k; k++) { for (int j = 0; j < P::n; j++) { @@ -475,7 +879,8 @@ inline void trgswhadd(TRGSW

&trgsw, const Polynomial

&p) } } } - constexpr std::array h = hgen

(); + + // Main part for (int i = 0; i < P::l; i++) { for (int j = 0; j < P::n; j++) { trgsw[i + P::k * P::lₐ][P::k][j] += @@ -484,31 +889,196 @@ inline void trgswhadd(TRGSW

&trgsw, const Polynomial

&p) } } +// Add gadget values for constant 1 to TRGSW (standard decomposition only) +// For Double Decomposition, use trgswSymEncryptOne directly template inline void trgswhoneadd(TRGSW

&trgsw) { + static_assert(P::l̅ == 1 && P::l̅ₐ == 1, + "trgswhoneadd only supports standard decomposition (l̅=l̅ₐ=1). " + "Use trgswSymEncryptOne for DD."); constexpr std::array nonceh = noncehgen

(); + constexpr std::array h = hgen

(); + + // Nonce part for (int i = 0; i < P::lₐ; i++) - for (int k = 0; k < P::k; k++) trgsw[i + k * P::lₐ][k][0] += nonceh[i]; + for (int k = 0; k < P::k; k++) + trgsw[i + k * P::lₐ][k][0] += nonceh[i]; + + // Main part + for (int i = 0; i < P::l; i++) + trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; +} +// Encrypt constant 1 in TRGSW with proper DD support +template +void trgswSymEncryptOneImpl(TRGSW

&trgsw, const NoiseType noise, + const Key

&key) +{ + constexpr std::array nonceh = noncehgen

(); constexpr std::array h = hgen

(); - for (int i = 0; i < P::l; i++) trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; + + if constexpr (P::l̅ > 1 || P::l̅ₐ > 1) { + // Double Decomposition path + constexpr int ordinary_rows = P::k * P::lₐ + P::l; + std::array, ordinary_rows> ordinary_trgsw; + for (auto &trlwe : ordinary_trgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + // Add gadget for constant 1 + for (int i = 0; i < P::lₐ; i++) + for (int k_idx = 0; k_idx < P::k; k_idx++) + ordinary_trgsw[i + k_idx * P::lₐ][k_idx][0] += nonceh[i]; + + for (int i = 0; i < P::l; i++) + ordinary_trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; + + // Apply DD + for (int k_idx = 0; k_idx < P::k; k_idx++) { + for (int i = 0; i < P::lₐ; i++) { + std::array, P::l̅ₐ> decomposed; + TRLWEBaseBbarDecomposeNonce

(decomposed, + ordinary_trgsw[i + k_idx * P::lₐ]); + for (int j = 0; j < P::l̅ₐ; j++) + trgsw[(i * P::l̅ₐ + j) + k_idx * P::lₐ * P::l̅ₐ] = + decomposed[j]; + } + } + for (int i = 0; i < P::l; i++) { + std::array, P::l̅> decomposed; + TRLWEBaseBbarDecompose

(decomposed, + ordinary_trgsw[i + P::k * P::lₐ]); + for (int j = 0; j < P::l̅; j++) + trgsw[(i * P::l̅ + j) + P::k * P::lₐ * P::l̅ₐ] = decomposed[j]; + } + } + else { + // Standard path + for (TRLWE

&trlwe : trgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + for (int i = 0; i < P::lₐ; i++) + for (int k_idx = 0; k_idx < P::k; k_idx++) + trgsw[i + k_idx * P::lₐ][k_idx][0] += nonceh[i]; + + for (int i = 0; i < P::l; i++) + trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; + } +} + +template +void trgswSymEncryptOne(TRGSW

&trgsw, const double α, const Key

&key) +{ + trgswSymEncryptOneImpl

(trgsw, α, key); +} + +template +void trgswSymEncryptOne(TRGSW

&trgsw, const uint η, const Key

&key) +{ + trgswSymEncryptOneImpl

(trgsw, η, key); +} + +template +void trgswSymEncryptOne(TRGSW

&trgsw, const Key

&key) +{ + if constexpr (P::errordist == ErrorDistribution::ModularGaussian) + trgswSymEncryptOne

(trgsw, P::α, key); + else + trgswSymEncryptOne

(trgsw, P::η, key); +} + +template +void trgswSymEncryptImpl(TRGSW

&trgsw, const Polynomial

&p, + const NoiseType noise, const Key

&key) +{ + constexpr std::array nonceh = noncehgen

(); + constexpr std::array h = hgen

(); + + if constexpr (P::l̅ > 1 || P::l̅ₐ > 1) { + // Double Decomposition path: + // Step 1: Create ordinary TRGSW with k*lₐ + l rows + constexpr int ordinary_rows = P::k * P::lₐ + P::l; + std::array, ordinary_rows> ordinary_trgsw; + for (auto &trlwe : ordinary_trgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + // Step 2: Add gadget values to create ordinary TRGSW + // Nonce part + for (int i = 0; i < P::lₐ; i++) { + for (int k_idx = 0; k_idx < P::k; k_idx++) { + for (int n = 0; n < P::n; n++) { + ordinary_trgsw[i + k_idx * P::lₐ][k_idx][n] += + static_cast(p[n]) * nonceh[i]; + } + } + } + // Main part + for (int i = 0; i < P::l; i++) { + for (int n = 0; n < P::n; n++) { + ordinary_trgsw[i + P::k * P::lₐ][P::k][n] += + static_cast(p[n]) * h[i]; + } + } + + // Step 3: Apply DD to each ordinary row + // Nonce part: each of the k*lₐ rows expands to l̅ₐ rows + for (int k_idx = 0; k_idx < P::k; k_idx++) { + for (int i = 0; i < P::lₐ; i++) { + std::array, P::l̅ₐ> decomposed; + TRLWEBaseBbarDecomposeNonce

(decomposed, + ordinary_trgsw[i + k_idx * P::lₐ]); + for (int j = 0; j < P::l̅ₐ; j++) { + trgsw[(i * P::l̅ₐ + j) + k_idx * P::lₐ * P::l̅ₐ] = + decomposed[j]; + } + } + } + // Main part: each of the l rows expands to l̅ rows + for (int i = 0; i < P::l; i++) { + std::array, P::l̅> decomposed; + TRLWEBaseBbarDecompose

(decomposed, + ordinary_trgsw[i + P::k * P::lₐ]); + for (int j = 0; j < P::l̅; j++) { + trgsw[(i * P::l̅ + j) + P::k * P::lₐ * P::l̅ₐ] = decomposed[j]; + } + } + } + else { + // Standard path (no DD): encrypt and add gadget directly + for (TRLWE

&trlwe : trgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + // Nonce part + for (int i = 0; i < P::lₐ; i++) { + for (int k_idx = 0; k_idx < P::k; k_idx++) { + for (int n = 0; n < P::n; n++) { + trgsw[i + k_idx * P::lₐ][k_idx][n] += + static_cast(p[n]) * nonceh[i]; + } + } + } + // Main part + for (int i = 0; i < P::l; i++) { + for (int n = 0; n < P::n; n++) { + trgsw[i + P::k * P::lₐ][P::k][n] += + static_cast(p[n]) * h[i]; + } + } + } } template void trgswSymEncrypt(TRGSW

&trgsw, const Polynomial

&p, const double α, const Key

&key) { - for (TRLWE

&trlwe : trgsw) trlweSymEncryptZero

(trlwe, α, key); - trgswhadd

(trgsw, p); + trgswSymEncryptImpl

(trgsw, p, α, key); } template void trgswSymEncrypt(TRGSW

&trgsw, const Polynomial

&p, const uint η, const Key

&key) { - for (TRLWE

&trlwe : trgsw) trlweSymEncryptZero

(trlwe, η, key); - trgswhadd

(trgsw, p); + trgswSymEncryptImpl

(trgsw, p, η, key); } template @@ -521,28 +1091,62 @@ void trgswSymEncrypt(TRGSW

&trgsw, const Polynomial

&p, trgswSymEncrypt

(trgsw, p, P::η, key); } +template +void halftrgswSymEncryptImpl(HalfTRGSW

&halftrgsw, const Polynomial

&p, + const NoiseType noise, const Key

&key) +{ + constexpr std::array h = hgen

(); + + if constexpr (P::l̅ > 1) { + // Double Decomposition path: + // Step 1: Create ordinary HalfTRGSW with l rows + std::array, P::l> ordinary_halftrgsw; + for (auto &trlwe : ordinary_halftrgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + // Step 2: Add gadget values + for (int i = 0; i < P::l; i++) { + for (int n = 0; n < P::n; n++) { + ordinary_halftrgsw[i][P::k][n] += + static_cast(p[n]) * h[i]; + } + } + + // Step 3: Apply DD to each row, expanding l rows to l*l̅ rows + for (int i = 0; i < P::l; i++) { + std::array, P::l̅> decomposed; + TRLWEBaseBbarDecompose

(decomposed, ordinary_halftrgsw[i]); + for (int j = 0; j < P::l̅; j++) { + halftrgsw[i * P::l̅ + j] = decomposed[j]; + } + } + } + else { + // Standard path (no DD) + for (TRLWE

&trlwe : halftrgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + for (int i = 0; i < P::l; i++) { + for (int n = 0; n < P::n; n++) { + halftrgsw[i][P::k][n] += + static_cast(p[n]) * h[i]; + } + } + } +} + template void halftrgswSymEncrypt(HalfTRGSW

&halftrgsw, const Polynomial

&p, const double α, const Key

&key) { - constexpr std::array h = hgen

(); - - for (TRLWE

&trlwe : halftrgsw) trlweSymEncryptZero

(trlwe, α, key); - for (int i = 0; i < P::l; i++) - for (int j = 0; j < P::n; j++) - halftrgsw[i][P::k][j] += static_cast(p[j]) * h[i]; + halftrgswSymEncryptImpl

(halftrgsw, p, α, key); } template void halftrgswSymEncrypt(HalfTRGSW

&halftrgsw, const Polynomial

&p, const uint η, const Key

&key) { - constexpr std::array h = hgen

(); - - for (TRLWE

&trlwe : halftrgsw) trlweSymEncryptZero

(trlwe, η, key); - for (int i = 0; i < P::l; i++) - for (int j = 0; j < P::n; j++) - halftrgsw[i][P::k][j] += static_cast(p[j]) * h[i]; + halftrgswSymEncryptImpl

(halftrgsw, p, η, key); } template diff --git a/include/trlwe.hpp b/include/trlwe.hpp index db9e1233..3835d9ba 100644 --- a/include/trlwe.hpp +++ b/include/trlwe.hpp @@ -8,11 +8,9 @@ namespace TFHEpp { template void trlweSymEncryptZero(TRLWE

&c, const double α, const Key

&key) { - std::uniform_int_distribution Torusdist( - 0, std::numeric_limits::max()); for (typename P::T &i : c[P::k]) i = ModularGaussian

(0, α); for (int k = 0; k < P::k; k++) { - for (typename P::T &i : c[k]) i = Torusdist(generator); + for (typename P::T &i : c[k]) i = UniformTorusRandom

(); std::array partkey; for (int i = 0; i < P::n; i++) partkey[i] = key[k * P::n + i]; Polynomial

temp; @@ -24,12 +22,10 @@ void trlweSymEncryptZero(TRLWE

&c, const double α, const Key

&key) template void trlweSymEncryptZero(TRLWE

&c, const uint η, const Key

&key) { - std::uniform_int_distribution Torusdist( - 0, std::numeric_limits::max()); for (typename P::T &i : c[P::k]) i = (CenteredBinomial

(η) << std::numeric_limits

::digits) / P::q; for (int k = 0; k < P::k; k++) { - for (typename P::T &i : c[k]) i = Torusdist(generator); + for (typename P::T &i : c[k]) i = UniformTorusRandom

(); alignas(64) std::array partkey; for (int i = 0; i < P::n; i++) partkey[i] = key[k * P::n + i]; alignas(64) Polynomial

temp; diff --git a/include/utils.hpp b/include/utils.hpp index 9eace5c3..38b0122d 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -29,6 +29,26 @@ static thread_local std::random_device generator; template constexpr bool false_v = false; +// Helper function to generate uniform random Torus values +// For standard types, use uniform_int_distribution +// For __uint128_t, combine two 64-bit random values +template +inline typename P::T UniformTorusRandom() +{ + if constexpr (std::is_same_v) { + std::uniform_int_distribution dist64( + 0, std::numeric_limits::max()); + __uint128_t high = dist64(generator); + __uint128_t low = dist64(generator); + return (high << 64) | low; + } + else { + std::uniform_int_distribution dist( + 0, std::numeric_limits::max()); + return dist(generator); + } +} + // https://qiita.com/negi-drums/items/a527c05050781a5af523 template concept hasq = requires { T::q; }; @@ -140,6 +160,15 @@ inline typename P::T ModularGaussian(typename P::T center, double stdev) const uint64_t ival = static_cast(val); return ival + center; } + else if constexpr (std::is_same_v) { + // 128bit fixed-point number version + // Use two 64-bit Gaussians for high and low parts + static const double _2p64 = std::pow(2., 64); + std::normal_distribution distribution(0., 1.0); + const double val = stdev * distribution(generator) * _2p64; + const __int128_t ival = static_cast<__int128_t>(val); + return static_cast<__uint128_t>(ival) + center; + } else static_assert(false_v, "Undefined Modular Gaussian!"); } diff --git a/test/externalproductdoubledecomposition.cpp b/test/externalproductdoubledecomposition.cpp new file mode 100644 index 00000000..de7301e7 --- /dev/null +++ b/test/externalproductdoubledecomposition.cpp @@ -0,0 +1,295 @@ +#include +#include +#include +#include + +using namespace std; +using namespace TFHEpp; + +// Custom parameter set for testing standard decomposition (l̅=1) with 64-bit Torus +// This tests that the standard path still works correctly +struct DDTestParamStandard { + static constexpr int32_t key_value_max = 1; + static constexpr int32_t key_value_min = -1; + static constexpr std::uint32_t nbit = lvl2param::nbit; // Use lvl2param's nbit for FFT compatibility + static constexpr std::uint32_t n = 1 << nbit; + static constexpr std::uint32_t k = 1; + static constexpr std::uint32_t lₐ = 4; + static constexpr std::uint32_t l = 4; + static constexpr std::uint32_t Bgbit = 10; + static constexpr std::uint32_t Bgₐbit = 10; + static constexpr uint32_t Bg = 1U << Bgbit; + static constexpr uint32_t Bgₐ = 1U << Bgₐbit; + static constexpr ErrorDistribution errordist = + ErrorDistribution::ModularGaussian; + static const inline double α = std::pow(2.0, -40); + using T = uint64_t; + static constexpr T μ = 1ULL << 61; + static constexpr uint32_t plain_modulus = 8; + static constexpr double Δ = μ; + // Standard decomposition (l̅=1) + static constexpr std::uint32_t l̅ = 1; + static constexpr std::uint32_t l̅ₐ = 1; + static constexpr std::uint32_t B̅gbit = 10; + static constexpr std::uint32_t B̅gₐbit = 10; +}; + +// Custom parameter set for testing actual Double Decomposition (l̅=2) with 64-bit Torus +// This tests the DD code path where l̅ > 1 +struct DDTestParam { + static constexpr int32_t key_value_max = 1; + static constexpr int32_t key_value_min = -1; + static constexpr std::uint32_t nbit = lvl2param::nbit; // Use lvl2param's nbit for FFT compatibility + static constexpr std::uint32_t n = 1 << nbit; + static constexpr std::uint32_t k = 1; + static constexpr std::uint32_t lₐ = 2; // Reduced to fit within bit budget + static constexpr std::uint32_t l = 2; // Reduced to fit within bit budget + static constexpr std::uint32_t Bgbit = 16; + static constexpr std::uint32_t Bgₐbit = 16; + static constexpr uint32_t Bg = 1U << Bgbit; + static constexpr uint32_t Bgₐ = 1U << Bgₐbit; + static constexpr ErrorDistribution errordist = + ErrorDistribution::ModularGaussian; + static const inline double α = std::pow(2.0, -40); + using T = uint64_t; + static constexpr T μ = 1ULL << 61; + static constexpr uint32_t plain_modulus = 8; + static constexpr double Δ = μ; + // Actual Double Decomposition: l=2, Bgbit=16, l̅=2, B̅gbit=16 + // Total bits used: 2*16 + 2*16 = 64 (exact fit for 64-bit Torus) + static constexpr std::uint32_t l̅ = 2; + static constexpr std::uint32_t l̅ₐ = 2; + static constexpr std::uint32_t B̅gbit = 16; + static constexpr std::uint32_t B̅gₐbit = 16; +}; + +// Test Double Decomposition correctness +// Verifies that DoubleDecomposition correctly represents the original value +template +void testDoubleDecomposition() +{ + random_device seed_gen; + default_random_engine engine(seed_gen()); + uniform_int_distribution dist(0, + numeric_limits::max()); + + constexpr int num_test = 100; + // Tolerance based on decomposition precision + // The lowest shift is: width - l*Bgbit - (l̅-1)*B̅gbit (j goes from 0 to l̅-1) + // So remaining_bits = width - l*Bgbit - (l̅-1)*B̅gbit + constexpr int remaining_bits = std::numeric_limits::digits - + P::l * P::Bgbit - (P::l̅ - 1) * P::B̅gbit; + static_assert(remaining_bits >= 0, + "Invalid double decomposition parameters"); + + for (int test = 0; test < num_test; test++) { + // Create a random polynomial + Polynomial

poly; + for (int i = 0; i < P::n; i++) { + poly[i] = dist(engine); + } + + // Apply double decomposition + DecomposedPolynomialDD

decpoly; + DoubleDecomposition

(decpoly, poly); + + // Reconstruct the polynomial from decomposed components + Polynomial

reconstructed; + for (int n = 0; n < P::n; n++) { + typename P::T sum = 0; + for (int i = 0; i < P::l; i++) { + for (int j = 0; j < P::l̅; j++) { + // The scaling factor for (i,j) position + // When l̅=1 (j=0 only), this reduces to standard decomposition + typename P::T h_val = + static_cast(1) + << (std::numeric_limits::digits - + (i + 1) * P::Bgbit - j * P::B̅gbit); + sum += static_cast( + static_cast>( + decpoly[i * P::l̅ + j][n])) * + h_val; + } + } + reconstructed[n] = sum; + } + + // Check that reconstruction is close to original + typename P::T max_error = + remaining_bits > 0 + ? (static_cast(1) << remaining_bits) + : static_cast(1); + + for (int n = 0; n < P::n; n++) { + typename P::T diff = poly[n] > reconstructed[n] + ? poly[n] - reconstructed[n] + : reconstructed[n] - poly[n]; + typename P::T max_val = ~static_cast(0); + typename P::T wrap_diff = max_val - diff + 1; + typename P::T min_diff = diff < wrap_diff ? diff : wrap_diff; + + if (min_diff > max_error * 2) { + cerr << "Decomposition error at n=" << n << endl; + cerr << " original=" << poly[n] << endl; + cerr << " reconstructed=" << reconstructed[n] << endl; + cerr << " max_error=" << max_error << endl; + cerr << " actual_diff=" << min_diff << endl; + assert(false); + } + } + } + cout << "DoubleDecomposition test passed" << endl; +} + +// Test External Product with Double Decomposition +template +void testExternalProduct() +{ + constexpr uint32_t num_test = 10; + random_device seed_gen; + default_random_engine engine(seed_gen()); + uniform_int_distribution binary(0, 1); + + cout << "Parameters: n=" << P::n << ", k=" << P::k << endl; + cout << "Primary: l=" << P::l << ", Bgbit=" << P::Bgbit << endl; + cout << "Auxiliary: l̅=" << P::l̅ << ", B̅gbit=" << P::B̅gbit << endl; + cout << "TRGSW rows: " << (P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅) << endl; + + // Test with p=1 (identity) + cout << "Testing with plaintext = 1 (identity)..." << endl; + for (int test = 0; test < num_test; test++) { + // Generate a random TRLWE key for this param set + std::array key; + for (int i = 0; i < P::n; i++) { + key[i] = (binary(engine) == 1) + ? static_cast(P::key_value_max) + : static_cast(P::key_value_min); + } + + // Create a random message polynomial + array p; + for (bool &i : p) i = (binary(engine) > 0); + Polynomial

pmu; + for (int i = 0; i < P::n; i++) + pmu[i] = p[i] ? P::μ : -P::μ; + + // Encrypt as TRLWE + TRLWE

c; + trlweSymEncrypt

(c, pmu, key); + + // Create TRGSW encrypting 1 (identity) + const Polynomial

plainpoly = {static_cast(1)}; + + TRGSWFFT

trgswfft; + trgswSymEncrypt

(trgswfft, plainpoly, key); + + // Apply external product (automatically uses DD when P::l̅ > 1) + TRLWE

result; + ExternalProduct

(result, c, trgswfft); + + // Decrypt and verify + const array p2 = trlweSymDecrypt

(result, key); + for (int i = 0; i < P::n; i++) { + if (p[i] != p2[i]) { + cerr << "ExternalProduct (DD) failed at index " << i + << ": expected " << p[i] << " got " << p2[i] << endl; + assert(false); + } + } + } + cout << "ExternalProduct (DD) test passed (p=1)" << endl; + + // Test with p=-1 (negation) + cout << "Testing with plaintext = -1 (negation)..." << endl; + for (int test = 0; test < num_test; test++) { + std::array key; + for (int i = 0; i < P::n; i++) { + key[i] = (binary(engine) == 1) + ? static_cast(P::key_value_max) + : static_cast(P::key_value_min); + } + + array p; + for (bool &i : p) i = (binary(engine) > 0); + Polynomial

pmu; + for (int i = 0; i < P::n; i++) + pmu[i] = p[i] ? P::μ : -P::μ; + + TRLWE

c; + trlweSymEncrypt

(c, pmu, key); + + // TRGSW encrypting -1 (negation) + const Polynomial

plainpoly = {static_cast(-1)}; + + TRGSWFFT

trgswfft; + trgswSymEncrypt

(trgswfft, plainpoly, key); + + TRLWE

result; + ExternalProduct

(result, c, trgswfft); + + const array p2 = trlweSymDecrypt

(result, key); + for (int i = 0; i < P::n; i++) { + if (p[i] != !p2[i]) { + cerr << "ExternalProduct (DD, p=-1) failed at index " << i + << ": expected " << !p[i] << " got " << p2[i] << endl; + assert(false); + } + } + } + cout << "ExternalProduct (DD) test passed (p=-1)" << endl; +} + +int main() +{ + cout << "=== Testing Double Decomposition ===" << endl; + cout << endl; + + // Test 1: Standard decomposition (l̅=1) + cout << "=== Test 1: Standard decomposition (l̅=1) ===" << endl; + cout << "DDTestParamStandard configuration:" << endl; + cout << " n = " << DDTestParamStandard::n << ", k = " << DDTestParamStandard::k << endl; + cout << " Primary: l = " << DDTestParamStandard::l + << ", Bgbit = " << DDTestParamStandard::Bgbit << endl; + cout << " Auxiliary: l̅ = " << DDTestParamStandard::l̅ + << ", B̅gbit = " << DDTestParamStandard::B̅gbit + << " (standard path)" << endl; + cout << " TRGSW rows: " + << (DDTestParamStandard::k * DDTestParamStandard::lₐ * DDTestParamStandard::l̅ₐ + + DDTestParamStandard::l * DDTestParamStandard::l̅) + << endl; + cout << endl; + + cout << "--- Testing ExternalProduct (standard) ---" << endl; + testExternalProduct(); + cout << endl; + + // Test 2: Actual Double Decomposition (l̅=2) + cout << "=== Test 2: Double Decomposition (l̅=2) ===" << endl; + cout << "DDTestParam configuration:" << endl; + cout << " n = " << DDTestParam::n << ", k = " << DDTestParam::k << endl; + cout << " Primary: l = " << DDTestParam::l + << ", Bgbit = " << DDTestParam::Bgbit << endl; + cout << " Auxiliary: l̅ = " << DDTestParam::l̅ + << ", B̅gbit = " << DDTestParam::B̅gbit + << " (DD path)" << endl; + cout << " Total bits: " + << (DDTestParam::l * DDTestParam::Bgbit + + DDTestParam::l̅ * DDTestParam::B̅gbit) + << " / 64" << endl; + cout << " TRGSW rows: " + << (DDTestParam::k * DDTestParam::lₐ * DDTestParam::l̅ₐ + + DDTestParam::l * DDTestParam::l̅) + << endl; + cout << endl; + + // Note: Skip testDoubleDecomposition for l̅>1 as that tests the old + // bivariate polynomial decomposition. The new DD algorithm uses + // TRLWEBaseBbarDecompose which is tested implicitly via ExternalProduct. + + cout << "--- Testing ExternalProduct (DD) ---" << endl; + testExternalProduct(); + cout << endl; + + cout << "=== All Double Decomposition tests passed ===" << endl; + return 0; +} diff --git a/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp b/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp new file mode 100644 index 00000000..30cf984c --- /dev/null +++ b/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp @@ -0,0 +1,108 @@ +#include +#include +#include +#include +#include + +int main() +{ +#if defined(USE_CONCRETE) || defined(USE_CONCRETE_FFT) + // Skip this test for CONCRETE builds - lvl3param uses nbit=13 which exceeds + // the FFT table sizes initialized for CONCRETE parameters + std::cout << "Skipping DD gate bootstrapping test for CONCRETE build" << std::endl; + return 0; +#else + using namespace std; + using namespace TFHEpp; + + constexpr uint32_t num_test = 10; // Reduced for faster testing + random_device seed_gen; + default_random_engine engine(seed_gen()); + uniform_int_distribution binary(0, 1); + + // Use lvl03param which bootstraps from lvl0param to lvl3param + // lvl3param has nbit=12 (n=4096) and non-trivial DD: l=2, l̅=2, Bgbit=16, B̅gbit=16 + using bkP = lvl03param; + + cout << "=== Testing GateBootstrappingTLWE2TLWE with DD (lvl3param, nbit=12) ===" << endl; + cout << "Domain: n=" << bkP::domainP::n << ", T=" << sizeof(typename bkP::domainP::T) * 8 << "-bit" << endl; + cout << "Target: n=" << bkP::targetP::n << ", nbit=" << bkP::targetP::nbit << ", T=" << sizeof(typename bkP::targetP::T) * 8 << "-bit" << endl; + cout << "Primary decomposition: l=" << bkP::targetP::l << ", Bgbit=" << bkP::targetP::Bgbit << endl; + cout << "Auxiliary decomposition: l̅=" << bkP::targetP::l̅ << ", B̅gbit=" << bkP::targetP::B̅gbit << endl; + cout << "Total decomposition levels: " << (bkP::targetP::l * bkP::targetP::l̅) << endl; + cout << "Bits used: " << (bkP::targetP::l * bkP::targetP::Bgbit + bkP::targetP::l̅ * bkP::targetP::B̅gbit) << " / " << std::numeric_limits::digits << endl; + cout << endl; + + // Generate keys + cout << "Generating secret keys..." << endl; + array domainKey; + for (int i = 0; i < bkP::domainP::n; i++) { + domainKey[i] = binary(engine); + } + + array targetKey; + for (int i = 0; i < bkP::targetP::n; i++) { + targetKey[i] = (binary(engine) == 1) + ? static_cast(bkP::targetP::key_value_max) + : static_cast(bkP::targetP::key_value_min); + } + + // Generate bootstrapping key + cout << "Generating bootstrapping key (this may take a while for n=4096)..." << endl; + auto bkfft = make_unique>(); + bkfftgen(*bkfft, domainKey, targetKey); + + // Test arrays + array, num_test> tlwe; + array, num_test> bootedtlwe; + array p; + + // Encrypt test values + cout << "Encrypting test values..." << endl; + for (int i = 0; i < num_test; i++) { + p[i] = binary(engine) > 0; + tlweSymEncrypt( + tlwe[i], p[i] ? bkP::domainP::μ : -bkP::domainP::μ, bkP::domainP::α, + domainKey); + } + + // Perform gate bootstrapping (automatically uses DD when l̅ > 1) + cout << "Running GateBootstrappingTLWE2TLWE (with DD)..." << endl; + chrono::system_clock::time_point start, end; + start = chrono::system_clock::now(); + + for (int test = 0; test < num_test; test++) { + GateBootstrappingTLWE2TLWE( + bootedtlwe[test], tlwe[test], *bkfft, + μpolygen()); + } + + end = chrono::system_clock::now(); + + // Verify results + cout << "Verifying results..." << endl; + int errors = 0; + for (int i = 0; i < num_test; i++) { + bool p2 = tlweSymDecrypt(bootedtlwe[i], targetKey); + if (p[i] != p2) { + cerr << "Error at index " << i << ": expected " << p[i] << " got " << p2 << endl; + errors++; + } + } + + if (errors == 0) { + cout << "All tests passed!" << endl; + } + else { + cerr << errors << " out of " << num_test << " tests failed!" << endl; + return 1; + } + + double elapsed = + chrono::duration_cast(end - start).count(); + cout << "Average time per bootstrapping: " << elapsed / num_test << "ms" << endl; + + cout << "=== GateBootstrappingTLWE2TLWE (DD) test completed ===" << endl; + return 0; +#endif +} diff --git a/thirdparties/spqlios/fft_processor_spqlios.cpp b/thirdparties/spqlios/fft_processor_spqlios.cpp index 31b0f340..c04d4b63 100644 --- a/thirdparties/spqlios/fft_processor_spqlios.cpp +++ b/thirdparties/spqlios/fft_processor_spqlios.cpp @@ -376,3 +376,4 @@ FFT_Processor_Spqlios::~FFT_Processor_Spqlios() { thread_local FFT_Processor_Spqlios fftplvl1(TFHEpp::lvl1param::n); thread_local FFT_Processor_Spqlios fftplvl2(TFHEpp::lvl2param::n); +thread_local FFT_Processor_Spqlios fftplvl3(TFHEpp::lvl3param::n); diff --git a/thirdparties/spqlios/fft_processor_spqlios.h b/thirdparties/spqlios/fft_processor_spqlios.h index ae21e42b..c0711619 100644 --- a/thirdparties/spqlios/fft_processor_spqlios.h +++ b/thirdparties/spqlios/fft_processor_spqlios.h @@ -54,4 +54,5 @@ class FFT_Processor_Spqlios { }; extern thread_local FFT_Processor_Spqlios fftplvl1; -extern thread_local FFT_Processor_Spqlios fftplvl2; \ No newline at end of file +extern thread_local FFT_Processor_Spqlios fftplvl2; +extern thread_local FFT_Processor_Spqlios fftplvl3; \ No newline at end of file