diff --git a/c/alt_bn128_test.cpp b/c/alt_bn128_test.cpp index ddcc5ad..93d8e6c 100644 --- a/c/alt_bn128_test.cpp +++ b/c/alt_bn128_test.cpp @@ -1,16 +1,17 @@ -#include #include +#include +#include #include "gtest/gtest.h" #include "alt_bn128.hpp" #include "fft.hpp" +#include "../../../build/mp.hpp" using namespace AltBn128; namespace { TEST(altBn128, f2_simpleMul) { - F2Element e1; F2.fromString(e1, "(2,2)"); @@ -23,8 +24,6 @@ TEST(altBn128, f2_simpleMul) { F2Element e33; F2.fromString(e33, "(0,12)"); - // std::cout << F2.toString(e3) << std::endl; - ASSERT_TRUE(F2.eq(e3, e33)); } @@ -94,17 +93,13 @@ TEST(altBn128, f12_inv) { TEST(altBn128, g1_PlusZero) { G1Point p1; - G1.add(p1, G1.one(), G1.zero()); - ASSERT_TRUE(G1.eq(p1, G1.one())); } TEST(altBn128, g1_minus_g1) { G1Point p1; - G1.sub(p1, G1.one(), G1.one()); - ASSERT_TRUE(G1.isZero(p1)); } @@ -121,7 +116,6 @@ TEST(altBn128, g1_times_4) { ASSERT_TRUE(G1.eq(p1,p2)); } - TEST(altBn128, g1_times_3) { G1Point p1; G1.add(p1, G1.one(), G1.one()); @@ -135,21 +129,18 @@ TEST(altBn128, g1_times_3) { ASSERT_TRUE(G1.eq(p1,p2)); } -TEST(altBn128, g1_times_3_exp) { + TEST(altBn128, g1_times_3_exp) { G1Point p1; G1.add(p1, G1.one(), G1.one()); G1.add(p1, p1, G1.one()); - mpz_t e; - mpz_init_set_str(e, "3", 10); - - uint8_t scalar[32]; - for (int i=0;i<32;i++) scalar[i] = 0; - mpz_export((void *)scalar, NULL, -1, 8, -1, 0, e); - mpz_clear(e); + mp_uint_t scalar; + mp_uint_t x; + mp_set(x, 3); + mp_copy(scalar, x); G1Point p2; - G1.mulByScalar(p2, G1.one(), scalar, 32); + G1.mulByScalar(p2, G1.one(), (uint8_t*)scalar, MP_N); ASSERT_TRUE(G1.eq(p1,p2)); } @@ -174,8 +165,7 @@ TEST(altBn128, g1_times_5) { ASSERT_TRUE(G1.eq(p1,p6)); } -TEST(altBn128, g1_times_65_exp) { - + TEST(altBn128, g1_times_65_exp) { G1Point p1; G1.dbl(p1, G1.one()); G1.dbl(p1, p1); @@ -185,89 +175,79 @@ TEST(altBn128, g1_times_65_exp) { G1.dbl(p1, p1); G1.add(p1, p1, G1.one()); - mpz_t e; - mpz_init_set_str(e, "65", 10); - - uint8_t scalar[32]; - for (int i=0;i<32;i++) scalar[i] = 0; - mpz_export((void *)scalar, NULL, -1, 8, -1, 0, e); - mpz_clear(e); + mp_uint_t scalar; + mp_uint_t x; + mp_set(x, 65); + mp_copy(scalar, x); G1Point p2; - G1.mulByScalar(p2, G1.one(), scalar, 32); + G1.mulByScalar(p2, G1.one(), (uint8_t*)scalar, MP_N); ASSERT_TRUE(G1.eq(p1,p2)); } -TEST(altBn128, g1_expToOrder) { - mpz_t e; - mpz_init_set_str(e, "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10); - - uint8_t scalar[32]; - - for (int i=0;i<32;i++) scalar[i] = 0; - mpz_export((void *)scalar, NULL, -1, 8, -1, 0, e); - mpz_clear(e); + TEST(altBn128, g1_expToOrder) { + mp_uint_t scalar; + mp_uint_t x; + ASSERT_EQ(mp_set(x, + "21888242871839275222246405745257275088548364400416034343698204186575808495617", + 10 + ), 0); + mp_copy(scalar, x); G1Point p1; - - G1.mulByScalar(p1, G1.one(), scalar, 32); + G1.mulByScalar(p1, G1.one(), (uint8_t *)scalar, MP_N); ASSERT_TRUE(G1.isZero(p1)); } -TEST(altBn128, g2_expToOrder) { - mpz_t e; - mpz_init_set_str(e, "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10); - - uint8_t scalar[32]; - - for (int i=0;i<32;i++) scalar[i] = 0; - mpz_export((void *)scalar, NULL, -1, 8, -1, 0, e); - mpz_clear(e); + TEST(altBn128, g2_expToOrder) { + mp_uint_t scalar; + mp_uint_t x; + ASSERT_EQ(mp_set(x, + "21888242871839275222246405745257275088548364400416034343698204186575808495617", + 10 + ), 0); + mp_copy(scalar, x); Curve>::Point p1; - - G2.mulByScalar(p1, G2.one(), scalar, 32); + G2.mulByScalar(p1, G2.one(), (uint8_t *)scalar, MP_N); ASSERT_TRUE(G2.isZero(p1)); } TEST(altBn128, multiExp) { - int NMExp = 40000; - typedef uint8_t Scalar[32]; + typedef mp_uint_t Scalar; Scalar *scalars = new Scalar[NMExp]; G1PointAffine *bases = new G1PointAffine[NMExp]; - uint64_t acc=0; - for (int i=0; i #include +#include +#include +#include + #include "misc.hpp" +#include "../../../build/mp.hpp" using namespace std; -// The function we want to execute on the new thread. - template u_int32_t FFT::log2(u_int64_t n) { assert(n!=0); @@ -23,7 +26,7 @@ static inline u_int64_t BR(u_int64_t x, u_int64_t domainPow) x = ((x & 0xFF00FF00) >> 8) | ((x & 0x00FF00FF) << 8); x = ((x & 0xF0F0F0F0) >> 4) | ((x & 0x0F0F0F0F) << 4); x = ((x & 0xCCCCCCCC) >> 2) | ((x & 0x33333333) << 2); - return (((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (32-domainPow); + return (((x & 0xAAAAAAAA) >> 1) | ((x & 0x55555555) << 1)) >> (MP_N-domainPow); } #define ROOT(s,j) (rootsOfUnit[(1<<(s))+(j)]) @@ -34,84 +37,86 @@ FFT::FFT(u_int64_t maxDomainSize, uint32_t _nThreads) { f = Field::field; + static_assert(Field::N64 == 4, "FFT expects 256-bit field elements (4x64 limbs)."); + u_int32_t domainPow = log2(maxDomainSize); - mpz_t m_qm1d2; - mpz_t m_q; - mpz_t m_nqr; - mpz_t m_aux; - mpz_init(m_qm1d2); - mpz_init(m_q); - mpz_init(m_nqr); - mpz_init(m_aux); - - f.toMpz(m_aux, f.negOne()); - - mpz_add_ui(m_q, m_aux, 1); - mpz_fdiv_q_2exp(m_qm1d2, m_aux, 1); - - mpz_set_ui(m_nqr, 2); - mpz_powm(m_aux, m_nqr, m_qm1d2, m_q); - while (mpz_cmp_ui(m_aux, 1) == 0) { - mpz_add_ui(m_nqr, m_nqr, 1); - mpz_powm(m_aux, m_nqr, m_qm1d2, m_q); + Element qm1_norm; + f.fromMontgomery(qm1_norm, f.negOne()); + + mp_uint_t qm1; + std::memcpy(qm1, (const void*)qm1_norm.v, sizeof(qm1)); + + mp_uint_t qm1d2; + mp_shr(qm1d2, qm1, 1); + + Element cand, res; + uint64_t cand_ui = 2; + for (;;) { + f.fromUI(cand, cand_ui); + f.exp(res, cand, + reinterpret_cast(qm1d2), + (unsigned)sizeof(qm1d2)); + if (!f.eq(res, f.one())) { + f.copy(nqr, cand); + break; + } + cand_ui++; } - f.fromMpz(nqr, m_nqr); - - // std::cout << "nqr: " << f.toString(nqr) << std::endl; + mp_uint_t aux; + mp_copy(aux, qm1d2); - s = 1; - mpz_set(m_aux, m_qm1d2); - while ((!mpz_tstbit(m_aux, 0))&&(s1) { - mpz_powm(m_aux, m_nqr, m_aux, m_q); - f.fromMpz(roots[1], m_aux); - mpz_set_ui(m_aux, 2); - mpz_invert(m_aux, m_aux, m_q); - f.fromMpz(powTwoInv[1], m_aux); + if (nRoots > 1) { + f.exp(roots[1], nqr, + reinterpret_cast(aux), + (unsigned)sizeof(aux)); + + Element two; + f.fromUI(two, 2); + f.inv(powTwoInv[1], two); } threadPool.parallelBlock([&] (uint64_t nThreads, uint64_t idThread) { - uint64_t increment = nRoots / nThreads; - uint64_t start = idThread==0 ? 2 : idThread * increment; - uint64_t end = idThread==nThreads-1 ? nRoots : (idThread+1) * increment; - if (end>start) { + uint64_t start = (idThread == 0) ? 2 : idThread * increment; + uint64_t end = (idThread == nThreads - 1) ? nRoots : (idThread + 1) * increment; + + if (end > start) { f.exp(roots[start], roots[1], (uint8_t *)(&start), sizeof(start)); } - for (uint64_t i=start+1; i @@ -120,50 +125,15 @@ FFT::~FFT() { delete[] powTwoInv; } -/* -template -void FFT::reversePermutationInnerLoop(Element *a, u_int64_t from, u_int64_t to, u_int32_t domainPow) { - Element tmp; - for (u_int64_t i=from; ir) { - f.copy(tmp, a[i]); - f.copy(a[i], a[r]); - f.copy(a[r], tmp); - } - } -} - - -template -void FFT::reversePermutation(Element *a, u_int64_t n) { - int domainPow = log2(n); - std::vector threads(nThreads-1); - u_int64_t increment = n / nThreads; - if (increment) { - for (u_int64_t i=0; i::reversePermutationInnerLoop, this, a, i*increment, (i+1)*increment, domainPow); - } - } - reversePermutationInnerLoop(a, (nThreads-1)*increment, n, domainPow); - if (increment) { - for (u_int32_t i=0; i void FFT::reversePermutation(Element *a, u_int64_t n) { int domainPow = log2(n); threadPool.parallelFor(0, n, [&] (int begin, int end, int numThread) { - for (u_int64_t i=begin; ir) { + if (i > r) { f.copy(tmp, a[i]); f.copy(a[i], a[r]); f.copy(a[r], tmp); @@ -172,59 +142,57 @@ void FFT::reversePermutation(Element *a, u_int64_t n) { }); } - template void FFT::fft(Element *a, u_int64_t n) { reversePermutation(a, n); - u_int64_t domainPow =log2(n); + u_int64_t domainPow = log2(n); assert(((u_int64_t)1 << domainPow) == n); - for (u_int32_t s=1; s<=domainPow; s++) { + + for (u_int32_t s = 1; s <= domainPow; s++) { u_int64_t m = 1 << s; u_int64_t mdiv2 = m >> 1; - threadPool.parallelFor(0, (n>>1), [&] (int begin, int end, int numThread) { - for (u_int64_t i=begin; i< end; i++) { + threadPool.parallelFor(0, (n >> 1), [&] (int begin, int end, int numThread) { + for (u_int64_t i = (u_int64_t)begin; i < (u_int64_t)end; i++) { Element t; Element u; - u_int64_t k=(i/mdiv2)*m; - u_int64_t j=i%mdiv2; + u_int64_t k = (i / mdiv2) * m; + u_int64_t j = i % mdiv2; - f.mul(t, root(s, j), a[k+j+mdiv2]); - f.copy(u,a[k+j]); - f.add(a[k+j], t, u); - f.sub(a[k+j+mdiv2], u, t); + f.mul(t, root(s, j), a[k + j + mdiv2]); + f.copy(u, a[k + j]); + f.add(a[k + j], t, u); + f.sub(a[k + j + mdiv2], u, t); } }); } } template -void FFT::ifft(Element *a, u_int64_t n ) { +void FFT::ifft(Element *a, u_int64_t n) { fft(a, n); - u_int64_t domainPow =log2(n); - u_int64_t nDiv2= n >> 1; + u_int64_t domainPow = log2(n); + u_int64_t nDiv2 = n >> 1; threadPool.parallelFor(1, nDiv2, [&] (int begin, int end, int numThread) { - for (u_int64_t i=begin; i> 1], a[n >> 1], powTwoInv[domainPow]); } - - template -void FFT::printVector(Element *a, u_int64_t n ) { +void FFT::printVector(Element *a, u_int64_t n) { cout << "[" << endl; - for (u_int64_t i=0; i loadHeader(BinFileUtils::BinFile *f) { - Header *h = new Header(); + std::unique_ptr
h(new Header()); + f->startReadSection(1); h->n8 = f->readU32LE(); - mpz_init(h->prime); - mpz_import(h->prime, h->n8, -1, 1, -1, 0, f->read(h->n8)); + { + const uint8_t* p = reinterpret_cast(f->read(h->n8)); + h->prime.assign(p, p + h->n8); + } h->nVars = f->readU32LE(); f->endReadSection(); - return std::unique_ptr
(h); + return h; } -} // NAMESPACE \ No newline at end of file +} // namespace WtnsUtils diff --git a/c/wtns_utils.hpp b/c/wtns_utils.hpp index 70aebc2..1766b2e 100644 --- a/c/wtns_utils.hpp +++ b/c/wtns_utils.hpp @@ -1,7 +1,9 @@ #ifndef WTNS_UTILS #define WTNS_UTILS -#include +#include +#include +#include #include "binfile_utils.hpp" @@ -10,7 +12,7 @@ namespace WtnsUtils { class Header { public: u_int32_t n8; - mpz_t prime; + std::vector prime; u_int32_t nVars; @@ -20,6 +22,6 @@ namespace WtnsUtils { std::unique_ptr
loadHeader(BinFileUtils::BinFile *f); -} +} // namespace WtnsUtils -#endif // ZKEY_UTILS_H \ No newline at end of file +#endif // WTNS_UTILS diff --git a/c/zkey_utils.cpp b/c/zkey_utils.cpp index 52a5629..f8ecbbf 100644 --- a/c/zkey_utils.cpp +++ b/c/zkey_utils.cpp @@ -1,17 +1,26 @@ - #include "zkey_utils.hpp" -namespace ZKeyUtils { - +#include -Header::Header() { -} +namespace ZKeyUtils { -Header::~Header() { - mpz_clear(qPrime); - mpz_clear(rPrime); +Header::Header() + : n8q(0), + n8r(0), + nVars(0), + nPublic(0), + domainSize(0), + nCoefs(0), + vk_alpha1(nullptr), + vk_beta1(nullptr), + vk_beta2(nullptr), + vk_gamma2(nullptr), + vk_delta1(nullptr), + vk_delta2(nullptr) +{ } +Header::~Header() = default; std::unique_ptr
loadHeader(BinFileUtils::BinFile *f) { auto h = new Header(); @@ -26,12 +35,16 @@ std::unique_ptr
loadHeader(BinFileUtils::BinFile *f) { f->startReadSection(2); h->n8q = f->readU32LE(); - mpz_init(h->qPrime); - mpz_import(h->qPrime, h->n8q, -1, 1, -1, 0, f->read(h->n8q)); + { + const uint8_t* p = reinterpret_cast(f->read(h->n8q)); + h->qPrime.assign(p, p + h->n8q); + } h->n8r = f->readU32LE(); - mpz_init(h->rPrime); - mpz_import(h->rPrime, h->n8r , -1, 1, -1, 0, f->read(h->n8r)); + { + const uint8_t* p = reinterpret_cast(f->read(h->n8r)); + h->rPrime.assign(p, p + h->n8r); + } h->nVars = f->readU32LE(); h->nPublic = f->readU32LE(); diff --git a/c/zkey_utils.hpp b/c/zkey_utils.hpp index 5bd5e67..f3503d0 100644 --- a/c/zkey_utils.hpp +++ b/c/zkey_utils.hpp @@ -1,8 +1,9 @@ #ifndef ZKEY_UTILS_H #define ZKEY_UTILS_H -#include #include +#include +#include #include "binfile_utils.hpp" @@ -13,9 +14,10 @@ namespace ZKeyUtils { public: u_int32_t n8q; - mpz_t qPrime; + std::vector qPrime; + u_int32_t n8r; - mpz_t rPrime; + std::vector rPrime; u_int32_t nVars; u_int32_t nPublic; diff --git a/tasksfile.js b/tasksfile.js index 1f6883b..af91d4f 100644 --- a/tasksfile.js +++ b/tasksfile.js @@ -41,7 +41,7 @@ function testSplitParStr() { " ../c/splitparstr.cpp"+ " ../c/splitparstr_test.cpp"+ " googletest-release-1.10.0/libgtest.a"+ - " -pthread -std=c++11 -o splitparsestr_test", {cwd: "build", nopipe: true} + " -pthread -std=c++17 -o splitparsestr_test", {cwd: "build", nopipe: true} ); sh("./splitparsestr_test", {cwd: "build", nopipe: true}); } @@ -62,7 +62,7 @@ function testAltBn128() { " fr.o"+ " googletest-release-1.10.0/libgtest.a"+ " -o altbn128_test" + - " -fmax-errors=5 -pthread -std=c++11 -fopenmp -lgmp -g", {cwd: "build", nopipe: true} + " -fmax-errors=5 -pthread -std=c++17 -fopenmp -lgmp -g", {cwd: "build", nopipe: true} ); sh("./altbn128_test", {cwd: "build", nopipe: true}); } @@ -84,7 +84,7 @@ function benchMultiExpG1() { " fr.o"+ // " googletest-release-1.10.0/libgtest.a"+ " -o multiexp_g1_benchmark" + - " -lgmp -pthread -std=c++11 -fopenmp" , {cwd: "build", nopipe: true} + " -lgmp -pthread -std=c++17 -fopenmp" , {cwd: "build", nopipe: true} ); sh("./multiexp_g1_benchmark 16777216", {cwd: "build", nopipe: true}); } @@ -105,7 +105,7 @@ function benchMultiExpG2() { " fr.o"+ // " googletest-release-1.10.0/libgtest.a"+ " -o multiexp_g2_benchmark" + - " -lgmp -pthread -std=c++11 -fopenmp" , {cwd: "build", nopipe: true} + " -lgmp -pthread -std=c++17 -fopenmp" , {cwd: "build", nopipe: true} ); sh("./multiexp_g2_benchmark 16777216", {cwd: "build", nopipe: true}); }