diff --git a/birdman/builder.py b/birdman/builder.py new file mode 100644 index 0000000..e5a7f05 --- /dev/null +++ b/birdman/builder.py @@ -0,0 +1,258 @@ +from pathlib import Path +from pkg_resources import resource_filename + +import biom +from jinja2 import Template +import numpy as np +import pandas as pd +from patsy import dmatrix + +from birdman import SingleFeatureModel, TableModel + +J2_DIR = Path(resource_filename("birdman", "jinja2")) +SF_TEMPLATE = J2_DIR / "negative_binomial_single.j2.stan" +FULL_TEMPLATE = J2_DIR / "negative_binomial_full.j2.stan" + + +def create_single_feature_model( + table: biom.Table, + metadata: pd.DataFrame, + stan_file_path: Path, + fixed_effects: list = None, + random_effects: list = None, + beta_prior: float = 5.0, + group_var_prior: float = 1.0, + inv_disp_sd_prior: float = 0.5 +) -> SingleFeatureModel: + """Build SingleFeatureModel. + + :param table: Feature table (features x samples) + :type table: biom.table.Table + + :param metadata: Metadata for design matrix + :type metadata: pd.DataFrame + + :param stan_file: Path to save rendered Stan file + :type stan_file: pathlib.Path + + :param fixed_effects: List of fixed effects to include in model + :type fixed_effects: list + + :param random_effects: List of random effects to include in model + :type random_effects: list + + :param beta_prior: Standard deviation for normally distributed prior values + of beta, defaults to 5.0 + :type beta_prior: float + + :param group_var_prior: Standard deviation for normally distributed prior + values of random effects, defaults to 1.0 + :type group_var_prio: float + + :param inv_disp_sd_prior: Standard deviation for lognormally distributed + prior values of 1/phi, defaults to 0.5 + :type inv_disp_sd: float + + :returns: SingleFeatureModel with specified fixed and random effects + :rtype: birdman.model_base.SingleFeatureModel + """ + if not set(table.ids()) == set(metadata.index): + raise ValueError("Sample IDs must match!") + + fe_formula = " + ".join(fixed_effects) + dmat = dmatrix(fe_formula, metadata, return_type="dataframe") + + sf_stanfile = _render_stanfile(SF_TEMPLATE, metadata, random_effects) + + with open(stan_file_path, "w") as f: + f.write(sf_stanfile) + + class _SingleFeatureModel(SingleFeatureModel): + def __init__(self, feature_id: str): + super().__init__(table=table, feature_id=feature_id, + model_path=stan_file_path) + self.feature_id = feature_id + values = table.data( + id=feature_id, + axis="observation", + dense=True + ).astype(int) + + A = np.log(1 / table.shape[0]) + + param_dict = { + "y": values, + "x": dmat, + "p": dmat.shape[1], + "depth": np.log(table.sum("sample")), + "A": A, + "B_p": beta_prior, + "inv_disp_sd": inv_disp_sd_prior, + "re_p": group_var_prior + } + + self.re_dict = dict() + + for group_var in random_effects: + group_var_series = metadata[group_var].loc[self.sample_names] + group_subj_map = ( + group_var_series.astype("category").cat.codes + 1 + ) + param_dict[f"{group_var}_map"] = group_subj_map + + self.re_dict[group_var] = np.sort(group_var_series.unique()) + + self.add_parameters(param_dict) + + self.specify_model( + params=["beta_var", "inv_disp"], + dims={ + "beta_var": ["covariate"], + "log_lhood": ["tbl_sample"], + "y_predict": ["tbl_sample"], + "inv_disp": [] + }, + coords={ + "covariate": dmat.columns, + "tbl_sample": self.sample_names, + }, + include_observed_data=True, + posterior_predictive="y_predict", + log_likelihood="log_lhood" + ) + + return _SingleFeatureModel + + +def create_table_model( + table: biom.Table, + metadata: pd.DataFrame, + stan_file_path: Path, + fixed_effects: list = None, + random_effects: list = None, + beta_prior: float = 5.0, + group_var_prior: float = 1.0, + inv_disp_sd_prior: float = 0.5 +): + """Build TableModel. + + :param table: Feature table (features x samples) + :type table: biom.table.Table + + :param metadata: Metadata for design matrix + :type metadata: pd.DataFrame + + :param stan_file: Path to save rendered Stan file + :type stan_file: pathlib.Path + + :param fixed_effects: List of fixed effects to include in model + :type fixed_effects: list + + :param random_effects: List of random effects to include in model + :type random_effects: list + + :param beta_prior: Standard deviation for normally distributed prior values + of beta, defaults to 5.0 + :type beta_prior: float + + :param group_var_prior: Standard deviation for normally distributed prior + values of random effects, defaults to 1.0 + :type group_var_prio: float + + :param inv_disp_sd_prior: Standard deviation for lognormally distributed + prior values of 1/phi, defaults to 0.5 + :type inv_disp_sd_prior: float + + :returns: TableModel with specified fixed and random effects + :rtype: birdman.model_base.TableModel + """ + if not set(table.ids()) == set(metadata.index): + raise ValueError("Sample IDs must match!") + + fe_formula = " + ".join(fixed_effects) + dmat = dmatrix(fe_formula, metadata, return_type="dataframe") + + sf_stanfile = _render_stanfile(FULL_TEMPLATE, metadata, random_effects) + + with open(stan_file_path, "w") as f: + f.write(sf_stanfile) + + class _TableModel(TableModel): + def __init__(self): + super().__init__(table=table, model_path=stan_file_path) + + A = np.log(1 / table.shape[0]) + + param_dict = { + "x": dmat, + "p": dmat.shape[1], + "depth": np.log(table.sum("sample")), + "A": A, + "B_p": beta_prior, + "inv_disp_sd": inv_disp_sd_prior, + "re_p": group_var_prior + } + + self.re_dict = dict() + + for group_var in random_effects: + group_var_series = metadata[group_var].loc[self.sample_names] + group_subj_map = ( + group_var_series.astype("category").cat.codes + 1 + ) + param_dict[f"{group_var}_map"] = group_subj_map + + self.re_dict[group_var] = np.sort(group_var_series.unique()) + + self.add_parameters(param_dict) + + self.specify_model( + params=["beta_var", "inv_disp"], + dims={ + "beta_var": ["covariate", "feature_alr"], + "log_lhood": ["tbl_sample", "feature"], + "y_predict": ["tbl_sample", "feature"], + "inv_disp": ["feature"] + }, + coords={ + "covariate": dmat.columns, + "tbl_sample": self.sample_names, + "feature": table.ids("observation"), + "feature_alr": table.ids("observation")[1:] + }, + include_observed_data=True, + posterior_predictive="y_predict", + log_likelihood="log_lhood" + ) + + return _TableModel + + +def _render_stanfile( + template_path: Path, + metadata: pd.DataFrame, + random_effects: list = None +) -> str: + """Render Stan file given fixed and random effects. + + :param template_path: Path to Jinja2 template file + :type template_path: pathlib.Path + + :param metadata: Metadata for design matrix + :type metadata: pd.DataFrame + + :param random_effects: List of random effects to include in model + :type random_effects: list + + :returns: Rendred Jinja2 template + :rtype: str + """ + re_dict = dict() + for group in random_effects: + n = len(metadata[group].unique()) + re_dict[group] = n + + with open(template_path, "r") as f: + stanfile = Template(f.read()).render({"re_dict": re_dict}) + + return stanfile diff --git a/birdman/jinja2/negative_binomial_full.j2.stan b/birdman/jinja2/negative_binomial_full.j2.stan new file mode 100644 index 0000000..405c235 --- /dev/null +++ b/birdman/jinja2/negative_binomial_full.j2.stan @@ -0,0 +1,84 @@ +data { + int N; // number of samples + int D; // number of features + int p; // number of covariates + real A; // mean intercept + vector[N] depth; // log sequencing depths of microbes + matrix[N, p] x; // covariate matrix + array[N, D] int y; // observed microbe abundances + + real B_p; // stdev for beta normal prior + real inv_disp_sd; // stdev for inv disp lognormal prior + + // Random Effects + real re_p; // stdev for random effect normal prior + {% for re_name, num_factors in re_dict.items() %} + array[N] int {{ re_name }}_map; + {%- endfor %} + // End Random Effects +} + +parameters { + row_vector[D-1] beta_0; + matrix[p-1, D-1] beta_x; + vector[D] inv_disp; + + // Random Effects + {%- for re_name, num_factors in re_dict.items() %} + matrix[{{ num_factors }}, D-1] {{ re_name }}_eff; + {%- endfor %} + // End Random Effects +} + +transformed parameters { + matrix[p, D-1] beta_var = append_row(beta_0, beta_x); + matrix[N, D-1] lam = x * beta_var; + matrix[N, D] lam_clr; + + // Random Effects + for (n in 1:N) { + lam[n] += depth[n]; + {%- for re_name, num_factors in re_dict.items() %} + lam[n] += {{ re_name }}_eff[{{ re_name }}_map[n]]; + {%- endfor %} + } + // End Random Effects + + lam_clr = append_col(to_vector(rep_array(0, N)), lam); +} + +model { + inv_disp ~ lognormal(0., inv_disp_sd); + + beta_0 ~ normal(A, B_p); + for (i in 1:D-1){ + for (j in 1:p-1){ + beta_x[j, i] ~ normal(0., B_p); + } + // Random Effects + {%- for re_name, num_factors in re_dict.items() %} + for (j in 1:{{ num_factors }}) { + {{ re_name }}_eff[j, i] ~ normal(0, re_p); + } + {%- endfor %} + // End Random Effects + } + + for (n in 1:N){ + for (i in 1:D){ + target += neg_binomial_2_log_lpmf(y[n, i] | lam_clr[n, i], inv_disp[i]); + } + } +} + +generated quantities { + array[N, D] int y_predict; + array[N, D] real log_lhood; + + for (n in 1:N){ + for (i in 1:D){ + y_predict[n, i] = neg_binomial_2_log_rng(lam_clr[n, i], inv_disp[i]); + log_lhood[n, i] = neg_binomial_2_log_lpmf(y[n, i] | lam_clr[n, i], inv_disp[i]); + } + } +} diff --git a/birdman/jinja2/negative_binomial_single.j2.stan b/birdman/jinja2/negative_binomial_single.j2.stan new file mode 100644 index 0000000..bb99414 --- /dev/null +++ b/birdman/jinja2/negative_binomial_single.j2.stan @@ -0,0 +1,67 @@ +data { + int N; // number of samples + int p; // number of covariates + real A; // mean intercept + vector[N] depth; // log sequencing depths of microbes + matrix[N, p] x; // covariate matrix + array[N] int y; // observed microbe abundances + + real B_p; // stdev for beta normal prior + real inv_disp_sd; // stdev for inv disp lognormal prior + + // Random Effects + real re_p; // stdev for random effect normal prior + {% for re_name, num_factors in re_dict.items() %} + array[N] int {{ re_name }}_map; + {%- endfor %} + // End Random Effects +} + +parameters { + real beta_0; + vector[p-1] beta_x; + real inv_disp; + + // Random Effects + {%- for re_name, num_factors in re_dict.items() %} + vector[{{ num_factors }}] {{ re_name }}_eff; + {%- endfor %} + // End Random Effects +} + +transformed parameters { + vector[p] beta_var = append_row(beta_0, beta_x); + vector[N] lam = x * beta_var + depth; + + // Random Effects + for (n in 1:N) { + {%- for re_name, num_factors in re_dict.items() %} + lam[n] += {{ re_name }}_eff[{{ re_name }}_map[n]]; + {%- endfor %} + } + // End Random Effects +} + +model { + inv_disp ~ lognormal(0., inv_disp_sd); + beta_0 ~ normal(A, B_p); + beta_x ~ normal(0, B_p); + + // Random Effects + {%- for re_name, num_factors in re_dict.items() %} + {{ re_name }}_eff ~ normal(0, re_p); + {%- endfor %} + // End Random Effects + + y ~ neg_binomial_2_log(lam, inv(inv_disp)); +} + +generated quantities { + vector[N] log_lhood; + vector[N] y_predict; + + for (n in 1:N){ + y_predict[n] = neg_binomial_2_log_rng(lam[n], inv(inv_disp)); + log_lhood[n] = neg_binomial_2_log_lpmf(y[n] | lam[n], inv(inv_disp)); + } +} diff --git a/tests/test_builder.py b/tests/test_builder.py new file mode 100644 index 0000000..e9521ba --- /dev/null +++ b/tests/test_builder.py @@ -0,0 +1,107 @@ +import logging +from pathlib import Path +import tempfile + +import numpy as np +from birdman.builder import create_single_feature_model, create_table_model + + +def test_create_sf_model(table_biom, metadata): + rng = np.random.default_rng(42) + + metadata = metadata.copy() + N = metadata.shape[0] + + # Add pseudo-random effects + rand_eff_1 = rng.choice([1, 2, 3], N) + rand_eff_2 = rng.choice([1, 2, 3, 5], N) + metadata["group_1"] = rand_eff_1 + metadata["group_2"] = rand_eff_2 + + with tempfile.TemporaryDirectory() as f: + stan_file_path = Path(f) / "model.stan" + + BuiltModel = create_single_feature_model( + table_biom, + metadata, + fixed_effects=["host_common_name"], + random_effects=["group_1", "group_2"], + stan_file_path=stan_file_path + ) + + cmdstanpy_logger = logging.getLogger("cmdstanpy") + cmdstanpy_logger.disabled = True + + first_feat = table_biom.ids("observation")[0] + model = BuiltModel(first_feat) + model.compile_model() + model.fit_model(method="vi") + + inf = model.to_inference() + + post = inf.posterior + data_vars = set(post.data_vars.keys()) + assert data_vars == {"beta_var", "inv_disp"} + + exp_covariates = ("Intercept", "host_common_name[T.long-tailed macaque]") + assert (post.coords["covariate"] == exp_covariates).all() + + inf_groups = set(inf.groups()) + assert inf_groups == {"posterior", "posterior_predictive", + "log_likelihood", "observed_data"} + + group_1_eff = model.fit.stan_variable("group_1_eff") + group_2_eff = model.fit.stan_variable("group_2_eff") + + assert group_1_eff.shape == (500, 3) + assert group_2_eff.shape == (500, 4) + + +def test_create_full_model(table_biom, metadata): + rng = np.random.default_rng(42) + + metadata = metadata.copy() + N = metadata.shape[0] + + # Add pseudo-random effects + rand_eff_1 = rng.choice([1, 2, 3], N) + rand_eff_2 = rng.choice([1, 2, 3, 5], N) + metadata["group_1"] = rand_eff_1 + metadata["group_2"] = rand_eff_2 + + with tempfile.TemporaryDirectory() as f: + stan_file_path = Path(f) / "model.stan" + + BuiltModel = create_table_model( + table_biom, + metadata, + fixed_effects=["host_common_name"], + random_effects=["group_1", "group_2"], + stan_file_path=stan_file_path + ) + + cmdstanpy_logger = logging.getLogger("cmdstanpy") + cmdstanpy_logger.disabled = True + + model = BuiltModel() + model.compile_model() + model.fit_model(method="vi") + + inf = model.to_inference() + + post = inf.posterior + data_vars = set(post.data_vars.keys()) + assert data_vars == {"beta_var", "inv_disp"} + + exp_covariates = ("Intercept", "host_common_name[T.long-tailed macaque]") + assert (post.coords["covariate"] == exp_covariates).all() + + inf_groups = set(inf.groups()) + assert inf_groups == {"posterior", "posterior_predictive", + "log_likelihood", "observed_data"} + + group_1_eff = model.fit.stan_variable("group_1_eff") + group_2_eff = model.fit.stan_variable("group_2_eff") + + assert group_1_eff.shape == (500, 3, 27) + assert group_2_eff.shape == (500, 4, 27)