From 03348b04056875ccbded8ad2bd4ed595576ab4f7 Mon Sep 17 00:00:00 2001 From: perrydv Date: Mon, 17 Jun 2024 20:23:23 -0700 Subject: [PATCH] Defer initEpsilon() to start of next iteration --- nimbleHMC/R/HMC_samplers.R | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/nimbleHMC/R/HMC_samplers.R b/nimbleHMC/R/HMC_samplers.R index 56b8c23..816970b 100644 --- a/nimbleHMC/R/HMC_samplers.R +++ b/nimbleHMC/R/HMC_samplers.R @@ -896,6 +896,7 @@ sampler_NUTS <- nimbleFunction( if(d > 1) if(length(M) != d) stop('length of NUTS sampler M must match length of NUTS target nodes', call. = FALSE) if(maxTreeDepth < 1) stop('NUTS maxTreeDepth must be at least one', call. = FALSE) hmc_checkWarmup(warmupMode, warmup, 'NUTS') + init_epsilon_next_iter <- FALSE }, run = function() { ## No-U-Turn Sampler based on Stan @@ -910,6 +911,14 @@ sampler_NUTS <- nimbleFunction( mu <<- log(10*epsilon) ## curiously, Stan sets this for the first round *before* init_stepsize if(initializeEpsilon & adaptive) initEpsilon() } + if(init_epsilon_next_iter) { + if(initializeEpsilon) initEpsilon() + Hbar <<- 0 + logEpsilonBar <<- 0 + stepsizeCounter <<- 0 + mu <<- log(10*epsilon) + init_epsilon_next_iter <<- FALSE + } timesRan <<- timesRan + 1 if(printTimesRan) print('============ times ran = ', timesRan) if(printEpsilon) print('epsilon = ', epsilon) @@ -1003,11 +1012,7 @@ sampler_NUTS <- nimbleFunction( update <- FALSE if(adaptM) update <- adapt_M() if(update & adaptEpsilon) { - if(initializeEpsilon) initEpsilon() - Hbar <<- 0 - logEpsilonBar <<- 0 - stepsizeCounter <<- 0 - mu <<- log(10*epsilon) + init_epsilon_next_iter <<- TRUE } } }, @@ -1265,6 +1270,7 @@ sampler_NUTS <- nimbleFunction( Hbar <<- 0 logEpsilonBar <<- 0 stepsizeCounter <<- 0 + init_epsilon_next_iter <<- FALSE setSize(warmupSamples, adaptWindow_size, d, fillZeros = FALSE) } } @@ -1306,6 +1312,7 @@ sampler_NUTS <- nimbleFunction( adaptWindow_counter <<- 0 adaptWindow_iter <<- 0 stepsizeCounter <<- 0 + init_epsilon_next_iter <<- FALSE } ), buildDerivs = list(