From 894ed4ca64f48dc271598ad28dee44b70b8fc887 Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Wed, 31 Dec 2025 06:26:39 +0000 Subject: [PATCH 1/9] Add Double Decomposition (bivariate representation) infrastructure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement foundational support for the Double Decomposition technique from "Revisiting Key Decomposition Techniques for FHE" (ePrint 2023/771). Changes: - Add auxiliary decomposition parameters (l̅, l̅ₐ, B̅gbit, B̅gₐbit) to all parameter structs with trivial default values (l̅=1, B̅g=2^digits) - Update TRGSW type definitions to use k*lₐ*l̅ₐ + l*l̅ row structure - Add h̅gen() and nonceh̅gen() for auxiliary h value generation - Modify trgswhadd, halftrgswhadd, trgswhoneadd for double decomposition - Update ApplyFFT2trgsw, ApplyNTT2trgsw, ApplyRAINTT2trgsw loop bounds With trivial values, behavior is unchanged (h̅[0]=1, sizes identical). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- include/params.hpp | 12 +++--- include/params/128bit.hpp | 31 +++++++++++++ include/params/CGGI16.hpp | 21 +++++++++ include/params/CGGI19.hpp | 21 +++++++++ include/params/compress.hpp | 21 +++++++++ include/params/concrete.hpp | 33 ++++++++++++++ include/params/ternary.hpp | 21 +++++++++ include/params/tfhe-rs.hpp | 21 +++++++++ include/trgsw.hpp | 86 +++++++++++++++++++++++++------------ 9 files changed, 234 insertions(+), 33 deletions(-) diff --git a/include/params.hpp b/include/params.hpp index 56afe2d7..bcab96ca 100644 --- a/include/params.hpp +++ b/include/params.hpp @@ -138,17 +138,17 @@ template using TRLWERAINTT = std::array, P::k + 1>; template -using TRGSW = std::array, P::k * P::lₐ + P::l>; +using TRGSW = std::array, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>; template -using HalfTRGSW = std::array, P::l>; +using HalfTRGSW = std::array, P::l * P::l̅>; template -using TRGSWFFT = aligned_array, P::k * P::lₐ + P::l>; +using TRGSWFFT = aligned_array, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>; template -using HalfTRGSWFFT = aligned_array, P::l>; +using HalfTRGSWFFT = aligned_array, P::l * P::l̅>; template -using TRGSWNTT = std::array, P::k * P::lₐ + P::l>; +using TRGSWNTT = std::array, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>; template -using TRGSWRAINTT = std::array, P::k * P::lₐ + P::l>; +using TRGSWRAINTT = std::array, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>; #ifdef USE_KEY_BUNDLE template diff --git a/include/params/128bit.hpp b/include/params/128bit.hpp index 19ffe8a8..1059c4fe 100644 --- a/include/params/128bit.hpp +++ b/include/params/128bit.hpp @@ -62,6 +62,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct AHlvl1param { @@ -83,6 +90,11 @@ struct AHlvl1param { static constexpr std::make_signed_t μ = baseP::μ; static constexpr uint32_t plain_modulus = baseP::plain_modulus; static constexpr double Δ = baseP::Δ; + // Double Decomposition parameters inherited from baseP + static constexpr std::uint32_t l̅ = baseP::l̅; + static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ; + static constexpr std::uint32_t B̅gbit = baseP::B̅gbit; + static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; struct lvl2param { @@ -106,6 +118,13 @@ struct lvl2param { static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = static_cast(1ULL << (std::numeric_limits::digits - 4)); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct AHlvl2param { @@ -127,6 +146,11 @@ struct AHlvl2param { static constexpr std::make_signed_t μ = baseP::μ; static constexpr uint32_t plain_modulus = baseP::plain_modulus; static constexpr double Δ = baseP::Δ; + // Double Decomposition parameters inherited from baseP + static constexpr std::uint32_t l̅ = baseP::l̅; + static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ; + static constexpr std::uint32_t B̅gbit = baseP::B̅gbit; + static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; struct lvl3param { @@ -150,6 +174,13 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Key Switching parameters diff --git a/include/params/CGGI16.hpp b/include/params/CGGI16.hpp index c7936c06..d3601af5 100644 --- a/include/params/CGGI16.hpp +++ b/include/params/CGGI16.hpp @@ -59,6 +59,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -80,6 +87,13 @@ struct lvl2param { static constexpr T μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -104,6 +118,13 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl10param { diff --git a/include/params/CGGI19.hpp b/include/params/CGGI19.hpp index 623282d9..b5441e7a 100644 --- a/include/params/CGGI19.hpp +++ b/include/params/CGGI19.hpp @@ -59,6 +59,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -78,6 +85,13 @@ struct lvl2param { static constexpr T μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -102,6 +116,13 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy diff --git a/include/params/compress.hpp b/include/params/compress.hpp index d836ab0f..31415232 100644 --- a/include/params/compress.hpp +++ b/include/params/compress.hpp @@ -69,6 +69,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -93,6 +100,13 @@ struct lvl2param { static constexpr std::make_signed_t μ = q / 8; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -117,6 +131,13 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Key Switching parameters diff --git a/include/params/concrete.hpp b/include/params/concrete.hpp index 39c28f3c..d527cd34 100644 --- a/include/params/concrete.hpp +++ b/include/params/concrete.hpp @@ -69,6 +69,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -94,6 +101,13 @@ struct lvl2param { static constexpr T μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; #define USE_DIFFERENT_BR_PARAM @@ -118,6 +132,13 @@ struct cblvl2param { static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = static_cast(1ULL << (std::numeric_limits::digits - 4)); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -143,6 +164,11 @@ struct cbAHlvl2param { static constexpr std::make_signed_t μ = baseP::μ; static constexpr uint32_t plain_modulus = baseP::plain_modulus; static constexpr double Δ = baseP::Δ; + // Double Decomposition parameters inherited from baseP + static constexpr std::uint32_t l̅ = baseP::l̅; + static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ; + static constexpr std::uint32_t B̅gbit = baseP::B̅gbit; + static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; struct lvl3param { @@ -166,6 +192,13 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Key Switching parameters diff --git a/include/params/ternary.hpp b/include/params/ternary.hpp index c976747f..c68022eb 100644 --- a/include/params/ternary.hpp +++ b/include/params/ternary.hpp @@ -63,6 +63,13 @@ struct lvl1param { static constexpr double Δ = static_cast(1ULL << std::numeric_limits::digits) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -85,6 +92,13 @@ struct lvl2param { static constexpr std::make_signed_t μ = 1LL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl3param { @@ -108,6 +122,13 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Key Switching parameters diff --git a/include/params/tfhe-rs.hpp b/include/params/tfhe-rs.hpp index 09709e87..98357103 100644 --- a/include/params/tfhe-rs.hpp +++ b/include/params/tfhe-rs.hpp @@ -70,6 +70,13 @@ struct lvl1param { static constexpr double Δ = 2 * static_cast(1ULL << (std::numeric_limits::digits - 1)) / plain_modulus; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; struct lvl2param { @@ -92,6 +99,13 @@ struct lvl2param { static constexpr std::make_signed_t μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Dummy @@ -114,6 +128,13 @@ struct lvl3param { static constexpr uint32_t plain_modulusbit = 31; static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + // Double Decomposition (bivariate representation) parameters + // For now, set to trivial values (no actual second decomposition) + static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = l̅; + static constexpr std::uint32_t B̅gbit = + std::numeric_limits::digits; // full coefficient width + static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; // Key Switching parameters diff --git a/include/trgsw.hpp b/include/trgsw.hpp index 429d7a4d..7ba87055 100644 --- a/include/trgsw.hpp +++ b/include/trgsw.hpp @@ -353,7 +353,7 @@ template TRGSWFFT

ApplyFFT2trgsw(const TRGSW

&trgsw) { alignas(64) TRGSWFFT

trgswfft; - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < (P::k + 1); j++) TwistIFFT

(trgswfft[i][j], trgsw[i][j]); return trgswfft; @@ -362,7 +362,7 @@ TRGSWFFT

ApplyFFT2trgsw(const TRGSW

&trgsw) template void ApplyFFT2trgsw(TRGSWFFT

&trgswfft, const TRGSW

&trgsw) { - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < (P::k + 1); j++) TwistIFFT

(trgswfft[i][j], trgsw[i][j]); } @@ -371,7 +371,7 @@ template HalfTRGSWFFT

ApplyFFT2halftrgsw(const HalfTRGSW

&trgsw) { alignas(64) HalfTRGSWFFT

halftrgswfft; - for (int i = 0; i < P::l; i++) + for (int i = 0; i < P::l * P::l̅; i++) for (int j = 0; j < (P::k + 1); j++) TwistIFFT

(halftrgswfft[i][j], trgsw[i][j]); return halftrgswfft; @@ -381,7 +381,7 @@ template TRGSWNTT

ApplyNTT2trgsw(const TRGSW

&trgsw) { TRGSWNTT

trgswntt; - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < P::k + 1; j++) TwistINTT

(trgswntt[i][j], trgsw[i][j]); return trgswntt; @@ -392,7 +392,7 @@ TRGSWRAINTT

ApplyRAINTT2trgsw(const TRGSW

&trgsw) { constexpr uint8_t remainder = ((P::nbit - 1) % 3) + 1; TRGSWRAINTT

trgswntt; - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < P::k + 1; j++) { raintt::TwistINTT( trgswntt[i][j], trgsw[i][j], (*raintttable)[1], @@ -412,7 +412,7 @@ template TRGSWNTT

TRGSW2NTT(const TRGSW

&trgsw) { TRGSWNTT

trgswntt; - for (int i = 0; i < P::k * P::lₐ + P::l; i++) + for (int i = 0; i < P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅; i++) for (int j = 0; j < P::k + 1; j++) { PolynomialNTT

temp; TwistINTT

(temp, trgsw[i][j]); @@ -452,13 +452,38 @@ constexpr std::array noncehgen() return h; } +// Auxiliary h generation for Double Decomposition (bivariate representation) +template +constexpr std::array h̅gen() +{ + std::array h̅{}; + for (int i = 0; i < P::l̅; i++) + h̅[i] = 1ULL << (std::numeric_limits::digits - + (i + 1) * P::B̅gbit); + return h̅; +} + +template +constexpr std::array nonceh̅gen() +{ + std::array h̅{}; + for (int i = 0; i < P::l̅ₐ; i++) + h̅[i] = 1ULL << (std::numeric_limits::digits - + (i + 1) * P::B̅gₐbit); + return h̅; +} + template inline void halftrgswhadd(HalfTRGSW

&halftrgsw, const Polynomial

&p) { constexpr std::array h = hgen

(); + constexpr std::array h̅ = h̅gen

(); for (int i = 0; i < P::l; i++) { - for (int j = 0; j < P::n; j++) { - halftrgsw[i][P::k][j] += static_cast(p[j]) * h[i]; + for (int ī = 0; ī < P::l̅; ī++) { + for (int j = 0; j < P::n; j++) { + halftrgsw[i * P::l̅ + ī][P::k][j] += + static_cast(p[j]) * h[i] * h̅[ī]; + } } } } @@ -467,19 +492,26 @@ template inline void trgswhadd(TRGSW

&trgsw, const Polynomial

&p) { constexpr std::array nonceh = noncehgen

(); + constexpr std::array nonceh̅ = nonceh̅gen

(); for (int i = 0; i < P::lₐ; i++) { - for (int k = 0; k < P::k; k++) { - for (int j = 0; j < P::n; j++) { - trgsw[i + k * P::lₐ][k][j] += - static_cast(p[j]) * nonceh[i]; + for (int ī = 0; ī < P::l̅ₐ; ī++) { + for (int k = 0; k < P::k; k++) { + for (int j = 0; j < P::n; j++) { + trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ][k][j] += + static_cast(p[j]) * nonceh[i] * + nonceh̅[ī]; + } } } } constexpr std::array h = hgen

(); + constexpr std::array h̅ = h̅gen

(); for (int i = 0; i < P::l; i++) { - for (int j = 0; j < P::n; j++) { - trgsw[i + P::k * P::lₐ][P::k][j] += - static_cast(p[j]) * h[i]; + for (int ī = 0; ī < P::l̅; ī++) { + for (int j = 0; j < P::n; j++) { + trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ][P::k][j] += + static_cast(p[j]) * h[i] * h̅[ī]; + } } } } @@ -488,11 +520,19 @@ template inline void trgswhoneadd(TRGSW

&trgsw) { constexpr std::array nonceh = noncehgen

(); + constexpr std::array nonceh̅ = nonceh̅gen

(); for (int i = 0; i < P::lₐ; i++) - for (int k = 0; k < P::k; k++) trgsw[i + k * P::lₐ][k][0] += nonceh[i]; + for (int ī = 0; ī < P::l̅ₐ; ī++) + for (int k = 0; k < P::k; k++) + trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ][k][0] += + nonceh[i] * nonceh̅[ī]; constexpr std::array h = hgen

(); - for (int i = 0; i < P::l; i++) trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; + constexpr std::array h̅ = h̅gen

(); + for (int i = 0; i < P::l; i++) + for (int ī = 0; ī < P::l̅; ī++) + trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ][P::k][0] += + h[i] * h̅[ī]; } template @@ -525,24 +565,16 @@ template void halftrgswSymEncrypt(HalfTRGSW

&halftrgsw, const Polynomial

&p, const double α, const Key

&key) { - constexpr std::array h = hgen

(); - for (TRLWE

&trlwe : halftrgsw) trlweSymEncryptZero

(trlwe, α, key); - for (int i = 0; i < P::l; i++) - for (int j = 0; j < P::n; j++) - halftrgsw[i][P::k][j] += static_cast(p[j]) * h[i]; + halftrgswhadd

(halftrgsw, p); } template void halftrgswSymEncrypt(HalfTRGSW

&halftrgsw, const Polynomial

&p, const uint η, const Key

&key) { - constexpr std::array h = hgen

(); - for (TRLWE

&trlwe : halftrgsw) trlweSymEncryptZero

(trlwe, η, key); - for (int i = 0; i < P::l; i++) - for (int j = 0; j < P::n; j++) - halftrgsw[i][P::k][j] += static_cast(p[j]) * h[i]; + halftrgswhadd

(halftrgsw, p); } template From 2a861c1a3b93a563e6b9a471b7da435ce841d622 Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Wed, 31 Dec 2025 12:08:42 +0000 Subject: [PATCH 2/9] =?UTF-8?q?Fix=20h=CC=85gen=20shift=20formula=20bug=20?= =?UTF-8?q?in=20Double=20Decomposition?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The h̅gen and nonceh̅gen functions had an incorrect shift formula that used (i+1)*B̅gbit instead of i*B̅gbit. This caused ExternalProductDD to fail with ~50% error rate because the gadget values h[i]*h̅[j] were off by a factor of 2^B̅gbit. The correct formula is: - h̅[0] = 1 (j=0 means no auxiliary shift) - h̅[j] = 2^(width - j*B̅gbit) for j > 0 This matches the decomposition shift formula: width - (i+1)*Bgbit - j*B̅gbit When l̅=1 (trivial auxiliary decomposition), h̅[0]=1 correctly reduces double decomposition to standard decomposition. Also adds externalproductdoubledecomposition test to verify correctness. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- include/params.hpp | 4 + include/params/128bit.hpp | 45 +++- include/params/CGGI16.hpp | 8 +- include/params/CGGI19.hpp | 8 +- include/params/compress.hpp | 8 +- include/params/concrete.hpp | 8 +- include/params/ternary.hpp | 8 +- include/params/tfhe-rs.hpp | 8 +- include/trgsw.hpp | 157 +++++++++++- test/externalproductdoubledecomposition.cpp | 258 ++++++++++++++++++++ 10 files changed, 501 insertions(+), 11 deletions(-) create mode 100644 test/externalproductdoubledecomposition.cpp diff --git a/include/params.hpp b/include/params.hpp index bcab96ca..58e99a6e 100644 --- a/include/params.hpp +++ b/include/params.hpp @@ -118,6 +118,10 @@ using DecomposedPolynomial = std::array, P::l>; template using DecomposedNoncePolynomial = std::array, P::lₐ>; template +using DecomposedPolynomialDD = std::array, P::l * P::l̅>; +template +using DecomposedNoncePolynomialDD = std::array, P::lₐ * P::l̅ₐ>; +template using DecomposedPolynomialNTT = std::array, P::l>; template using DecomposedNoncePolynomialNTT = std::array, P::lₐ>; diff --git a/include/params/128bit.hpp b/include/params/128bit.hpp index 1059c4fe..33057dcd 100644 --- a/include/params/128bit.hpp +++ b/include/params/128bit.hpp @@ -153,7 +153,41 @@ struct AHlvl2param { static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; +// New lvl3param with 128-bit Torus and non-trivial Double Decomposition +// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 128 +// Using l=4, Bgbit=16, l̅=4, B̅gbit=16: 4*16 + 4*16 = 128 bits (fully utilized) struct lvl3param { + static constexpr int32_t key_value_max = 1; + static constexpr int32_t key_value_min = -1; + static const std::uint32_t nbit = 12; // dimension must be a power of 2 for + // ease of polynomial multiplication. + static constexpr std::uint32_t n = 1 << nbit; // dimension = 4096 + static constexpr std::uint32_t k = 1; + static constexpr std::uint32_t lₐ = 4; + static constexpr std::uint32_t l = 4; + static constexpr std::uint32_t Bgbit = 16; + static constexpr std::uint32_t Bgₐbit = 16; + static constexpr uint32_t Bg = 1U << Bgbit; + static constexpr uint32_t Bgₐ = 1U << Bgₐbit; + static constexpr ErrorDistribution errordist = + ErrorDistribution::ModularGaussian; + static const inline double α = std::pow(2.0, -105); // fresh noise + using T = __uint128_t; // Torus representation + static constexpr T μ = static_cast(1) << 125; + static constexpr uint32_t plain_modulusbit = 31; + static constexpr __uint128_t plain_modulus = static_cast(1) << plain_modulusbit; + static constexpr double Δ = + static_cast(static_cast(1) << (128 - plain_modulusbit - 1)); + // Double Decomposition (bivariate representation) parameters + // Non-trivial values for testing actual double decomposition + // Constraint: l * Bgbit + l̅ * B̅gbit <= 128 + static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = 4; + static constexpr std::uint32_t B̅gbit = 16; // 2^16 base for auxiliary + static constexpr std::uint32_t B̅gₐbit = 16; +}; + +struct lvl4param { static constexpr int32_t key_value_max = 1; static constexpr int32_t key_value_min = -1; static const std::uint32_t nbit = 13; // dimension must be a power of 2 for @@ -175,7 +209,7 @@ struct lvl3param { static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); // Double Decomposition (bivariate representation) parameters - // For now, set to trivial values (no actual second decomposition) + // Trivial values (no actual second decomposition) static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels static constexpr std::uint32_t l̅ₐ = l̅; static constexpr std::uint32_t B̅gbit = @@ -270,3 +304,12 @@ struct lvl31param { using domainP = lvl3param; using targetP = lvl1param; }; + +struct lvl41param { + static constexpr std::uint32_t t = 7; // number of addition in keyswitching + static constexpr std::uint32_t basebit = + 2; // how many bit should be encrypted in keyswitching key + static const inline double α = lvl1param::α; // key noise + using domainP = lvl4param; + using targetP = lvl1param; +}; diff --git a/include/params/CGGI16.hpp b/include/params/CGGI16.hpp index d3601af5..c691a95b 100644 --- a/include/params/CGGI16.hpp +++ b/include/params/CGGI16.hpp @@ -127,6 +127,9 @@ struct lvl3param { static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + struct lvl10param { static constexpr std::uint32_t t = 8; static constexpr std::uint32_t basebit = 2; @@ -212,4 +215,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/CGGI19.hpp b/include/params/CGGI19.hpp index b5441e7a..dacf1d41 100644 --- a/include/params/CGGI19.hpp +++ b/include/params/CGGI19.hpp @@ -125,6 +125,9 @@ struct lvl3param { static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Dummy struct lvl11param { static constexpr std::uint32_t t = 0; // number of addition in keyswitching @@ -211,4 +214,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/compress.hpp b/include/params/compress.hpp index 31415232..633889d7 100644 --- a/include/params/compress.hpp +++ b/include/params/compress.hpp @@ -140,6 +140,9 @@ struct lvl3param { static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Key Switching parameters struct lvl10param { static constexpr std::uint32_t t = 5; // number of addition in keyswitching @@ -213,4 +216,7 @@ struct lvl31param { 2; // how many bit should be encrypted in keyswitching key using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/concrete.hpp b/include/params/concrete.hpp index d527cd34..99d65313 100644 --- a/include/params/concrete.hpp +++ b/include/params/concrete.hpp @@ -201,6 +201,9 @@ struct lvl3param { static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Key Switching parameters struct lvl10param { static constexpr std::uint32_t t = 5; // number of addition in keyswitching @@ -290,4 +293,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/ternary.hpp b/include/params/ternary.hpp index c68022eb..8adbfac8 100644 --- a/include/params/ternary.hpp +++ b/include/params/ternary.hpp @@ -131,6 +131,9 @@ struct lvl3param { static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Key Switching parameters struct lvl10param { static constexpr std::uint32_t t = 7; // number of addition in keyswitching @@ -219,4 +222,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/params/tfhe-rs.hpp b/include/params/tfhe-rs.hpp index 98357103..cd5f9229 100644 --- a/include/params/tfhe-rs.hpp +++ b/include/params/tfhe-rs.hpp @@ -137,6 +137,9 @@ struct lvl3param { static constexpr std::uint32_t B̅gₐbit = B̅gbit; }; +// Dummy +using lvl4param = lvl3param; + // Key Switching parameters struct lvl10param { static constexpr std::uint32_t t = 3; // number of addition in keyswitching @@ -225,4 +228,7 @@ struct lvl31param { static const inline double α = lvl1param::α; // key noise using domainP = lvl3param; using targetP = lvl1param; -}; \ No newline at end of file +}; + +// Dummy +using lvl41param = lvl31param; \ No newline at end of file diff --git a/include/trgsw.hpp b/include/trgsw.hpp index 7ba87055..01c690c7 100644 --- a/include/trgsw.hpp +++ b/include/trgsw.hpp @@ -117,6 +117,102 @@ inline void NonceDecomposition(DecomposedNoncePolynomial

&decpoly, } } +// Double Decomposition (bivariate representation) for external product +// Decomposes each coefficient a into l*l̅ components such that: +// a ≈ Σᵢ Σⱼ aᵢⱼ * Bg^(l-i) * B̅g^(l̅-j) +// When l̅=1 (j=0 only), this reduces to standard decomposition. +template +constexpr typename P::T ddoffsetgen() +{ + typename P::T offset = 0; + for (int i = 1; i <= P::l; i++) + for (int j = 0; j < P::l̅; j++) + offset += (static_cast(P::Bg) / 2) * + (static_cast(1) + << (std::numeric_limits::digits - + i * P::Bgbit - j * P::B̅gbit)); + return offset; +} + +template +inline void DoubleDecomposition(DecomposedPolynomialDD

&decpoly, + const Polynomial

&poly) +{ + constexpr typename P::T offset = ddoffsetgen

(); + // Remaining bits after decomposition + constexpr int remaining_bits = std::numeric_limits::digits - + P::l * P::Bgbit - P::l̅ * P::B̅gbit; + // roundoffset is 0 if no remaining bits, otherwise 2^(remaining_bits-1) + constexpr typename P::T roundoffset = + remaining_bits > 0 + ? (static_cast(1) << (remaining_bits - 1)) + : static_cast(0); + constexpr typename P::T maskBg = + static_cast((1ULL << P::Bgbit) - 1); + constexpr typename P::T halfBg = (1ULL << (P::Bgbit - 1)); + + for (int n = 0; n < P::n; n++) { + typename P::T a = poly[n] + offset + roundoffset; + for (int i = 0; i < P::l; i++) { + for (int j = 0; j < P::l̅; j++) { + // Shift to get the (i,j)-th digit in base Bg (after B̅g grouping) + // When l̅=1 (j=0 only), this reduces to standard decomposition + const int shift = std::numeric_limits::digits - + (i + 1) * P::Bgbit - j * P::B̅gbit; + decpoly[i * P::l̅ + j][n] = + static_cast>( + ((a >> shift) & maskBg) - halfBg); + } + } + } +} + +template +constexpr typename P::T nonceddoffsetgen() +{ + typename P::T offset = 0; + for (int i = 1; i <= P::lₐ; i++) + for (int j = 0; j < P::l̅ₐ; j++) + offset += (static_cast(P::Bgₐ) / 2) * + (static_cast(1) + << (std::numeric_limits::digits - + i * P::Bgₐbit - j * P::B̅gₐbit)); + return offset; +} + +template +inline void NonceDoubleDecomposition(DecomposedNoncePolynomialDD

&decpoly, + const Polynomial

&poly) +{ + constexpr typename P::T offset = nonceddoffsetgen

(); + // Remaining bits after decomposition + constexpr int remaining_bits = std::numeric_limits::digits - + P::lₐ * P::Bgₐbit - P::l̅ₐ * P::B̅gₐbit; + // roundoffset is 0 if no remaining bits, otherwise 2^(remaining_bits-1) + constexpr typename P::T roundoffset = + remaining_bits > 0 + ? (static_cast(1) << (remaining_bits - 1)) + : static_cast(0); + constexpr typename P::T maskBg = + static_cast((1ULL << P::Bgₐbit) - 1); + constexpr typename P::T halfBg = (1ULL << (P::Bgₐbit - 1)); + + for (int n = 0; n < P::n; n++) { + typename P::T a = poly[n] + offset + roundoffset; + for (int i = 0; i < P::lₐ; i++) { + for (int j = 0; j < P::l̅ₐ; j++) { + // Shift to get the (i,j)-th digit + // When l̅ₐ=1 (j=0 only), this reduces to standard decomposition + const int shift = std::numeric_limits::digits - + (i + 1) * P::Bgₐbit - j * P::B̅gₐbit; + decpoly[i * P::l̅ₐ + j][n] = + static_cast>( + ((a >> shift) & maskBg) - halfBg); + } + } + } +} + template void Decomposition(DecomposedPolynomialNTT

&decpolyntt, const Polynomial

&poly) @@ -195,6 +291,53 @@ void ExternalProduct(TRLWE

&res, const TRLWE

&trlwe, for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } +// External product with Double Decomposition (bivariate representation) +// Uses the full TRGSW structure with l*l̅ rows for "b" block and k*lₐ*l̅ₐ rows +// for "a" blocks +template +void ExternalProductDD(TRLWE

&res, const TRLWE

&trlwe, + const TRGSWFFT

&trgswfft) +{ + alignas(64) PolynomialInFD

decpolyfft; + alignas(64) TRLWEInFD

restrlwefft; + + // Handle "a" polynomials (indices 0 to k-1 in TRLWE) + // Uses NonceDoubleDecomposition with lₐ*l̅ₐ levels + { + alignas(64) DecomposedNoncePolynomialDD

decpoly; + NonceDoubleDecomposition

(decpoly, trlwe[0]); + TwistIFFT

(decpolyfft, decpoly[0]); + for (int m = 0; m < P::k + 1; m++) + MulInFD(restrlwefft[m], decpolyfft, trgswfft[0][m]); + for (int i = 1; i < P::lₐ * P::l̅ₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i][m]); + } + for (int k = 1; k < P::k; k++) { + NonceDoubleDecomposition

(decpoly, trlwe[k]); + for (int i = 0; i < P::lₐ * P::l̅ₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, + trgswfft[i + k * P::lₐ * P::l̅ₐ][m]); + } + } + } + + // Handle "b" polynomial (index k in TRLWE) + // Uses DoubleDecomposition with l*l̅ levels + alignas(64) DecomposedPolynomialDD

