Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)
}

chol_update_arma <- function(R, u, downdate = FALSE, eps = 1e-12) {
.Call(`_bgms_chol_update_arma`, R, u, downdate, eps)
}

get_explog_switch <- function() {
.Call(`_bgms_get_explog_switch`)
}
Expand All @@ -29,6 +33,10 @@ sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interact
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)
}

sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)
}

compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) {
.Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda)
}
Expand Down
36 changes: 36 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// chol_update_arma
arma::mat chol_update_arma(arma::mat& R, arma::vec& u, bool downdate, double eps);
RcppExport SEXP _bgms_chol_update_arma(SEXP RSEXP, SEXP uSEXP, SEXP downdateSEXP, SEXP epsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< arma::mat& >::type R(RSEXP);
Rcpp::traits::input_parameter< arma::vec& >::type u(uSEXP);
Rcpp::traits::input_parameter< bool >::type downdate(downdateSEXP);
Rcpp::traits::input_parameter< double >::type eps(epsSEXP);
rcpp_result_gen = Rcpp::wrap(chol_update_arma(R, u, downdate, eps));
return rcpp_result_gen;
END_RCPP
}
// get_explog_switch
Rcpp::String get_explog_switch();
RcppExport SEXP _bgms_get_explog_switch() {
Expand Down Expand Up @@ -167,6 +181,26 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// sample_ggm
Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type);
RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP);
Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP);
Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP);
Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP);
Rcpp::traits::input_parameter< const int >::type no_chains(no_chainsSEXP);
Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP);
Rcpp::traits::input_parameter< const int >::type seed(seedSEXP);
Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP);
Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP);
rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type));
return rcpp_result_gen;
END_RCPP
}
// compute_Vn_mfm_sbm
arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, double dirichlet_alpha, arma::uword t_max, double lambda);
RcppExport SEXP _bgms_compute_Vn_mfm_sbm(SEXP no_variablesSEXP, SEXP dirichlet_alphaSEXP, SEXP t_maxSEXP, SEXP lambdaSEXP) {
Expand All @@ -185,11 +219,13 @@ END_RCPP
static const R_CallMethodDef CallEntries[] = {
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 36},
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 34},
{"_bgms_chol_update_arma", (DL_FUNC) &_bgms_chol_update_arma, 4},
{"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0},
{"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1},
{"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1},
{"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6},
{"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8},
{"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 10},
{"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4},
{NULL, NULL, 0}
};
Expand Down
68 changes: 68 additions & 0 deletions src/adaptiveMetropolis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include <RcppArmadillo.h>
#include <stdexcept>

class AdaptiveProposal {

public:

AdaptiveProposal(size_t num_params, size_t adaption_window = 50, double target_accept = 0.44) {
proposal_sds_ = arma::vec(num_params, arma::fill::ones) * 0.25; // Initial SD, need to tweak this somehow?
acceptance_counts_ = arma::ivec(num_params, arma::fill::zeros);
adaptation_window_ = adaption_window;
target_accept_ = target_accept;
}

double get_proposal_sd(size_t param_index) const {
validate_index(param_index);
return proposal_sds_[param_index];
}

void update_proposal_sd(size_t param_index) {

if (!adapting_) {
return;
}

double current_sd = get_proposal_sd(param_index);
double observed_acceptance_probability = acceptance_counts_[param_index] / static_cast<double>(iterations_ + 1);
double rm_weight = std::pow(iterations_, -decay_rate_);

// Robbins-Monro update step
double updated_sd = current_sd + (observed_acceptance_probability - target_accept_) * rm_weight;
updated_sd = std::clamp(updated_sd, rm_lower_bound, rm_upper_bound);

proposal_sds_(param_index) = updated_sd;
}

void increment_accepts(size_t param_index) {
validate_index(param_index);
acceptance_counts_[param_index]++;
}

void increment_iteration() {
iterations_++;
if (iterations_ >= adaptation_window_) {
adapting_ = false;
}
}

private:
arma::vec proposal_sds_;
arma::ivec acceptance_counts_;
int iterations_ = 0,
adaptation_window_;
double target_accept_ = 0.44,
decay_rate_ = 0.75,
rm_lower_bound = 0.001,
rm_upper_bound = 2.0;
bool adapting_ = true;

void validate_index(size_t index) const {
if (index >= proposal_sds_.n_elem) {
throw std::out_of_range("Parameter index out of range");
}
}

};
Empty file added src/base_model.cpp
Empty file.
60 changes: 60 additions & 0 deletions src/base_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <RcppArmadillo.h>
#include <stdexcept>
#include <memory>

class BaseModel {
public:
virtual ~BaseModel() = default;

// Capability queries
virtual bool has_gradient() const { return false; }
virtual bool has_adaptive_mh() const { return false; }

// Core methods (to be overridden by derived classes)
virtual double logp(const arma::vec& parameters) = 0;

virtual arma::vec gradient(const arma::vec& parameters) {
if (!has_gradient()) {
throw std::runtime_error("Gradient not implemented for this model");
}
throw std::runtime_error("Gradient method must be implemented in derived class");
}

virtual std::pair<double, arma::vec> logp_and_gradient(
const arma::vec& parameters) {
if (!has_gradient()) {
throw std::runtime_error("Gradient not implemented for this model");
}
return {logp(parameters), gradient(parameters)};
}

// For Metropolis-Hastings (model handles parameter groups internally)
virtual void do_one_mh_step() {
throw std::runtime_error("do_one_mh_step method must be implemented in derived class");
}

virtual arma::vec get_vectorized_parameters() {
throw std::runtime_error("get_vectorized_parameters method must be implemented in derived class");
}

virtual arma::ivec get_vectorized_indicator_parameters() {
throw std::runtime_error("get_vectorized_indicator_parameters method must be implemented in derived class");
}

// Return dimensionality of the parameter space
virtual size_t parameter_dimension() const = 0;

virtual void set_seed(int seed) {
throw std::runtime_error("set_seed method must be implemented in derived class");
}

virtual std::unique_ptr<BaseModel> clone() const {
throw std::runtime_error("clone method must be implemented in derived class");
}


protected:
BaseModel() = default;
};
32 changes: 32 additions & 0 deletions src/chainResultNew.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <string>
#include <RcppArmadillo.h>

class ChainResultNew {

public:
ChainResultNew() {}

bool error = false,
userInterrupt = false;
std::string error_msg;
int chain_id;

arma::mat samples;

void reserve(const size_t param_dim, const size_t n_iter) {
samples.set_size(param_dim, n_iter);
}
void store_sample(const size_t iter, const arma::vec& sample) {
samples.col(iter) = sample;
}

// arma::imat indicator_samples;

// other samples
// arma::ivec treedepth_samples;
// arma::ivec divergent_samples;
// arma::vec energy_samples;
// arma::imat allocation_samples;
};
129 changes: 129 additions & 0 deletions src/cholupdate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include "cholupdate.h"

extern "C" {

// from mgcv: https://github.com/cran/mgcv/blob/1b6a4c8374612da27e36420b4459e93acb183f2d/src/mat.c#L1876-L1883
static inline double hypote(double x, double y) {
/* stable computation of sqrt(x^2 + y^2) */
double t;
x = fabs(x);y=fabs(y);
if (y>x) { t = x;x = y; y = t;}
if (x==0) return(y); else t = y/x;
return(x*sqrt(1+t*t));
} /* hypote */

// from mgcv: https://github.com/cran/mgcv/blob/1b6a4c8374612da27e36420b4459e93acb183f2d/src/mat.c#L1956
void chol_up(double *R,double *u, int *n,int *up,double *eps) {
/* Rank 1 update of a cholesky factor. Works as follows:

[up=1] R'R + uu' = [u,R'][u,R']' = [u,R']Q'Q[u,R']', and then uses Givens rotations to
construct Q such that Q[u,R']' = [0,R1']'. Hence R1'R1 = R'R + uu'. The construction
operates from first column to last.

[up=0] uses an almost identical sequence, but employs hyperbolic rotations
in place of Givens. See Golub and van Loan (2013, 4e 6.5.4)

Givens rotations are of form [c,-s] where c = cos(theta), s = sin(theta).
[s,c]

Assumes R upper triangular, and that it is OK to use first two columns
below diagonal as temporary strorage for Givens rotations (the storage is
needed to ensure algorithm is column oriented).

For downdate returns a negative value in R[1] (R[1,0]) if not +ve definite.
*/
double c0,s0,*c,*s,z,*x,z0,*c1;
int j,j1,n1;
n1 = *n - 1;
if (*up) for (j1=-1,j=0;j<*n;j++,u++,j1++) { /* loop over columns of R */
z = *u; /* initial element of u */
x = R + *n * j; /* current column */
c = R + 2;s = R + *n + 2; /* Storage for first n-2 Givens rotations */
for (c1=c+j1;c<c1;c++,s++,x++) { /* apply previous Givens */
z0 = z;
z = *c * z - *s * *x;
*x = *s * z0 + *c * *x;
}
if (j) {
/* apply last computed Givens */
z0 = z;
z = c0 * z - s0 * *x;
*x = s0 * z0 + c0 * *x;
x++;
if (j<n1) {*c = c0; *s = s0;} /* store if needed for further columns */
}

/* now construct the next rotation u[j] <-> R[j,j] */
z0 = hypote(z,*x); /* sqrt(z^2+R[j,j]^2) */
c0 = *x/z0; s0 = z/z0; /* need to zero z */
/* now apply this rotation and this column is finished (so no need to update z) */
*x = s0 * z + c0 * *x;
} else for (j1=-1,j=0;j<*n;j++,u++,j1++) { /* loop over columns of R for down-dating */
z = *u; /* initial element of u */
x = R + *n * j; /* current column */
c = R + 2;s = R + *n + 2; /* Storage for first n-2 hyperbolic rotations */
for (c1=c+j1;c<c1;c++,s++,x++) { /* apply previous hyperbolic */
z0 = z;
z = *c * z - *s * *x;
*x = -*s * z0 + *c * *x;
}
if (j) {
/* apply last computed hyperbolic */
z0 = z;
z = c0 * z - s0 * *x;
*x = -s0 * z0 + c0 * *x;
x++;
if (j<n1) {*c = c0; *s = s0;} /* store if needed for further columns */
}

/* now construct the next hyperbolic rotation u[j] <-> R[j,j] */
z0 = z / *x; /* sqrt(z^2+R[j,j]^2) */
if (fabs(z0)>=1) { /* downdate not +ve def */
//Rprintf("j = %d d = %g ",j,z0);
if (*n>1) R[1] = -2.0;
return; /* signals error */
}
if (z0 > 1 - *eps) z0 = 1 - *eps;
c0 = 1/sqrt(1-z0*z0);s0 = c0 * z0;
/* now apply this rotation and this column is finished (so no need to update z) */
*x = -s0 * z + c0 * *x;
}

/* now zero c and s storage */
c = R + 2;s = R + *n + 2;
for (x = c + *n - 2;c<x;c++,s++) *c = *s = 0.0;
} /* chol_up */
}

// for internal use
void cholesky_update(arma::mat& R, arma::vec& u, double eps) {
int n = R.n_cols;
int up = 1;
chol_up(R.memptr(), u.memptr(), &n, &up, &eps);
}

void cholesky_downdate(arma::mat& R, arma::vec& u, double eps) {
int n = R.n_cols;
int up = 0;
chol_up(R.memptr(), u.memptr(), &n, &up, &eps);
}

// for testing
// [[Rcpp::export]]
arma::mat chol_update_arma(arma::mat& R, arma::vec& u, bool downdate = false, double eps = 1e-12) {
if (R.n_rows != R.n_cols)
Rcpp::stop("R must be square");
if (u.n_elem != R.n_cols)
Rcpp::stop("length(u) must match dimension of R");

if (downdate)
cholesky_downdate(R, u, eps);
else
cholesky_update(R, u, eps);

return R;
int n = R.n_cols;
int up = downdate ? 0 : 1;
chol_up(R.memptr(), u.memptr(), &n, &up, &eps);
return R;
}
6 changes: 6 additions & 0 deletions src/cholupdate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

#include <RcppArmadillo.h>

void cholesky_update( arma::mat& R, arma::vec& u, double eps = 1e-12);
void cholesky_downdate(arma::mat& R, arma::vec& u, double eps = 1e-12);
Loading