Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,11 @@ Suggests:
workflows,
xgboost,
yardstick
VignetteBuilder:
quarto
VignetteBuilder: knitr
Config/testthat/edition: 3
Config/testthat/parallel: false
Config/testthat/start-first: interface, explain, params
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Language: en-US
SystemRequirements: Quarto command line tool
(<https://github.com/quarto-dev/quarto-cli>).
4 changes: 4 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pak::pak("mlverse/tabnet")
Here we show a **binary classification** example of the `attrition` dataset, using a **recipe** for dataset input specification.

```{r model-fit}
#| fig.alt: "A training loss line-plot along training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded."
library(tabnet)
suppressPackageStartupMessages(library(recipes))
library(yardstick)
Expand Down Expand Up @@ -92,13 +93,15 @@ cbind(test, predict(fit, test, type = "prob")) %>%
TabNet has intrinsic explainability feature through the visualization of attention map, either **aggregated**:

```{r model-explain}
#| fig.alt: "An expainability plot showing for each variable of the test-set on the y axis the importance along each observation on the x axis. The value is a mask agggregate."
explain <- tabnet_explain(fit, test)
autoplot(explain)
```

or at **each layer** through the `type = "steps"` option:

```{r step-explain}
#| fig.alt: "An small-multiple expainability plot for each step of the Tabnet network. Each plot shows for each variable of the test-set on the y axis the importance along each observation on the x axis."
autoplot(explain, type = "steps")
```

Expand All @@ -107,6 +110,7 @@ autoplot(explain, type = "steps")
For cases when a consistent part of your dataset has no outcome, TabNet offers a self-supervised training step allowing to model to capture predictors intrinsic features and predictors interactions, upfront the supervised task.

```{r step-pretrain}
#| fig.alt: "A training loss line-plot along pre-training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded."
pretrain <- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate = 1e-2)
autoplot(pretrain)
```
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ status](https://www.r-pkg.org/badges/version/tabnet)](https://CRAN.R-project.org
An R implementation of: [TabNet: Attentive Interpretable Tabular
Learning](https://arxiv.org/abs/1908.07442) [(Sercan O. Arik, Tomas
Pfister)](https://doi.org/10.48550/arXiv.1908.07442).

The code in this repository started by an R port using the
[torch](https://github.com/mlverse/torch) package of
[dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet)
Expand Down Expand Up @@ -79,7 +80,7 @@ fit <- tabnet_fit(rec, train, epochs = 30, valid_split=0.1, learn_rate = 5e-3)
autoplot(fit)
```

<img src="man/figures/README-model-fit-1.png" width="100%" />
<img src="man/figures/README-model-fit-1.png" alt="A training loss line-plot along training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded." width="100%" />

The plots gives you an immediate insight about model over-fitting, and
if any, the available model checkpoints available before the
Expand Down Expand Up @@ -125,15 +126,15 @@ explain <- tabnet_explain(fit, test)
autoplot(explain)
```

<img src="man/figures/README-model-explain-1.png" width="100%" />
<img src="man/figures/README-model-explain-1.png" alt="An expainability plot showing for each variable of the test-set on the y axis the importance along each observation on the x axis. The value is a mask agggregate." width="100%" />

or at **each layer** through the `type = "steps"` option:

``` r
autoplot(explain, type = "steps")
```

<img src="man/figures/README-step-explain-1.png" width="100%" />
<img src="man/figures/README-step-explain-1.png" alt="An small-multiple expainability plot for each step of the Tabnet network. Each plot shows for each variable of the test-set on the y axis the importance along each observation on the x axis." width="100%" />

## Self-supervised pretraining

Expand All @@ -147,7 +148,7 @@ pretrain <- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate
autoplot(pretrain)
```

<img src="man/figures/README-step-pretrain-1.png" width="100%" />
<img src="man/figures/README-step-pretrain-1.png" alt="A training loss line-plot along pre-training epochs. Both validation loss and training loss are shown. Training loss line includes regular dots at epochs where a checkpoint is recorded." width="100%" />

The example here is a toy example as the `train` dataset does actually
contain outcomes. The vignette
Expand Down
Binary file removed man/figures/README-unnamed-chunk-2-1.png
Binary file not shown.
Binary file removed man/figures/README-unnamed-chunk-4-1.png
Binary file not shown.
Binary file removed man/figures/README-unnamed-chunk-5-1.png
Binary file not shown.
25 changes: 13 additions & 12 deletions vignettes/aum_loss.qmd → vignettes/aum_loss.Rmd
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
---
title: "Using ROC AUM loss for imbalanced binary classification"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Using ROC AUM loss for imbalanced binary classification}
%\VignetteEngine{quarto::html}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
format:
html:
editor_options:
markdown:
fig-width: 9
fig-height: 6
fig-cap-location: "top"
Expand All @@ -24,7 +25,7 @@ library(tabnet)
suppressPackageStartupMessages(library(tidymodels))
library(modeldata)
data("lending_club", package = "modeldata")
set.seed(20250409)
set.seed(20250809)
```

::: callout-note
Expand Down Expand Up @@ -115,8 +116,8 @@ We can now `fit()` each model and plot the precision-recall curve on the test-se

```{r}
#| label: "vanilia_models_fitting"
#| layout-ncol: 2
#| fig-cap:
#| layout.ncol: 2
#| fig.cap:
#| - "Tabnet, no case-weight, default loss"
#| - "XGBoost, no case-weight"
#|
Expand Down Expand Up @@ -150,8 +151,8 @@ Let's proceed

