diff --git a/nimbleHMC/R/HMC_samplers.R b/nimbleHMC/R/HMC_samplers.R index 599f897..e720564 100644 --- a/nimbleHMC/R/HMC_samplers.R +++ b/nimbleHMC/R/HMC_samplers.R @@ -879,6 +879,7 @@ sampler_NUTS <- nimbleFunction( warningCodes <- array(0, c(max(numWarnings,1), 2)) numDivergences <- 0 numTimesMaxTreeDepth <- 0 + init_epsilon_next_iter <- FALSE ADreset <- 1 ## nimbleLists treebranchNL <- treebranchNL_NUTS ## reference input to buildtree @@ -913,6 +914,14 @@ sampler_NUTS <- nimbleFunction( mu <<- log(10*epsilon) ## curiously, Stan sets this for the first round *before* init_stepsize if(initializeEpsilon & adaptive) initEpsilon() } + if(init_epsilon_next_iter) { + if(initializeEpsilon) initEpsilon() + Hbar <<- 0 + logEpsilonBar <<- 0 + stepsizeCounter <<- 0 + mu <<- log(10*epsilon) + init_epsilon_next_iter <<- FALSE + } timesRan <<- timesRan + 1 if(printTimesRan) print('============ times ran = ', timesRan) if(printEpsilon) print('epsilon = ', epsilon) @@ -1003,16 +1012,7 @@ sampler_NUTS <- nimbleFunction( if(adaptEpsilon) adapt_stepsize(accept_prob) update <- FALSE if(adaptM) update <- adapt_M() - if(update & adaptEpsilon) { - if(initializeEpsilon) { - inverseTransformStoreCalculate(state_sample$q) ## defensively ensure model states are up to date. - initEpsilon() - } - Hbar <<- 0 - logEpsilonBar <<- 0 - stepsizeCounter <<- 0 - mu <<- log(10*epsilon) - } + if(update & adaptEpsilon) init_epsilon_next_iter <<- TRUE } inverseTransformStoreCalculate(state_sample$q) nimCopy(from = model, to = mvSaved, row = 1, nodes = calcNodes, logProb = TRUE) @@ -1273,6 +1273,7 @@ sampler_NUTS <- nimbleFunction( Hbar <<- 0 logEpsilonBar <<- 0 stepsizeCounter <<- 0 + init_epsilon_next_iter <<- FALSE setSize(warmupSamples, adaptWindow_size, d, fillZeros = FALSE) } } @@ -1306,6 +1307,7 @@ sampler_NUTS <- nimbleFunction( warningInd <<- 0 M <<- Morig sqrtM <<- sqrt(M) + ADreset <<- 1 ## the adapt_* variables are initialized in before_chain() adaptWindow_size <<- 0 adapt_initBuffer <<- 0 @@ -1314,7 +1316,7 @@ sampler_NUTS <- nimbleFunction( adaptWindow_counter <<- 0 adaptWindow_iter <<- 0 stepsizeCounter <<- 0 - ADreset <<- 1 + init_epsilon_next_iter <<- FALSE } ), buildDerivs = list(