Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ if(CMAKE_BUILD_TYPE MATCHES Debug)
endif()

if(NOT CMAKE_CUDA_TOOLKIT_ROOT_DIR)
set(CMAKE_CUDA_TOOLKIT_ROOT_DIR "/usr/local/cuda")
set(CMAKE_CUDA_COMPILER "${CMAKE_CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc")
set(CMAKE_CUDA_TOOLKIT_ROOT_DIR "/usr")
set(CMAKE_CUDA_COMPILER "/usr/local/cuda-12.8/bin/nvcc")
endif()
if(NOT CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH})
Expand Down
19 changes: 19 additions & 0 deletions native/lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ extern "C"
NTT_Direction ntt_direction, NTT_Config cfg)
{
auto &gpu = select_gpu(device_id);
gpu.select();
cudaDeviceSynchronize();
return ntt::batch_ntt(gpu, (fr_t *)inout, lg_domain_size, ntt_direction, cfg);
}

Expand Down Expand Up @@ -172,4 +174,21 @@ extern "C"
init_cuda()
{
init_cuda_degree(24);
}

#if defined(EXPOSE_C_INTERFACE)
extern "C"
#endif
void
clear_cuda_errors_all_devices()
{
int num_gpus = ngpus();
for (int i = 0; i < num_gpus; i++) {
auto &gpu = select_gpu(i);
gpu.select();
// Clear the sticky error state from this device
cudaGetLastError();
// Ensure all operations on the default stream complete
cudaStreamSynchronize(0);
}
}
2 changes: 2 additions & 0 deletions native/lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,6 @@ EXTERN RustError compute_transpose_rev(size_t device_id, void *output, void *inp
EXTERN RustError compute_naive_transpose_rev(size_t device_id, void *output, void *input, uint32_t lg_n,
NTT_TransposeConfig cfg);

EXTERN void clear_cuda_errors_all_devices();

#endif // __ZEKNOX_CUDA_LIB_H__
85 changes: 43 additions & 42 deletions native/poseidon2/poseidon2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,36 @@ __device__ __constant__ u64 GPU_RC12[360] = {
#else
const u64 RC12[360] = {
#endif
1431286215153372998ull, 3509349009260703107ull, 2289575380984896342ull, 10625215922958251110ull, 17137022507167291684ull, 17143426961497010024ull, 9589775313463224365ull, 7736066733515538648ull, 2217569167061322248ull, 10394930802584583083ull, 4612393375016695705ull, 5332470884919453534ull,
8724526834049581439ull, 17673787971454860688ull, 2519987773101056005ull, 7999687124137420323ull, 18312454652563306701ull, 15136091233824155669ull, 1257110570403430003ull, 5665449074466664773ull, 16178737609685266571ull, 52855143527893348ull, 8084454992943870230ull, 2597062441266647183ull,
3342624911463171251ull, 6781356195391537436ull, 4697929572322733707ull, 4179687232228901671ull, 17841073646522133059ull, 18340176721233187897ull, 13152929999122219197ull, 6306257051437840427ull, 4974451914008050921ull, 11258703678970285201ull, 581736081259960204ull, 18323286026903235604ull,
10250026231324330997ull, 13321947507807660157ull, 13020725208899496943ull, 11416990495425192684ull, 7221795794796219413ull, 2607917872900632985ull, 2591896057192169329ull, 10485489452304998145ull, 9480186048908910015ull, 2645141845409940474ull, 16242299839765162610ull, 12203738590896308135ull,
5395176197344543510ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
17941136338888340715ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
7559392505546762987ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
549633128904721280ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
15658455328409267684ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
10078371877170729592ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
2349868247408080783ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
13105911261634181239ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
12868653202234053626ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
9471330315555975806ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
4580289636625406680ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
13222733136951421572ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
4555032575628627551ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
7619130111929922899ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
4547848507246491777ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
5662043532568004632ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
15723873049665279492ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
13585630674756818185ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
6990417929677264473ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
6373257983538884779ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
1005856792729125863ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
17850970025369572891ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
14306783492963476045ull, 12653264875831356889ull, 10887434669785806501ull, 7221072982690633460ull, 9953585853856674407ull, 13497620366078753434ull, 18140292631504202243ull, 17311934738088402529ull, 6686302214424395771ull, 11193071888943695519ull, 10233795775801758543ull, 3362219552562939863ull,
8595401306696186761ull, 7753411262943026561ull, 12415218859476220947ull, 12517451587026875834ull, 3257008032900598499ull, 2187469039578904770ull, 657675168296710415ull, 8659969869470208989ull, 12526098871288378639ull, 12525853395769009329ull, 15388161689979551704ull, 7880966905416338909ull,
2911694411222711481ull, 6420652251792580406ull, 323544930728360053ull, 11718666476052241225ull, 2449132068789045592ull, 17993014181992530560ull, 15161788952257357966ull, 3788504801066818367ull, 1282111773460545571ull, 8849495164481705550ull, 8380852402060721190ull, 2161980224591127360ull,
2440151485689245146ull, 17521895002090134367ull, 13821005335130766955ull, 17513705631114265826ull, 17068447856797239529ull, 17964439003977043993ull, 5685000919538239429ull, 11615940660682589106ull, 2522854885180605258ull, 12584118968072796115ull, 17841258728624635591ull, 10821564568873127316ull};
15492826721047263190ull, 11728330187201910315ull, 8836021247773420868ull, 16777404051263952451ull, 5510875212538051896ull, 6173089941271892285ull, 2927757366422211339ull, 10340958981325008808ull, 8541987352684552425ull, 9739599543776434497ull, 15073950188101532019ull, 12084856431752384512ull,
4584713381960671270ull, 8807052963476652830ull, 54136601502601741ull, 4872702333905478703ull, 5551030319979516287ull, 12889366755535460989ull, 16329242193178844328ull, 412018088475211848ull, 10505784623379650541ull, 9758812378619434837ull, 7421979329386275117ull, 375240370024755551ull,
3331431125640721931ull, 15684937309956309981ull, 578521833432107983ull, 14379242000670861838ull, 17922409828154900976ull, 8153494278429192257ull, 15904673920630731971ull, 11217863998460634216ull, 3301540195510742136ull, 9937973023749922003ull, 3059102938155026419ull, 1895288289490976132ull,
5580912693628927540ull, 10064804080494788323ull, 9582481583369602410ull, 10186259561546797986ull, 247426333829703916ull, 13193193905461376067ull, 6386232593701758044ull, 17954717245501896472ull, 1531720443376282699ull, 2455761864255501970ull, 11234429217864304495ull, 4746959618548874102ull,
11921381764981422944ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
10318423381711320787ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
8291411502347000766ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
229948027109387563ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
9152521390190983261ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
7129306032690285515ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
15395989607365232011ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
8641397269074305925ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
17256848792241043600ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
6046475228902245682ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
12041608676381094092ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
12785542378683951657ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
14546032085337914034ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
3304199118235116851ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
16499627707072547655ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
10386478025625759321ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
13475579315436919170ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
16042710511297532028ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
1411266850385657080ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
9024840976168649958ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
14047056970978379368ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
838728605080212101ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull, 0ull,
13571697342473846203ull, 17477857865056504753ull, 15963032953523553760ull, 16033593225279635898ull, 14252634232868282405ull, 8219748254835277737ull, 7459165569491914711ull, 15855939513193752003ull, 16788866461340278896ull, 7102224659693946577ull, 3024718005636976471ull, 13695468978618890430ull,
8214202050877825436ull, 2670727992739346204ull, 16259532062589659211ull, 11869922396257088411ull, 3179482916972760137ull, 13525476046633427808ull, 3217337278042947412ull, 14494689598654046340ull, 15837379330312175383ull, 8029037639801151344ull, 2153456285263517937ull, 8301106462311849241ull,
13294194396455217955ull, 17394768489610594315ull, 12847609130464867455ull, 14015739446356528640ull, 5879251655839607853ull, 9747000124977436185ull, 8950393546890284269ull, 10765765936405694368ull, 14695323910334139959ull, 16366254691123000864ull, 15292774414889043182ull, 10910394433429313384ull,
17253424460214596184ull, 3442854447664030446ull, 3005570425335613727ull, 10859158614900201063ull, 9763230642109343539ull, 6647722546511515039ull, 909012944955815706ull, 18101204076790399111ull, 11588128829349125809ull, 15863878496612806566ull, 5201119062417750399ull, 176665553780565743ull};

#define ROUNDS_F 8
#define ROUNDS_P 22
Expand All @@ -67,18 +67,19 @@ __device__ __forceinline__ void apply_m_4(gl64_t *x)
inline void apply_m_4(GoldilocksField *x)
#endif
{
auto t0 = x[0] + x[1];
auto t1 = x[2] + x[3];
auto t2 = x[1] + x[1] + t1;
auto t3 = x[3] + x[3] + t0;
auto t4 = t1 + t1 + t1 + t1 + t3;
auto t5 = t0 + t0 + t0 + t0 + t2;
auto t6 = t3 + t5;
auto t7 = t2 + t4;
x[0] = t6;
x[1] = t5;
x[2] = t7;
x[3] = t4;
auto t01 = x[0] + x[1];
auto t23 = x[2] + x[3];
auto t0123 = t01 + t23;
auto t01123 = t0123 + x[1];
auto t01233 = t0123 + x[3];
auto new_x3 = t01233 + x[0] + x[0];
auto new_x1 = t01123 + x[2] + x[2];
auto new_x0 = t01123 + t01;
auto new_x2 = t01233 + t23;
x[0] = new_x0;
x[1] = new_x1;
x[2] = new_x2;
x[3] = new_x3;
}

#ifdef USE_CUDA
Expand Down
7 changes: 5 additions & 2 deletions wrappers/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ rayon = "1.8.1"
rustacuda = "0.1"
rustacuda_core = "0.1"
rustacuda_derive = "0.1"
plonky2 = {git="https://github.com/okx/plonky2.git", rev="9a917ba27c26aca6d0e5d9760e8575cd5fc8dd0a"}
plonky2_field = {git="https://github.com/okx/plonky2.git", rev="9a917ba27c26aca6d0e5d9760e8575cd5fc8dd0a"}
# plonky2 = {git="https://github.com/okx/plonky2.git", rev="9a917ba27c26aca6d0e5d9760e8575cd5fc8dd0a"}
# plonky2_field = {git="https://github.com/okx/plonky2.git", rev="9a917ba27c26aca6d0e5d9760e8575cd5fc8dd0a"}
plonky2 = { path = "../../../plonky2/plonky2" }
plonky2_field = { path = "../../../plonky2/field" }


[build-dependencies]
cc = "^1.0.70"
Expand Down
20 changes: 17 additions & 3 deletions wrappers/rust/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@ extern crate rustacuda;
// based on: https://github.com/matter-labs/z-prize-msm-gpu/blob/main/bellman-cuda-rust/cudart-sys/build.rs
#[cfg(not(feature = "no_cuda"))]
fn build_device_wrapper() {
let cuda_runtime_api_path = PathBuf::from("/usr/local/cuda/include")
let cuda_runtime_api_path = PathBuf::from("/usr/local/cuda-12.8/include")
.join("cuda_runtime_api.h")
.to_string_lossy()
.to_string();
let binding_path = PathBuf::from("src/device").join("bindings.rs");
println!("cargo:rustc-link-search=native={}", "/usr/local/cuda/lib64");
println!(
"cargo:rustc-link-search=native={}",
"/usr/local/cuda-12.8/targets/x86_64-linux/lib"
);
println!(
"cargo:rustc-link-search=native={}",
"/usr/lib/x86_64-linux-gnu"
);
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rerun-if-changed={}", cuda_runtime_api_path);
println!(
Expand Down Expand Up @@ -107,7 +114,14 @@ fn build_lib() {
println!("cargo:rustc-link-search={}", libdir.to_str().unwrap());

// Static lib
println!("cargo:rustc-link-search=native={}", "/usr/local/cuda/lib64");
println!(
"cargo:rustc-link-search=native={}",
"/usr/local/cuda-12.8/targets/x86_64-linux/lib"
);
println!(
"cargo:rustc-link-search=native={}",
"/usr/lib/x86_64-linux-gnu"
);
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rustc-link-lib=stdc++");
println!("cargo:rustc-link-lib=static=zeknox");
Expand Down
28 changes: 14 additions & 14 deletions wrappers/rust/run_rust_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ if [ -z "$NUM_OF_GPUS" ]; then
fi
echo "Running the tests on ${NUM_OF_GPUS} GPU(s)."

cargo test --features=gl64 --test device -- test_get_number_of_gpus --exact --nocapture
cargo test --features=gl64 --test device -- test_list_devices_info_rs --exact --nocapture
cargo test --features=gl64 --test ntt -- test_intt_gl64_consistency_with_plonky2 --exact --nocapture
cargo test --features=gl64 --test ntt -- test_ntt_batch_gl64_consistency_with_plonky2 --exact --nocapture
cargo test --features=gl64 --test ntt -- test_ntt_batch_intt_batch_gl64_self_consistency --exact --nocapture
cargo test --features=gl64 --test ntt -- test_intt_batch_gl64_consistency_with_plonky2 --exact --nocapture
cargo test --features=gl64 --test ntt -- test_ntt_on_device --exact --nocapture
cargo test --features=gl64 --test ntt -- test_ntt_batch_on_device --exact --nocapture
cargo test --features=gl64 --test ntt -- test_ntt_batch_with_coset --exact --nocapture
cargo test --features=gl64 --test ntt -- test_compute_batched_lde --exact --nocapture
cargo test --features=gl64 --test ntt -- test_compute_batched_lde_data_on_device --exact --nocapture
cargo test --features=gl64 --test ntt -- test_transpose_rev --exact --nocapture
cargo test --features=gl64 --test merkle_tree -- --exact --nocapture
cargo +nightly test --features=gl64 --test device -- test_get_number_of_gpus --exact --nocapture
cargo +nightly test --features=gl64 --test device -- test_list_devices_info_rs --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_intt_gl64_consistency_with_plonky2 --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_ntt_batch_gl64_consistency_with_plonky2 --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_ntt_batch_intt_batch_gl64_self_consistency --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_intt_batch_gl64_consistency_with_plonky2 --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_ntt_on_device --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_ntt_batch_on_device --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_ntt_batch_with_coset --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_compute_batched_lde --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_compute_batched_lde_data_on_device --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_transpose_rev --exact --nocapture
cargo +nightly test --features=gl64 --test merkle_tree -- --exact --nocapture

if [ $NUM_OF_GPUS -gt 1 ]; then
cargo test --features=gl64 --test ntt -- test_compute_batched_lde_multi_gpu_data_on_one_gpu --exact --nocapture
cargo +nightly test --features=gl64 --test ntt -- test_compute_batched_lde_multi_gpu_data_on_one_gpu --exact --nocapture
fi
2 changes: 2 additions & 0 deletions wrappers/rust/src/device/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ pub enum cudaError {
cudaErrorJitCompilationDisabled = 223,
cudaErrorUnsupportedExecAffinity = 224,
cudaErrorUnsupportedDevSideSync = 225,
cudaErrorContained = 226,
cudaErrorInvalidSource = 300,
cudaErrorFileNotFound = 301,
cudaErrorSharedObjectSymbolNotFound = 302,
Expand Down Expand Up @@ -109,6 +110,7 @@ pub enum cudaError {
cudaErrorInvalidPc = 718,
cudaErrorLaunchFailure = 719,
cudaErrorCooperativeLaunchTooLarge = 720,
cudaErrorTensorMemoryLeak = 721,
cudaErrorNotPermitted = 800,
cudaErrorNotSupported = 801,
cudaErrorSystemNotReady = 802,
Expand Down
44 changes: 37 additions & 7 deletions wrappers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ extern "C" {

fn init_cuda_degree(max_degree: usize);

fn clear_cuda_errors_all_devices();

fn compute_batched_ntt(
device_id: usize,
inout: *mut core::ffi::c_void,
Expand Down Expand Up @@ -162,16 +164,28 @@ pub fn lde_batch_multi_gpu<T>(
}
}

pub fn ntt_batch<T>(
device_id: usize,
inout: *mut T, // &mut [T],
log_n_size: usize,
cfg: NTTConfig,
) {
pub fn ntt_batch<T>(device_id: usize, inout: &mut [T], log_n_size: usize, cfg: NTTConfig) {
// println!("ntt_batch log n size: {log_n_size}");
let err = unsafe {
compute_batched_ntt(
device_id,
inout.as_mut_ptr() as *mut core::ffi::c_void,
log_n_size,
types::NTTDirection::Forward,
cfg,
)
};

if err.code != 0 {
panic!("{}", String::from(err));
}
}

/// NTT batch with raw pointer (for GPU-resident data)
pub fn ntt_batch_ptr<T>(device_id: usize, inout: *mut T, log_n_size: usize, cfg: NTTConfig) {
let err = unsafe {
compute_batched_ntt(
device_id,
// inout.as_mut_ptr() as *mut core::ffi::c_void,
inout as *mut core::ffi::c_void,
log_n_size,
types::NTTDirection::Forward,
Expand Down Expand Up @@ -249,3 +263,19 @@ pub fn init_cuda_degree_rs(max_degree: usize) {
init_cuda_degree(max_degree);
}
}

/// Clears CUDA error state across all devices.
///
/// This function should be called between tests or after failed operations
/// to prevent error state propagation. CUDA errors are "sticky" - once an error
/// occurs, it persists until explicitly cleared with cudaGetLastError().
///
/// This function:
/// - Iterates through all available GPUs
/// - Clears the error queue on each device
/// - Synchronizes all streams to ensure operations complete
pub fn clear_cuda_errors_rs() {
unsafe {
clear_cuda_errors_all_devices();
}
}
Loading