diff --git a/.gitmodules b/.gitmodules index 97ffbf03be..024d6c5791 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "stan"] path = stan - url = https://github.com/stan-dev/stan + url = https://github.com/bbbales2/stan diff --git a/src/cmdstan/arguments/arg_auto_e.hpp b/src/cmdstan/arguments/arg_auto_e.hpp new file mode 100644 index 0000000000..5cd7700622 --- /dev/null +++ b/src/cmdstan/arguments/arg_auto_e.hpp @@ -0,0 +1,17 @@ +#ifndef CMDSTAN_ARGUMENTS_ARG_AUTO_E_HPP +#define CMDSTAN_ARGUMENTS_ARG_AUTO_E_HPP + +#include + +namespace cmdstan { + + class arg_auto_e: public unvalued_argument { + public: + arg_auto_e() { + _name = "auto_e"; + _description = "Euclidean manifold that chooses between dense/diagonal metric at warmup"; + } + }; + +} +#endif diff --git a/src/cmdstan/arguments/arg_metric.hpp b/src/cmdstan/arguments/arg_metric.hpp index 01dc7a52ac..b24415fabd 100644 --- a/src/cmdstan/arguments/arg_metric.hpp +++ b/src/cmdstan/arguments/arg_metric.hpp @@ -2,6 +2,7 @@ #define CMDSTAN_ARGUMENTS_ARG_METRIC_HPP #include +#include #include #include #include @@ -17,6 +18,7 @@ class arg_metric : public list_argument { _values.push_back(new arg_unit_e()); _values.push_back(new arg_diag_e()); _values.push_back(new arg_dense_e()); + _values.push_back(new arg_auto_e()); _default_cursor = 1; _cursor = _default_cursor; diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index bd527ca953..ded0d2e451 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -424,7 +425,8 @@ int command(int argc, const char *argv[]) { "The number of warmup samples (num_warmup) must be greater than " "zero if adaptation is enabled."); return_code = stan::services::error_codes::CONFIG; - } else if (engine->value() == "nuts" && metric->value() == "dense_e" + +} else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "auto_e") && adapt_engaged == false && metric_supplied == false) { int max_depth = dynamic_cast( dynamic_cast( @@ -436,7 +438,7 @@ int command(int argc, const char *argv[]) { num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, interrupt, logger, init_writer, sample_writer, diagnostic_writer); - } else if (engine->value() == "nuts" && metric->value() == "dense_e" + } else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "auto_e") && adapt_engaged == false && metric_supplied == true) { int max_depth = dynamic_cast( dynamic_cast( @@ -504,7 +506,63 @@ int command(int argc, const char *argv[]) { stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, logger, init_writer, sample_writer, diagnostic_writer); - } else if (engine->value() == "nuts" && metric->value() == "diag_e" + } else if (engine->value() == "nuts" && metric->value() == "auto_e" + && adapt_engaged == true && metric_supplied == false) { + int max_depth = dynamic_cast( + dynamic_cast( + algo->arg("hmc")->arg("engine")->arg("nuts")) + ->arg("max_depth")) + ->value(); + double delta + = dynamic_cast(adapt->arg("delta"))->value(); + double gamma + = dynamic_cast(adapt->arg("gamma"))->value(); + double kappa + = dynamic_cast(adapt->arg("kappa"))->value(); + double t0 = dynamic_cast(adapt->arg("t0"))->value(); + unsigned int init_buffer + = dynamic_cast(adapt->arg("init_buffer")) + ->value(); + unsigned int term_buffer + = dynamic_cast(adapt->arg("term_buffer")) + ->value(); + unsigned int window + = dynamic_cast(adapt->arg("window"))->value(); + return_code = stan::services::sample::hmc_nuts_auto_e_adapt( + model, *init_context, random_seed, id, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, + term_buffer, window, interrupt, logger, init_writer, sample_writer, + diagnostic_writer); + } else if (engine->value() == "nuts" && metric->value() == "auto_e" + && adapt_engaged == true && metric_supplied == true) { + int max_depth = dynamic_cast( + dynamic_cast( + algo->arg("hmc")->arg("engine")->arg("nuts")) + ->arg("max_depth")) + ->value(); + double delta + = dynamic_cast(adapt->arg("delta"))->value(); + double gamma + = dynamic_cast(adapt->arg("gamma"))->value(); + double kappa + = dynamic_cast(adapt->arg("kappa"))->value(); + double t0 = dynamic_cast(adapt->arg("t0"))->value(); + unsigned int init_buffer + = dynamic_cast(adapt->arg("init_buffer")) + ->value(); + unsigned int term_buffer + = dynamic_cast(adapt->arg("term_buffer")) + ->value(); + unsigned int window + = dynamic_cast(adapt->arg("window"))->value(); + return_code = stan::services::sample::hmc_nuts_auto_e_adapt( + model, *init_context, *metric_context, random_seed, id, init_radius, + num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize, + stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer, + term_buffer, window, interrupt, logger, init_writer, sample_writer, + diagnostic_writer); + } else if (engine->value() == "nuts" && metric->value() == "diag_e" && adapt_engaged == false && metric_supplied == false) { categorical_argument *base = dynamic_cast( algo->arg("hmc")->arg("engine")->arg("nuts")); diff --git a/stan b/stan index 2624416e4f..c4f30f6cfb 160000 --- a/stan +++ b/stan @@ -1 +1 @@ -Subproject commit 2624416e4fc712bc042dd4d3c7632aa22051020e +Subproject commit c4f30f6cfb49c1929ecc2a0c9264b36e8755cec8