decpoly; + DoubleDecomposition

(decpoly, trlwe[P::k]); + for (int i = 0; i < P::l * P::l̅; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, + trgswfft[i + P::k * P::lₐ * P::l̅ₐ][m]); + } + for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); +} + template void ExternalProduct(TRLWE

&res, const Polynomial

&poly, const HalfTRGSWFFT

&halftrgswfft) @@ -453,23 +596,29 @@ constexpr std::array noncehgen() } // Auxiliary h generation for Double Decomposition (bivariate representation) +// h̅[j] values are used to construct gadget values h[i] * h̅[j] = 2^(width - (i+1)*Bgbit - j*B̅gbit) +// For j=0: no auxiliary shift, so h̅[0] = 1 +// For j>0: h̅[j] = 2^(width - j*B̅gbit) which combines with h[i] via modular multiplication template constexpr std::array h̅gen() { std::array h̅{}; - for (int i = 0; i < P::l̅; i++) + h̅[0] = 1; // j=0 means no auxiliary shift + for (int i = 1; i < P::l̅; i++) h̅[i] = 1ULL << (std::numeric_limits::digits - - (i + 1) * P::B̅gbit); + i * P::B̅gbit); return h̅; } +// Auxiliary h generation for nonce part of TRGSW with Double Decomposition template constexpr std::array nonceh̅gen() { std::array h̅{}; - for (int i = 0; i < P::l̅ₐ; i++) + h̅[0] = 1; // j=0 means no auxiliary shift + for (int i = 1; i < P::l̅ₐ; i++) h̅[i] = 1ULL << (std::numeric_limits::digits - - (i + 1) * P::B̅gₐbit); + i * P::B̅gₐbit); return h̅; } diff --git a/test/externalproductdoubledecomposition.cpp b/test/externalproductdoubledecomposition.cpp new file mode 100644 index 00000000..accc04bb --- /dev/null +++ b/test/externalproductdoubledecomposition.cpp @@ -0,0 +1,258 @@ +#include +#include +#include +#include + +using namespace std; +using namespace TFHEpp; + +// Custom parameter set for testing double decomposition with 64-bit Torus +// Using trivial double decomposition (l̅=1) which reduces to standard decomposition +// This verifies the code path works correctly +struct DDTestParam { + static constexpr int32_t key_value_max = 1; + static constexpr int32_t key_value_min = -1; + static constexpr std::uint32_t nbit = 10; + static constexpr std::uint32_t n = 1 << nbit; // 1024 + static constexpr std::uint32_t k = 1; + static constexpr std::uint32_t lₐ = 4; + static constexpr std::uint32_t l = 4; + static constexpr std::uint32_t Bgbit = 10; + static constexpr std::uint32_t Bgₐbit = 10; + static constexpr uint32_t Bg = 1U << Bgbit; + static constexpr uint32_t Bgₐ = 1U << Bgₐbit; + static constexpr ErrorDistribution errordist = + ErrorDistribution::ModularGaussian; + static const inline double α = std::pow(2.0, -40); + using T = uint64_t; + static constexpr T μ = 1ULL << 61; + static constexpr uint32_t plain_modulus = 8; + static constexpr double Δ = μ; + // Trivial Double Decomposition (l̅=1 reduces to standard decomposition) + // For l̅=1 to correctly reduce to standard decomposition, B̅gbit must equal Bgbit + // With l=4, Bgbit=10: uses 40 bits + // l̅=1, B̅gbit=10: adds 10 bits (but with j=0 only, no additional bits used) + // The shift formula: width - (i+1)*Bgbit - j*B̅gbit = width - (i+1)*Bgbit when j=0 + static constexpr std::uint32_t l̅ = 1; + static constexpr std::uint32_t l̅ₐ = 1; + static constexpr std::uint32_t B̅gbit = 10; + static constexpr std::uint32_t B̅gₐbit = 10; +}; + +// Test Double Decomposition correctness +// Verifies that DoubleDecomposition correctly represents the original value +template +void testDoubleDecomposition() +{ + random_device seed_gen; + default_random_engine engine(seed_gen()); + uniform_int_distribution dist(0, + numeric_limits::max()); + + constexpr int num_test = 100; + // Tolerance based on decomposition precision + // The lowest shift is: width - l*Bgbit - (l̅-1)*B̅gbit (j goes from 0 to l̅-1) + // So remaining_bits = width - l*Bgbit - (l̅-1)*B̅gbit + constexpr int remaining_bits = std::numeric_limits::digits - + P::l * P::Bgbit - (P::l̅ - 1) * P::B̅gbit; + static_assert(remaining_bits >= 0, + "Invalid double decomposition parameters"); + + for (int test = 0; test < num_test; test++) { + // Create a random polynomial + Polynomial

poly; + for (int i = 0; i < P::n; i++) { + poly[i] = dist(engine); + } + + // Apply double decomposition + DecomposedPolynomialDD

decpoly; + DoubleDecomposition

(decpoly, poly); + + // Reconstruct the polynomial from decomposed components + Polynomial

