From aeb4efc9ff4856271ec581905c752ec18d5f3678 Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Wed, 22 Oct 2025 14:27:12 +0900 Subject: [PATCH 1/2] Add tensorboard support --- scripts/train.py | 31 ++++++++++++++++++++++++++++--- setup.cfg | 1 + 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 4f6a148d..51cf2450 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp +from torch.utils.tensorboard import SummaryWriter EPS = float(jnp.finfo(float).eps) DEFAULT_OUTPUT_NAME = 'weights.txt' @@ -233,7 +234,8 @@ def update(w: jax.Array, scores: jax.Array, rows: jax.Array, cols: jax.Array, def fit(dataset_train: Dataset, dataset_val: typing.Optional[Dataset], features: typing.List[str], iters: int, weights_filename: str, - log_filename: str, out_span: int) -> jax.Array: + log_filename: str, out_span: int, + tensorboard_log_dir: typing.Optional[str]) -> jax.Array: """Trains an AdaBoost binary classifier. Args: @@ -244,10 +246,14 @@ def fit(dataset_train: Dataset, dataset_val: typing.Optional[Dataset], weights_filename (str): A file path to write the learned weights. log_filename (str): A file path to log the accuracy along with training. out_span (int): Iteration span to output metics and weights. + tensorboard_log_dir (Optional[str]): A file path to log data for + TensorBoard. Returns: scores (jax.Array): The contribution scores. """ + writer = SummaryWriter( + log_dir=tensorboard_log_dir) if tensorboard_log_dir else None with open(weights_filename, 'w') as f: f.write('') with open(log_filename, 'w') as f: @@ -290,6 +296,11 @@ def output_progress(t: int) -> None: metrics_train.recall, metrics_train.fscore, )) + if writer: + writer.add_scalar('accuracy/train', metrics_train.accuracy, t) + writer.add_scalar('precision/train', metrics_train.precision, t) + writer.add_scalar('recall/train', metrics_train.recall, t) + writer.add_scalar('fscore/train', metrics_train.fscore, t) if dataset_val: pred_test = pred(scores, dataset_val.X_rows, dataset_val.X_cols, N_test) @@ -306,8 +317,14 @@ def output_progress(t: int) -> None: metrics_test.recall, metrics_test.fscore, )) - + if writer: + writer.add_scalar('accuracy/test', metrics_test.accuracy, t) + writer.add_scalar('precision/test', metrics_test.precision, t) + writer.add_scalar('recall/test', metrics_test.recall, t) + writer.add_scalar('fscore/test', metrics_test.fscore, t) f.write('\n') + if writer: + writer.add_histogram('weight', w, t) for t in range(iters): w, scores, best_feature_index, score = update(w, scores, @@ -320,6 +337,8 @@ def output_progress(t: int) -> None: output_progress(t + 1) if len(feature_score_buffer) > 0: output_progress(t + 1) + if writer: + writer.close() return scores @@ -364,6 +383,11 @@ def parse_args(test: ArgList = None) -> argparse.Namespace: default=DEFAULT_OUT_SPAN) parser.add_argument( '--val-data', help='File path for the encoded validation data.', type=str) + parser.add_argument( + '--tensorboard', + help='Log directory for TensorBoard.', + type=str, + default=None) if test is None: return parser.parse_args() else: @@ -379,11 +403,12 @@ def main() -> None: iterations = int(args.iter) out_span = int(args.out_span) val_data: typing.Optional[str] = args.val_data + tensorboard_log_dir: typing.Optional[str] = args.tensorboard dataset_train, features, dataset_val = preprocess(data_filename, feature_thres, val_data) fit(dataset_train, dataset_val, features, iterations, weights_filename, - log_filename, out_span) + log_filename, out_span, tensorboard_log_dir) print('Training done. Export the model by passing %s to build_model.py' % (weights_filename)) diff --git a/setup.cfg b/setup.cfg index f926a3e5..26863d64 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ dev = types-regex types-setuptools yapf + tensorboard jaxcpu = jax==0.8.0 From 97955f5ad23f34b188723098af304637c65d8633 Mon Sep 17 00:00:00 2001 From: Shuhei Iitsuka Date: Wed, 22 Oct 2025 05:44:44 +0000 Subject: [PATCH 2/2] Add some more dependencies --- scripts/tests/test_train.py | 2 +- setup.cfg | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/tests/test_train.py b/scripts/tests/test_train.py index 89c6e431..a6474a24 100644 --- a/scripts/tests/test_train.py +++ b/scripts/tests/test_train.py @@ -219,7 +219,7 @@ def test_fit(self) -> None: iters = 5 out_span = 2 scores = train.fit(dataset, dataset, features, iters, weights_file_path, - log_file_path, out_span) + log_file_path, out_span, None) with open(weights_file_path) as f: weights = [ line.split('\t') for line in f.read().splitlines() if line.strip() diff --git a/setup.cfg b/setup.cfg index 26863d64..56eaac4d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,12 +32,13 @@ dev = mypy==1.18.2 pytest regex + tensorboard toml + torch twine types-regex types-setuptools yapf - tensorboard jaxcpu = jax==0.8.0