diff --git a/DESCRIPTION b/DESCRIPTION index 55de0f42..d227d0a5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -62,8 +62,7 @@ Suggests: workflows, xgboost, yardstick -VignetteBuilder: - quarto +VignetteBuilder: knitr Config/testthat/edition: 3 Config/testthat/parallel: false Config/testthat/start-first: interface, explain, params @@ -71,5 +70,3 @@ Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 Language: en-US -SystemRequirements: Quarto command line tool - (). diff --git a/README.Rmd b/README.Rmd index bd894650..26a8e862 100644 --- a/README.Rmd +++ b/README.Rmd @@ -51,6 +51,7 @@ pak::pak("mlverse/tabnet") Here we show a **binary classification** example of the `attrition` dataset, using a **recipe** for dataset input specification. ```{r model-fit} +#| fig.alt: "A training loss line-plot along training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded." library(tabnet) suppressPackageStartupMessages(library(recipes)) library(yardstick) @@ -92,6 +93,7 @@ cbind(test, predict(fit, test, type = "prob")) %>% TabNet has intrinsic explainability feature through the visualization of attention map, either **aggregated**: ```{r model-explain} +#| fig.alt: "An expainability plot showing for each variable of the test-set on the y axis the importance along each observation on the x axis. The value is a mask agggregate." explain <- tabnet_explain(fit, test) autoplot(explain) ``` @@ -99,6 +101,7 @@ autoplot(explain) or at **each layer** through the `type = "steps"` option: ```{r step-explain} +#| fig.alt: "An small-multiple expainability plot for each step of the Tabnet network. Each plot shows for each variable of the test-set on the y axis the importance along each observation on the x axis." autoplot(explain, type = "steps") ``` @@ -107,6 +110,7 @@ autoplot(explain, type = "steps") For cases when a consistent part of your dataset has no outcome, TabNet offers a self-supervised training step allowing to model to capture predictors intrinsic features and predictors interactions, upfront the supervised task. ```{r step-pretrain} +#| fig.alt: "A training loss line-plot along pre-training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded." pretrain <- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate = 1e-2) autoplot(pretrain) ``` diff --git a/README.md b/README.md index 754a4f96..1c57068c 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ status](https://www.r-pkg.org/badges/version/tabnet)](https://CRAN.R-project.org An R implementation of: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) [(Sercan O. Arik, Tomas Pfister)](https://doi.org/10.48550/arXiv.1908.07442). + The code in this repository started by an R port using the [torch](https://github.com/mlverse/torch) package of [dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet) @@ -79,7 +80,7 @@ fit <- tabnet_fit(rec, train, epochs = 30, valid_split=0.1, learn_rate = 5e-3) autoplot(fit) ``` - +A training loss line-plot along training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded. The plots gives you an immediate insight about model over-fitting, and if any, the available model checkpoints available before the @@ -125,7 +126,7 @@ explain <- tabnet_explain(fit, test) autoplot(explain) ``` - +An expainability plot showing for each variable of the test-set on the y axis the importance along each observation on the x axis. The value is a mask agggregate. or at **each layer** through the `type = "steps"` option: @@ -133,7 +134,7 @@ or at **each layer** through the `type = "steps"` option: autoplot(explain, type = "steps") ``` - +An small-multiple expainability plot for each step of the Tabnet network. Each plot shows for each variable of the test-set on the y axis the importance along each observation on the x axis. ## Self-supervised pretraining @@ -147,7 +148,7 @@ pretrain <- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate autoplot(pretrain) ``` - +A training loss line-plot along pre-training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded. The example here is a toy example as the `train` dataset does actually contain outcomes. The vignette diff --git a/man/figures/README-unnamed-chunk-2-1.png b/man/figures/README-unnamed-chunk-2-1.png deleted file mode 100644 index be616afd..00000000 Binary files a/man/figures/README-unnamed-chunk-2-1.png and /dev/null differ diff --git a/man/figures/README-unnamed-chunk-4-1.png b/man/figures/README-unnamed-chunk-4-1.png deleted file mode 100644 index ab8b1ca1..00000000 Binary files a/man/figures/README-unnamed-chunk-4-1.png and /dev/null differ diff --git a/man/figures/README-unnamed-chunk-5-1.png b/man/figures/README-unnamed-chunk-5-1.png deleted file mode 100644 index 9fd520b9..00000000 Binary files a/man/figures/README-unnamed-chunk-5-1.png and /dev/null differ diff --git a/vignettes/aum_loss.qmd b/vignettes/aum_loss.Rmd similarity index 96% rename from vignettes/aum_loss.qmd rename to vignettes/aum_loss.Rmd index e5d0c66d..bc5d9931 100644 --- a/vignettes/aum_loss.qmd +++ b/vignettes/aum_loss.Rmd @@ -1,11 +1,12 @@ --- title: "Using ROC AUM loss for imbalanced binary classification" +output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Using ROC AUM loss for imbalanced binary classification} - %\VignetteEngine{quarto::html} + %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} -format: - html: +editor_options: + markdown: fig-width: 9 fig-height: 6 fig-cap-location: "top" @@ -24,7 +25,7 @@ library(tabnet) suppressPackageStartupMessages(library(tidymodels)) library(modeldata) data("lending_club", package = "modeldata") -set.seed(20250409) +set.seed(20250809) ``` ::: callout-note @@ -115,8 +116,8 @@ We can now `fit()` each model and plot the precision-recall curve on the test-se ```{r} #| label: "vanilia_models_fitting" -#| layout-ncol: 2 -#| fig-cap: +#| layout.ncol: 2 +#| fig.cap: #| - "Tabnet, no case-weight, default loss" #| - "XGBoost, no case-weight" #| @@ -150,8 +151,8 @@ Let's proceed ```{r} #| label: "case-weights_prediction" -#| layout-ncol: 2 -#| fig-cap: +#| layout.ncol: 2 +#| fig.cap: #| - "Tabnet, with case-weight, default loss" #| - "XGBoost, with case-weight" #| @@ -194,8 +195,8 @@ Now let's compare the result on the PR curve with the default loss side by side: ```{r} #| label: "AUM_model_pr_curve" -#| layout-ncol: 2 -#| fig-cap: +#| layout.ncol: 2 +#| fig.cap: #| - "Tabnet, no case-weight, default loss" #| - "Tabnet, no case-weight, ROC_AUM loss" #| @@ -218,8 +219,8 @@ Nothing prevent us to use both features, as they are independent. That is what w ```{r} #| label: "AUM_and_case-weights_prediction" -#| layout-ncol: 2 -#| fig-cap: +#| layout.ncol: 2 +#| fig.cap: #| - "Tabnet, with case-weight, default loss" #| - "Tabnet, with case-weight, ROC_AUM loss" #| diff --git a/vignettes/interpretation.Rmd b/vignettes/interpretation.Rmd index 215a946f..fd69ed7f 100644 --- a/vignettes/interpretation.Rmd +++ b/vignettes/interpretation.Rmd @@ -105,13 +105,14 @@ Let's fit a TabNet model to the `syn2` dataset and analyze the interpretation metrics. ```{r} -fit_syn2 <- tabnet_fit(y ~ ., syn2, epochs = 10, verbose = TRUE, device = "cpu") +fit_syn2 <- tabnet_fit(y ~ ., syn2, epochs = 45, learn_rate = 0.06, device = "cpu") ``` In the feature importance plot we can see that, as expected, features `V03-V06` are by far the most important ones. ```{r} +#| fig.alt: "A variable importance plot of the fitted model on syn2 dataset showing V03 then V06, V04, v10 and V5 as the 5 most important features, in that order." vip::vip(fit_syn2) ``` @@ -121,10 +122,11 @@ colors represent the importance of the feature in predicting the value for each observation. ```{r} +#| fig.alt: "A tabnet explaination plot of the fitted model on syn2 dataset. The plot shows numerous important observations in V03 then V06 and V04. the other variables are shown with low importance points or sparse observations with importance" library(tidyverse) ex_syn2 <- tabnet_explain(fit_syn2, syn2) -autoplot(ex_syn2) +autoplot(ex_syn2, quantile = 0.99) ``` We can see that the region between V03 and V06 concentrates most of the @@ -136,6 +138,7 @@ Next, we can visualize the attention masks for each step in the architecture. ```{r} +#| fig.alt: "3 tabnet explaination plots, one for each step of the fitted model on syn2 dataset. The Step 1 plot shows numerous important observations with V02 and V03 having high importance, step 2 plot highlight the importance of V03 and V06. Third step plot highlight V03 and V10 as important variables" autoplot(ex_syn2, type="steps") ``` @@ -152,7 +155,7 @@ create the response variable and we expect to see this in the masks. First we fit the model for 10 epochs. ```{r} -fit_syn4 <- tabnet_fit(y ~ ., syn4, epochs = 10, verbose = TRUE, device = "cpu") +fit_syn4 <- tabnet_fit(y ~ ., syn4, epochs = 50, device = "cpu", learn_rate = 0.08) ``` In the feature importance plot we have, as expected, strong importance @@ -160,19 +163,21 @@ for `V10`, and the other features that are used conditionally - either `V01-V02` or `V05-V06`. ```{r} +#| fig.alt: "A variable importance plot of the fitted model on syn4 dataset" vip::vip(fit_syn4) ``` Now let's visualize the attention masks. Notice that we arranged the dataset by `V10` so we can easily visualize the interaction effects. -We also trimmed to the 99th percentile so the colors shows the +We also trimmed to the 98th percentile so the colors shows the importance even if there are strong outliers. ```{r} +#| fig.alt: "A tabnet explaination plot of the fitted model on syn4 dataset. The plot shows numerous important observations in V06 for low values of V10, and importance of V01 and V02 for high values of V10." ex_syn4 <- tabnet_explain(fit_syn4, arrange(syn4, V10)) -autoplot(ex_syn4, quantile=.995) +autoplot(ex_syn4, quantile=.98) ``` From the figure we see that V10 is important for all observations. We @@ -183,6 +188,7 @@ the important ones. We can also visualize the masks at each step in the architecture. ```{r} +#| fig.alt: "3 tabnet explaination plots, one for each step of the fitted model on syn4 dataset." autoplot(ex_syn4, type="steps", quantile=.995) ```