reconstructed; + for (int n = 0; n < P::n; n++) { + typename P::T sum = 0; + for (int i = 0; i < P::l; i++) { + for (int j = 0; j < P::l̅; j++) { + // The scaling factor for (i,j) position + // When l̅=1 (j=0 only), this reduces to standard decomposition + typename P::T h_val = + static_cast(1) + << (std::numeric_limits::digits - + (i + 1) * P::Bgbit - j * P::B̅gbit); + sum += static_cast( + static_cast>( + decpoly[i * P::l̅ + j][n])) * + h_val; + } + } + reconstructed[n] = sum; + } + + // Check that reconstruction is close to original + typename P::T max_error = + remaining_bits > 0 + ? (static_cast(1) << remaining_bits) + : static_cast(1); + + for (int n = 0; n < P::n; n++) { + typename P::T diff = poly[n] > reconstructed[n] + ? poly[n] - reconstructed[n] + : reconstructed[n] - poly[n]; + typename P::T max_val = ~static_cast(0); + typename P::T wrap_diff = max_val - diff + 1; + typename P::T min_diff = diff < wrap_diff ? diff : wrap_diff; + + if (min_diff > max_error * 2) { + cerr << "Decomposition error at n=" << n << endl; + cerr << " original=" << poly[n] << endl; + cerr << " reconstructed=" << reconstructed[n] << endl; + cerr << " max_error=" << max_error << endl; + cerr << " actual_diff=" << min_diff << endl; + assert(false); + } + } + } + cout << "DoubleDecomposition test passed" << endl; +} + +// Test External Product with Double Decomposition +template +void testExternalProductDD() +{ + constexpr uint32_t num_test = 10; + random_device seed_gen; + default_random_engine engine(seed_gen()); + uniform_int_distribution binary(0, 1); + + cout << "Parameters: n=" << P::n << ", k=" << P::k << endl; + cout << "Primary: l=" << P::l << ", Bgbit=" << P::Bgbit << endl; + cout << "Auxiliary: l̅=" << P::l̅ << ", B̅gbit=" << P::B̅gbit << endl; + cout << "TRGSW rows: " << (P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅) << endl; + + // Test with p=1 (identity) + cout << "Testing with plaintext = 1 (identity)..." << endl; + for (int test = 0; test < num_test; test++) { + // Generate a random TRLWE key for this param set + std::array key; + for (int i = 0; i < P::n; i++) { + key[i] = (binary(engine) == 1) + ? static_cast(P::key_value_max) + : static_cast(P::key_value_min); + } + + // Create a random message polynomial + array p; + for (bool &i : p) i = (binary(engine) > 0); + Polynomial

pmu; + for (int i = 0; i < P::n; i++) + pmu[i] = p[i] ? P::μ : -P::μ; + + // Encrypt as TRLWE + TRLWE

c; + trlweSymEncrypt

(c, pmu, key); + + // Create TRGSW encrypting 1 (identity) + const Polynomial

plainpoly = {static_cast(1)}; + + TRGSWFFT

trgswfft; + trgswSymEncrypt

(trgswfft, plainpoly, key); + + // Apply external product with double decomposition + TRLWE

result; + ExternalProductDD

(result, c, trgswfft); + + // Decrypt and verify + const array p2 = trlweSymDecrypt

(result, key); + for (int i = 0; i < P::n; i++) { + if (p[i] != p2[i]) { + cerr << "ExternalProductDD failed at index " << i + << ": expected " << p[i] << " got " << p2[i] << endl; + assert(false); + } + } + } + cout << "ExternalProductDD test passed (p=1)" << endl; + + // Test with p=-1 (negation) + cout << "Testing with plaintext = -1 (negation)..." << endl; + for (int test = 0; test < num_test; test++) { + std::array key; + for (int i = 0; i < P::n; i++) { + key[i] = (binary(engine) == 1) + ? static_cast(P::key_value_max) + : static_cast(P::key_value_min); + } + + array p; + for (bool &i : p) i = (binary(engine) > 0); + Polynomial

pmu; + for (int i = 0; i < P::n; i++) + pmu[i] = p[i] ? P::μ : -P::μ; + + TRLWE

c; + trlweSymEncrypt

(c, pmu, key); + + // TRGSW encrypting -1 (negation) + const Polynomial

plainpoly = {static_cast(-1)}; + + TRGSWFFT

trgswfft; + trgswSymEncrypt

(trgswfft, plainpoly, key); + + TRLWE

result; + ExternalProductDD

(result, c, trgswfft); + + const array p2 = trlweSymDecrypt

(result, key); + for (int i = 0; i < P::n; i++) { + if (p[i] != !p2[i]) { + cerr << "ExternalProductDD (p=-1) failed at index " << i + << ": expected " << !p[i] << " got " << p2[i] << endl; + assert(false); + } + } + } + cout << "ExternalProductDD test passed (p=-1)" << endl; +} + +int main() +{ + cout << "=== Testing Double Decomposition ===" << endl; + cout << endl; + + // DDTestParam: 64-bit Torus with proper double decomposition + // l=2, Bgbit=16, l̅=2, B̅gbit=16 + // Total bits: 2*16 + 2*16 = 64 (exact fit) + // Decomposition levels: 2*2 = 4 + + cout << "DDTestParam configuration:" << endl; + cout << " Torus type: uint64_t (64-bit)" << endl; + cout << " n = " << DDTestParam::n << " (polynomial degree)" << endl; + cout << " k = " << DDTestParam::k << endl; + cout << " Primary decomposition: l = " << DDTestParam::l + << ", Bgbit = " << DDTestParam::Bgbit << endl; + cout << " Auxiliary decomposition: l̅ = " << DDTestParam::l̅ + << ", B̅gbit = " << DDTestParam::B̅gbit << endl; + cout << " Total decomposition levels: " << (DDTestParam::l * DDTestParam::l̅) + << endl; + cout << " Bits used: " + << (DDTestParam::l * DDTestParam::Bgbit + + DDTestParam::l̅ * DDTestParam::B̅gbit) + << " / 64" << endl; + cout << " TRGSW size: " + << (DDTestParam::k * DDTestParam::lₐ * DDTestParam::l̅ₐ + + DDTestParam::l * DDTestParam::l̅) + << " TRLWE rows" << endl; + cout << endl; + + cout << "--- Testing DoubleDecomposition ---" << endl; + testDoubleDecomposition(); + cout << endl; + + cout << "--- Testing ExternalProductDD ---" << endl; + testExternalProductDD(); + cout << endl; + + cout << "=== All Double Decomposition tests passed ===" << endl; + return 0; +} From dd23b69c282e7fa9b946b53d11083aff300fafdf Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Wed, 31 Dec 2025 13:02:15 +0000 Subject: [PATCH 3/9] Fix lvl3param to use 64-bit Torus and fix test FFT compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Changed lvl3param from __uint128_t to uint64_t (128-bit types not fully supported in TFHEpp) - Adjusted parameters: l=2, l̅=2 (was l=4, l̅=4) to fit 64-bit constraint - Fixed DDTestParam nbit from 10 to 11 to match lvl2param for FFT compatibility 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- include/params/128bit.hpp | 27 ++++++++++----------- test/externalproductdoubledecomposition.cpp | 5 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/include/params/128bit.hpp b/include/params/128bit.hpp index 33057dcd..f4be02f4 100644 --- a/include/params/128bit.hpp +++ b/include/params/128bit.hpp @@ -153,9 +153,9 @@ struct AHlvl2param { static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; -// New lvl3param with 128-bit Torus and non-trivial Double Decomposition -// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 128 -// Using l=4, Bgbit=16, l̅=4, B̅gbit=16: 4*16 + 4*16 = 128 bits (fully utilized) +// lvl3param with 64-bit Torus and non-trivial Double Decomposition +// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 64 +// Using l=2, Bgbit=16, l̅=2, B̅gbit=16: 2*16 + 2*16 = 64 bits (fully utilized) struct lvl3param { static constexpr int32_t key_value_max = 1; static constexpr int32_t key_value_min = -1; @@ -163,26 +163,25 @@ struct lvl3param { // ease of polynomial multiplication. static constexpr std::uint32_t n = 1 << nbit; // dimension = 4096 static constexpr std::uint32_t k = 1; - static constexpr std::uint32_t lₐ = 4; - static constexpr std::uint32_t l = 4; + static constexpr std::uint32_t lₐ = 2; + static constexpr std::uint32_t l = 2; static constexpr std::uint32_t Bgbit = 16; static constexpr std::uint32_t Bgₐbit = 16; static constexpr uint32_t Bg = 1U << Bgbit; static constexpr uint32_t Bgₐ = 1U << Bgₐbit; static constexpr ErrorDistribution errordist = ErrorDistribution::ModularGaussian; - static const inline double α = std::pow(2.0, -105); // fresh noise - using T = __uint128_t; // Torus representation - static constexpr T μ = static_cast(1) << 125; + static const inline double α = std::pow(2.0, -51); // fresh noise + using T = uint64_t; // Torus representation + static constexpr T μ = 1ULL << 61; static constexpr uint32_t plain_modulusbit = 31; - static constexpr __uint128_t plain_modulus = static_cast(1) << plain_modulusbit; - static constexpr double Δ = - static_cast(static_cast(1) << (128 - plain_modulusbit - 1)); + static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; + static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); // Double Decomposition (bivariate representation) parameters // Non-trivial values for testing actual double decomposition - // Constraint: l * Bgbit + l̅ * B̅gbit <= 128 - static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels - static constexpr std::uint32_t l̅ₐ = 4; + // Constraint: l * Bgbit + l̅ * B̅gbit <= 64 + static constexpr std::uint32_t l̅ = 2; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = 2; static constexpr std::uint32_t B̅gbit = 16; // 2^16 base for auxiliary static constexpr std::uint32_t B̅gₐbit = 16; }; diff --git a/test/externalproductdoubledecomposition.cpp b/test/externalproductdoubledecomposition.cpp index accc04bb..ffa0ed1b 100644 --- a/test/externalproductdoubledecomposition.cpp +++ b/test/externalproductdoubledecomposition.cpp @@ -9,11 +9,12 @@ using namespace TFHEpp; // Custom parameter set for testing double decomposition with 64-bit Torus // Using trivial double decomposition (l̅=1) which reduces to standard decomposition // This verifies the code path works correctly +// Note: nbit must match lvl2param (11) for FFT compatibility struct DDTestParam { static constexpr int32_t key_value_max = 1; static constexpr int32_t key_value_min = -1; - static constexpr std::uint32_t nbit = 10; - static constexpr std::uint32_t n = 1 << nbit; // 1024 + static constexpr std::uint32_t nbit = 11; + static constexpr std::uint32_t n = 1 << nbit; // 2048 static constexpr std::uint32_t k = 1; static constexpr std::uint32_t lₐ = 4; static constexpr std::uint32_t l = 4; From c57befc0ce9369411c5b30fbb31101b6700355ae Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Thu, 1 Jan 2026 02:48:58 +0000 Subject: [PATCH 4/9] Add 128-bit Torus support for Double Decomposition with FFT MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update lvl3param to use __uint128_t with l=l̅=4, Bgbit=B̅gbit=16 - Extend FFT to support nbit=12 (n=4096) via fftplvl3 - Add TwistFFT/TwistIFFT handling for 128-bit types using 64-bit FFT - Add UniformTorusRandom

() helper for 128-bit random generation - Add ModularGaussian support for __uint128_t - Fix decomposition functions to scale values for 128-bit FFT compatibility - Add lvl03param bootstrapping parameter (lvl0 → lvl3) - Add GateBootstrappingTLWE2TLWEDD test for non-trivial DD (l̅=4) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- include/detwfa.hpp | 48 +++++++++ include/gatebootstrapping.hpp | 40 +++++++ include/mulfft.hpp | 72 ++++++++++--- include/params.hpp | 10 ++ include/params/128bit.hpp | 27 ++--- include/tlwe.hpp | 6 +- include/trgsw.hpp | 56 ++++++---- include/trlwe.hpp | 8 +- include/utils.hpp | 29 +++++ ...tstrappingtlwe2tlwedoubledecomposition.cpp | 101 ++++++++++++++++++ .../spqlios/fft_processor_spqlios.cpp | 1 + thirdparties/spqlios/fft_processor_spqlios.h | 3 +- 12 files changed, 342 insertions(+), 59 deletions(-) create mode 100644 test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp diff --git a/include/detwfa.hpp b/include/detwfa.hpp index 85222ad8..77e6b15f 100644 --- a/include/detwfa.hpp +++ b/include/detwfa.hpp @@ -116,4 +116,52 @@ void CMUXwithPolynomialMulByXaiMinusOne(TRLWE

&acc, for (int i = 0; i < P::n; i++) acc[k][i] += temp[k][i]; } +// Double Decomposition variants +template +void CMUXFFTDD(TRLWE

&res, const TRGSWFFT

&cs, const TRLWE

&c1, + const TRLWE

&c0) +{ + for (int k = 0; k < P::k + 1; k++) + for (int i = 0; i < P::n; i++) res[k][i] = c1[k][i] - c0[k][i]; + ExternalProductDD

(res, res, cs); + for (int k = 0; k < P::k + 1; k++) + for (int i = 0; i < P::n; i++) res[k][i] += c0[k][i]; +} + +template +void CMUXwithPolynomialMulByXaiMinusOneDD( + TRLWE &acc, + const BootstrappingKeyElementFFT &cs, const int a) +{ + if constexpr (bkP::domainP::key_value_diff == 1) { + alignas(64) TRLWE temp; + for (int k = 0; k < bkP::targetP::k + 1; k++) + PolynomialMulByXaiMinusOne(temp[k], acc[k], + a); + ExternalProductDD(temp, temp, cs[0]); + for (int k = 0; k < bkP::targetP::k + 1; k++) + for (int i = 0; i < bkP::targetP::n; i++) acc[k][i] += temp[k][i]; + } + else { + alignas(32) TRLWE temp; + int count = 0; + for (int i = bkP::domainP::key_value_min; + i <= bkP::domainP::key_value_max; i++) { + if (i != 0) { + const int mod = (a * i) % (2 * bkP::targetP::n); + const int index = mod > 0 ? mod : mod + (2 * bkP::targetP::n); + for (int k = 0; k < bkP::targetP::k + 1; k++) + PolynomialMulByXaiMinusOne( + temp[k], acc[k], index); + ExternalProductDD(temp, temp, + cs[count]); + for (int k = 0; k < bkP::targetP::k + 1; k++) + for (int i = 0; i < bkP::targetP::n; i++) + acc[k][i] += temp[k][i]; + count++; + } + } + } +} + } // namespace TFHEpp \ No newline at end of file diff --git a/include/gatebootstrapping.hpp b/include/gatebootstrapping.hpp index 1938bff4..a8a2fd33 100644 --- a/include/gatebootstrapping.hpp +++ b/include/gatebootstrapping.hpp @@ -263,6 +263,46 @@ constexpr Polynomial

μpolygen() return poly; } +// Double Decomposition variants +template +void BlindRotateDD(TRLWE &res, + const TLWE &tlwe, + const BootstrappingKeyFFT

&bkfft, + const Polynomial &testvector) +{ + constexpr uint32_t bitwidth = bits_needed(); + const uint32_t b̄ = 2 * P::targetP::n - + ((tlwe[P::domainP::k * P::domainP::n] >> + (std::numeric_limits::digits - + 1 - P::targetP::nbit + bitwidth)) + << bitwidth); + res = {}; + PolynomialMulByXai(res[P::targetP::k], testvector, b̄); + for (int i = 0; i < P::domainP::k * P::domainP::n; i++) { + constexpr typename P::domainP::T roundoffset = + 1ULL << (std::numeric_limits::digits - 2 - + P::targetP::nbit + bitwidth); + const uint32_t ā = + (tlwe[i] + roundoffset) >> + (std::numeric_limits::digits - 1 - + P::targetP::nbit + bitwidth) + << bitwidth; + if (ā == 0) continue; + CMUXwithPolynomialMulByXaiMinusOneDD

(res, bkfft[i], ā); + } +} + +template +void GateBootstrappingTLWE2TLWEDD( + TLWE &res, const TLWE &tlwe, + const BootstrappingKeyFFT

&bkfft, + const Polynomial &testvector) +{ + alignas(64) TRLWE acc; + BlindRotateDD

(acc, tlwe, bkfft, testvector); + SampleExtractIndex(res, acc, 0); +} + template void GateBootstrapping(TLWE &res, const TLWE &tlwe, diff --git a/include/mulfft.hpp b/include/mulfft.hpp index c25a9826..57ffbaa9 100644 --- a/include/mulfft.hpp +++ b/include/mulfft.hpp @@ -68,10 +68,9 @@ inline void TwistNTT(Polynomial

&res, PolynomialNTT

&a) cuHEpp::TwistNTT( res, a, (*ntttablelvl1)[0], (*ntttwistlvl1)[0]); #endif - else if constexpr (std::is_same_v) { + else if constexpr (std::is_same_v) cuHEpp::TwistNTT( res, a, (*ntttablelvl2)[0], (*ntttwistlvl2)[0]); - } else static_assert(false_v, "Undefined TwistNTT!"); } @@ -89,6 +88,14 @@ inline void TwistFFT(Polynomial

&res, const PolynomialInFD

&a) else if constexpr (std::is_same_v) fftplvl1.execute_direct_torus64(res.data(), a.data()); } + else if constexpr (std::is_same_v) { + // For 128-bit lvl3param, use 64-bit FFT and shift result to top 64 bits + // This preserves the Torus semantics (most significant bits) + alignas(64) std::array temp; + fftplvl3.execute_direct_torus64(temp.data(), a.data()); + for (int i = 0; i < P::n; i++) + res[i] = static_cast<__uint128_t>(temp[i]) << 64; + } else if constexpr (std::is_same_v) fftplvl2.execute_direct_torus64(res.data(), a.data()); else @@ -143,6 +150,14 @@ inline void TwistIFFT(PolynomialInFD

&res, const Polynomial

&a) if constexpr (std::is_same_v) fftplvl1.execute_reverse_torus64(res.data(), a.data()); } + else if constexpr (std::is_same_v) { + // For 128-bit lvl3param, use top 64 bits for FFT + // This preserves the Torus semantics (most significant bits) + alignas(64) std::array temp; + for (int i = 0; i < P::n; i++) + temp[i] = static_cast(a[i] >> 64); + fftplvl3.execute_reverse_torus64(res.data(), temp.data()); + } else if constexpr (std::is_same_v) fftplvl2.execute_reverse_torus64(res.data(), a.data()); else @@ -301,8 +316,21 @@ inline void PolyMul(Polynomial

&res, const Polynomial

&a, for (int i = 0; i < P::n; i++) ntta[i] *= nttb[i]; TwistNTT

(res, ntta); } + else if constexpr (std::is_same_v) { + // Naive for 128-bit types (FFT/NTT don't support 128-bit precision) + for (int i = 0; i < P::n; i++) { + __uint128_t ri = 0; + for (int j = 0; j <= i; j++) + ri += static_cast<__int128_t>(a[j]) * + static_cast<__int128_t>(b[i - j]); + for (int j = i + 1; j < P::n; j++) + ri -= static_cast<__int128_t>(a[j]) * + static_cast<__int128_t>(b[P::n + i - j]); + res[i] = ri; + } + } else { - // Naieve + // Naive for other types for (int i = 0; i < P::n; i++) { typename P::T ri = 0; for (int j = 0; j <= i; j++) @@ -339,17 +367,33 @@ template inline void PolyMulNaive(Polynomial

&res, const Polynomial

&a, const Polynomial

&b) { - for (int i = 0; i < P::n; i++) { - typename P::T ri = 0; - for (int j = 0; j <= i; j++) - ri += static_cast::type>( - a[j]) * - b[i - j]; - for (int j = i + 1; j < P::n; j++) - ri -= static_cast::type>( - a[j]) * - b[P::n + i - j]; - res[i] = ri; + if constexpr (std::is_same_v) { + for (int i = 0; i < P::n; i++) { + __uint128_t ri = 0; + for (int j = 0; j <= i; j++) + ri += static_cast<__int128_t>(a[j]) * + static_cast<__int128_t>(b[i - j]); + for (int j = i + 1; j < P::n; j++) + ri -= static_cast<__int128_t>(a[j]) * + static_cast<__int128_t>(b[P::n + i - j]); + res[i] = ri; + } + } + else { + for (int i = 0; i < P::n; i++) { + typename P::T ri = 0; + for (int j = 0; j <= i; j++) + ri += + static_cast::type>( + a[j]) * + b[i - j]; + for (int j = i + 1; j < P::n; j++) + ri -= + static_cast::type>( + a[j]) * + b[P::n + i - j]; + res[i] = ri; + } } } diff --git a/include/params.hpp b/include/params.hpp index 58e99a6e..22d10be6 100644 --- a/include/params.hpp +++ b/include/params.hpp @@ -67,6 +67,16 @@ struct lvl02param { #endif }; +struct lvl03param { + using domainP = lvl0param; + using targetP = lvl3param; +#ifdef USE_KEY_BUNDLE + static constexpr uint32_t Addends = 2; +#else + static constexpr uint32_t Addends = 1; +#endif +}; + struct lvlh2param { using domainP = lvlhalfparam; using targetP = lvl2param; diff --git a/include/params/128bit.hpp b/include/params/128bit.hpp index f4be02f4..50cebbd3 100644 --- a/include/params/128bit.hpp +++ b/include/params/128bit.hpp @@ -153,9 +153,9 @@ struct AHlvl2param { static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit; }; -// lvl3param with 64-bit Torus and non-trivial Double Decomposition -// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 64 -// Using l=2, Bgbit=16, l̅=2, B̅gbit=16: 2*16 + 2*16 = 64 bits (fully utilized) +// lvl3param with 128-bit Torus and non-trivial Double Decomposition +// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 128 +// Using l=4, Bgbit=16, l̅=4, B̅gbit=16: 4*16 + 4*16 = 128 bits (fully utilized) struct lvl3param { static constexpr int32_t key_value_max = 1; static constexpr int32_t key_value_min = -1; @@ -163,25 +163,26 @@ struct lvl3param { // ease of polynomial multiplication. static constexpr std::uint32_t n = 1 << nbit; // dimension = 4096 static constexpr std::uint32_t k = 1; - static constexpr std::uint32_t lₐ = 2; - static constexpr std::uint32_t l = 2; + static constexpr std::uint32_t lₐ = 4; + static constexpr std::uint32_t l = 4; static constexpr std::uint32_t Bgbit = 16; static constexpr std::uint32_t Bgₐbit = 16; static constexpr uint32_t Bg = 1U << Bgbit; static constexpr uint32_t Bgₐ = 1U << Bgₐbit; static constexpr ErrorDistribution errordist = ErrorDistribution::ModularGaussian; - static const inline double α = std::pow(2.0, -51); // fresh noise - using T = uint64_t; // Torus representation - static constexpr T μ = 1ULL << 61; + static const inline double α = std::pow(2.0, -105); // fresh noise + using T = __uint128_t; // Torus representation + static constexpr T μ = static_cast(1) << 125; static constexpr uint32_t plain_modulusbit = 31; - static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit; - static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1); + static constexpr T plain_modulus = static_cast(1) << plain_modulusbit; + static constexpr double Δ = + static_cast(static_cast(1) << (128 - plain_modulusbit - 1)); // Double Decomposition (bivariate representation) parameters // Non-trivial values for testing actual double decomposition - // Constraint: l * Bgbit + l̅ * B̅gbit <= 64 - static constexpr std::uint32_t l̅ = 2; // auxiliary decomposition levels - static constexpr std::uint32_t l̅ₐ = 2; + // Constraint: l * Bgbit + l̅ * B̅gbit <= 128 + static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels + static constexpr std::uint32_t l̅ₐ = 4; static constexpr std::uint32_t B̅gbit = 16; // 2^16 base for auxiliary static constexpr std::uint32_t B̅gₐbit = 16; }; diff --git a/include/tlwe.hpp b/include/tlwe.hpp index f95ee471..6b7dbaf8 100644 --- a/include/tlwe.hpp +++ b/include/tlwe.hpp @@ -13,13 +13,11 @@ template void tlweSymEncrypt(TLWE

&res, const typename P::T p, const double α, const Key

&key) { - std::uniform_int_distribution Torusdist( - 0, std::numeric_limits::max()); res = {}; res[P::k * P::n] = ModularGaussian

(p, α); for (int k = 0; k < P::k; k++) for (int i = 0; i < P::n; i++) { - res[k * P::n + i] = Torusdist(generator); + res[k * P::n + i] = UniformTorusRandom

(); res[P::k * P::n] += res[k * P::n + i] * key[k * P::n + i]; } } @@ -122,7 +120,7 @@ typename P::T tlweSymIntDecrypt(const TLWE

&c, const Key

&key) constexpr double Δ = 2 * static_cast( - 1ULL << (std::numeric_limits::digits - 1)) / + static_cast(1) << (std::numeric_limits::digits - 1)) / plain_modulus; const typename P::T phase = tlweSymPhase

