diff --git a/CMakeLists.txt b/CMakeLists.txt index 101712bd4a..eea9b89fb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -137,6 +137,15 @@ celeritas_setup_option(CELERITAS_CORE_RNG hipRAND CELERITAS_USE_HIP) celeritas_define_options(CELERITAS_CORE_RNG "Celeritas runtime random number generator") +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# CELERITAS_RESEED +# Random number generator selection +#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +celeritas_setup_option(CELERITAS_RESEED trackslot) +celeritas_setup_option(CELERITAS_RESEED track) +celeritas_define_options(CELERITAS_RESEED + "RNG reseeding strategy") + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # CELERITAS_CORE_GEO # Runtime geometry selection diff --git a/src/accel/LocalTransporter.cc b/src/accel/LocalTransporter.cc index 85a33b6848..ca8df217dd 100644 --- a/src/accel/LocalTransporter.cc +++ b/src/accel/LocalTransporter.cc @@ -293,6 +293,10 @@ void LocalTransporter::Push(G4Track& g4track) */ track.event_id = EventId{0}; + // Add track id and initialize step counter + track.track_id = g4track.GetTrackId(); + track.step_count = g4track.GetCurrentStepNumber(); + buffer_.push_back(track); buffer_accum_.energy += track.energy.value(); if (buffer_.size() >= auto_flush_) diff --git a/src/celeritas/global/CoreTrackView.hh b/src/celeritas/global/CoreTrackView.hh index 9ea8dc0159..61fa47a4c1 100644 --- a/src/celeritas/global/CoreTrackView.hh +++ b/src/celeritas/global/CoreTrackView.hh @@ -162,6 +162,11 @@ CoreTrackView::operator=(TrackInitializer const& init) // Initialize the simulation state this->sim() = init.sim; +// Initializer RNG state +#if CELERITAS_RESEED == CELERITAS_RESEED_TRACK + this->rng() = init.rng; +#endif + // Initialize the particle attributes this->particle() = init.particle; diff --git a/src/celeritas/phys/Primary.hh b/src/celeritas/phys/Primary.hh index 7556f232ff..aa7d751d4d 100644 --- a/src/celeritas/phys/Primary.hh +++ b/src/celeritas/phys/Primary.hh @@ -27,7 +27,9 @@ struct Primary Real3 direction{0, 0, 0}; real_type time{}; EventId event_id; - PrimaryId primary_id; + unsigned int geant_track_id{}; + unsigned int geant_step_count{}; + PrimaryId primary_id; //! todo: remove this in the future. real_type weight{1.0}; }; diff --git a/src/celeritas/phys/PrimaryGenerator.cc b/src/celeritas/phys/PrimaryGenerator.cc index 1a70eac72d..eafff2ee4a 100644 --- a/src/celeritas/phys/PrimaryGenerator.cc +++ b/src/celeritas/phys/PrimaryGenerator.cc @@ -183,6 +183,8 @@ auto PrimaryGenerator::operator()() -> result_type p.direction = sample_dir_(rng_); p.time = 0; p.event_id = EventId{event_count_}; + p.primary_id = PrimaryId{event_count_ * primaries_per_event_ + i}; + p.geant_track_id = p.primary_id.get(); } ++event_count_; return result; diff --git a/src/celeritas/track/TrackInitData.hh b/src/celeritas/track/TrackInitData.hh index 8b2c32de71..aac4201dbe 100644 --- a/src/celeritas/track/TrackInitData.hh +++ b/src/celeritas/track/TrackInitData.hh @@ -11,6 +11,7 @@ #include "corecel/data/Collection.hh" #include "corecel/data/CollectionAlgorithms.hh" #include "corecel/data/CollectionBuilder.hh" +#include "corecel/random/data/RngData.hh" #include "corecel/sys/Device.hh" #include "corecel/sys/ThreadId.hh" #include "geocel/Types.hh" @@ -69,6 +70,7 @@ struct TrackInitializer SimTrackInitializer sim; GeoTrackInitializer geo; ParticleTrackInitializer particle; + RngStateInitializer rng; //! True if assigned and valid explicit CELER_FUNCTION operator bool() const diff --git a/src/celeritas/track/detail/ProcessPrimariesExecutor.hh b/src/celeritas/track/detail/ProcessPrimariesExecutor.hh index e8d856f733..43b4e9834a 100644 --- a/src/celeritas/track/detail/ProcessPrimariesExecutor.hh +++ b/src/celeritas/track/detail/ProcessPrimariesExecutor.hh @@ -9,6 +9,10 @@ #include "corecel/Assert.hh" #include "corecel/Macros.hh" #include "corecel/cont/Span.hh" +#include "corecel/random/engine/RanluxppRngEngine.hh" +#include "corecel/random/engine/RngEngine.hh" +#include "corecel/random/engine/SplitMix64.hh" +#include "corecel/random/engine/XorwowRngEngine.hh" #include "celeritas/Quantities.hh" #include "celeritas/Types.hh" #include "celeritas/global/CoreTrackData.hh" @@ -45,6 +49,23 @@ struct ProcessPrimariesExecutor // Create track initializers from primaries inline CELER_FUNCTION void operator()(ThreadId tid) const; + + private: + // Fill the RNG state initializer for the Xorwow engine + inline CELER_FUNCTION void fillRngStateInitializer( + unsigned int seed, + unsigned int event_id, + unsigned int geant_track_id, + unsigned int geant_step_id, + XorwowRngEngine::RngStateInitializer_t& rng_init) const; + + // Fill the RNG state initializer for the Ranluxpp engine + void fillRngStateInitializer( + unsigned int seed, + unsigned int event_id, + unsigned int geant_track_id, + unsigned int geant_step_id, + RanluxppRngEngine::RngStateInitializer_t& rng_init) const; }; //---------------------------------------------------------------------------// @@ -72,11 +93,85 @@ CELER_FUNCTION void ProcessPrimariesExecutor::operator()(ThreadId tid) const ti.particle.particle_id = primary.particle_id; ti.particle.energy = primary.energy; +// Set the RNG state initializer appropriately dispatched on RNG type +#if CELERITAS_RESEED == CELERITAS_RESEED_TRACK + this->fillRngStateInitializer(params->rng.get_seed(), + ti.sim.event_id.get(), + primary.geant_track_id, + primary.geant_step_count, + ti.rng); +#endif + // Store the initializer size_type idx = counters->num_initializers - primaries.size() + tid.get(); state->init.initializers[ItemId(idx)] = ti; } +//---------------------------------------------------------------------------// +/*! + * Fill a XorwowRngEngine state initializer + */ +CELER_FUNCTION void ProcessPrimariesExecutor::fillRngStateInitializer( + unsigned int seed, + unsigned int event_id, + unsigned int geant_track_id, + unsigned int geant_step_id, + XorwowRngEngine::RngStateInitializer_t& rng_init) const +{ + // Initialize SplitMix64 with the seed XORed with the track id + SplitMix64 rng(seed ^ geant_track_id); + + // Fill first two state values + std::uint64_t val = rng(); + rng_init.xorstate[0] = static_cast(val); + rng_init.xorstate[1] = static_cast(val >> 32); + + // XOR with event id + rng.xor_state(event_id); + val = rng(); + rng_init.xorstate[2] = static_cast(val); + rng_init.xorstate[3] = static_cast(val >> 32); + + // XOR with step id + rng.xor_state(geant_step_id); + val = rng(); + rng_init.xorstate[4] = static_cast(val); + rng_init.weylstate = static_cast(val >> 32); +} + +//---------------------------------------------------------------------------// +/*! + * Fill a Ranluxpp state initializer + */ +CELER_FUNCTION +void ProcessPrimariesExecutor::fillRngStateInitializer( + unsigned int seed, + unsigned int event_id, + unsigned int geant_track_id, + unsigned int geant_step_id, + RanluxppRngEngine::RngStateInitializer_t& rng_init) const +{ + // Initialize SplitMix64 with the seed XORed with the track id + SplitMix64 rng(seed ^ geant_track_id); + + // Fill first three state values + rng_init.value.number[0] = rng(); + rng_init.value.number[1] = rng(); + rng_init.value.number[2] = rng(); + + // XOR with event id and fill next three values + rng.xor_state(event_id); + rng_init.value.number[3] = rng(); + rng_init.value.number[4] = rng(); + rng_init.value.number[5] = rng(); + + // XOR with step id and fill next three values + rng.xor_state(geant_step_id); + rng_init.value.number[6] = rng(); + rng_init.value.number[7] = rng(); + rng_init.value.number[8] = rng(); +} + //---------------------------------------------------------------------------// } // namespace detail } // namespace celeritas diff --git a/src/celeritas/track/detail/ProcessSecondariesExecutor.hh b/src/celeritas/track/detail/ProcessSecondariesExecutor.hh index f8c7d88025..a90cce9380 100644 --- a/src/celeritas/track/detail/ProcessSecondariesExecutor.hh +++ b/src/celeritas/track/detail/ProcessSecondariesExecutor.hh @@ -84,8 +84,11 @@ ProcessSecondariesExecutor::operator()(TrackSlotId tid) const // initialized in this slot TrackId const track_id{sim.track_id()}; - for (auto const& secondary : track.physics_step().secondaries()) + for (unsigned int secondary_idx : + celeritas::range(track.physics_step().secondaries().size())) { + auto const& secondary + = track.physics_step().secondaries()[secondary_idx]; if (secondary) { CELER_ASSERT(secondary.energy > zero_quantity() @@ -108,6 +111,9 @@ ProcessSecondariesExecutor::operator()(TrackSlotId tid) const ti.geo.dir = secondary.direction; ti.particle.particle_id = secondary.particle_id; ti.particle.energy = secondary.energy; +#if CELERITAS_RESEED == CELERITAS_RESEED_TRACK + ti.rng = track.rng().branch(); +#endif CELER_ASSERT(ti); if (sim.track_id() == track_id && sim.status() != TrackStatus::alive diff --git a/src/corecel/CMakeLists.txt b/src/corecel/CMakeLists.txt index d13c478b97..3fc7d30e7f 100644 --- a/src/corecel/CMakeLists.txt +++ b/src/corecel/CMakeLists.txt @@ -39,6 +39,7 @@ celeritas_generate_option_config(CELERITAS_CORE_RNG) celeritas_generate_option_config(CELERITAS_OPENMP) celeritas_generate_option_config(CELERITAS_REAL_TYPE) celeritas_generate_option_config(CELERITAS_UNITS) +celeritas_generate_option_config(CELERITAS_RESEED) #----------------------------------------------------------------------------# # Detailed build configuration for reproducibility/provenance diff --git a/src/corecel/Config.hh.in b/src/corecel/Config.hh.in index 73d6eb9fde..a141c1c32b 100644 --- a/src/corecel/Config.hh.in +++ b/src/corecel/Config.hh.in @@ -61,6 +61,8 @@ @CELERITAS_CORE_RNG_CONFIG@ +@CELERITAS_RESEED_CONFIG@ + //---------------------------------------------------------------------------// // System-specific properties for Celeritas //---------------------------------------------------------------------------// diff --git a/src/corecel/random/data/RanluxppRngData.hh b/src/corecel/random/data/RanluxppRngData.hh index 00f3629173..5a3a4403ba 100644 --- a/src/corecel/random/data/RanluxppRngData.hh +++ b/src/corecel/random/data/RanluxppRngData.hh @@ -36,6 +36,10 @@ struct RanluxppRngParamsDataImpl RanluxppArray9 advance_sequence; //// FUNCTIONS //// + + //! Get the seed as a 64-bit unsigned integer + std::uint64_t get_seed() const { return seed; } + //! Whether the data is assigned. explicit CELER_CONSTEXPR_FUNCTION operator bool() const { diff --git a/src/corecel/random/data/RngData.hh b/src/corecel/random/data/RngData.hh index 6b08b1af7c..04e7b6ccd9 100644 --- a/src/corecel/random/data/RngData.hh +++ b/src/corecel/random/data/RngData.hh @@ -28,6 +28,7 @@ template using RngParamsData = XorwowRngParamsData; template using RngStateData = XorwowRngStateData; +using RngStateInitializer = XorwowRngStateInitializer; } // namespace celeritas #elif (CELERITAS_CORE_RNG == CELERITAS_CORE_RNG_RANLUXPP) # include "RanluxppRngData.hh" @@ -37,6 +38,7 @@ template using RngParamsData = RanluxppRngParamsData; template using RngStateData = RanluxppRngStateData; +using RngStateInitializer = RanluxppRngStateInitializer; } // namespace celeritas #endif // IWYU pragma: end_exports diff --git a/src/corecel/random/data/XorwowRngData.hh b/src/corecel/random/data/XorwowRngData.hh index 969c982578..3c80e3ad2b 100644 --- a/src/corecel/random/data/XorwowRngData.hh +++ b/src/corecel/random/data/XorwowRngData.hh @@ -56,6 +56,9 @@ struct XorwowRngParamsData return 8 * sizeof(XorwowUInt); } + //! Retrieve the seed as a 64-bit unsigned integer + std::uint64_t get_seed() const { return seed[0]; } + //! Whether the data is assigned explicit CELER_FUNCTION operator bool() const { return true; } diff --git a/src/corecel/random/engine/SplitMix64.hh b/src/corecel/random/engine/SplitMix64.hh index c1624558f4..99a424d038 100644 --- a/src/corecel/random/engine/SplitMix64.hh +++ b/src/corecel/random/engine/SplitMix64.hh @@ -28,6 +28,9 @@ class SplitMix64 // Produce a random number inline CELER_FUNCTION std::uint64_t operator()(); + // XOR this state with another + inline CELER_FUNCTION void xor_state(std::uint64_t state); + private: // SplitMix64 State std::uint64_t state_; @@ -58,5 +61,14 @@ CELER_FUNCTION std::uint64_t SplitMix64::operator()() return z ^ (z >> 31); } +//---------------------------------------------------------------------------// +/*! + * Perform a XOR operation of this state with another state. + */ +CELER_FUNCTION void SplitMix64::xor_state(std::uint64_t state) +{ + state_ ^= state; +} + //---------------------------------------------------------------------------// } // namespace celeritas diff --git a/test/celeritas/global/StepperTestBase.cc b/test/celeritas/global/StepperTestBase.cc index a358f71099..e0476898dc 100644 --- a/test/celeritas/global/StepperTestBase.cc +++ b/test/celeritas/global/StepperTestBase.cc @@ -86,6 +86,11 @@ auto StepperTestBase::run(StepperInterface& step, { // Perform first step auto primaries = this->make_primaries(num_primaries); + for (auto i : celeritas::range(primaries.size())) + { + primaries[i].geant_track_id = i; + } + StepperResult counts; CELER_TRY_HANDLE(counts = step(make_span(primaries)), LogContextException{this->output_reg().get()}); diff --git a/test/celeritas/io/EventIOTestBase.cc b/test/celeritas/io/EventIOTestBase.cc index c8ece564a2..886f15eb46 100644 --- a/test/celeritas/io/EventIOTestBase.cc +++ b/test/celeritas/io/EventIOTestBase.cc @@ -131,6 +131,8 @@ void EventIOTestBase::write_test_event(Writer& write_event) const Real3{1, 0, 0}, 5.67e-9 * units::second, EventId{0}, + 0, + 0, PrimaryId{}}; Primary proton{proton_id, MevEnergy{2.34}, @@ -138,6 +140,8 @@ void EventIOTestBase::write_test_event(Writer& write_event) const Real3{0, 1, 0}, 5.78e-9 * units::second, EventId{0}, + 0, + 0, PrimaryId{}}; std::vector primaries{gamma, proton, gamma, proton}; primaries[1].position = from_cm(Real3{-3, -4, 5}); diff --git a/test/celeritas/phys/PrimaryGenerator.test.cc b/test/celeritas/phys/PrimaryGenerator.test.cc index 340eaee1d9..8ace61062a 100644 --- a/test/celeritas/phys/PrimaryGenerator.test.cc +++ b/test/celeritas/phys/PrimaryGenerator.test.cc @@ -67,6 +67,7 @@ TEST_F(PrimaryGeneratorTest, basic) std::vector particle_id; std::vector event_id; + std::vector primary_id; for (size_type i = 0; i < inp.num_events; ++i) { @@ -81,6 +82,7 @@ TEST_F(PrimaryGeneratorTest, basic) EXPECT_TRUE(is_soft_unit_vector(p.direction)); particle_id.push_back(p.particle_id.unchecked_get()); event_id.push_back(p.event_id.unchecked_get()); + primary_id.push_back(p.primary_id.unchecked_get()); } } auto primaries = generate_primaries(); @@ -88,9 +90,11 @@ TEST_F(PrimaryGeneratorTest, basic) static int const expected_particle_id[] = {0, 1, 0, 0, 1, 0}; static int const expected_event_id[] = {0, 0, 0, 1, 1, 1}; + static int const expected_primary_id[] = {0, 1, 2, 3, 4, 5}; EXPECT_VEC_EQ(expected_particle_id, particle_id); EXPECT_VEC_EQ(expected_event_id, event_id); + EXPECT_VEC_EQ(expected_primary_id, primary_id); } TEST_F(PrimaryGeneratorTest, options) diff --git a/test/corecel/random/SplitMix64.test.cc b/test/corecel/random/SplitMix64.test.cc index 26649eb02b..1a15aa3a11 100644 --- a/test/corecel/random/SplitMix64.test.cc +++ b/test/corecel/random/SplitMix64.test.cc @@ -14,7 +14,7 @@ namespace test { //---------------------------------------------------------------------------// -TEST(SplitMix64Test, host) +TEST(SplitMix64Test, draw_rng) { celeritas::SplitMix64 sm(12345); @@ -25,6 +25,24 @@ TEST(SplitMix64Test, host) EXPECT_EQ(9350289611492784363ul, sm()); } +TEST(SplitMix64Test, xor) +{ + std::uint64_t state1 = 12345; + std::uint64_t state2 = 98765; + + // Create a test RNG for testing XOR function + celeritas::SplitMix64 test_rng(state1); + test_rng.xor_state(state2); + + // Create a pre-xored RNG + celeritas::SplitMix64 ref_rng(state1 ^ state2); + + for (auto i : celeritas::range(10)) + { + EXPECT_EQ(ref_rng(), test_rng()); + } +} + //---------------------------------------------------------------------------// } // namespace test } // namespace celeritas