```{r}
#| label: "case-weights_prediction"
#| layout-ncol: 2
#| fig-cap:
#| layout.ncol: 2
#| fig.cap:
#| - "Tabnet, with case-weight, default loss"
#| - "XGBoost, with case-weight"
#|
Expand Down Expand Up @@ -194,8 +195,8 @@ Now let's compare the result on the PR curve with the default loss side by side:

```{r}
#| label: "AUM_model_pr_curve"
#| layout-ncol: 2
#| fig-cap:
#| layout.ncol: 2
#| fig.cap:
#| - "Tabnet, no case-weight, default loss"
#| - "Tabnet, no case-weight, ROC_AUM loss"
#|
Expand All @@ -218,8 +219,8 @@ Nothing prevent us to use both features, as they are independent. That is what w

```{r}
#| label: "AUM_and_case-weights_prediction"
#| layout-ncol: 2
#| fig-cap:
#| layout.ncol: 2
#| fig.cap:
#| - "Tabnet, with case-weight, default loss"
#| - "Tabnet, with case-weight, ROC_AUM loss"
#|
Expand Down
16 changes: 11 additions & 5 deletions vignettes/interpretation.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ Let's fit a TabNet model to the `syn2` dataset and analyze the
interpretation metrics.

```{r}
fit_syn2 <- tabnet_fit(y ~ ., syn2, epochs = 10, verbose = TRUE, device = "cpu")
fit_syn2 <- tabnet_fit(y ~ ., syn2, epochs = 45, learn_rate = 0.06, device = "cpu")
```

In the feature importance plot we can see that, as expected, features
`V03-V06` are by far the most important ones.

```{r}
#| fig.alt: "A variable importance plot of the fitted model on syn2 dataset showing V03 then V06, V04, v10 and V5 as the 5 most important features, in that order."
vip::vip(fit_syn2)
```

Expand All @@ -121,10 +122,11 @@ colors represent the importance of the feature in predicting the value
for each observation.

```{r}
#| fig.alt: "A tabnet explaination plot of the fitted model on syn2 dataset. The plot shows numerous important observations in V03 then V06 and V04. the other variables are shown with low importance points or sparse observations with importance"
library(tidyverse)
ex_syn2 <- tabnet_explain(fit_syn2, syn2)

autoplot(ex_syn2)
autoplot(ex_syn2, quantile = 0.99)
```

We can see that the region between V03 and V06 concentrates most of the
Expand All @@ -136,6 +138,7 @@ Next, we can visualize the attention masks for each step in the
architecture.

```{r}
#| fig.alt: "3 tabnet explaination plots, one for each step of the fitted model on syn2 dataset. The Step 1 plot shows numerous important observations with V02 and V03 having high importance, step 2 plot highlight the importance of V03 and V06. Third step plot highlight V03 and V10 as important variables"
autoplot(ex_syn2, type="steps")
```

Expand All @@ -152,27 +155,29 @@ create the response variable and we expect to see this in the masks.
First we fit the model for 10 epochs.

```{r}
fit_syn4 <- tabnet_fit(y ~ ., syn4, epochs = 10, verbose = TRUE, device = "cpu")
fit_syn4 <- tabnet_fit(y ~ ., syn4, epochs = 50, device = "cpu", learn_rate = 0.08)
```

In the feature importance plot we have, as expected, strong importance
for `V10`, and the other features that are used conditionally - either
`V01-V02` or `V05-V06`.

```{r}
#| fig.alt: "A variable importance plot of the fitted model on syn4 dataset"
vip::vip(fit_syn4)
```

Now let's visualize the attention masks. Notice that we arranged the
dataset by `V10` so we can easily visualize the interaction effects.

We also trimmed to the 99th percentile so the colors shows the
We also trimmed to the 98th percentile so the colors shows the
importance even if there are strong outliers.

```{r}
#| fig.alt: "A tabnet explaination plot of the fitted model on syn4 dataset. The plot shows numerous important observations in V06 for low values of V10, and importance of V01 and V02 for high values of V10."
ex_syn4 <- tabnet_explain(fit_syn4, arrange(syn4, V10))

autoplot(ex_syn4, quantile=.995)
autoplot(ex_syn4, quantile=.98)
```

From the figure we see that V10 is important for all observations. We
Expand All @@ -183,6 +188,7 @@ the important ones.
We can also visualize the masks at each step in the architecture.

```{r}
#| fig.alt: "3 tabnet explaination plots, one for each step of the fitted model on syn4 dataset."
autoplot(ex_syn4, type="steps", quantile=.995)
```

Expand Down
Loading