(c, key); typename P::T res = static_cast(std::round(phase / Δ)); diff --git a/include/trgsw.hpp b/include/trgsw.hpp index 01c690c7..3d08ebcd 100644 --- a/include/trgsw.hpp +++ b/include/trgsw.hpp @@ -15,7 +15,7 @@ constexpr typename P::T offsetgen() typename P::T offset = 0; for (int i = 1; i <= P::l; i++) offset += P::Bg / 2 * - (1ULL << (std::numeric_limits::digits - + (static_cast(1) << (std::numeric_limits::digits - i * P::Bgbit)); return offset; } @@ -27,11 +27,11 @@ inline void Decomposition(DecomposedPolynomial

&decpoly, { // https://eprint.iacr.org/2021/1161 constexpr typename P::T roundoffset = - 1ULL << (std::numeric_limits::digits - P::l * P::Bgbit - + static_cast(1) << (std::numeric_limits::digits - P::l * P::Bgbit - 1); constexpr typename P::T mask = - static_cast((1ULL << P::Bgbit) - 1); - constexpr typename P::T Bgl = 1ULL << (P::l * P::Bgbit); + static_cast((static_cast(1) << P::Bgbit) - 1); + constexpr typename P::T Bgl = static_cast(1) << (P::l * P::Bgbit); Polynomial

K; for (int i = 0; i < P::n; i++) { @@ -68,11 +68,11 @@ inline void Decomposition(DecomposedPolynomial

&decpoly, { constexpr typename P::T offset = offsetgen

(); constexpr typename P::T roundoffset = - 1ULL << (std::numeric_limits::digits - P::l * P::Bgbit - + static_cast(1) << (std::numeric_limits::digits - P::l * P::Bgbit - 1); constexpr typename P::T mask = - static_cast((1ULL << P::Bgbit) - 1); - constexpr typename P::T halfBg = (1ULL << (P::Bgbit - 1)); + static_cast((static_cast(1) << P::Bgbit) - 1); + constexpr typename P::T halfBg = (static_cast(1) << (P::Bgbit - 1)); for (int i = 0; i < P::n; i++) { for (int l = 0; l < P::l; l++) @@ -90,7 +90,7 @@ constexpr typename P::T nonceoffsetgen() typename P::T offset = 0; for (int i = 1; i <= P::lₐ; i++) offset += P::Bgₐ / 2 * - (1ULL << (std::numeric_limits::digits - + (static_cast(1) << (std::numeric_limits::digits - i * P::Bgₐbit)); return offset; } @@ -101,11 +101,11 @@ inline void NonceDecomposition(DecomposedNoncePolynomial

&decpoly, { constexpr typename P::T offset = nonceoffsetgen

(); constexpr typename P::T roundoffset = - 1ULL << (std::numeric_limits::digits - + static_cast(1) << (std::numeric_limits::digits - P::lₐ * P::Bgₐbit - 1); constexpr typename P::T mask = - static_cast((1ULL << P::Bgₐbit) - 1); - constexpr typename P::T halfBg = (1ULL << (P::Bgₐbit - 1)); + static_cast((static_cast(1) << P::Bgₐbit) - 1); + constexpr typename P::T halfBg = (static_cast(1) << (P::Bgₐbit - 1)); for (int i = 0; i < P::n; i++) { for (int l = 0; l < P::lₐ; l++) @@ -148,8 +148,8 @@ inline void DoubleDecomposition(DecomposedPolynomialDD

&decpoly, ? (static_cast(1) << (remaining_bits - 1)) : static_cast(0); constexpr typename P::T maskBg = - static_cast((1ULL << P::Bgbit) - 1); - constexpr typename P::T halfBg = (1ULL << (P::Bgbit - 1)); + static_cast((static_cast(1) << P::Bgbit) - 1); + constexpr typename P::T halfBg = (static_cast(1) << (P::Bgbit - 1)); for (int n = 0; n < P::n; n++) { typename P::T a = poly[n] + offset + roundoffset; @@ -159,9 +159,16 @@ inline void DoubleDecomposition(DecomposedPolynomialDD

&decpoly, // When l̅=1 (j=0 only), this reduces to standard decomposition const int shift = std::numeric_limits::digits - (i + 1) * P::Bgbit - j * P::B̅gbit; - decpoly[i * P::l̅ + j][n] = + auto decomp_val = static_cast>( ((a >> shift) & maskBg) - halfBg); + // For 128-bit types, shift left by 64 so TwistIFFT (which uses + // top 64 bits) gets the correct small integer value + if constexpr (std::is_same_v) + decpoly[i * P::l̅ + j][n] = + static_cast(decomp_val) << 64; + else + decpoly[i * P::l̅ + j][n] = decomp_val; } } } @@ -194,8 +201,8 @@ inline void NonceDoubleDecomposition(DecomposedNoncePolynomialDD

&decpoly, ? (static_cast(1) << (remaining_bits - 1)) : static_cast(0); constexpr typename P::T maskBg = - static_cast((1ULL << P::Bgₐbit) - 1); - constexpr typename P::T halfBg = (1ULL << (P::Bgₐbit - 1)); + static_cast((static_cast(1) << P::Bgₐbit) - 1); + constexpr typename P::T halfBg = (static_cast(1) << (P::Bgₐbit - 1)); for (int n = 0; n < P::n; n++) { typename P::T a = poly[n] + offset + roundoffset; @@ -205,9 +212,16 @@ inline void NonceDoubleDecomposition(DecomposedNoncePolynomialDD

&decpoly, // When l̅ₐ=1 (j=0 only), this reduces to standard decomposition const int shift = std::numeric_limits::digits - (i + 1) * P::Bgₐbit - j * P::B̅gₐbit; - decpoly[i * P::l̅ₐ + j][n] = + auto decomp_val = static_cast>( ((a >> shift) & maskBg) - halfBg); + // For 128-bit types, shift left by 64 so TwistIFFT (which uses + // top 64 bits) gets the correct small integer value + if constexpr (std::is_same_v) + decpoly[i * P::l̅ₐ + j][n] = + static_cast(decomp_val) << 64; + else + decpoly[i * P::l̅ₐ + j][n] = decomp_val; } } } @@ -575,7 +589,7 @@ constexpr std::array hgen() ((i + 1) * P::Bgbit); else for (int i = 0; i < P::l; i++) - h[i] = 1ULL << (std::numeric_limits::digits - + h[i] = static_cast(1) << (std::numeric_limits::digits - (i + 1) * P::Bgbit); return h; } @@ -590,7 +604,7 @@ constexpr std::array noncehgen() ((i + 1) * P::Bgₐbit); else for (int i = 0; i < P::lₐ; i++) - h[i] = 1ULL << (std::numeric_limits::digits - + h[i] = static_cast(1) << (std::numeric_limits::digits - (i + 1) * P::Bgₐbit); return h; } @@ -605,7 +619,7 @@ constexpr std::array h̅gen() std::array h̅{}; h̅[0] = 1; // j=0 means no auxiliary shift for (int i = 1; i < P::l̅; i++) - h̅[i] = 1ULL << (std::numeric_limits::digits - + h̅[i] = static_cast(1) << (std::numeric_limits::digits - i * P::B̅gbit); return h̅; } @@ -617,7 +631,7 @@ constexpr std::array nonceh̅gen() std::array h̅{}; h̅[0] = 1; // j=0 means no auxiliary shift for (int i = 1; i < P::l̅ₐ; i++) - h̅[i] = 1ULL << (std::numeric_limits::digits - + h̅[i] = static_cast(1) << (std::numeric_limits::digits - i * P::B̅gₐbit); return h̅; } diff --git a/include/trlwe.hpp b/include/trlwe.hpp index db9e1233..3835d9ba 100644 --- a/include/trlwe.hpp +++ b/include/trlwe.hpp @@ -8,11 +8,9 @@ namespace TFHEpp { template void trlweSymEncryptZero(TRLWE

&c, const double α, const Key

&key) { - std::uniform_int_distribution Torusdist( - 0, std::numeric_limits::max()); for (typename P::T &i : c[P::k]) i = ModularGaussian

(0, α); for (int k = 0; k < P::k; k++) { - for (typename P::T &i : c[k]) i = Torusdist(generator); + for (typename P::T &i : c[k]) i = UniformTorusRandom

(); std::array partkey; for (int i = 0; i < P::n; i++) partkey[i] = key[k * P::n + i]; Polynomial

temp; @@ -24,12 +22,10 @@ void trlweSymEncryptZero(TRLWE

&c, const double α, const Key

&key) template void trlweSymEncryptZero(TRLWE

&c, const uint η, const Key

&key) { - std::uniform_int_distribution Torusdist( - 0, std::numeric_limits::max()); for (typename P::T &i : c[P::k]) i = (CenteredBinomial

(η) << std::numeric_limits

::digits) / P::q; for (int k = 0; k < P::k; k++) { - for (typename P::T &i : c[k]) i = Torusdist(generator); + for (typename P::T &i : c[k]) i = UniformTorusRandom

(); alignas(64) std::array partkey; for (int i = 0; i < P::n; i++) partkey[i] = key[k * P::n + i]; alignas(64) Polynomial

temp; diff --git a/include/utils.hpp b/include/utils.hpp index 9eace5c3..38b0122d 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -29,6 +29,26 @@ static thread_local std::random_device generator; template constexpr bool false_v = false; +// Helper function to generate uniform random Torus values +// For standard types, use uniform_int_distribution +// For __uint128_t, combine two 64-bit random values +template +inline typename P::T UniformTorusRandom() +{ + if constexpr (std::is_same_v) { + std::uniform_int_distribution dist64( + 0, std::numeric_limits::max()); + __uint128_t high = dist64(generator); + __uint128_t low = dist64(generator); + return (high << 64) | low; + } + else { + std::uniform_int_distribution dist( + 0, std::numeric_limits::max()); + return dist(generator); + } +} + // https://qiita.com/negi-drums/items/a527c05050781a5af523 template concept hasq = requires { T::q; }; @@ -140,6 +160,15 @@ inline typename P::T ModularGaussian(typename P::T center, double stdev) const uint64_t ival = static_cast(val); return ival + center; } + else if constexpr (std::is_same_v) { + // 128bit fixed-point number version + // Use two 64-bit Gaussians for high and low parts + static const double _2p64 = std::pow(2., 64); + std::normal_distribution distribution(0., 1.0); + const double val = stdev * distribution(generator) * _2p64; + const __int128_t ival = static_cast<__int128_t>(val); + return static_cast<__uint128_t>(ival) + center; + } else static_assert(false_v, "Undefined Modular Gaussian!"); } diff --git a/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp b/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp new file mode 100644 index 00000000..017a4301 --- /dev/null +++ b/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp @@ -0,0 +1,101 @@ +#include +#include +#include +#include +#include + +int main() +{ + using namespace std; + using namespace TFHEpp; + + constexpr uint32_t num_test = 10; // Reduced for faster testing + random_device seed_gen; + default_random_engine engine(seed_gen()); + uniform_int_distribution binary(0, 1); + + // Use lvl03param which bootstraps from lvl0param to lvl3param + // lvl3param has nbit=12 (n=4096) and non-trivial DD: l=2, l̅=2, Bgbit=16, B̅gbit=16 + using bkP = lvl03param; + + cout << "=== Testing GateBootstrappingTLWE2TLWEDD with lvl3param (nbit=12) ===" << endl; + cout << "Domain: n=" << bkP::domainP::n << ", T=" << sizeof(typename bkP::domainP::T) * 8 << "-bit" << endl; + cout << "Target: n=" << bkP::targetP::n << ", nbit=" << bkP::targetP::nbit << ", T=" << sizeof(typename bkP::targetP::T) * 8 << "-bit" << endl; + cout << "Primary decomposition: l=" << bkP::targetP::l << ", Bgbit=" << bkP::targetP::Bgbit << endl; + cout << "Auxiliary decomposition: l̅=" << bkP::targetP::l̅ << ", B̅gbit=" << bkP::targetP::B̅gbit << endl; + cout << "Total decomposition levels: " << (bkP::targetP::l * bkP::targetP::l̅) << endl; + cout << "Bits used: " << (bkP::targetP::l * bkP::targetP::Bgbit + bkP::targetP::l̅ * bkP::targetP::B̅gbit) << " / " << std::numeric_limits::digits << endl; + cout << endl; + + // Generate keys + cout << "Generating secret keys..." << endl; + array domainKey; + for (int i = 0; i < bkP::domainP::n; i++) { + domainKey[i] = binary(engine); + } + + array targetKey; + for (int i = 0; i < bkP::targetP::n; i++) { + targetKey[i] = (binary(engine) == 1) + ? static_cast(bkP::targetP::key_value_max) + : static_cast(bkP::targetP::key_value_min); + } + + // Generate bootstrapping key + cout << "Generating bootstrapping key (this may take a while for n=4096)..." << endl; + auto bkfft = make_unique>(); + bkfftgen(*bkfft, domainKey, targetKey); + + // Test arrays + array, num_test> tlwe; + array, num_test> bootedtlwe; + array p; + + // Encrypt test values + cout << "Encrypting test values..." << endl; + for (int i = 0; i < num_test; i++) { + p[i] = binary(engine) > 0; + tlweSymEncrypt( + tlwe[i], p[i] ? bkP::domainP::μ : -bkP::domainP::μ, bkP::domainP::α, + domainKey); + } + + // Perform gate bootstrapping with double decomposition + cout << "Running GateBootstrappingTLWE2TLWEDD..." << endl; + chrono::system_clock::time_point start, end; + start = chrono::system_clock::now(); + + for (int test = 0; test < num_test; test++) { + GateBootstrappingTLWE2TLWEDD( + bootedtlwe[test], tlwe[test], *bkfft, + μpolygen()); + } + + end = chrono::system_clock::now(); + + // Verify results + cout << "Verifying results..." << endl; + int errors = 0; + for (int i = 0; i < num_test; i++) { + bool p2 = tlweSymDecrypt(bootedtlwe[i], targetKey); + if (p[i] != p2) { + cerr << "Error at index " << i << ": expected " << p[i] << " got " << p2 << endl; + errors++; + } + } + + if (errors == 0) { + cout << "All tests passed!" << endl; + } + else { + cerr << errors << " out of " << num_test << " tests failed!" << endl; + return 1; + } + + double elapsed = + chrono::duration_cast(end - start).count(); + cout << "Average time per bootstrapping: " << elapsed / num_test << "ms" << endl; + + cout << "=== GateBootstrappingTLWE2TLWEDD test completed ===" << endl; + return 0; +} diff --git a/thirdparties/spqlios/fft_processor_spqlios.cpp b/thirdparties/spqlios/fft_processor_spqlios.cpp index 31b0f340..c04d4b63 100644 --- a/thirdparties/spqlios/fft_processor_spqlios.cpp +++ b/thirdparties/spqlios/fft_processor_spqlios.cpp @@ -376,3 +376,4 @@ FFT_Processor_Spqlios::~FFT_Processor_Spqlios() { thread_local FFT_Processor_Spqlios fftplvl1(TFHEpp::lvl1param::n); thread_local FFT_Processor_Spqlios fftplvl2(TFHEpp::lvl2param::n); +thread_local FFT_Processor_Spqlios fftplvl3(TFHEpp::lvl3param::n); diff --git a/thirdparties/spqlios/fft_processor_spqlios.h b/thirdparties/spqlios/fft_processor_spqlios.h index ae21e42b..c0711619 100644 --- a/thirdparties/spqlios/fft_processor_spqlios.h +++ b/thirdparties/spqlios/fft_processor_spqlios.h @@ -54,4 +54,5 @@ class FFT_Processor_Spqlios { }; extern thread_local FFT_Processor_Spqlios fftplvl1; -extern thread_local FFT_Processor_Spqlios fftplvl2; \ No newline at end of file +extern thread_local FFT_Processor_Spqlios fftplvl2; +extern thread_local FFT_Processor_Spqlios fftplvl3; \ No newline at end of file From 940de56ce3bcce93d93396d3ae27845347ee491a Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Thu, 1 Jan 2026 03:32:35 +0000 Subject: [PATCH 5/9] Refactor DD functions to use function overloading and fix 128-bit support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Merge ExternalProduct and ExternalProductDD using if constexpr (P::l̅ > 1) - Remove redundant DD-postfixed functions (CMUXFFTDD, BlindRotateDD, etc.) - Fix 128-bit shift overflow in keyswitch.hpp by using proper type casts - Update lvl3param DD parameters to satisfy constraint l*Bgbit + (l̅-1)*B̅gbit ≤ 128 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- include/detwfa.hpp | 48 ------- include/gatebootstrapping.hpp | 40 ------ include/keyswitch.hpp | 28 ++--- include/params/128bit.hpp | 23 ++-- include/trgsw.hpp | 118 ++++++++---------- test/externalproductdoubledecomposition.cpp | 20 +-- ...tstrappingtlwe2tlwedoubledecomposition.cpp | 10 +- 7 files changed, 98 insertions(+), 189 deletions(-) diff --git a/include/detwfa.hpp b/include/detwfa.hpp index 77e6b15f..85222ad8 100644 --- a/include/detwfa.hpp +++ b/include/detwfa.hpp @@ -116,52 +116,4 @@ void CMUXwithPolynomialMulByXaiMinusOne(TRLWE

&acc, for (int i = 0; i < P::n; i++) acc[k][i] += temp[k][i]; } -// Double Decomposition variants -template -void CMUXFFTDD(TRLWE

&res, const TRGSWFFT

&cs, const TRLWE

&c1, - const TRLWE

&c0) -{ - for (int k = 0; k < P::k + 1; k++) - for (int i = 0; i < P::n; i++) res[k][i] = c1[k][i] - c0[k][i]; - ExternalProductDD

(res, res, cs); - for (int k = 0; k < P::k + 1; k++) - for (int i = 0; i < P::n; i++) res[k][i] += c0[k][i]; -} - -template -void CMUXwithPolynomialMulByXaiMinusOneDD( - TRLWE &acc, - const BootstrappingKeyElementFFT &cs, const int a) -{ - if constexpr (bkP::domainP::key_value_diff == 1) { - alignas(64) TRLWE temp; - for (int k = 0; k < bkP::targetP::k + 1; k++) - PolynomialMulByXaiMinusOne(temp[k], acc[k], - a); - ExternalProductDD(temp, temp, cs[0]); - for (int k = 0; k < bkP::targetP::k + 1; k++) - for (int i = 0; i < bkP::targetP::n; i++) acc[k][i] += temp[k][i]; - } - else { - alignas(32) TRLWE temp; - int count = 0; - for (int i = bkP::domainP::key_value_min; - i <= bkP::domainP::key_value_max; i++) { - if (i != 0) { - const int mod = (a * i) % (2 * bkP::targetP::n); - const int index = mod > 0 ? mod : mod + (2 * bkP::targetP::n); - for (int k = 0; k < bkP::targetP::k + 1; k++) - PolynomialMulByXaiMinusOne( - temp[k], acc[k], index); - ExternalProductDD(temp, temp, - cs[count]); - for (int k = 0; k < bkP::targetP::k + 1; k++) - for (int i = 0; i < bkP::targetP::n; i++) - acc[k][i] += temp[k][i]; - count++; - } - } - } -} - } // namespace TFHEpp \ No newline at end of file diff --git a/include/gatebootstrapping.hpp b/include/gatebootstrapping.hpp index a8a2fd33..1938bff4 100644 --- a/include/gatebootstrapping.hpp +++ b/include/gatebootstrapping.hpp @@ -263,46 +263,6 @@ constexpr Polynomial

μpolygen() return poly; } -// Double Decomposition variants -template -void BlindRotateDD(TRLWE &res, - const TLWE &tlwe, - const BootstrappingKeyFFT

&bkfft, - const Polynomial &testvector) -{ - constexpr uint32_t bitwidth = bits_needed(); - const uint32_t b̄ = 2 * P::targetP::n - - ((tlwe[P::domainP::k * P::domainP::n] >> - (std::numeric_limits::digits - - 1 - P::targetP::nbit + bitwidth)) - << bitwidth); - res = {}; - PolynomialMulByXai(res[P::targetP::k], testvector, b̄); - for (int i = 0; i < P::domainP::k * P::domainP::n; i++) { - constexpr typename P::domainP::T roundoffset = - 1ULL << (std::numeric_limits::digits - 2 - - P::targetP::nbit + bitwidth); - const uint32_t ā = - (tlwe[i] + roundoffset) >> - (std::numeric_limits::digits - 1 - - P::targetP::nbit + bitwidth) - << bitwidth; - if (ā == 0) continue; - CMUXwithPolynomialMulByXaiMinusOneDD

(res, bkfft[i], ā); - } -} - -template -void GateBootstrappingTLWE2TLWEDD( - TLWE &res, const TLWE &tlwe, - const BootstrappingKeyFFT

&bkfft, - const Polynomial &testvector) -{ - alignas(64) TRLWE acc; - BlindRotateDD

(acc, tlwe, bkfft, testvector); - SampleExtractIndex(res, acc, 0); -} - template void GateBootstrapping(TLWE &res, const TLWE &tlwe, diff --git a/include/keyswitch.hpp b/include/keyswitch.hpp index 961a91ad..cd767d55 100644 --- a/include/keyswitch.hpp +++ b/include/keyswitch.hpp @@ -17,8 +17,8 @@ constexpr typename P::domainP::T iksoffsetgen() typename P::domainP::T offset = 0; for (int i = 1; i <= P::t; i++) offset += - (1ULL << P::basebit) / 2 * - (1ULL << (std::numeric_limits::digits - + (static_cast(1) << P::basebit) / 2 * + (static_cast(1) << (std::numeric_limits::digits - i * P::basebit)); return offset; } @@ -35,7 +35,7 @@ void IdentityKeySwitch(TLWE &res, std::numeric_limits::digits; constexpr typename P::domainP::T roundoffset = (P::basebit * P::t) < domain_digit - ? 1ULL << (domain_digit - (1 + P::basebit * P::t)) + ? static_cast(1) << (domain_digit - (1 + P::basebit * P::t)) : 0; if constexpr (domain_digit == target_digit) res[P::targetP::k * P::targetP::n] = @@ -43,7 +43,7 @@ void IdentityKeySwitch(TLWE &res, else if constexpr (domain_digit > target_digit) res[P::targetP::k * P::targetP::n] = (tlwe[P::domainP::k * P::domainP::n] + - (1ULL << (domain_digit - target_digit - 1))) >> + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); else if constexpr (domain_digit < target_digit) res[P::targetP::k * P::targetP::n] = @@ -86,7 +86,7 @@ void CatIdentityKeySwitch( std::numeric_limits::digits; constexpr typename P::domainP::T roundoffset = (P::basebit * P::t) < domain_digit - ? 1ULL << (domain_digit - (1 + P::basebit * P::t)) + ? static_cast(1) << (domain_digit - (1 + P::basebit * P::t)) : 0; for (int cat = 0; cat < numcat; cat++) { @@ -96,7 +96,7 @@ void CatIdentityKeySwitch( else if constexpr (domain_digit > target_digit) res[cat][P::targetP::k * P::targetP::n] = (tlwe[cat][P::domainP::k * P::domainP::n] + - (1ULL << (domain_digit - target_digit - 1))) >> + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); else if constexpr (domain_digit < target_digit) res[cat][P::targetP::k * P::targetP::n] = @@ -138,7 +138,7 @@ void SubsetIdentityKeySwitch(TLWE &res, const SubsetKeySwitchingKey

&ksk) { constexpr typename P::domainP::T prec_offset = - 1ULL << (std::numeric_limits::digits - + static_cast(1) << (std::numeric_limits::digits - (1 + P::basebit * P::t)); constexpr uint32_t mask = (1U << P::basebit) - 1; res = {}; @@ -154,11 +154,11 @@ void SubsetIdentityKeySwitch(TLWE &res, } else if constexpr (domain_digit > target_digit) { for (int i = 0; i < P::targetP::k * P::targetP::n; i++) - res[i] = (tlwe[i] + (1ULL << (domain_digit - target_digit - 1))) >> + res[i] = (tlwe[i] + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); res[P::targetP::k * P::targetP::n] = (tlwe[P::domainP::k * P::domainP::n] + - (1ULL << (domain_digit - target_digit - 1))) >> + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); } else if constexpr (domain_digit < target_digit) { @@ -189,13 +189,13 @@ void PrivKeySwitch(TRLWE &res, const PrivateKeySwitchingKey

&privksk) { constexpr typename P::domainP::T roundoffset = - 1ULL << (std::numeric_limits::digits - + static_cast(1) << (std::numeric_limits::digits - (1 + P::basebit * P::t)); // Koga's Optimization constexpr typename P::domainP::T offset = iksoffsetgen

(); - constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1; - constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1); + constexpr typename P::domainP::T mask = (static_cast(1) << P::basebit) - 1; + constexpr typename P::domainP::T halfbase = static_cast(1) << (P::basebit - 1); res = {}; for (int i = 0; i <= P::domainP::k * P::domainP::n; i++) { const typename P::domainP::T aibar = tlwe[i] + offset + roundoffset; @@ -258,7 +258,7 @@ void TLWE2TRLWEIKS(TRLWE &res, const TLWE2TRLWEIKSKey

&iksk) { constexpr typename P::domainP::T prec_offset = - 1ULL << (std::numeric_limits::digits - + static_cast(1) << (std::numeric_limits::digits - (1 + P::basebit * P::t)); constexpr uint32_t mask = (1U << P::basebit) - 1; res = {}; @@ -270,7 +270,7 @@ void TLWE2TRLWEIKS(TRLWE &res, res[P::targetP::k][0] = tlwe[P::domainP::n]; else if constexpr (domain_digit > target_digit) res[P::targetP::k][0] = (tlwe[P::domainP::n] + - (1ULL << (domain_digit - target_digit - 1))) >> + (static_cast(1) << (domain_digit - target_digit - 1))) >> (domain_digit - target_digit); else if constexpr (domain_digit < target_digit) res[P::targetP::k][0] = tlwe[P::domainP::n] diff --git a/include/params/128bit.hpp b/include/params/128bit.hpp index 50cebbd3..a48e0576 100644 --- a/include/params/128bit.hpp +++ b/include/params/128bit.hpp @@ -154,8 +154,13 @@ struct AHlvl2param { }; // lvl3param with 128-bit Torus and non-trivial Double Decomposition -// Double decomposition constraint: l * Bgbit + l̅ * B̅gbit <= 128 -// Using l=4, Bgbit=16, l̅=4, B̅gbit=16: 4*16 + 4*16 = 128 bits (fully utilized) +// Double decomposition structure: +// - Primary decomposition (l, Bgbit): Decomposes plaintext by μ in TRGSW gadget +// - Auxiliary decomposition (l̅, B̅gbit): Decomposes TRLWE ciphertext coefficients +// in the external product. Must cover full 128-bit coefficient. +// Constraint for DD algorithm: l*Bgbit + (l̅-1)*B̅gbit ≤ 128 +// Using l=2, Bgbit=16 for primary (32 bits); l̅=4, B̅gbit=32 for auxiliary (128 bits) +// This gives: 32 + 96 = 128 ✓ struct lvl3param { static constexpr int32_t key_value_max = 1; static constexpr int32_t key_value_min = -1; @@ -163,8 +168,8 @@ struct lvl3param { // ease of polynomial multiplication. static constexpr std::uint32_t n = 1 << nbit; // dimension = 4096 static constexpr std::uint32_t k = 1; - static constexpr std::uint32_t lₐ = 4; - static constexpr std::uint32_t l = 4; + static constexpr std::uint32_t lₐ = 2; // reduced to fit DD constraint + static constexpr std::uint32_t l = 2; // reduced to fit DD constraint static constexpr std::uint32_t Bgbit = 16; static constexpr std::uint32_t Bgₐbit = 16; static constexpr uint32_t Bg = 1U << Bgbit; @@ -179,12 +184,12 @@ struct lvl3param { static constexpr double Δ = static_cast(static_cast(1) << (128 - plain_modulusbit - 1)); // Double Decomposition (bivariate representation) parameters - // Non-trivial values for testing actual double decomposition - // Constraint: l * Bgbit + l̅ * B̅gbit <= 128 - static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels + // Auxiliary decomposition must cover full 128-bit ciphertext coefficients + // l̅ * B̅gbit = 4 * 32 = 128 bits + static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels static constexpr std::uint32_t l̅ₐ = 4; - static constexpr std::uint32_t B̅gbit = 16; // 2^16 base for auxiliary - static constexpr std::uint32_t B̅gₐbit = 16; + static constexpr std::uint32_t B̅gbit = 32; // 2^32 base for auxiliary (covers 128-bit T) + static constexpr std::uint32_t B̅gₐbit = 32; }; struct lvl4param { diff --git a/include/trgsw.hpp b/include/trgsw.hpp index 3d08ebcd..916efd53 100644 --- a/include/trgsw.hpp +++ b/include/trgsw.hpp @@ -267,87 +267,79 @@ void NonceDecomposition(DecomposedNoncePolynomialRAINTT

&decpolyntt, decpolyntt[i], decpoly[i], (*raintttable)[1], (*raintttwist)[1]); } +// External product with TRGSWFFT +// Automatically uses Double Decomposition when P::l̅ > 1 template void ExternalProduct(TRLWE

&res, const TRLWE

&trlwe, const TRGSWFFT

&trgswfft) { alignas(64) PolynomialInFD

decpolyfft; alignas(64) TRLWEInFD

restrlwefft; - { - alignas(64) DecomposedNoncePolynomial

decpoly; - NonceDecomposition

(decpoly, trlwe[0]); - TwistIFFT

(decpolyfft, decpoly[0]); - for (int m = 0; m < P::k + 1; m++) - MulInFD(restrlwefft[m], decpolyfft, trgswfft[0][m]); - for (int i = 1; i < P::lₐ; i++) { - TwistIFFT

(decpolyfft, decpoly[i]); + + if constexpr (P::l̅ > 1) { + // Double Decomposition (bivariate representation) + // Uses l*l̅ rows for "b" block and k*lₐ*l̅ₐ rows for "a" blocks + { + alignas(64) DecomposedNoncePolynomialDD

decpoly; + NonceDoubleDecomposition

(decpoly, trlwe[0]); + TwistIFFT

(decpolyfft, decpoly[0]); for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i][m]); - } - for (int k = 1; k < P::k; k++) { - NonceDecomposition

(decpoly, trlwe[k]); - for (int i = 0; i < P::lₐ; i++) { + MulInFD(restrlwefft[m], decpolyfft, trgswfft[0][m]); + for (int i = 1; i < P::lₐ * P::l̅ₐ; i++) { TwistIFFT

(decpolyfft, decpoly[i]); for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, - trgswfft[i + k * P::lₐ][m]); + FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i][m]); + } + for (int k = 1; k < P::k; k++) { + NonceDoubleDecomposition

(decpoly, trlwe[k]); + for (int i = 0; i < P::lₐ * P::l̅ₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, + trgswfft[i + k * P::lₐ * P::l̅ₐ][m]); + } } } - } - alignas(64) DecomposedPolynomial

decpoly; - Decomposition

(decpoly, trlwe[P::k]); - for (int i = 0; i < P::l; i++) { - TwistIFFT

(decpolyfft, decpoly[i]); - for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, - trgswfft[i + P::k * P::lₐ][m]); - } - for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); -} - -// External product with Double Decomposition (bivariate representation) -// Uses the full TRGSW structure with l*l̅ rows for "b" block and k*lₐ*l̅ₐ rows -// for "a" blocks -template -void ExternalProductDD(TRLWE

&res, const TRLWE

&trlwe, - const TRGSWFFT

&trgswfft) -{ - alignas(64) PolynomialInFD

decpolyfft; - alignas(64) TRLWEInFD

restrlwefft; - - // Handle "a" polynomials (indices 0 to k-1 in TRLWE) - // Uses NonceDoubleDecomposition with lₐ*l̅ₐ levels - { - alignas(64) DecomposedNoncePolynomialDD

decpoly; - NonceDoubleDecomposition

(decpoly, trlwe[0]); - TwistIFFT

(decpolyfft, decpoly[0]); - for (int m = 0; m < P::k + 1; m++) - MulInFD(restrlwefft[m], decpolyfft, trgswfft[0][m]); - for (int i = 1; i < P::lₐ * P::l̅ₐ; i++) { + alignas(64) DecomposedPolynomialDD

decpoly; + DoubleDecomposition

(decpoly, trlwe[P::k]); + for (int i = 0; i < P::l * P::l̅; i++) { TwistIFFT

(decpolyfft, decpoly[i]); for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i][m]); + FMAInFD(restrlwefft[m], decpolyfft, + trgswfft[i + P::k * P::lₐ * P::l̅ₐ][m]); } - for (int k = 1; k < P::k; k++) { - NonceDoubleDecomposition

(decpoly, trlwe[k]); - for (int i = 0; i < P::lₐ * P::l̅ₐ; i++) { + } + else { + // Standard decomposition (l̅ == 1) + { + alignas(64) DecomposedNoncePolynomial

decpoly; + NonceDecomposition

(decpoly, trlwe[0]); + TwistIFFT

(decpolyfft, decpoly[0]); + for (int m = 0; m < P::k + 1; m++) + MulInFD(restrlwefft[m], decpolyfft, trgswfft[0][m]); + for (int i = 1; i < P::lₐ; i++) { TwistIFFT

(decpolyfft, decpoly[i]); for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, - trgswfft[i + k * P::lₐ * P::l̅ₐ][m]); + FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i][m]); + } + for (int k = 1; k < P::k; k++) { + NonceDecomposition

(decpoly, trlwe[k]); + for (int i = 0; i < P::lₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, + trgswfft[i + k * P::lₐ][m]); + } } } - } - - // Handle "b" polynomial (index k in TRLWE) - // Uses DoubleDecomposition with l*l̅ levels - alignas(64) DecomposedPolynomialDD

decpoly; - DoubleDecomposition

(decpoly, trlwe[P::k]); - for (int i = 0; i < P::l * P::l̅; i++) { - TwistIFFT

(decpolyfft, decpoly[i]); - for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, - trgswfft[i + P::k * P::lₐ * P::l̅ₐ][m]); + alignas(64) DecomposedPolynomial

decpoly; + Decomposition

(decpoly, trlwe[P::k]); + for (int i = 0; i < P::l; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, + trgswfft[i + P::k * P::lₐ][m]); + } } for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } diff --git a/test/externalproductdoubledecomposition.cpp b/test/externalproductdoubledecomposition.cpp index ffa0ed1b..74e9c83c 100644 --- a/test/externalproductdoubledecomposition.cpp +++ b/test/externalproductdoubledecomposition.cpp @@ -120,7 +120,7 @@ void testDoubleDecomposition() // Test External Product with Double Decomposition template -void testExternalProductDD() +void testExternalProduct() { constexpr uint32_t num_test = 10; random_device seed_gen; @@ -160,21 +160,21 @@ void testExternalProductDD() TRGSWFFT

trgswfft; trgswSymEncrypt

(trgswfft, plainpoly, key); - // Apply external product with double decomposition + // Apply external product (automatically uses DD when P::l̅ > 1) TRLWE

result; - ExternalProductDD

(result, c, trgswfft); + ExternalProduct

(result, c, trgswfft); // Decrypt and verify const array p2 = trlweSymDecrypt

(result, key); for (int i = 0; i < P::n; i++) { if (p[i] != p2[i]) { - cerr << "ExternalProductDD failed at index " << i + cerr << "ExternalProduct (DD) failed at index " << i << ": expected " << p[i] << " got " << p2[i] << endl; assert(false); } } } - cout << "ExternalProductDD test passed (p=1)" << endl; + cout << "ExternalProduct (DD) test passed (p=1)" << endl; // Test with p=-1 (negation) cout << "Testing with plaintext = -1 (negation)..." << endl; @@ -202,18 +202,18 @@ void testExternalProductDD() trgswSymEncrypt

(trgswfft, plainpoly, key); TRLWE

result; - ExternalProductDD

(result, c, trgswfft); + ExternalProduct

(result, c, trgswfft); const array p2 = trlweSymDecrypt

(result, key); for (int i = 0; i < P::n; i++) { if (p[i] != !p2[i]) { - cerr << "ExternalProductDD (p=-1) failed at index " << i + cerr << "ExternalProduct (DD, p=-1) failed at index " << i << ": expected " << !p[i] << " got " << p2[i] << endl; assert(false); } } } - cout << "ExternalProductDD test passed (p=-1)" << endl; + cout << "ExternalProduct (DD) test passed (p=-1)" << endl; } int main() @@ -250,8 +250,8 @@ int main() testDoubleDecomposition(); cout << endl; - cout << "--- Testing ExternalProductDD ---" << endl; - testExternalProductDD(); + cout << "--- Testing ExternalProduct (DD) ---" << endl; + testExternalProduct(); cout << endl; cout << "=== All Double Decomposition tests passed ===" << endl; diff --git a/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp b/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp index 017a4301..20f8eb39 100644 --- a/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp +++ b/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp @@ -18,7 +18,7 @@ int main() // lvl3param has nbit=12 (n=4096) and non-trivial DD: l=2, l̅=2, Bgbit=16, B̅gbit=16 using bkP = lvl03param; - cout << "=== Testing GateBootstrappingTLWE2TLWEDD with lvl3param (nbit=12) ===" << endl; + cout << "=== Testing GateBootstrappingTLWE2TLWE with DD (lvl3param, nbit=12) ===" << endl; cout << "Domain: n=" << bkP::domainP::n << ", T=" << sizeof(typename bkP::domainP::T) * 8 << "-bit" << endl; cout << "Target: n=" << bkP::targetP::n << ", nbit=" << bkP::targetP::nbit << ", T=" << sizeof(typename bkP::targetP::T) * 8 << "-bit" << endl; cout << "Primary decomposition: l=" << bkP::targetP::l << ", Bgbit=" << bkP::targetP::Bgbit << endl; @@ -60,13 +60,13 @@ int main() domainKey); } - // Perform gate bootstrapping with double decomposition - cout << "Running GateBootstrappingTLWE2TLWEDD..." << endl; + // Perform gate bootstrapping (automatically uses DD when l̅ > 1) + cout << "Running GateBootstrappingTLWE2TLWE (with DD)..." << endl; chrono::system_clock::time_point start, end; start = chrono::system_clock::now(); for (int test = 0; test < num_test; test++) { - GateBootstrappingTLWE2TLWEDD( + GateBootstrappingTLWE2TLWE( bootedtlwe[test], tlwe[test], *bkfft, μpolygen()); } @@ -96,6 +96,6 @@ int main() chrono::duration_cast(end - start).count(); cout << "Average time per bootstrapping: " << elapsed / num_test << "ms" << endl; - cout << "=== GateBootstrappingTLWE2TLWEDD test completed ===" << endl; + cout << "=== GateBootstrappingTLWE2TLWE (DD) test completed ===" << endl; return 0; } From d1def37578cb74a72eabf4fabb4cb3a3ef2a103f Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Thu, 1 Jan 2026 04:35:17 +0000 Subject: [PATCH 6/9] Fix DD tests for CONCRETE parameter compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use lvl2param::nbit in DDTestParam for FFT compatibility across param sets - Skip gate bootstrapping DD test for CONCRETE builds (lvl3param nbit too large) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- test/externalproductdoubledecomposition.cpp | 6 +++--- test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/externalproductdoubledecomposition.cpp b/test/externalproductdoubledecomposition.cpp index 74e9c83c..fcb80258 100644 --- a/test/externalproductdoubledecomposition.cpp +++ b/test/externalproductdoubledecomposition.cpp @@ -9,12 +9,12 @@ using namespace TFHEpp; // Custom parameter set for testing double decomposition with 64-bit Torus // Using trivial double decomposition (l̅=1) which reduces to standard decomposition // This verifies the code path works correctly -// Note: nbit must match lvl2param (11) for FFT compatibility +// Note: nbit must match lvl2param for FFT compatibility struct DDTestParam { static constexpr int32_t key_value_max = 1; static constexpr int32_t key_value_min = -1; - static constexpr std::uint32_t nbit = 11; - static constexpr std::uint32_t n = 1 << nbit; // 2048 + static constexpr std::uint32_t nbit = lvl2param::nbit; // Use lvl2param's nbit for FFT compatibility + static constexpr std::uint32_t n = 1 << nbit; static constexpr std::uint32_t k = 1; static constexpr std::uint32_t lₐ = 4; static constexpr std::uint32_t l = 4; diff --git a/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp b/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp index 20f8eb39..30cf984c 100644 --- a/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp +++ b/test/gatebootstrappingtlwe2tlwedoubledecomposition.cpp @@ -6,6 +6,12 @@ int main() { +#if defined(USE_CONCRETE) || defined(USE_CONCRETE_FFT) + // Skip this test for CONCRETE builds - lvl3param uses nbit=13 which exceeds + // the FFT table sizes initialized for CONCRETE parameters + std::cout << "Skipping DD gate bootstrapping test for CONCRETE build" << std::endl; + return 0; +#else using namespace std; using namespace TFHEpp; @@ -98,4 +104,5 @@ int main() cout << "=== GateBootstrappingTLWE2TLWE (DD) test completed ===" << endl; return 0; +#endif } From 18b2bfbaa5bfd360d3165c9d23699dbc2be74cd7 Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Thu, 1 Jan 2026 05:34:46 +0000 Subject: [PATCH 7/9] Fix 128-bit Double Decomposition FFT handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 128-bit Torus uses a 64-bit FFT backend where TwistIFFT extracts the top 64 bits and TwistFFT places results in the top 64 bits. All decomposition functions needed to account for this offset. Key fixes: - Add << 64 shift in Decomposition/NonceDecomposition for 128-bit types - Add << 64 shift in TRLWEBaseBbarDecompose/Nonce for 128-bit types - Fix RecombineTRLWEFromDD/Nonce to compensate for TwistFFT's << 64 shift by adjusting recombination shifts (actual_shift = target_shift - 64) - Update test to include both l̅=1 (standard) and l̅=2 (DD) code paths All tests pass: - 64-bit external product DD (both l̅=1 and l̅=2) - 128-bit gate bootstrapping DD (lvl3param: l=2, l̅=4, B̅gbit=32) - Standard external product and gate bootstrapping tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- include/trgsw.hpp | 524 +++++++++++++++++--- test/externalproductdoubledecomposition.cpp | 92 ++-- 2 files changed, 516 insertions(+), 100 deletions(-) diff --git a/include/trgsw.hpp b/include/trgsw.hpp index 916efd53..20a878f1 100644 --- a/include/trgsw.hpp +++ b/include/trgsw.hpp @@ -75,12 +75,21 @@ inline void Decomposition(DecomposedPolynomial

&decpoly, constexpr typename P::T halfBg = (static_cast(1) << (P::Bgbit - 1)); for (int i = 0; i < P::n; i++) { - for (int l = 0; l < P::l; l++) - decpoly[l][i] = (((poly[i] + offset + roundoffset) >> - (std::numeric_limits::digits - - (l + 1) * P::Bgbit)) & - mask) - - halfBg; + for (int l = 0; l < P::l; l++) { + auto decomp_val = + static_cast>( + (((poly[i] + offset + roundoffset) >> + (std::numeric_limits::digits - + (l + 1) * P::Bgbit)) & + mask) - + halfBg); + // For 128-bit types, shift left by 64 so TwistIFFT (which uses + // top 64 bits) gets the correct small integer value + if constexpr (std::is_same_v) + decpoly[l][i] = static_cast(decomp_val) << 64; + else + decpoly[l][i] = decomp_val; + } } } @@ -108,12 +117,21 @@ inline void NonceDecomposition(DecomposedNoncePolynomial

&decpoly, constexpr typename P::T halfBg = (static_cast(1) << (P::Bgₐbit - 1)); for (int i = 0; i < P::n; i++) { - for (int l = 0; l < P::lₐ; l++) - decpoly[l][i] = (((poly[i] + offset + roundoffset) >> - (std::numeric_limits::digits - - (l + 1) * P::Bgₐbit)) & - mask) - - halfBg; + for (int l = 0; l < P::lₐ; l++) { + auto decomp_val = + static_cast>( + (((poly[i] + offset + roundoffset) >> + (std::numeric_limits::digits - + (l + 1) * P::Bgₐbit)) & + mask) - + halfBg); + // For 128-bit types, shift left by 64 so TwistIFFT (which uses + // top 64 bits) gets the correct small integer value + if constexpr (std::is_same_v) + decpoly[l][i] = static_cast(decomp_val) << 64; + else + decpoly[l][i] = decomp_val; + } } } @@ -227,6 +245,117 @@ inline void NonceDoubleDecomposition(DecomposedNoncePolynomialDD

&decpoly, } } +// TRLWE Decomposition to base B̅g for Double Decomposition +// Decomposes each TRLWE coefficient into l̅ digits in base B̅g +// Used to pre-apply DD to TRGSW rows during encryption +// Returns l̅ TRLWEs where result[j] contains the j-th B̅g digit of each coefficient +template +inline void TRLWEBaseBbarDecompose(std::array, P::l̅> &result, + const TRLWE

&input) +{ + constexpr typename P::T maskB̅g = + static_cast((static_cast(1) << P::B̅gbit) - + 1); + constexpr typename P::T halfB̅g = + static_cast(1) << (P::B̅gbit - 1); + + // Compute offset for signed digit representation + constexpr typename P::T offset = []() { + typename P::T off = 0; + for (int j = 0; j < P::l̅; j++) + off += halfB̅g * (static_cast(1) + << (std::numeric_limits::digits - + (j + 1) * P::B̅gbit)); + return off; + }(); + + // Remaining bits after decomposition + constexpr int remaining_bits = + std::numeric_limits::digits - P::l̅ * P::B̅gbit; + constexpr typename P::T roundoffset = + remaining_bits > 0 + ? (static_cast(1) << (remaining_bits - 1)) + : static_cast(0); + + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + typename P::T a = input[k][n] + offset + roundoffset; + for (int j = 0; j < P::l̅; j++) { + // Extract j-th digit from MSB side + const int shift = std::numeric_limits::digits - + (j + 1) * P::B̅gbit; + // Get masked value and compute signed digit + // Use signed arithmetic to properly handle negative digits + typename P::T masked = (a >> shift) & maskB̅g; + // Cast to signed, subtract halfB̅g, then back to unsigned for + // proper sign extension + using SignedT = std::make_signed_t; + SignedT digit = + static_cast(masked) - static_cast(halfB̅g); + // For 128-bit types, shift left by 64 so TwistIFFT (which uses + // top 64 bits) gets the correct small integer value + if constexpr (std::is_same_v) + result[j][k][n] = static_cast(digit) << 64; + else + result[j][k][n] = static_cast(digit); + } + } + } +} + +// Nonce version of TRLWE Decomposition to base B̅gₐ +template +inline void TRLWEBaseBbarDecomposeNonce(std::array, P::l̅ₐ> &result, + const TRLWE

&input) +{ + constexpr typename P::T maskB̅g = + static_cast((static_cast(1) << P::B̅gₐbit) - + 1); + constexpr typename P::T halfB̅g = + static_cast(1) << (P::B̅gₐbit - 1); + + // Compute offset for signed digit representation + constexpr typename P::T offset = []() { + typename P::T off = 0; + for (int j = 0; j < P::l̅ₐ; j++) + off += halfB̅g * (static_cast(1) + << (std::numeric_limits::digits - + (j + 1) * P::B̅gₐbit)); + return off; + }(); + + // Remaining bits after decomposition + constexpr int remaining_bits = + std::numeric_limits::digits - P::l̅ₐ * P::B̅gₐbit; + constexpr typename P::T roundoffset = + remaining_bits > 0 + ? (static_cast(1) << (remaining_bits - 1)) + : static_cast(0); + + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + typename P::T a = input[k][n] + offset + roundoffset; + for (int j = 0; j < P::l̅ₐ; j++) { + // Extract j-th digit from MSB side + const int shift = std::numeric_limits::digits - + (j + 1) * P::B̅gₐbit; + // Get masked value and compute signed digit + // Use signed arithmetic to properly handle negative digits + typename P::T masked = (a >> shift) & maskB̅g; + using SignedT = std::make_signed_t; + SignedT digit = + static_cast(masked) - static_cast(halfB̅g); + // For 128-bit types, shift left by 64 so TwistIFFT (which uses + // top 64 bits) gets the correct small integer value + if constexpr (std::is_same_v) + result[j][k][n] = static_cast(digit) << 64; + else + result[j][k][n] = static_cast(digit); + } + } + } +} + template void Decomposition(DecomposedPolynomialNTT

&decpolyntt, const Polynomial

&poly) @@ -267,6 +396,83 @@ void NonceDecomposition(DecomposedNoncePolynomialRAINTT

&decpolyntt, decpolyntt[i], decpoly[i], (*raintttable)[1], (*raintttwist)[1]); } +// Recombine l̅ TRLWEs from Double Decomposition back to single TRLWE +// result[j] contains j-th B̅g digit; recombine as: res = Σⱼ result[j] * 2^(width - (j+1)*B̅gbit) +// For 128-bit types, TwistFFT places results in top 64 bits, so we adjust shifts accordingly +template +inline void RecombineTRLWEFromDD(TRLWE

&res, + const std::array, P::l̅> &decomposed) +{ + constexpr int width = std::numeric_limits::digits; + // For 128-bit types, TwistFFT adds a << 64 shift, so we compensate + constexpr int fft_offset = + std::is_same_v ? 64 : 0; + + // Initialize result to zero + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + res[k][n] = 0; + } + } + + // Add all components with appropriate shifts + for (int j = 0; j < P::l̅; j++) { + // Target shift: width - (j+1)*B̅gbit + // Actual shift needed: target - fft_offset + const int target_shift = width - (j + 1) * P::B̅gbit; + const int actual_shift = target_shift - fft_offset; + + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + if (actual_shift >= 0) { + res[k][n] += decomposed[j][k][n] << actual_shift; + } + else { + res[k][n] += decomposed[j][k][n] >> (-actual_shift); + } + } + } + } +} + +// Recombine l̅ₐ TRLWEs from Double Decomposition (nonce version) +// For 128-bit types, TwistFFT places results in top 64 bits, so we adjust shifts accordingly +template +inline void RecombineTRLWEFromDDNonce( + TRLWE

&res, const std::array, P::l̅ₐ> &decomposed) +{ + constexpr int width = std::numeric_limits::digits; + // For 128-bit types, TwistFFT adds a << 64 shift, so we compensate + constexpr int fft_offset = + std::is_same_v ? 64 : 0; + + // Initialize result to zero + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + res[k][n] = 0; + } + } + + // Add all components with appropriate shifts + for (int j = 0; j < P::l̅ₐ; j++) { + // Target shift: width - (j+1)*B̅gₐbit + // Actual shift needed: target - fft_offset + const int target_shift = width - (j + 1) * P::B̅gₐbit; + const int actual_shift = target_shift - fft_offset; + + for (int k = 0; k <= P::k; k++) { + for (int n = 0; n < P::n; n++) { + if (actual_shift >= 0) { + res[k][n] += decomposed[j][k][n] << actual_shift; + } + else { + res[k][n] += decomposed[j][k][n] >> (-actual_shift); + } + } + } + } +} + // External product with TRGSWFFT // Automatically uses Double Decomposition when P::l̅ > 1 template @@ -274,43 +480,103 @@ void ExternalProduct(TRLWE

&res, const TRLWE

&trlwe, const TRGSWFFT

&trgswfft) { alignas(64) PolynomialInFD

decpolyfft; - alignas(64) TRLWEInFD

restrlwefft; if constexpr (P::l̅ > 1) { - // Double Decomposition (bivariate representation) - // Uses l*l̅ rows for "b" block and k*lₐ*l̅ₐ rows for "a" blocks - { - alignas(64) DecomposedNoncePolynomialDD

decpoly; - NonceDoubleDecomposition

(decpoly, trlwe[0]); + // Double Decomposition: use standard decomposition on input, + // accumulate l̅ separate results, then recombine + // TRGSW rows are organized as: for each ordinary row i, l̅ rows for B̅g digits + + // l̅ separate accumulators in FD domain + alignas(64) std::array, P::l̅> restrlwefft_dd; + + // Initialize all accumulators to zero + for (int j = 0; j < P::l̅; j++) + for (int m = 0; m <= P::k; m++) + for (int n = 0; n < P::n; n++) restrlwefft_dd[j][m][n] = 0.0; + + // Process nonce part with standard decomposition (lₐ levels) + if constexpr (P::l̅ₐ > 1) { + alignas(64) DecomposedNoncePolynomial

decpoly; + NonceDecomposition

(decpoly, trlwe[0]); + for (int i = 0; i < P::lₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + // Each decomposition level i multiplies with l̅ₐ TRGSW rows + for (int j = 0; j < P::l̅ₐ; j++) { + const int row_idx = i * P::l̅ₐ + j; + for (int m = 0; m <= P::k; m++) { + if (i == 0 && j == 0) + MulInFD(restrlwefft_dd[j][m], decpolyfft, + trgswfft[row_idx][m]); + else + FMAInFD(restrlwefft_dd[j][m], decpolyfft, + trgswfft[row_idx][m]); + } + } + } + for (int k_idx = 1; k_idx < P::k; k_idx++) { + NonceDecomposition

(decpoly, trlwe[k_idx]); + for (int i = 0; i < P::lₐ; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int j = 0; j < P::l̅ₐ; j++) { + const int row_idx = + (i * P::l̅ₐ + j) + k_idx * P::lₐ * P::l̅ₐ; + for (int m = 0; m <= P::k; m++) + FMAInFD(restrlwefft_dd[j][m], decpolyfft, + trgswfft[row_idx][m]); + } + } + } + } + else { + // l̅ₐ == 1: nonce part has no DD, just standard decomposition + alignas(64) DecomposedNoncePolynomial

decpoly; + NonceDecomposition

(decpoly, trlwe[0]); TwistIFFT

(decpolyfft, decpoly[0]); - for (int m = 0; m < P::k + 1; m++) - MulInFD(restrlwefft[m], decpolyfft, trgswfft[0][m]); - for (int i = 1; i < P::lₐ * P::l̅ₐ; i++) { + for (int m = 0; m <= P::k; m++) + MulInFD(restrlwefft_dd[0][m], decpolyfft, trgswfft[0][m]); + for (int i = 1; i < P::lₐ; i++) { TwistIFFT

(decpolyfft, decpoly[i]); - for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i][m]); + for (int m = 0; m <= P::k; m++) + FMAInFD(restrlwefft_dd[0][m], decpolyfft, + trgswfft[i][m]); } - for (int k = 1; k < P::k; k++) { - NonceDoubleDecomposition

(decpoly, trlwe[k]); - for (int i = 0; i < P::lₐ * P::l̅ₐ; i++) { + for (int k_idx = 1; k_idx < P::k; k_idx++) { + NonceDecomposition

(decpoly, trlwe[k_idx]); + for (int i = 0; i < P::lₐ; i++) { TwistIFFT

(decpolyfft, decpoly[i]); - for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, - trgswfft[i + k * P::lₐ * P::l̅ₐ][m]); + for (int m = 0; m <= P::k; m++) + FMAInFD(restrlwefft_dd[0][m], decpolyfft, + trgswfft[i + k_idx * P::lₐ][m]); } } } - alignas(64) DecomposedPolynomialDD

decpoly; - DoubleDecomposition

(decpoly, trlwe[P::k]); - for (int i = 0; i < P::l * P::l̅; i++) { + + // Process main part with standard decomposition (l levels) + alignas(64) DecomposedPolynomial

decpoly; + Decomposition

(decpoly, trlwe[P::k]); + for (int i = 0; i < P::l; i++) { TwistIFFT

(decpolyfft, decpoly[i]); - for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, - trgswfft[i + P::k * P::lₐ * P::l̅ₐ][m]); + // Each decomposition level i multiplies with l̅ TRGSW rows + for (int j = 0; j < P::l̅; j++) { + const int row_idx = (i * P::l̅ + j) + P::k * P::lₐ * P::l̅ₐ; + for (int m = 0; m <= P::k; m++) + FMAInFD(restrlwefft_dd[j][m], decpolyfft, + trgswfft[row_idx][m]); + } } + + // FFT back to coefficient domain for each accumulator and recombine + std::array, P::l̅> results_dd; + for (int j = 0; j < P::l̅; j++) + for (int k = 0; k <= P::k; k++) + TwistFFT

(results_dd[j][k], restrlwefft_dd[j][k]); + + // Recombine the l̅ TRLWEs back to single TRLWE + RecombineTRLWEFromDD

(res, results_dd); } else { // Standard decomposition (l̅ == 1) + alignas(64) TRLWEInFD

restrlwefft; { alignas(64) DecomposedNoncePolynomial

decpoly; NonceDecomposition

(decpoly, trlwe[0]); @@ -340,8 +606,8 @@ void ExternalProduct(TRLWE

&res, const TRLWE

&trlwe, FMAInFD(restrlwefft[m], decpolyfft, trgswfft[i + P::k * P::lₐ][m]); } + for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } - for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } template @@ -351,18 +617,52 @@ void ExternalProduct(TRLWE

&res, const Polynomial

&poly, alignas(64) DecomposedPolynomial

decpoly; Decomposition

(decpoly, poly); alignas(64) PolynomialInFD

decpolyfft; - // __builtin_prefetch(trgswfft[0].data()); - TwistIFFT

(decpolyfft, decpoly[0]); - alignas(64) TRLWEInFD

restrlwefft; - for (int m = 0; m < P::k + 1; m++) - MulInFD(restrlwefft[m], decpolyfft, halftrgswfft[0][m]); - for (int i = 1; i < P::l; i++) { - // __builtin_prefetch(trgswfft[i].data()); - TwistIFFT

(decpolyfft, decpoly[i]); + + if constexpr (P::l̅ > 1) { + // DD: use standard decomposition, accumulate l̅ results, recombine + alignas(64) std::array, P::l̅> restrlwefft_dd; + + // Initialize accumulators to zero + for (int j = 0; j < P::l̅; j++) + for (int m = 0; m <= P::k; m++) + for (int n = 0; n < P::n; n++) restrlwefft_dd[j][m][n] = 0.0; + + for (int i = 0; i < P::l; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int j = 0; j < P::l̅; j++) { + const int row_idx = i * P::l̅ + j; + for (int m = 0; m <= P::k; m++) { + if (i == 0 && j == 0) + MulInFD(restrlwefft_dd[j][m], decpolyfft, + halftrgswfft[row_idx][m]); + else + FMAInFD(restrlwefft_dd[j][m], decpolyfft, + halftrgswfft[row_idx][m]); + } + } + } + + // FFT back and recombine + std::array, P::l̅> results_dd; + for (int j = 0; j < P::l̅; j++) + for (int k = 0; k <= P::k; k++) + TwistFFT

(results_dd[j][k], restrlwefft_dd[j][k]); + + RecombineTRLWEFromDD

(res, results_dd); + } + else { + // Standard decomposition (l̅ == 1) + TwistIFFT

(decpolyfft, decpoly[0]); + alignas(64) TRLWEInFD

restrlwefft; for (int m = 0; m < P::k + 1; m++) - FMAInFD(restrlwefft[m], decpolyfft, halftrgswfft[i][m]); + MulInFD(restrlwefft[m], decpolyfft, halftrgswfft[0][m]); + for (int i = 1; i < P::l; i++) { + TwistIFFT

(decpolyfft, decpoly[i]); + for (int m = 0; m < P::k + 1; m++) + FMAInFD(restrlwefft[m], decpolyfft, halftrgswfft[i][m]); + } + for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } - for (int k = 0; k < P::k + 1; k++) TwistFFT

(res[k], restrlwefft[k]); } template @@ -632,12 +932,30 @@ template inline void halftrgswhadd(HalfTRGSW

&halftrgsw, const Polynomial

&p) { constexpr std::array h = hgen

(); - constexpr std::array h̅ = h̅gen

(); - for (int i = 0; i < P::l; i++) { - for (int ī = 0; ī < P::l̅; ī++) { + + if constexpr (P::l̅ > 1) { + // Double Decomposition: first add ordinary gadget, then decompose TRLWE + // Create l temporary TRLWEs with ordinary gadget values p * h[i] + for (int i = 0; i < P::l; i++) { + TRLWE

temp_trlwe = halftrgsw[i * P::l̅]; // Copy base row (encrypted zero) + for (int j = 0; j < P::n; j++) { + temp_trlwe[P::k][j] += + static_cast(p[j]) * h[i]; + } + // Decompose this TRLWE to base B̅g to get l̅ rows + std::array, P::l̅> decomposed; + TRLWEBaseBbarDecompose

(decomposed, temp_trlwe); + for (int ī = 0; ī < P::l̅; ī++) { + halftrgsw[i * P::l̅ + ī] = decomposed[ī]; + } + } + } + else { + // Standard decomposition (l̅ = 1): add ordinary gadget values directly + for (int i = 0; i < P::l; i++) { for (int j = 0; j < P::n; j++) { - halftrgsw[i * P::l̅ + ī][P::k][j] += - static_cast(p[j]) * h[i] * h̅[ī]; + halftrgsw[i][P::k][j] += + static_cast(p[j]) * h[i]; } } } @@ -647,25 +965,61 @@ template inline void trgswhadd(TRGSW

&trgsw, const Polynomial

&p) { constexpr std::array nonceh = noncehgen

(); - constexpr std::array nonceh̅ = nonceh̅gen

(); - for (int i = 0; i < P::lₐ; i++) { - for (int ī = 0; ī < P::l̅ₐ; ī++) { + + if constexpr (P::l̅ₐ > 1) { + // DD for nonce part: first add ordinary gadget, then decompose TRLWE + for (int k = 0; k < P::k; k++) { + for (int i = 0; i < P::lₐ; i++) { + TRLWE

temp_trlwe = trgsw[(i * P::l̅ₐ) + k * P::lₐ * P::l̅ₐ]; + for (int j = 0; j < P::n; j++) { + temp_trlwe[k][j] += + static_cast(p[j]) * nonceh[i]; + } + // Decompose to base B̅gₐ + std::array, P::l̅ₐ> decomposed; + TRLWEBaseBbarDecomposeNonce

(decomposed, temp_trlwe); + for (int ī = 0; ī < P::l̅ₐ; ī++) { + trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ] = decomposed[ī]; + } + } + } + } + else { + // Standard decomposition for nonce part + for (int i = 0; i < P::lₐ; i++) { for (int k = 0; k < P::k; k++) { for (int j = 0; j < P::n; j++) { - trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ][k][j] += - static_cast(p[j]) * nonceh[i] * - nonceh̅[ī]; + trgsw[i + k * P::lₐ][k][j] += + static_cast(p[j]) * nonceh[i]; } } } } + constexpr std::array h = hgen

(); - constexpr std::array h̅ = h̅gen

(); - for (int i = 0; i < P::l; i++) { - for (int ī = 0; ī < P::l̅; ī++) { + + if constexpr (P::l̅ > 1) { + // DD for main part: first add ordinary gadget, then decompose TRLWE + for (int i = 0; i < P::l; i++) { + TRLWE

temp_trlwe = trgsw[(i * P::l̅) + P::k * P::lₐ * P::l̅ₐ]; for (int j = 0; j < P::n; j++) { - trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ][P::k][j] += - static_cast(p[j]) * h[i] * h̅[ī]; + temp_trlwe[P::k][j] += + static_cast(p[j]) * h[i]; + } + // Decompose to base B̅g + std::array, P::l̅> decomposed; + TRLWEBaseBbarDecompose

(decomposed, temp_trlwe); + for (int ī = 0; ī < P::l̅; ī++) { + trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ] = decomposed[ī]; + } + } + } + else { + // Standard decomposition for main part + for (int i = 0; i < P::l; i++) { + for (int j = 0; j < P::n; j++) { + trgsw[i + P::k * P::lₐ][P::k][j] += + static_cast(p[j]) * h[i]; } } } @@ -675,19 +1029,45 @@ template inline void trgswhoneadd(TRGSW

&trgsw) { constexpr std::array nonceh = noncehgen

(); - constexpr std::array nonceh̅ = nonceh̅gen

(); - for (int i = 0; i < P::lₐ; i++) - for (int ī = 0; ī < P::l̅ₐ; ī++) + + if constexpr (P::l̅ₐ > 1) { + // DD for nonce part: add ordinary gadget, then decompose + for (int k = 0; k < P::k; k++) { + for (int i = 0; i < P::lₐ; i++) { + TRLWE

temp_trlwe = trgsw[(i * P::l̅ₐ) + k * P::lₐ * P::l̅ₐ]; + temp_trlwe[k][0] += nonceh[i]; // Add 1 * h[i] + std::array, P::l̅ₐ> decomposed; + TRLWEBaseBbarDecomposeNonce

(decomposed, temp_trlwe); + for (int ī = 0; ī < P::l̅ₐ; ī++) { + trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ] = decomposed[ī]; + } + } + } + } + else { + for (int i = 0; i < P::lₐ; i++) for (int k = 0; k < P::k; k++) - trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ][k][0] += - nonceh[i] * nonceh̅[ī]; + trgsw[i + k * P::lₐ][k][0] += nonceh[i]; + } constexpr std::array h = hgen

(); - constexpr std::array h̅ = h̅gen

(); - for (int i = 0; i < P::l; i++) - for (int ī = 0; ī < P::l̅; ī++) - trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ][P::k][0] += - h[i] * h̅[ī]; + + if constexpr (P::l̅ > 1) { + // DD for main part: add ordinary gadget, then decompose + for (int i = 0; i < P::l; i++) { + TRLWE

temp_trlwe = trgsw[(i * P::l̅) + P::k * P::lₐ * P::l̅ₐ]; + temp_trlwe[P::k][0] += h[i]; // Add 1 * h[i] + std::array, P::l̅> decomposed; + TRLWEBaseBbarDecompose

(decomposed, temp_trlwe); + for (int ī = 0; ī < P::l̅; ī++) { + trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ] = decomposed[ī]; + } + } + } + else { + for (int i = 0; i < P::l; i++) + trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; + } } template diff --git a/test/externalproductdoubledecomposition.cpp b/test/externalproductdoubledecomposition.cpp index fcb80258..de7301e7 100644 --- a/test/externalproductdoubledecomposition.cpp +++ b/test/externalproductdoubledecomposition.cpp @@ -6,11 +6,9 @@ using namespace std; using namespace TFHEpp; -// Custom parameter set for testing double decomposition with 64-bit Torus -// Using trivial double decomposition (l̅=1) which reduces to standard decomposition -// This verifies the code path works correctly -// Note: nbit must match lvl2param for FFT compatibility -struct DDTestParam { +// Custom parameter set for testing standard decomposition (l̅=1) with 64-bit Torus +// This tests that the standard path still works correctly +struct DDTestParamStandard { static constexpr int32_t key_value_max = 1; static constexpr int32_t key_value_min = -1; static constexpr std::uint32_t nbit = lvl2param::nbit; // Use lvl2param's nbit for FFT compatibility @@ -29,17 +27,42 @@ struct DDTestParam { static constexpr T μ = 1ULL << 61; static constexpr uint32_t plain_modulus = 8; static constexpr double Δ = μ; - // Trivial Double Decomposition (l̅=1 reduces to standard decomposition) - // For l̅=1 to correctly reduce to standard decomposition, B̅gbit must equal Bgbit - // With l=4, Bgbit=10: uses 40 bits - // l̅=1, B̅gbit=10: adds 10 bits (but with j=0 only, no additional bits used) - // The shift formula: width - (i+1)*Bgbit - j*B̅gbit = width - (i+1)*Bgbit when j=0 + // Standard decomposition (l̅=1) static constexpr std::uint32_t l̅ = 1; static constexpr std::uint32_t l̅ₐ = 1; static constexpr std::uint32_t B̅gbit = 10; static constexpr std::uint32_t B̅gₐbit = 10; }; +// Custom parameter set for testing actual Double Decomposition (l̅=2) with 64-bit Torus +// This tests the DD code path where l̅ > 1 +struct DDTestParam { + static constexpr int32_t key_value_max = 1; + static constexpr int32_t key_value_min = -1; + static constexpr std::uint32_t nbit = lvl2param::nbit; // Use lvl2param's nbit for FFT compatibility + static constexpr std::uint32_t n = 1 << nbit; + static constexpr std::uint32_t k = 1; + static constexpr std::uint32_t lₐ = 2; // Reduced to fit within bit budget + static constexpr std::uint32_t l = 2; // Reduced to fit within bit budget + static constexpr std::uint32_t Bgbit = 16; + static constexpr std::uint32_t Bgₐbit = 16; + static constexpr uint32_t Bg = 1U << Bgbit; + static constexpr uint32_t Bgₐ = 1U << Bgₐbit; + static constexpr ErrorDistribution errordist = + ErrorDistribution::ModularGaussian; + static const inline double α = std::pow(2.0, -40); + using T = uint64_t; + static constexpr T μ = 1ULL << 61; + static constexpr uint32_t plain_modulus = 8; + static constexpr double Δ = μ; + // Actual Double Decomposition: l=2, Bgbit=16, l̅=2, B̅gbit=16 + // Total bits used: 2*16 + 2*16 = 64 (exact fit for 64-bit Torus) + static constexpr std::uint32_t l̅ = 2; + static constexpr std::uint32_t l̅ₐ = 2; + static constexpr std::uint32_t B̅gbit = 16; + static constexpr std::uint32_t B̅gₐbit = 16; +}; + // Test Double Decomposition correctness // Verifies that DoubleDecomposition correctly represents the original value template @@ -221,34 +244,47 @@ int main() cout << "=== Testing Double Decomposition ===" << endl; cout << endl; - // DDTestParam: 64-bit Torus with proper double decomposition - // l=2, Bgbit=16, l̅=2, B̅gbit=16 - // Total bits: 2*16 + 2*16 = 64 (exact fit) - // Decomposition levels: 2*2 = 4 + // Test 1: Standard decomposition (l̅=1) + cout << "=== Test 1: Standard decomposition (l̅=1) ===" << endl; + cout << "DDTestParamStandard configuration:" << endl; + cout << " n = " << DDTestParamStandard::n << ", k = " << DDTestParamStandard::k << endl; + cout << " Primary: l = " << DDTestParamStandard::l + << ", Bgbit = " << DDTestParamStandard::Bgbit << endl; + cout << " Auxiliary: l̅ = " << DDTestParamStandard::l̅ + << ", B̅gbit = " << DDTestParamStandard::B̅gbit + << " (standard path)" << endl; + cout << " TRGSW rows: " + << (DDTestParamStandard::k * DDTestParamStandard::lₐ * DDTestParamStandard::l̅ₐ + + DDTestParamStandard::l * DDTestParamStandard::l̅) + << endl; + cout << endl; + + cout << "--- Testing ExternalProduct (standard) ---" << endl; + testExternalProduct(); + cout << endl; + // Test 2: Actual Double Decomposition (l̅=2) + cout << "=== Test 2: Double Decomposition (l̅=2) ===" << endl; cout << "DDTestParam configuration:" << endl; - cout << " Torus type: uint64_t (64-bit)" << endl; - cout << " n = " << DDTestParam::n << " (polynomial degree)" << endl; - cout << " k = " << DDTestParam::k << endl; - cout << " Primary decomposition: l = " << DDTestParam::l + cout << " n = " << DDTestParam::n << ", k = " << DDTestParam::k << endl; + cout << " Primary: l = " << DDTestParam::l << ", Bgbit = " << DDTestParam::Bgbit << endl; - cout << " Auxiliary decomposition: l̅ = " << DDTestParam::l̅ - << ", B̅gbit = " << DDTestParam::B̅gbit << endl; - cout << " Total decomposition levels: " << (DDTestParam::l * DDTestParam::l̅) - << endl; - cout << " Bits used: " + cout << " Auxiliary: l̅ = " << DDTestParam::l̅ + << ", B̅gbit = " << DDTestParam::B̅gbit + << " (DD path)" << endl; + cout << " Total bits: " << (DDTestParam::l * DDTestParam::Bgbit + DDTestParam::l̅ * DDTestParam::B̅gbit) << " / 64" << endl; - cout << " TRGSW size: " + cout << " TRGSW rows: " << (DDTestParam::k * DDTestParam::lₐ * DDTestParam::l̅ₐ + DDTestParam::l * DDTestParam::l̅) - << " TRLWE rows" << endl; + << endl; cout << endl; - cout << "--- Testing DoubleDecomposition ---" << endl; - testDoubleDecomposition(); - cout << endl; + // Note: Skip testDoubleDecomposition for l̅>1 as that tests the old + // bivariate polynomial decomposition. The new DD algorithm uses + // TRLWEBaseBbarDecompose which is tested implicitly via ExternalProduct. cout << "--- Testing ExternalProduct (DD) ---" << endl; testExternalProduct(); From 9f6eecc588d714f96047ab911dfb1b8816062598 Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Thu, 1 Jan 2026 15:12:38 +0000 Subject: [PATCH 8/9] Refactor TRGSW encryption to properly follow DD blueprint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Double Decomposition algorithm requires: 1. First create ordinary TRGSW (k*lₐ + l rows) with encrypted zeros + gadget 2. Then apply DD to each TRLWE row, expanding to k*lₐ*l̅ₐ + l*l̅ rows Previous implementation incorrectly encrypted zeros into all DD-expanded rows then tried to transform in-place, which used wrong encrypted zeros. Changes: - trgswSymEncryptImpl: Create ordinary TRGSW first, then apply DD - halftrgswSymEncryptImpl: Same pattern for HalfTRGSW - trgswSymEncryptOne: New function for encrypting constant 1 with DD support - trgswhadd/halftrgswhadd/trgswhoneadd: Simplified to standard-only with static_assert to catch misuse with DD parameters 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- include/trgsw.hpp | 327 +++++++++++++++++++++++++++++++--------------- 1 file changed, 225 insertions(+), 102 deletions(-) diff --git a/include/trgsw.hpp b/include/trgsw.hpp index 20a878f1..82140101 100644 --- a/include/trgsw.hpp +++ b/include/trgsw.hpp @@ -928,145 +928,228 @@ constexpr std::array nonceh̅gen() return h̅; } +// Add gadget values to HalfTRGSW (standard decomposition only) +// For Double Decomposition, use halftrgswSymEncrypt directly template inline void halftrgswhadd(HalfTRGSW

&halftrgsw, const Polynomial

&p) { + static_assert(P::l̅ == 1, + "halftrgswhadd only supports standard decomposition (l̅=1). " + "Use halftrgswSymEncrypt for DD."); constexpr std::array h = hgen

(); - - if constexpr (P::l̅ > 1) { - // Double Decomposition: first add ordinary gadget, then decompose TRLWE - // Create l temporary TRLWEs with ordinary gadget values p * h[i] - for (int i = 0; i < P::l; i++) { - TRLWE

temp_trlwe = halftrgsw[i * P::l̅]; // Copy base row (encrypted zero) - for (int j = 0; j < P::n; j++) { - temp_trlwe[P::k][j] += - static_cast(p[j]) * h[i]; - } - // Decompose this TRLWE to base B̅g to get l̅ rows - std::array, P::l̅> decomposed; - TRLWEBaseBbarDecompose

(decomposed, temp_trlwe); - for (int ī = 0; ī < P::l̅; ī++) { - halftrgsw[i * P::l̅ + ī] = decomposed[ī]; - } - } - } - else { - // Standard decomposition (l̅ = 1): add ordinary gadget values directly - for (int i = 0; i < P::l; i++) { - for (int j = 0; j < P::n; j++) { - halftrgsw[i][P::k][j] += - static_cast(p[j]) * h[i]; - } + for (int i = 0; i < P::l; i++) { + for (int j = 0; j < P::n; j++) { + halftrgsw[i][P::k][j] += + static_cast(p[j]) * h[i]; } } } +// Add gadget values to TRGSW (standard decomposition only) +// For Double Decomposition, use trgswSymEncrypt directly template inline void trgswhadd(TRGSW

&trgsw, const Polynomial

&p) { + static_assert(P::l̅ == 1 && P::l̅ₐ == 1, + "trgswhadd only supports standard decomposition (l̅=l̅ₐ=1). " + "Use trgswSymEncrypt for DD."); constexpr std::array nonceh = noncehgen

(); + constexpr std::array h = hgen

(); - if constexpr (P::l̅ₐ > 1) { - // DD for nonce part: first add ordinary gadget, then decompose TRLWE + // Nonce part + for (int i = 0; i < P::lₐ; i++) { for (int k = 0; k < P::k; k++) { - for (int i = 0; i < P::lₐ; i++) { - TRLWE

temp_trlwe = trgsw[(i * P::l̅ₐ) + k * P::lₐ * P::l̅ₐ]; - for (int j = 0; j < P::n; j++) { - temp_trlwe[k][j] += - static_cast(p[j]) * nonceh[i]; - } - // Decompose to base B̅gₐ - std::array, P::l̅ₐ> decomposed; - TRLWEBaseBbarDecomposeNonce

(decomposed, temp_trlwe); - for (int ī = 0; ī < P::l̅ₐ; ī++) { - trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ] = decomposed[ī]; - } + for (int j = 0; j < P::n; j++) { + trgsw[i + k * P::lₐ][k][j] += + static_cast(p[j]) * nonceh[i]; } } } - else { - // Standard decomposition for nonce part - for (int i = 0; i < P::lₐ; i++) { - for (int k = 0; k < P::k; k++) { - for (int j = 0; j < P::n; j++) { - trgsw[i + k * P::lₐ][k][j] += - static_cast(p[j]) * nonceh[i]; - } - } + + // Main part + for (int i = 0; i < P::l; i++) { + for (int j = 0; j < P::n; j++) { + trgsw[i + P::k * P::lₐ][P::k][j] += + static_cast(p[j]) * h[i]; } } +} +// Add gadget values for constant 1 to TRGSW (standard decomposition only) +// For Double Decomposition, use trgswSymEncryptOne directly +template +inline void trgswhoneadd(TRGSW

&trgsw) +{ + static_assert(P::l̅ == 1 && P::l̅ₐ == 1, + "trgswhoneadd only supports standard decomposition (l̅=l̅ₐ=1). " + "Use trgswSymEncryptOne for DD."); + constexpr std::array nonceh = noncehgen

(); constexpr std::array h = hgen

(); - if constexpr (P::l̅ > 1) { - // DD for main part: first add ordinary gadget, then decompose TRLWE - for (int i = 0; i < P::l; i++) { - TRLWE

temp_trlwe = trgsw[(i * P::l̅) + P::k * P::lₐ * P::l̅ₐ]; - for (int j = 0; j < P::n; j++) { - temp_trlwe[P::k][j] += - static_cast(p[j]) * h[i]; + // Nonce part + for (int i = 0; i < P::lₐ; i++) + for (int k = 0; k < P::k; k++) + trgsw[i + k * P::lₐ][k][0] += nonceh[i]; + + // Main part + for (int i = 0; i < P::l; i++) + trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; +} + +// Encrypt constant 1 in TRGSW with proper DD support +template +void trgswSymEncryptOneImpl(TRGSW

&trgsw, const NoiseType noise, + const Key

&key) +{ + constexpr std::array nonceh = noncehgen

(); + constexpr std::array h = hgen

(); + + if constexpr (P::l̅ > 1 || P::l̅ₐ > 1) { + // Double Decomposition path + constexpr int ordinary_rows = P::k * P::lₐ + P::l; + std::array, ordinary_rows> ordinary_trgsw; + for (auto &trlwe : ordinary_trgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + // Add gadget for constant 1 + for (int i = 0; i < P::lₐ; i++) + for (int k_idx = 0; k_idx < P::k; k_idx++) + ordinary_trgsw[i + k_idx * P::lₐ][k_idx][0] += nonceh[i]; + + for (int i = 0; i < P::l; i++) + ordinary_trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; + + // Apply DD + for (int k_idx = 0; k_idx < P::k; k_idx++) { + for (int i = 0; i < P::lₐ; i++) { + std::array, P::l̅ₐ> decomposed; + TRLWEBaseBbarDecomposeNonce

(decomposed, + ordinary_trgsw[i + k_idx * P::lₐ]); + for (int j = 0; j < P::l̅ₐ; j++) + trgsw[(i * P::l̅ₐ + j) + k_idx * P::lₐ * P::l̅ₐ] = + decomposed[j]; } - // Decompose to base B̅g + } + for (int i = 0; i < P::l; i++) { std::array, P::l̅> decomposed; - TRLWEBaseBbarDecompose

(decomposed, temp_trlwe); - for (int ī = 0; ī < P::l̅; ī++) { - trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ] = decomposed[ī]; - } + TRLWEBaseBbarDecompose

(decomposed, + ordinary_trgsw[i + P::k * P::lₐ]); + for (int j = 0; j < P::l̅; j++) + trgsw[(i * P::l̅ + j) + P::k * P::lₐ * P::l̅ₐ] = decomposed[j]; } } else { - // Standard decomposition for main part - for (int i = 0; i < P::l; i++) { - for (int j = 0; j < P::n; j++) { - trgsw[i + P::k * P::lₐ][P::k][j] += - static_cast(p[j]) * h[i]; - } - } + // Standard path + for (TRLWE

&trlwe : trgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + for (int i = 0; i < P::lₐ; i++) + for (int k_idx = 0; k_idx < P::k; k_idx++) + trgsw[i + k_idx * P::lₐ][k_idx][0] += nonceh[i]; + + for (int i = 0; i < P::l; i++) + trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; } } template -inline void trgswhoneadd(TRGSW

&trgsw) +void trgswSymEncryptOne(TRGSW

&trgsw, const double α, const Key

&key) +{ + trgswSymEncryptOneImpl

(trgsw, α, key); +} + +template +void trgswSymEncryptOne(TRGSW

&trgsw, const uint η, const Key

&key) +{ + trgswSymEncryptOneImpl

(trgsw, η, key); +} + +template +void trgswSymEncryptOne(TRGSW

&trgsw, const Key

&key) +{ + if constexpr (P::errordist == ErrorDistribution::ModularGaussian) + trgswSymEncryptOne

(trgsw, P::α, key); + else + trgswSymEncryptOne

(trgsw, P::η, key); +} + +template +void trgswSymEncryptImpl(TRGSW

&trgsw, const Polynomial

&p, + const NoiseType noise, const Key

&key) { constexpr std::array nonceh = noncehgen

(); + constexpr std::array h = hgen

(); - if constexpr (P::l̅ₐ > 1) { - // DD for nonce part: add ordinary gadget, then decompose - for (int k = 0; k < P::k; k++) { + if constexpr (P::l̅ > 1 || P::l̅ₐ > 1) { + // Double Decomposition path: + // Step 1: Create ordinary TRGSW with k*lₐ + l rows + constexpr int ordinary_rows = P::k * P::lₐ + P::l; + std::array, ordinary_rows> ordinary_trgsw; + for (auto &trlwe : ordinary_trgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + // Step 2: Add gadget values to create ordinary TRGSW + // Nonce part + for (int i = 0; i < P::lₐ; i++) { + for (int k_idx = 0; k_idx < P::k; k_idx++) { + for (int n = 0; n < P::n; n++) { + ordinary_trgsw[i + k_idx * P::lₐ][k_idx][n] += + static_cast(p[n]) * nonceh[i]; + } + } + } + // Main part + for (int i = 0; i < P::l; i++) { + for (int n = 0; n < P::n; n++) { + ordinary_trgsw[i + P::k * P::lₐ][P::k][n] += + static_cast(p[n]) * h[i]; + } + } + + // Step 3: Apply DD to each ordinary row + // Nonce part: each of the k*lₐ rows expands to l̅ₐ rows + for (int k_idx = 0; k_idx < P::k; k_idx++) { for (int i = 0; i < P::lₐ; i++) { - TRLWE

temp_trlwe = trgsw[(i * P::l̅ₐ) + k * P::lₐ * P::l̅ₐ]; - temp_trlwe[k][0] += nonceh[i]; // Add 1 * h[i] std::array, P::l̅ₐ> decomposed; - TRLWEBaseBbarDecomposeNonce

(decomposed, temp_trlwe); - for (int ī = 0; ī < P::l̅ₐ; ī++) { - trgsw[(i * P::l̅ₐ + ī) + k * P::lₐ * P::l̅ₐ] = decomposed[ī]; + TRLWEBaseBbarDecomposeNonce

(decomposed, + ordinary_trgsw[i + k_idx * P::lₐ]); + for (int j = 0; j < P::l̅ₐ; j++) { + trgsw[(i * P::l̅ₐ + j) + k_idx * P::lₐ * P::l̅ₐ] = + decomposed[j]; } } } - } - else { - for (int i = 0; i < P::lₐ; i++) - for (int k = 0; k < P::k; k++) - trgsw[i + k * P::lₐ][k][0] += nonceh[i]; - } - - constexpr std::array h = hgen

