Skip to content
28 changes: 14 additions & 14 deletions include/keyswitch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ constexpr typename P::domainP::T iksoffsetgen()
typename P::domainP::T offset = 0;
for (int i = 1; i <= P::t; i++)
offset +=
(1ULL << P::basebit) / 2 *
(1ULL << (std::numeric_limits<typename P::domainP::T>::digits -
(static_cast<typename P::domainP::T>(1) << P::basebit) / 2 *
(static_cast<typename P::domainP::T>(1) << (std::numeric_limits<typename P::domainP::T>::digits -
i * P::basebit));
return offset;
}
Expand All @@ -35,15 +35,15 @@ void IdentityKeySwitch(TLWE<typename P::targetP> &res,
std::numeric_limits<typename P::targetP::T>::digits;
constexpr typename P::domainP::T roundoffset =
(P::basebit * P::t) < domain_digit
? 1ULL << (domain_digit - (1 + P::basebit * P::t))
? static_cast<typename P::domainP::T>(1) << (domain_digit - (1 + P::basebit * P::t))
: 0;
if constexpr (domain_digit == target_digit)
res[P::targetP::k * P::targetP::n] =
tlwe[P::domainP::k * P::domainP::n];
else if constexpr (domain_digit > target_digit)
res[P::targetP::k * P::targetP::n] =
(tlwe[P::domainP::k * P::domainP::n] +
(1ULL << (domain_digit - target_digit - 1))) >>
(static_cast<typename P::domainP::T>(1) << (domain_digit - target_digit - 1))) >>
(domain_digit - target_digit);
else if constexpr (domain_digit < target_digit)
res[P::targetP::k * P::targetP::n] =
Expand Down Expand Up @@ -86,7 +86,7 @@ void CatIdentityKeySwitch(
std::numeric_limits<typename P::targetP::T>::digits;
constexpr typename P::domainP::T roundoffset =
(P::basebit * P::t) < domain_digit
? 1ULL << (domain_digit - (1 + P::basebit * P::t))
? static_cast<typename P::domainP::T>(1) << (domain_digit - (1 + P::basebit * P::t))
: 0;

for (int cat = 0; cat < numcat; cat++) {
Expand All @@ -96,7 +96,7 @@ void CatIdentityKeySwitch(
else if constexpr (domain_digit > target_digit)
res[cat][P::targetP::k * P::targetP::n] =
(tlwe[cat][P::domainP::k * P::domainP::n] +
(1ULL << (domain_digit - target_digit - 1))) >>
(static_cast<typename P::domainP::T>(1) << (domain_digit - target_digit - 1))) >>
(domain_digit - target_digit);
else if constexpr (domain_digit < target_digit)
res[cat][P::targetP::k * P::targetP::n] =
Expand Down Expand Up @@ -138,7 +138,7 @@ void SubsetIdentityKeySwitch(TLWE<typename P::targetP> &res,
const SubsetKeySwitchingKey<P> &ksk)
{
constexpr typename P::domainP::T prec_offset =
1ULL << (std::numeric_limits<typename P::domainP::T>::digits -
static_cast<typename P::domainP::T>(1) << (std::numeric_limits<typename P::domainP::T>::digits -
(1 + P::basebit * P::t));
constexpr uint32_t mask = (1U << P::basebit) - 1;
res = {};
Expand All @@ -154,11 +154,11 @@ void SubsetIdentityKeySwitch(TLWE<typename P::targetP> &res,
}
else if constexpr (domain_digit > target_digit) {
for (int i = 0; i < P::targetP::k * P::targetP::n; i++)
res[i] = (tlwe[i] + (1ULL << (domain_digit - target_digit - 1))) >>
res[i] = (tlwe[i] + (static_cast<typename P::domainP::T>(1) << (domain_digit - target_digit - 1))) >>
(domain_digit - target_digit);
res[P::targetP::k * P::targetP::n] =
(tlwe[P::domainP::k * P::domainP::n] +
(1ULL << (domain_digit - target_digit - 1))) >>
(static_cast<typename P::domainP::T>(1) << (domain_digit - target_digit - 1))) >>
(domain_digit - target_digit);
}
else if constexpr (domain_digit < target_digit) {
Expand Down Expand Up @@ -189,13 +189,13 @@ void PrivKeySwitch(TRLWE<typename P::targetP> &res,
const PrivateKeySwitchingKey<P> &privksk)
{
constexpr typename P::domainP::T roundoffset =
1ULL << (std::numeric_limits<typename P::domainP::T>::digits -
static_cast<typename P::domainP::T>(1) << (std::numeric_limits<typename P::domainP::T>::digits -
(1 + P::basebit * P::t));

// Koga's Optimization
constexpr typename P::domainP::T offset = iksoffsetgen<P>();
constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1;
constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1);
constexpr typename P::domainP::T mask = (static_cast<typename P::domainP::T>(1) << P::basebit) - 1;
constexpr typename P::domainP::T halfbase = static_cast<typename P::domainP::T>(1) << (P::basebit - 1);
res = {};
for (int i = 0; i <= P::domainP::k * P::domainP::n; i++) {
const typename P::domainP::T aibar = tlwe[i] + offset + roundoffset;
Expand Down Expand Up @@ -258,7 +258,7 @@ void TLWE2TRLWEIKS(TRLWE<typename P::targetP> &res,
const TLWE2TRLWEIKSKey<P> &iksk)
{
constexpr typename P::domainP::T prec_offset =
1ULL << (std::numeric_limits<typename P::domainP::T>::digits -
static_cast<typename P::domainP::T>(1) << (std::numeric_limits<typename P::domainP::T>::digits -
(1 + P::basebit * P::t));
constexpr uint32_t mask = (1U << P::basebit) - 1;
res = {};
Expand All @@ -270,7 +270,7 @@ void TLWE2TRLWEIKS(TRLWE<typename P::targetP> &res,
res[P::targetP::k][0] = tlwe[P::domainP::n];
else if constexpr (domain_digit > target_digit)
res[P::targetP::k][0] = (tlwe[P::domainP::n] +
(1ULL << (domain_digit - target_digit - 1))) >>
(static_cast<typename P::domainP::T>(1) << (domain_digit - target_digit - 1))) >>
(domain_digit - target_digit);
else if constexpr (domain_digit < target_digit)
res[P::targetP::k][0] = tlwe[P::domainP::n]
Expand Down
74 changes: 60 additions & 14 deletions include/mulfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ inline void TwistNTT(Polynomial<P> &res, PolynomialNTT<P> &a)
cuHEpp::TwistNTT<typename lvl1param::T, lvl1param::nbit>(
res, a, (*ntttablelvl1)[0], (*ntttwistlvl1)[0]);
#endif
else if constexpr (std::is_same_v<typename P::T, uint64_t>) {
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
cuHEpp::TwistNTT<typename lvl2param::T, lvl2param::nbit>(
res, a, (*ntttablelvl2)[0], (*ntttwistlvl2)[0]);
}
else
static_assert(false_v<typename P::T>, "Undefined TwistNTT!");
}
Expand All @@ -89,6 +88,15 @@ inline void TwistFFT(Polynomial<P> &res, const PolynomialInFD<P> &a)
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
fftplvl1.execute_direct_torus64(res.data(), a.data());
}
else if constexpr (std::is_same_v<P, lvl3param>) {
// For 128-bit lvl3param with Double Decomposition:
// Output is intermediate result that will be recombined
// Store in low 64 bits - reconstruction handles proper positioning
alignas(64) std::array<uint64_t, P::n> temp;
fftplvl3.execute_direct_torus64(temp.data(), a.data());
for (int i = 0; i < P::n; i++)
res[i] = static_cast<__uint128_t>(temp[i]);
}
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
fftplvl2.execute_direct_torus64(res.data(), a.data());
else
Expand Down Expand Up @@ -143,6 +151,15 @@ inline void TwistIFFT(PolynomialInFD<P> &res, const Polynomial<P> &a)
if constexpr (std::is_same_v<typename P::T, uint64_t>)
fftplvl1.execute_reverse_torus64(res.data(), a.data());
}
else if constexpr (std::is_same_v<P, lvl3param>) {
// For 128-bit lvl3param with Double Decomposition:
// Input is always decomposition digits (small integers in low 64 bits)
// Use low 64 bits directly - no shift needed
alignas(64) std::array<uint64_t, P::n> temp;
for (int i = 0; i < P::n; i++)
temp[i] = static_cast<uint64_t>(a[i]);
fftplvl3.execute_reverse_torus64(res.data(), temp.data());
}
else if constexpr (std::is_same_v<typename P::T, uint64_t>)
fftplvl2.execute_reverse_torus64(res.data(), a.data());
else
Expand Down Expand Up @@ -301,8 +318,21 @@ inline void PolyMul(Polynomial<P> &res, const Polynomial<P> &a,
for (int i = 0; i < P::n; i++) ntta[i] *= nttb[i];
TwistNTT<P>(res, ntta);
}
else if constexpr (std::is_same_v<typename P::T, __uint128_t>) {
// Naive for 128-bit types (FFT/NTT don't support 128-bit precision)
for (int i = 0; i < P::n; i++) {
__uint128_t ri = 0;
for (int j = 0; j <= i; j++)
ri += static_cast<__int128_t>(a[j]) *
static_cast<__int128_t>(b[i - j]);
for (int j = i + 1; j < P::n; j++)
ri -= static_cast<__int128_t>(a[j]) *
static_cast<__int128_t>(b[P::n + i - j]);
res[i] = ri;
}
}
else {
// Naieve
// Naive for other types
for (int i = 0; i < P::n; i++) {
typename P::T ri = 0;
for (int j = 0; j <= i; j++)
Expand Down Expand Up @@ -339,17 +369,33 @@ template <class P>
inline void PolyMulNaive(Polynomial<P> &res, const Polynomial<P> &a,
const Polynomial<P> &b)
{
for (int i = 0; i < P::n; i++) {
typename P::T ri = 0;
for (int j = 0; j <= i; j++)
ri += static_cast<typename std::make_signed<typename P::T>::type>(
a[j]) *
b[i - j];
for (int j = i + 1; j < P::n; j++)
ri -= static_cast<typename std::make_signed<typename P::T>::type>(
a[j]) *
b[P::n + i - j];
res[i] = ri;
if constexpr (std::is_same_v<typename P::T, __uint128_t>) {
for (int i = 0; i < P::n; i++) {
__uint128_t ri = 0;
for (int j = 0; j <= i; j++)
ri += static_cast<__int128_t>(a[j]) *
static_cast<__int128_t>(b[i - j]);
for (int j = i + 1; j < P::n; j++)
ri -= static_cast<__int128_t>(a[j]) *
static_cast<__int128_t>(b[P::n + i - j]);
res[i] = ri;
}
}
else {
for (int i = 0; i < P::n; i++) {
typename P::T ri = 0;
for (int j = 0; j <= i; j++)
ri +=
static_cast<typename std::make_signed<typename P::T>::type>(
a[j]) *
b[i - j];
for (int j = i + 1; j < P::n; j++)
ri -=
static_cast<typename std::make_signed<typename P::T>::type>(
a[j]) *
b[P::n + i - j];
res[i] = ri;
}
}
}

Expand Down
26 changes: 20 additions & 6 deletions include/params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ struct lvl02param {
#endif
};

struct lvl03param {
using domainP = lvl0param;
using targetP = lvl3param;
#ifdef USE_KEY_BUNDLE
static constexpr uint32_t Addends = 2;
#else
static constexpr uint32_t Addends = 1;
#endif
};

struct lvlh2param {
using domainP = lvlhalfparam;
using targetP = lvl2param;
Expand Down Expand Up @@ -118,6 +128,10 @@ using DecomposedPolynomial = std::array<Polynomial<P>, P::l>;
template <class P>
using DecomposedNoncePolynomial = std::array<Polynomial<P>, P::lₐ>;
template <class P>
using DecomposedPolynomialDD = std::array<Polynomial<P>, P::l * P::l̅>;
template <class P>
using DecomposedNoncePolynomialDD = std::array<Polynomial<P>, P::lₐ * P::l̅ₐ>;
template <class P>
using DecomposedPolynomialNTT = std::array<PolynomialNTT<P>, P::l>;
template <class P>
using DecomposedNoncePolynomialNTT = std::array<PolynomialNTT<P>, P::lₐ>;
Expand All @@ -138,17 +152,17 @@ template <class P>
using TRLWERAINTT = std::array<PolynomialRAINTT<P>, P::k + 1>;

template <class P>
using TRGSW = std::array<TRLWE<P>, P::k * P::lₐ + P::l>;
using TRGSW = std::array<TRLWE<P>, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>;
template <class P>
using HalfTRGSW = std::array<TRLWE<P>, P::l>;
using HalfTRGSW = std::array<TRLWE<P>, P::l * P::l̅>;
template <class P>
using TRGSWFFT = aligned_array<TRLWEInFD<P>, P::k * P::lₐ + P::l>;
using TRGSWFFT = aligned_array<TRLWEInFD<P>, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>;
template <class P>
using HalfTRGSWFFT = aligned_array<TRLWEInFD<P>, P::l>;
using HalfTRGSWFFT = aligned_array<TRLWEInFD<P>, P::l * P::l̅>;
template <class P>
using TRGSWNTT = std::array<TRLWENTT<P>, P::k * P::lₐ + P::l>;
using TRGSWNTT = std::array<TRLWENTT<P>, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>;
template <class P>
using TRGSWRAINTT = std::array<TRLWERAINTT<P>, P::k * P::lₐ + P::l>;
using TRGSWRAINTT = std::array<TRLWERAINTT<P>, P::k * P::lₐ * P::l̅ₐ + P::l * P::l̅>;

#ifdef USE_KEY_BUNDLE
template <class P>
Expand Down
79 changes: 79 additions & 0 deletions include/params/128bit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ struct lvl1param {
static constexpr double Δ =
static_cast<double>(1ULL << std::numeric_limits<T>::digits) /
plain_modulus;
// Double Decomposition (bivariate representation) parameters
// For now, set to trivial values (no actual second decomposition)
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
static constexpr std::uint32_t l̅ₐ = l̅;
static constexpr std::uint32_t B̅gbit =
std::numeric_limits<T>::digits; // full coefficient width
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
};

struct AHlvl1param {
Expand All @@ -83,6 +90,11 @@ struct AHlvl1param {
static constexpr std::make_signed_t<T> μ = baseP::μ;
static constexpr uint32_t plain_modulus = baseP::plain_modulus;
static constexpr double Δ = baseP::Δ;
// Double Decomposition parameters inherited from baseP
static constexpr std::uint32_t l̅ = baseP::l̅;
static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ;
static constexpr std::uint32_t B̅gbit = baseP::B̅gbit;
static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit;
};

struct lvl2param {
Expand All @@ -106,6 +118,13 @@ struct lvl2param {
static constexpr uint32_t plain_modulus = 8;
static constexpr double Δ =
static_cast<double>(1ULL << (std::numeric_limits<T>::digits - 4));
// Double Decomposition (bivariate representation) parameters
// For now, set to trivial values (no actual second decomposition)
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
static constexpr std::uint32_t l̅ₐ = l̅;
static constexpr std::uint32_t B̅gbit =
std::numeric_limits<T>::digits; // full coefficient width
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
};

struct AHlvl2param {
Expand All @@ -127,9 +146,53 @@ struct AHlvl2param {
static constexpr std::make_signed_t<T> μ = baseP::μ;
static constexpr uint32_t plain_modulus = baseP::plain_modulus;
static constexpr double Δ = baseP::Δ;
// Double Decomposition parameters inherited from baseP
static constexpr std::uint32_t l̅ = baseP::l̅;
static constexpr std::uint32_t l̅ₐ = baseP::l̅ₐ;
static constexpr std::uint32_t B̅gbit = baseP::B̅gbit;
static constexpr std::uint32_t B̅gₐbit = baseP::B̅gₐbit;
};

// lvl3param with 128-bit Torus and non-trivial Double Decomposition
// Double decomposition structure:
// - Primary decomposition (l, Bgbit): Decomposes plaintext by μ in TRGSW gadget
// - Auxiliary decomposition (l̅, B̅gbit): Decomposes TRLWE ciphertext coefficients
// in the external product. Must cover full 128-bit coefficient.
// Constraint for DD algorithm: l*Bgbit + (l̅-1)*B̅gbit ≤ 128
// Using l=2, Bgbit=16 for primary (32 bits); l̅=4, B̅gbit=32 for auxiliary (128 bits)
// This gives: 32 + 96 = 128 ✓
struct lvl3param {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = -1;
static const std::uint32_t nbit = 12; // dimension must be a power of 2 for
// ease of polynomial multiplication.
static constexpr std::uint32_t n = 1 << nbit; // dimension = 4096
static constexpr std::uint32_t k = 1;
static constexpr std::uint32_t lₐ = 2; // reduced to fit DD constraint
static constexpr std::uint32_t l = 2; // reduced to fit DD constraint
static constexpr std::uint32_t Bgbit = 16;
static constexpr std::uint32_t Bgₐbit = 16;
static constexpr uint32_t Bg = 1U << Bgbit;
static constexpr uint32_t Bgₐ = 1U << Bgₐbit;
static constexpr ErrorDistribution errordist =
ErrorDistribution::ModularGaussian;
static const inline double α = std::pow(2.0, -105); // fresh noise
using T = __uint128_t; // Torus representation
static constexpr T μ = static_cast<T>(1) << 125;
static constexpr uint32_t plain_modulusbit = 31;
static constexpr T plain_modulus = static_cast<T>(1) << plain_modulusbit;
static constexpr double Δ =
static_cast<double>(static_cast<T>(1) << (128 - plain_modulusbit - 1));
// Double Decomposition (bivariate representation) parameters
// Auxiliary decomposition must cover full 128-bit ciphertext coefficients
// l̅ * B̅gbit = 4 * 32 = 128 bits
static constexpr std::uint32_t l̅ = 4; // auxiliary decomposition levels
static constexpr std::uint32_t l̅ₐ = 4;
static constexpr std::uint32_t B̅gbit = 32; // 2^32 base for auxiliary (covers 128-bit T)
static constexpr std::uint32_t B̅gₐbit = 32;
};

struct lvl4param {
static constexpr int32_t key_value_max = 1;
static constexpr int32_t key_value_min = -1;
static const std::uint32_t nbit = 13; // dimension must be a power of 2 for
Expand All @@ -150,6 +213,13 @@ struct lvl3param {
static constexpr uint32_t plain_modulusbit = 31;
static constexpr uint64_t plain_modulus = 1ULL << plain_modulusbit;
static constexpr double Δ = 1ULL << (64 - plain_modulusbit - 1);
// Double Decomposition (bivariate representation) parameters
// Trivial values (no actual second decomposition)
static constexpr std::uint32_t l̅ = 1; // auxiliary decomposition levels
static constexpr std::uint32_t l̅ₐ = l̅;
static constexpr std::uint32_t B̅gbit =
std::numeric_limits<T>::digits; // full coefficient width
static constexpr std::uint32_t B̅gₐbit = B̅gbit;
};

// Key Switching parameters
Expand Down Expand Up @@ -239,3 +309,12 @@ struct lvl31param {
using domainP = lvl3param;
using targetP = lvl1param;
};

struct lvl41param {
static constexpr std::uint32_t t = 7; // number of addition in keyswitching
static constexpr std::uint32_t basebit =
2; // how many bit should be encrypted in keyswitching key
static const inline double α = lvl1param::α; // key noise
using domainP = lvl4param;
using targetP = lvl1param;
};
Loading
Loading