diff --git a/r-package/balnet/R/balnet.R b/r-package/balnet/R/balnet.R index db23138..81bc0ce 100644 --- a/r-package/balnet/R/balnet.R +++ b/r-package/balnet/R/balnet.R @@ -7,9 +7,16 @@ #' @param W Treatment vector (0: control, 1: treated). #' @param target The target estimand. Default is ATE. #' @param sample.weights Optional sample weights. If `NULL` (default), then each unit receives the same weight. +#' @param max.imbalance Optional upper bound on the covariate imbalance. +#' For lasso penalization (`alpha = 1`), there is a one-to-one correspondence between the penalty parameter +#' \eqn{\lambda} and the maximum allowable covariate imbalance. +#' When supplied, `max.imbalance` is used to adjust the lambda sequence (via `lambda.min.ratio`) so that the +#' generated lambda sequence ends at the specified imbalance level. #' @param nlambda Number of values for `lambda`, if generated automatically. Default is 100. #' @param lambda.min.ratio Ratio between smallest and largest value of lambda. Default is 1e-2. -#' @param lambda Optional `lambda` sequence. By default, the `lambda` sequence is constructed automatically using `nlambda` and `lambda.min.ratio`. +#' @param lambda Optional `lambda` sequence. +#' By default, the `lambda` sequence is constructed automatically using `nlambda` and `lambda.min.ratio` +#' (or `max.imbalance`, if specified). #' @param penalty.factor Penalty factor per feature. Default is 1 (i.e, each feature recieves the same penalty). #' @param groups An optional list of group indices for group penalization. #' @param alpha Elastic net mixing parameter. Default is 1 (lasso). 0 is ridge. @@ -57,6 +64,7 @@ balnet <- function( W, target = c("ATE", "ATT", "treated", "control"), sample.weights = NULL, + max.imbalance = NULL, nlambda = 100L, lambda.min.ratio = 1e-2, lambda = NULL, @@ -92,7 +100,6 @@ balnet <- function( } else if (is.null(sample.weights)) { sample.weights <- rep_len(1, nrow(X)) } - lambda.in <- validate_lambda(lambda) if (is.character(standardize) && standardize == "inplace") { inplace <- TRUE standardize <- TRUE @@ -101,6 +108,7 @@ balnet <- function( } else { stop("Invalid standardize option.") } + lambda.in <- validate_lambda(lambda) colnames <- if (is.null(colnames(X))) make.names(1:ncol(X)) else colnames(X) validate_groups(groups, ncol(X), colnames) @@ -111,6 +119,8 @@ balnet <- function( inplace = inplace, n_threads = num.threads ) + lambda.min.ratio <- get_lambda_min_ratio(lambda.min.ratio, max.imbalance, stan$X, W, sample.weights, target, alpha) + if (target == "ATT") { target_scale = sum(sample.weights) / sum(sample.weights * W) # "n / n_1" } else { @@ -128,7 +138,7 @@ balnet <- function( target_scale = target_scale, lambda = lambda.in[[1]], lmda_path_size = nlambda, - min_ratio = lambda.min.ratio, + min_ratio = lambda.min.ratio[[1]], penalty = penalty.factor, groups = groups, alpha = alpha, @@ -149,7 +159,7 @@ balnet <- function( target_scale = target_scale, lambda = lambda.in[[2]], lmda_path_size = nlambda, - min_ratio = lambda.min.ratio, + min_ratio = lambda.min.ratio[[2]], penalty = penalty.factor, groups = groups, alpha = alpha, diff --git a/r-package/balnet/R/utils.R b/r-package/balnet/R/utils.R index 13f875e..2e56d81 100644 --- a/r-package/balnet/R/utils.R +++ b/r-package/balnet/R/utils.R @@ -71,6 +71,37 @@ standardize <- function( list(X = X, center = center, scale = scale) } +get_lambda_min_ratio <- function(lambda.min.ratio, max.imbalance, X.stan, W, sample.weights, target, alpha) { + if (is.null(max.imbalance)) { + out <- c(lambda.min.ratio, lambda.min.ratio) + } else { + if (alpha < 1) { + stop("Setting max.imbalance is only possible with lasso (alpha = 1).") + } + if (max.imbalance <= 0) { + stop("max.imbalance should be > 0.") + } + lambda.min.ratio0 <- lambda.min.ratio1 <- lambda.min.ratio + if (target %in% c("ATE", "ATT", "control")) { + stats0 <- col_stats(X.stan, weights = (1 - W) * sample.weights) + lambda0.max <- max(abs(stats0$center)) # Note, this assumes X.stan is standardized. + if (max.imbalance < lambda0.max) { + lambda.min.ratio0 <- max.imbalance / lambda0.max + } + } + if (target %in% c("ATE", "treated")) { + stats1 <- col_stats(X.stan, weights = W * sample.weights) + lambda1.max <- max(abs(stats1$center)) + if (max.imbalance < lambda1.max) { + lambda.min.ratio1 <- max.imbalance / lambda1.max + } + } + out <- c(lambda.min.ratio0, lambda.min.ratio1) + } + + out +} + validate_lambda <- function(lambda) { if (is.character(lambda)) { stop("Unsupported lambda argument.") diff --git a/r-package/balnet/man/balnet.Rd b/r-package/balnet/man/balnet.Rd index 4b398e8..857770b 100644 --- a/r-package/balnet/man/balnet.Rd +++ b/r-package/balnet/man/balnet.Rd @@ -9,6 +9,7 @@ balnet( W, target = c("ATE", "ATT", "treated", "control"), sample.weights = NULL, + max.imbalance = NULL, nlambda = 100L, lambda.min.ratio = 0.01, lambda = NULL, @@ -32,11 +33,19 @@ balnet( \item{sample.weights}{Optional sample weights. If \code{NULL} (default), then each unit receives the same weight.} +\item{max.imbalance}{Optional upper bound on the covariate imbalance. +For lasso penalization (\code{alpha = 1}), there is a one-to-one correspondence between the penalty parameter +\eqn{\lambda} and the maximum allowable covariate imbalance. +When supplied, \code{max.imbalance} is used to adjust the lambda sequence (via \code{lambda.min.ratio}) so that the +generated lambda sequence ends at the specified imbalance level.} + \item{nlambda}{Number of values for \code{lambda}, if generated automatically. Default is 100.} \item{lambda.min.ratio}{Ratio between smallest and largest value of lambda. Default is 1e-2.} -\item{lambda}{Optional \code{lambda} sequence. By default, the \code{lambda} sequence is constructed automatically using \code{nlambda} and \code{lambda.min.ratio}.} +\item{lambda}{Optional \code{lambda} sequence. +By default, the \code{lambda} sequence is constructed automatically using \code{nlambda} and \code{lambda.min.ratio} +(or \code{max.imbalance}, if specified).} \item{penalty.factor}{Penalty factor per feature. Default is 1 (i.e, each feature recieves the same penalty).} diff --git a/r-package/balnet/site/get-started.Rmd b/r-package/balnet/site/get-started.Rmd index a2dfa46..98d25b3 100644 --- a/r-package/balnet/site/get-started.Rmd +++ b/r-package/balnet/site/get-started.Rmd @@ -59,13 +59,13 @@ max(abs(smd.baseline)) Since the smallest value of $\lambda$ attained for the treated arm is approximately $\lambda_{\min} \approx 0.21$, this indicates that the closest we can bring the standardized treated covariate means to the overall means is an absolute SMD of about 0.21. -This interpretation of $\lambda$ provides a convenient way to target a desired level of imbalance. Users can compute $\lambda^{\max}$ for their dataset and then choose `lambda.min.ratio` to reflect an acceptable fraction of this maximum imbalance. For example, if $\lambda^{\max} = 10$, the default setting `lambda.min.ratio = 0.01` corresponds to a target maximum absolute SMD of $10 \times 0.01 = 0.1$. The algorithm then attempts to compute the full regularization path, stopping gracefully if further reductions in imbalance are not achievable (in cases where balance remains approximate, users may wish to augment IPW estimation with an outcome model). +This interpretation of $\lambda$ provides a convenient way to target a desired level of imbalance, available through the option `max.imbalance`. For lasso penalization, `balnet` then adjusts the generated $\lambda$ sequence so that it terminates at this value. The algorithm then attempts to compute the full regularization path, stopping gracefully if further reductions in imbalance are not achievable. Alternatively, users may compute $\lambda^{\max}$ (e.g., the maximum absolute unweighted SMD) for their dataset and then choose `lambda.min.ratio` to reflect an acceptable fraction of this maximum imbalance. For example, if $\lambda^{\max} = 10$, the default setting `lambda.min.ratio = 0.01` corresponds to a target maximum absolute SMD of $10 \times 0.01 = 0.1$. > *Note*: Setting lambda = 0 to try to achieve exact balance is not recommended, just as `glmnet` advises against it. `balnet` works best by using warm starts and gradually decreasing regularization, a strategy similar to barrier methods in convex optimization. This approach helps the algorithm converge reliably and improves performance on real-world datasets where achieving covariate balance can be difficult. ## Plotting path diagnostics -`balnet` provides default plotting methods for visualizing regularization path diagnostics. Calling `plot` without additional arguments produces a summary of key metrics along the path, indexed by $\lambda$ on the log scale. +`balnet` provides default plotting methods for visualizing regularization path diagnostics. Calling `plot` without additional arguments produces a summary of metrics along the path, indexed by $\lambda$ on the log scale. ```{r} plot(fit)