(); - - if constexpr (P::l̅ > 1) { - // DD for main part: add ordinary gadget, then decompose + // Main part: each of the l rows expands to l̅ rows for (int i = 0; i < P::l; i++) { - TRLWE

temp_trlwe = trgsw[(i * P::l̅) + P::k * P::lₐ * P::l̅ₐ]; - temp_trlwe[P::k][0] += h[i]; // Add 1 * h[i] std::array, P::l̅> decomposed; - TRLWEBaseBbarDecompose

(decomposed, temp_trlwe); - for (int ī = 0; ī < P::l̅; ī++) { - trgsw[(i * P::l̅ + ī) + P::k * P::lₐ * P::l̅ₐ] = decomposed[ī]; + TRLWEBaseBbarDecompose

(decomposed, + ordinary_trgsw[i + P::k * P::lₐ]); + for (int j = 0; j < P::l̅; j++) { + trgsw[(i * P::l̅ + j) + P::k * P::lₐ * P::l̅ₐ] = decomposed[j]; } } } else { - for (int i = 0; i < P::l; i++) - trgsw[i + P::k * P::lₐ][P::k][0] += h[i]; + // Standard path (no DD): encrypt and add gadget directly + for (TRLWE

&trlwe : trgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + // Nonce part + for (int i = 0; i < P::lₐ; i++) { + for (int k_idx = 0; k_idx < P::k; k_idx++) { + for (int n = 0; n < P::n; n++) { + trgsw[i + k_idx * P::lₐ][k_idx][n] += + static_cast(p[n]) * nonceh[i]; + } + } + } + // Main part + for (int i = 0; i < P::l; i++) { + for (int n = 0; n < P::n; n++) { + trgsw[i + P::k * P::lₐ][P::k][n] += + static_cast(p[n]) * h[i]; + } + } } } @@ -1074,16 +1157,14 @@ template void trgswSymEncrypt(TRGSW

&trgsw, const Polynomial

&p, const double α, const Key

&key) { - for (TRLWE

&trlwe : trgsw) trlweSymEncryptZero

(trlwe, α, key); - trgswhadd

(trgsw, p); + trgswSymEncryptImpl

(trgsw, p, α, key); } template void trgswSymEncrypt(TRGSW

&trgsw, const Polynomial

&p, const uint η, const Key

&key) { - for (TRLWE

&trlwe : trgsw) trlweSymEncryptZero

(trlwe, η, key); - trgswhadd

(trgsw, p); + trgswSymEncryptImpl

(trgsw, p, η, key); } template @@ -1096,20 +1177,62 @@ void trgswSymEncrypt(TRGSW

&trgsw, const Polynomial

&p, trgswSymEncrypt

(trgsw, p, P::η, key); } +template +void halftrgswSymEncryptImpl(HalfTRGSW

&halftrgsw, const Polynomial

&p, + const NoiseType noise, const Key

&key) +{ + constexpr std::array h = hgen

(); + + if constexpr (P::l̅ > 1) { + // Double Decomposition path: + // Step 1: Create ordinary HalfTRGSW with l rows + std::array, P::l> ordinary_halftrgsw; + for (auto &trlwe : ordinary_halftrgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + // Step 2: Add gadget values + for (int i = 0; i < P::l; i++) { + for (int n = 0; n < P::n; n++) { + ordinary_halftrgsw[i][P::k][n] += + static_cast(p[n]) * h[i]; + } + } + + // Step 3: Apply DD to each row, expanding l rows to l*l̅ rows + for (int i = 0; i < P::l; i++) { + std::array, P::l̅> decomposed; + TRLWEBaseBbarDecompose

(decomposed, ordinary_halftrgsw[i]); + for (int j = 0; j < P::l̅; j++) { + halftrgsw[i * P::l̅ + j] = decomposed[j]; + } + } + } + else { + // Standard path (no DD) + for (TRLWE

&trlwe : halftrgsw) + trlweSymEncryptZero

