From 76c7d79caddaa06361d2adbafa2b426a762cf6e1 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Wed, 29 Oct 2025 17:25:42 -0300 Subject: [PATCH] Pass predict additional args to the predict function --- R/module.R | 8 +++++--- tests/testthat/test-module.R | 28 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/R/module.R b/R/module.R index 20d46cee..332baa88 100644 --- a/R/module.R +++ b/R/module.R @@ -396,10 +396,12 @@ predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(), ) pars <- rlang::list2(...) - if (is.null(pars$stack)) + if (is.null(pars$stack)) { stack <- TRUE - else + } else { stack <- pars$stack + pars$stack <- NULL # pop from this list + } predict_fn <- if (is.null(ctx$model$predict)) ctx$model else ctx$model$predict @@ -416,7 +418,7 @@ predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(), coro::loop(for(batch in ctx$data) { ctx$batch <- batch ctx$call_callbacks("on_predict_batch_begin") - ctx$pred[[length(ctx$pred) + 1]] <- do.call(predict_fn, list(ctx$input)) + ctx$pred[[length(ctx$pred) + 1]] <- do.call(predict_fn, rlang::list2(ctx$input, !!!pars)) ctx$call_callbacks("on_predict_batch_end") }) } diff --git a/tests/testthat/test-module.R b/tests/testthat/test-module.R index e4753626..dcb91785 100644 --- a/tests/testthat/test-module.R +++ b/tests/testthat/test-module.R @@ -127,6 +127,34 @@ test_that("predict works for modules", { }) +test_that("predict passes additional arguments", { + + base_model <- get_model() + model <- torch::nn_module( + inherit = base_model, + predict = function(x, scale = 1) { + self$forward(x) * scale + } + ) + dl <- get_dl() + + fitted <- model %>% + setup( + loss = torch::nn_mse_loss(), + optimizer = torch::optim_adam + ) %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = FALSE) + + pred <- predict(fitted, dl) + pred_scaled <- predict(fitted, dl, scale = 2) + + expect_equal( + as.array(pred_scaled$to(device = "cpu")), + 2 * as.array(pred$to(device = "cpu")) + ) +}) + test_that("predict can use a progress bar", { model <- get_model()