Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions r-package/balnet/R/balnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
#' @param W Treatment vector (0: control, 1: treated).
#' @param target The target estimand. Default is ATE.
#' @param sample.weights Optional sample weights. If `NULL` (default), then each unit receives the same weight.
#' @param max.imbalance Optional upper bound on the covariate imbalance.
#' For lasso penalization (`alpha = 1`), there is a one-to-one correspondence between the penalty parameter
#' \eqn{\lambda} and the maximum allowable covariate imbalance.
#' When supplied, `max.imbalance` is used to adjust the lambda sequence (via `lambda.min.ratio`) so that the
#' generated lambda sequence ends at the specified imbalance level.
#' @param nlambda Number of values for `lambda`, if generated automatically. Default is 100.
#' @param lambda.min.ratio Ratio between smallest and largest value of lambda. Default is 1e-2.
#' @param lambda Optional `lambda` sequence. By default, the `lambda` sequence is constructed automatically using `nlambda` and `lambda.min.ratio`.
#' @param lambda Optional `lambda` sequence.
#' By default, the `lambda` sequence is constructed automatically using `nlambda` and `lambda.min.ratio`
#' (or `max.imbalance`, if specified).
#' @param penalty.factor Penalty factor per feature. Default is 1 (i.e, each feature recieves the same penalty).
#' @param groups An optional list of group indices for group penalization.
#' @param alpha Elastic net mixing parameter. Default is 1 (lasso). 0 is ridge.
Expand Down Expand Up @@ -57,6 +64,7 @@ balnet <- function(
W,
target = c("ATE", "ATT", "treated", "control"),
sample.weights = NULL,
max.imbalance = NULL,
nlambda = 100L,
lambda.min.ratio = 1e-2,
lambda = NULL,
Expand Down Expand Up @@ -92,7 +100,6 @@ balnet <- function(
} else if (is.null(sample.weights)) {
sample.weights <- rep_len(1, nrow(X))
}
lambda.in <- validate_lambda(lambda)
if (is.character(standardize) && standardize == "inplace") {
inplace <- TRUE
standardize <- TRUE
Expand All @@ -101,6 +108,7 @@ balnet <- function(
} else {
stop("Invalid standardize option.")
}
lambda.in <- validate_lambda(lambda)
colnames <- if (is.null(colnames(X))) make.names(1:ncol(X)) else colnames(X)
validate_groups(groups, ncol(X), colnames)

Expand All @@ -111,6 +119,8 @@ balnet <- function(
inplace = inplace,
n_threads = num.threads
)
lambda.min.ratio <- get_lambda_min_ratio(lambda.min.ratio, max.imbalance, stan$X, W, sample.weights, target, alpha)

if (target == "ATT") {
target_scale = sum(sample.weights) / sum(sample.weights * W) # "n / n_1"
} else {
Expand All @@ -128,7 +138,7 @@ balnet <- function(
target_scale = target_scale,
lambda = lambda.in[[1]],
lmda_path_size = nlambda,
min_ratio = lambda.min.ratio,
min_ratio = lambda.min.ratio[[1]],
penalty = penalty.factor,
groups = groups,
alpha = alpha,
Expand All @@ -149,7 +159,7 @@ balnet <- function(
target_scale = target_scale,
lambda = lambda.in[[2]],
lmda_path_size = nlambda,
min_ratio = lambda.min.ratio,
min_ratio = lambda.min.ratio[[2]],
penalty = penalty.factor,
groups = groups,
alpha = alpha,
Expand Down
31 changes: 31 additions & 0 deletions r-package/balnet/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,37 @@ standardize <- function(
list(X = X, center = center, scale = scale)
}

get_lambda_min_ratio <- function(lambda.min.ratio, max.imbalance, X.stan, W, sample.weights, target, alpha) {
if (is.null(max.imbalance)) {
out <- c(lambda.min.ratio, lambda.min.ratio)
} else {
if (alpha < 1) {
stop("Setting max.imbalance is only possible with lasso (alpha = 1).")
}
if (max.imbalance <= 0) {
stop("max.imbalance should be > 0.")
}
lambda.min.ratio0 <- lambda.min.ratio1 <- lambda.min.ratio
if (target %in% c("ATE", "ATT", "control")) {
stats0 <- col_stats(X.stan, weights = (1 - W) * sample.weights)
lambda0.max <- max(abs(stats0$center)) # Note, this assumes X.stan is standardized.
if (max.imbalance < lambda0.max) {
lambda.min.ratio0 <- max.imbalance / lambda0.max
}
}
if (target %in% c("ATE", "treated")) {
stats1 <- col_stats(X.stan, weights = W * sample.weights)
lambda1.max <- max(abs(stats1$center))
if (max.imbalance < lambda1.max) {
lambda.min.ratio1 <- max.imbalance / lambda1.max
}
}
out <- c(lambda.min.ratio0, lambda.min.ratio1)
}

out
}

validate_lambda <- function(lambda) {
if (is.character(lambda)) {
stop("Unsupported lambda argument.")
Expand Down
11 changes: 10 additions & 1 deletion r-package/balnet/man/balnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions r-package/balnet/site/get-started.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ max(abs(smd.baseline))

Since the smallest value of $\lambda$ attained for the treated arm is approximately $\lambda_{\min} \approx 0.21$, this indicates that the closest we can bring the standardized treated covariate means to the overall means is an absolute SMD of about 0.21.

This interpretation of $\lambda$ provides a convenient way to target a desired level of imbalance. Users can compute $\lambda^{\max}$ for their dataset and then choose `lambda.min.ratio` to reflect an acceptable fraction of this maximum imbalance. For example, if $\lambda^{\max} = 10$, the default setting `lambda.min.ratio = 0.01` corresponds to a target maximum absolute SMD of $10 \times 0.01 = 0.1$. The algorithm then attempts to compute the full regularization path, stopping gracefully if further reductions in imbalance are not achievable (in cases where balance remains approximate, users may wish to augment IPW estimation with an outcome model).
This interpretation of $\lambda$ provides a convenient way to target a desired level of imbalance, available through the option `max.imbalance`. For lasso penalization, `balnet` then adjusts the generated $\lambda$ sequence so that it terminates at this value. The algorithm then attempts to compute the full regularization path, stopping gracefully if further reductions in imbalance are not achievable. Alternatively, users may compute $\lambda^{\max}$ (e.g., the maximum absolute unweighted SMD) for their dataset and then choose `lambda.min.ratio` to reflect an acceptable fraction of this maximum imbalance. For example, if $\lambda^{\max} = 10$, the default setting `lambda.min.ratio = 0.01` corresponds to a target maximum absolute SMD of $10 \times 0.01 = 0.1$.

> *Note*: Setting lambda = 0 to try to achieve exact balance is not recommended, just as `glmnet` advises against it. `balnet` works best by using warm starts and gradually decreasing regularization, a strategy similar to barrier methods in convex optimization. This approach helps the algorithm converge reliably and improves performance on real-world datasets where achieving covariate balance can be difficult.

## Plotting path diagnostics

`balnet` provides default plotting methods for visualizing regularization path diagnostics. Calling `plot` without additional arguments produces a summary of key metrics along the path, indexed by $\lambda$ on the log scale.
`balnet` provides default plotting methods for visualizing regularization path diagnostics. Calling `plot` without additional arguments produces a summary of metrics along the path, indexed by $\lambda$ on the log scale.

```{r}
plot(fit)
Expand Down
Loading