(trlwe, noise, key); + + for (int i = 0; i < P::l; i++) { + for (int n = 0; n < P::n; n++) { + halftrgsw[i][P::k][n] += + static_cast(p[n]) * h[i]; + } + } + } +} + template void halftrgswSymEncrypt(HalfTRGSW

&halftrgsw, const Polynomial

&p, const double α, const Key

&key) { - for (TRLWE

&trlwe : halftrgsw) trlweSymEncryptZero

(trlwe, α, key); - halftrgswhadd

(halftrgsw, p); + halftrgswSymEncryptImpl

(halftrgsw, p, α, key); } template void halftrgswSymEncrypt(HalfTRGSW

&halftrgsw, const Polynomial

&p, const uint η, const Key

&key) { - for (TRLWE

&trlwe : halftrgsw) trlweSymEncryptZero

(trlwe, η, key); - halftrgswhadd

(halftrgsw, p); + halftrgswSymEncryptImpl

(halftrgsw, p, η, key); } template From 37b7e6706076ed4076d6eec67342dc2888ca542e Mon Sep 17 00:00:00 2001 From: nindanaoto Date: Wed, 7 Jan 2026 16:29:04 +0000 Subject: [PATCH 9/9] Remove 64-bit shift workarounds from DD implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With Double Decomposition, TwistIFFT always receives decomposition digits (small integers), never raw Torus values. This allows us to simplify the implementation by removing all 64-bit shift workarounds: - TwistIFFT: Use low 64 bits directly instead of >> 64 - TwistFFT: Store in low 64 bits instead of << 64 - All decomposition functions: Remove << 64 shift - RecombineTRLWEFromDD/Nonce: Remove fft_offset compensation The code is now much cleaner and the recombination logic is straightforward. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- include/mulfft.hpp | 14 ++--- include/trgsw.hpp | 126 +++++++-------------------------------------- 2 files changed, 28 insertions(+), 112 deletions(-) diff --git a/include/mulfft.hpp b/include/mulfft.hpp index 57ffbaa9..7caf5dc3 100644 --- a/include/mulfft.hpp +++ b/include/mulfft.hpp @@ -89,12 +89,13 @@ inline void TwistFFT(Polynomial

&res, const PolynomialInFD

&a) fftplvl1.execute_direct_torus64(res.data(), a.data()); } else if constexpr (std::is_same_v) { - // For 128-bit lvl3param, use 64-bit FFT and shift result to top 64 bits - // This preserves the Torus semantics (most significant bits) + // For 128-bit lvl3param with Double Decomposition: + // Output is intermediate result that will be recombined + // Store in low 64 bits - reconstruction handles proper positioning alignas(64) std::array temp; fftplvl3.execute_direct_torus64(temp.data(), a.data()); for (int i = 0; i < P::n; i++) - res[i] = static_cast<__uint128_t>(temp[i]) << 64; + res[i] = static_cast<__uint128_t>(temp[i]); } else if constexpr (std::is_same_v) fftplvl2.execute_direct_torus64(res.data(), a.data()); @@ -151,11 +152,12 @@ inline void TwistIFFT(PolynomialInFD

&res, const Polynomial

&a) fftplvl1.execute_reverse_torus64(res.data(), a.data()); } else if constexpr (std::is_same_v) { - // For 128-bit lvl3param, use top 64 bits for FFT - // This preserves the Torus semantics (most significant bits) + // For 128-bit lvl3param with Double Decomposition: + // Input is always decomposition digits (small integers in low 64 bits) + // Use low 64 bits directly - no shift needed alignas(64) std::array temp; for (int i = 0; i < P::n; i++) - temp[i] = static_cast(a[i] >> 64); + temp[i] = static_cast(a[i]); fftplvl3.execute_reverse_torus64(res.data(), temp.data()); } else if constexpr (std::is_same_v) diff --git a/include/trgsw.hpp b/include/trgsw.hpp index 82140101..257807b9 100644 --- a/include/trgsw.hpp +++ b/include/trgsw.hpp @@ -75,21 +75,12 @@ inline void Decomposition(DecomposedPolynomial

&decpoly, constexpr typename P::T halfBg = (static_cast(1) << (P::Bgbit - 1)); for (int i = 0; i < P::n; i++) { - for (int l = 0; l < P::l; l++) { - auto decomp_val = - static_cast>( - (((poly[i] + offset + roundoffset) >> - (std::numeric_limits::digits - - (l + 1) * P::Bgbit)) & - mask) - - halfBg); - // For 128-bit types, shift left by 64 so TwistIFFT (which uses - // top 64 bits) gets the correct small integer value - if constexpr (std::is_same_v) - decpoly[l][i] = static_cast(decomp_val) << 64; - else - decpoly[l][i] = decomp_val; - } + for (int l = 0; l < P::l; l++) + decpoly[l][i] = (((poly[i] + offset + roundoffset) >> + (std::numeric_limits::digits - + (l + 1) * P::Bgbit)) & + mask) - + halfBg; } } @@ -117,21 +108,12 @@ inline void NonceDecomposition(DecomposedNoncePolynomial

&decpoly, constexpr typename P::T halfBg = (static_cast(1) << (P::Bgₐbit - 1)); for (int i = 0; i < P::n; i++) { - for (int l = 0; l < P::lₐ; l++) { - auto decomp_val = - static_cast>( - (((poly[i] + offset + roundoffset) >> - (std::numeric_limits::digits - - (l + 1) * P::Bgₐbit)) & - mask) - - halfBg); - // For 128-bit types, shift left by 64 so TwistIFFT (which uses - // top 64 bits) gets the correct small integer value - if constexpr (std::is_same_v) - decpoly[l][i] = static_cast(decomp_val) << 64; - else - decpoly[l][i] = decomp_val; - } + for (int l = 0; l < P::lₐ; l++) + decpoly[l][i] = (((poly[i] + offset + roundoffset) >> + (std::numeric_limits::digits - + (l + 1) * P::Bgₐbit)) & + mask) - + halfBg; } } @@ -177,16 +159,7 @@ inline void DoubleDecomposition(DecomposedPolynomialDD

&decpoly, // When l̅=1 (j=0 only), this reduces to standard decomposition const int shift = std::numeric_limits::digits - (i + 1) * P::Bgbit - j * P::B̅gbit; - auto decomp_val = - static_cast>( - ((a >> shift) & maskBg) - halfBg); - // For 128-bit types, shift left by 64 so TwistIFFT (which uses - // top 64 bits) gets the correct small integer value - if constexpr (std::is_same_v) - decpoly[i * P::l̅ + j][n] = - static_cast(decomp_val) << 64; - else - decpoly[i * P::l̅ + j][n] = decomp_val; + decpoly[i * P::l̅ + j][n] = ((a >> shift) & maskBg) - halfBg; } } } @@ -230,16 +203,7 @@ inline void NonceDoubleDecomposition(DecomposedNoncePolynomialDD

&decpoly, // When l̅ₐ=1 (j=0 only), this reduces to standard decomposition const int shift = std::numeric_limits::digits - (i + 1) * P::Bgₐbit - j * P::B̅gₐbit; - auto decomp_val = - static_cast>( - ((a >> shift) & maskBg) - halfBg); - // For 128-bit types, shift left by 64 so TwistIFFT (which uses - // top 64 bits) gets the correct small integer value - if constexpr (std::is_same_v) - decpoly[i * P::l̅ₐ + j][n] = - static_cast(decomp_val) << 64; - else - decpoly[i * P::l̅ₐ + j][n] = decomp_val; + decpoly[i * P::l̅ₐ + j][n] = ((a >> shift) & maskBg) - halfBg; } } } @@ -284,20 +248,7 @@ inline void TRLWEBaseBbarDecompose(std::array, P::l̅> &result, // Extract j-th digit from MSB side const int shift = std::numeric_limits::digits - (j + 1) * P::B̅gbit; - // Get masked value and compute signed digit - // Use signed arithmetic to properly handle negative digits - typename P::T masked = (a >> shift) & maskB̅g; - // Cast to signed, subtract halfB̅g, then back to unsigned for - // proper sign extension - using SignedT = std::make_signed_t; - SignedT digit = - static_cast(masked) - static_cast(halfB̅g); - // For 128-bit types, shift left by 64 so TwistIFFT (which uses - // top 64 bits) gets the correct small integer value - if constexpr (std::is_same_v) - result[j][k][n] = static_cast(digit) << 64; - else - result[j][k][n] = static_cast(digit); + result[j][k][n] = ((a >> shift) & maskB̅g) - halfB̅g; } } } @@ -339,18 +290,7 @@ inline void TRLWEBaseBbarDecomposeNonce(std::array, P::l̅ₐ> &result, // Extract j-th digit from MSB side const int shift = std::numeric_limits::digits - (j + 1) * P::B̅gₐbit; - // Get masked value and compute signed digit - // Use signed arithmetic to properly handle negative digits - typename P::T masked = (a >> shift) & maskB̅g; - using SignedT = std::make_signed_t; - SignedT digit = - static_cast(masked) - static_cast(halfB̅g); - // For 128-bit types, shift left by 64 so TwistIFFT (which uses - // top 64 bits) gets the correct small integer value - if constexpr (std::is_same_v) - result[j][k][n] = static_cast(digit) << 64; - else - result[j][k][n] = static_cast(digit); + result[j][k][n] = ((a >> shift) & maskB̅g) - halfB̅g; } } } @@ -398,15 +338,11 @@ void NonceDecomposition(DecomposedNoncePolynomialRAINTT

&decpolyntt, // Recombine l̅ TRLWEs from Double Decomposition back to single TRLWE // result[j] contains j-th B̅g digit; recombine as: res = Σⱼ result[j] * 2^(width - (j+1)*B̅gbit) -// For 128-bit types, TwistFFT places results in top 64 bits, so we adjust shifts accordingly template inline void RecombineTRLWEFromDD(TRLWE

&res, const std::array, P::l̅> &decomposed) { constexpr int width = std::numeric_limits::digits; - // For 128-bit types, TwistFFT adds a << 64 shift, so we compensate - constexpr int fft_offset = - std::is_same_v ? 64 : 0; // Initialize result to zero for (int k = 0; k <= P::k; k++) { @@ -417,34 +353,21 @@ inline void RecombineTRLWEFromDD(TRLWE

&res, // Add all components with appropriate shifts for (int j = 0; j < P::l̅; j++) { - // Target shift: width - (j+1)*B̅gbit - // Actual shift needed: target - fft_offset - const int target_shift = width - (j + 1) * P::B̅gbit; - const int actual_shift = target_shift - fft_offset; - + const int shift = width - (j + 1) * P::B̅gbit; for (int k = 0; k <= P::k; k++) { for (int n = 0; n < P::n; n++) { - if (actual_shift >= 0) { - res[k][n] += decomposed[j][k][n] << actual_shift; - } - else { - res[k][n] += decomposed[j][k][n] >> (-actual_shift); - } + res[k][n] += decomposed[j][k][n] << shift; } } } } // Recombine l̅ₐ TRLWEs from Double Decomposition (nonce version) -// For 128-bit types, TwistFFT places results in top 64 bits, so we adjust shifts accordingly template inline void RecombineTRLWEFromDDNonce( TRLWE

&res, const std::array, P::l̅ₐ> &decomposed) { constexpr int width = std::numeric_limits::digits; - // For 128-bit types, TwistFFT adds a << 64 shift, so we compensate - constexpr int fft_offset = - std::is_same_v ? 64 : 0; // Initialize result to zero for (int k = 0; k <= P::k; k++) { @@ -455,19 +378,10 @@ inline void RecombineTRLWEFromDDNonce( // Add all components with appropriate shifts for (int j = 0; j < P::l̅ₐ; j++) { - // Target shift: width - (j+1)*B̅gₐbit - // Actual shift needed: target - fft_offset - const int target_shift = width - (j + 1) * P::B̅gₐbit; - const int actual_shift = target_shift - fft_offset; - + const int shift = width - (j + 1) * P::B̅gₐbit; for (int k = 0; k <= P::k; k++) { for (int n = 0; n < P::n; n++) { - if (actual_shift >= 0) { - res[k][n] += decomposed[j][k][n] << actual_shift; - } - else { - res[k][n] += decomposed[j][k][n] >> (-actual_shift); - } + res[k][n] += decomposed[j][k][n] << shift; } } }