diff --git a/nimbleHMC/R/HMC_samplers.R b/nimbleHMC/R/HMC_samplers.R index 7b658e6..599f897 100644 --- a/nimbleHMC/R/HMC_samplers.R +++ b/nimbleHMC/R/HMC_samplers.R @@ -76,6 +76,7 @@ sampler_langevin <- nimbleFunction( empirSamp <- matrix(0, nrow = adaptInterval, ncol = d) timesRan <- 0 timesAdapted <- 0 + ADreset <- 1 ## checks if(!isTRUE(nimbleOptions('enableDerivs'))) stop('must enable NIMBLE derivatives, set nimbleOptions(enableDerivs = TRUE)', call. = FALSE) if(!isTRUE(model$modelDef[['buildDerivs']])) stop('must set buildDerivs = TRUE when building model', call. = FALSE) @@ -97,7 +98,8 @@ sampler_langevin <- nimbleFunction( methods = list( jacobian = function(q = double(2)) { values(model, target) <<- q[1:d, 1] - derivsOutput <- nimDerivs(model$calculate(calcNodes), order = 1, wrt = target) + derivsOutput <- nimDerivs(model$calculate(calcNodes), order = 1, wrt = target, reset = ADreset) + if(ADreset == 1) ADreset <<- 0 grad[1:d, 1] <<- derivsOutput$jacobian[1, 1:d] returnType(double(2)) return(grad) @@ -120,6 +122,7 @@ sampler_langevin <- nimbleFunction( timesAdapted <<- 0 scaleVec <<- matrix(1, nrow = d, ncol = 1) epsilonVec <<- scale * scaleVec + ADreset <<- 1 } ) ) @@ -394,6 +397,7 @@ sampler_NUTS_classic <- nimbleFunction( sqrtM <- sqrt(M) numDivergences <- 0 numTimesMaxTreeDepth <- 0 + ADreset <- 1 ## nimbleLists qpNLDef <- nimbleList(q = double(1), p = double(1)) btNLDef <- nimbleList(q1 = double(1), p1 = double(1), q2 = double(1), p2 = double(1), q3 = double(1), n = double(), s = double(), a = double(), na = double()) @@ -486,12 +490,14 @@ sampler_NUTS_classic <- nimbleFunction( return(ans) }, gradient_aux = function(qArg = double(1)) { - derivsOutput <- nimDerivs(calcLogProb(qArg), order = 1, wrt = nimDerivs_wrt, model = model, updateNodes = nimDerivs_updateNodes, constantNodes = nimDerivs_constantNodes) + derivsOutput <- nimDerivs(calcLogProb(qArg), order = 1, wrt = nimDerivs_wrt, model = model, updateNodes = nimDerivs_updateNodes, constantNodes = nimDerivs_constantNodes, reset = ADreset) + if(ADreset == 1) ADreset <<- 0 returnType(double(1)) return(derivsOutput$jacobian[1, 1:d]) }, gradient = function(qArg = double(1)) { - derivsOutput <- nimDerivs(gradient_aux(qArg), order = 0, wrt = nimDerivs_wrt, model = model, updateNodes = nimDerivs_updateNodes, constantNodes = nimDerivs_constantNodes) + derivsOutput <- nimDerivs(gradient_aux(qArg), order = 0, wrt = nimDerivs_wrt, model = model, updateNodes = nimDerivs_updateNodes, constantNodes = nimDerivs_constantNodes, reset = ADreset) + if(ADreset == 1) ADreset <<- 0 grad <<- derivsOutput$value }, leapfrog = function(qArg = double(1), pArg = double(1), eps = double(), first = double(), v = double()) { @@ -702,6 +708,7 @@ sampler_NUTS_classic <- nimbleFunction( sqrtM <<- sqrt(M) warmupIntervalNumber <<- 1 warmupIntervalCount <<- 0 + ADreset <<- 1 } ), buildDerivs = list( @@ -872,6 +879,7 @@ sampler_NUTS <- nimbleFunction( warningCodes <- array(0, c(max(numWarnings,1), 2)) numDivergences <- 0 numTimesMaxTreeDepth <- 0 + ADreset <- 1 ## nimbleLists treebranchNL <- treebranchNL_NUTS ## reference input to buildtree stateNL <- stateNL_NUTS ## system state (p, q, H, lp, gr_lp) @@ -1033,12 +1041,14 @@ sampler_NUTS <- nimbleFunction( return(ans) }, gradient_aux = function(qArg = double(1)) { - derivsOutput <- nimDerivs(calcLogProb(qArg), order = 1, wrt = nimDerivs_wrt, model = model, updateNodes = nimDerivs_updateNodes, constantNodes = nimDerivs_constantNodes) + derivsOutput <- nimDerivs(calcLogProb(qArg), order = 1, wrt = nimDerivs_wrt, model = model, updateNodes = nimDerivs_updateNodes, constantNodes = nimDerivs_constantNodes, reset = ADreset) + if(ADreset == 1) ADreset <<- 0 returnType(double(1)) return(derivsOutput$jacobian[1, 1:d]) }, gradient = function(qArg = double(1)) { - derivsOutput <- nimDerivs(gradient_aux(qArg), order = 0, wrt = nimDerivs_wrt, model = model, updateNodes = nimDerivs_updateNodes, constantNodes = nimDerivs_constantNodes) + derivsOutput <- nimDerivs(gradient_aux(qArg), order = 0, wrt = nimDerivs_wrt, model = model, updateNodes = nimDerivs_updateNodes, constantNodes = nimDerivs_constantNodes, reset = ADreset) + if(ADreset == 1) ADreset <<- 0 returnType(double(1)) return(derivsOutput$value) }, @@ -1304,6 +1314,7 @@ sampler_NUTS <- nimbleFunction( adaptWindow_counter <<- 0 adaptWindow_iter <<- 0 stepsizeCounter <<- 0 + ADreset <<- 1 } ), buildDerivs = list(