Skip to content

Commit 6427429

Browse files
authored
Add multi-process logger utility for status monitoring (#254)
* Add logger utility for detailed status monitoring * Fix `flake8` errors * Clean up info messages * Leave PPO rollout progress bar * Update logging doc and add tip * Clarify `README.md` logging docs * Adopt Hugging Face logging API * Run pre-commit * Remove redundant verbosity setters * Toggle rollout bar positioning for suppressed verbosity levels
1 parent b70bc92 commit 6427429

File tree

6 files changed

+466
-28
lines changed

6 files changed

+466
-28
lines changed

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,41 @@ For more usage see the [NeMo README](./trlx/trainer/nemo)
7373
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py
7474
```
7575

76+
## Logging
77+
78+
trlX uses the standard Python `logging` library to log training information to the console. The default logger is set to the `INFO` level, which means that `INFO`, `WARNING`, `ERROR`, and `CRITICAL` level messages will be printed to standard output.
79+
80+
To change the log level directly, you can use the verbosity setter. For example, to set the log level to `WARNING` use:
81+
82+
```python
83+
import trlx
84+
85+
trlx.logging.set_verbosity(trlx.logging.WARNING)
86+
```
87+
88+
This will suppress `INFO` level messages, but still print `WARNING`, `ERROR`, and `CRITICAL` level messages.
89+
90+
You can also control logging verbosity by setting the `TRLX_VERBOSITY` environment variable to one of the standard logging [level names](https://docs.python.org/3/library/logging.html#logging-levels):
91+
92+
* `CRITICAL` (`trlx.logging.CRITICAL`)
93+
* `ERROR` (`trlx.logging.ERROR`)
94+
* `WARNING` (`trlx.logging.WARNING`)
95+
* `INFO` (`trlx.logging.INFO`)
96+
* `DEBUG` (`trlx.logging.DEBUG`)
97+
98+
```sh
99+
export TRLX_VERBOSITY=WARNING
100+
```
101+
102+
By default, [`tqdm`](https://tqdm.github.io/docs/tqdm/) progress bars are used to display training progress. You can disable them by calling `trlx.logging.disable_progress_bar()`, otherwise `trlx.logging.enable_progress_bar()` to enable.
103+
104+
Messages can be formatted with greater detail by setting `trlx.logging.enable_explicit_format()`. This will inject call-site information into each log which may be helpful for debugging.
105+
106+
```sh
107+
[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message...
108+
```
109+
110+
> 💡 Tip: To reduce the amount of logging output, you might find it helpful to change log levels of third-party libraries used by trlX. For example, try adding `transformers.logging.set_verbosity_error()` to the top of your trlX scripts to silence verbose messages from the `transformers` library (see their [logging docs](https://huggingface.co/docs/transformers/main_classes/logging#logging) for more details).
76111
77112
## Contributing
78113

trlx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .trlx import train
2+
from .utils import logging

trlx/orchestrator/offline_orchestrator.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import os
12
from typing import List, Union
23

34
import numpy as np
45
import torch
6+
from rich.console import Console
7+
from rich.table import Table
58

9+
import trlx.utils.logging as logging
610
from trlx.orchestrator import Orchestrator, register_orchestrator
711
from trlx.pipeline.offline_pipeline import ILQLRolloutStorage
8-
from trlx.utils import print_rank_0
12+
13+
logger = logging.get_logger(__name__)
914

1015

1116
def tokenize_dialogue( # noqa: C901
@@ -60,6 +65,8 @@ def make_experience(self, samples, rewards, max_length=2048):
6065
"""
6166
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer
6267
"""
68+
logger.info("Collecting rollouts")
69+
6370
if self.trainer.tokenizer:
6471
samples = [tokenize_dialogue(s, self.trainer.tokenizer, max_length) for s in samples]
6572

@@ -84,26 +91,29 @@ def make_experience(self, samples, rewards, max_length=2048):
8491
all_actions_ixs.append(torch.hstack(actions_ixs))
8592
all_states_ixs.append(states_ixs)
8693

87-
if self.trainer.tokenizer:
94+
if self.trainer.tokenizer and os.environ.get("RANK", "0") == "0":
95+
logger.info("Logging sample example")
8896
prompt = self.trainer.tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]])
8997
response = self.trainer.tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :])
90-
print_rank_0("[Sample example]")
91-
print_rank_0("Prompt: ", prompt)
92-
print_rank_0("Response: ", response)
93-
print_rank_0("Reward: ", rewards[0])
98+
columns = ["Prompt", "Response", "Reward"]
99+
table = Table(*columns, title="Sample Example", show_lines=True)
100+
table.add_row(prompt, response, str(rewards[0]))
101+
Console().print(table)
94102

95103
sample_lengths = np.array(list(map(len, all_input_ids)))
96104
output_lengths = np.array(list(map(len, all_actions_ixs)))
97105
prompt_lengths = sample_lengths - output_lengths
98106
returns = torch.tensor(rewards, dtype=float)
99107

100-
def string_stats(name: str, xs: np.array):
101-
return f"[Mean {name}] {xs.mean():.2f} ∈ [{min(xs)}, {max(xs)}]"
102-
103-
print_rank_0(string_stats("prompt length", prompt_lengths))
104-
print_rank_0(string_stats("output length", output_lengths))
105-
print_rank_0(string_stats("sample length", sample_lengths))
106-
print_rank_0(string_stats("return", returns))
108+
if os.environ.get("RANK", "0") == "0":
109+
logger.info("Logging experience string statistics")
110+
columns = ["Prompt Length", "Output Length", "Sample Length"]
111+
table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True)
112+
row = []
113+
for lengths in [prompt_lengths, output_lengths, sample_lengths]:
114+
row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]")
115+
table.add_row(*row)
116+
Console().print(table)
107117

108118
returns = (returns - returns.mean()) / (returns.std() + 1e-30)
109119
rewards = [torch.zeros(len(x)) for x in all_actions_ixs]

trlx/orchestrator/ppo_orchestrator.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import os
12
from time import time
23

34
import ray
45
import torch
56
import torch.nn.functional as F
67

8+
import trlx.utils.logging as logging
79
from trlx.data.accelerate_base_datatypes import PromptBatch
810
from trlx.data.ppo_types import PPORLElement
911
from trlx.orchestrator import Orchestrator, register_orchestrator
@@ -12,6 +14,8 @@
1214
from trlx.utils import Clock
1315
from trlx.utils.modeling import RunningMoments, logprobs_from_logits
1416

17+
logger = logging.get_logger(__name__)
18+
1519

1620
@register_orchestrator
1721
class PPOOrchestrator(Orchestrator):
@@ -55,9 +59,22 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
5559
Takes `num_rollouts` prompts from `pipeline`, samples model and computes the
5660
KL againts a reference model. It then appends PPOElements to trainer's `store`
5761
"""
62+
logger.info("Collecting rollouts")
63+
tbar = logging.tqdm(
64+
total=num_rollouts,
65+
disable=os.environ.get("RANK", 0) != "0",
66+
desc=f"[rollout 0 / {num_rollouts}]",
67+
# Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress
68+
# bars (e.g. loss progress in trainers)
69+
position=logging.get_verbosity() >= logging.WARNING,
70+
# Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels
71+
leave=logging.get_verbosity() < logging.WARNING,
72+
)
73+
5874
ppo_rl_elements = []
5975
stats = {}
6076
clock = Clock()
77+
6178
while len(ppo_rl_elements) < num_rollouts:
6279
# Get next batch in prompt dataset and refresh if exhausted
6380
try:
@@ -198,6 +215,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
198215
rewards = -self.trainer.kl_ctl.value * (logprobs - ref_logprobs)
199216
rewards = [rs[start : ends[ix]] for ix, rs in enumerate(rewards)]
200217

218+
rollout_count = 0
201219
for ix in range(n):
202220
if len(rewards[ix]) == 0 or len(all_logprobs[ix]) == 0:
203221
continue
@@ -213,8 +231,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
213231
rewards=rewards[ix],
214232
)
215233
)
216-
234+
rollout_count += 1
217235
exp_time = clock.tick()
236+
tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]")
237+
tbar.update(min(rollout_count, num_rollouts))
238+
tbar.close()
218239

219240
stats["kl_ctl_value"] = self.trainer.kl_ctl.value
220241
stats["time/exp"] = exp_time

trlx/trainer/accelerate_base_trainer.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from ray.air.checkpoint import Checkpoint
1414
from rich.console import Console
1515
from rich.table import Table
16-
from tqdm import tqdm
1716
from transformers import AutoTokenizer
1817

18+
import trlx.utils.logging as logging
1919
from trlx.data.configs import TRLConfig
2020
from trlx.trainer import BaseRLTrainer, register_trainer
2121
from trlx.utils import (
@@ -24,7 +24,6 @@
2424
get_git_tag,
2525
get_optimizer_class,
2626
get_scheduler_class,
27-
print_rank_0,
2827
significant,
2928
)
3029
from trlx.utils.modeling import (
@@ -35,6 +34,8 @@
3534
parse_delta_kwargs,
3635
)
3736

37+
logger = logging.get_logger(__name__)
38+
3839

3940
@register_trainer
4041
class AccelerateRLTrainer(BaseRLTrainer):
@@ -116,6 +117,8 @@ def setup_model(self):
116117
"""
117118
Returns a model derived from an instance's TRLConfig
118119
"""
120+
logger.info(f"Initializing model: {self.config.model.model_path}")
121+
119122
# Retrieves model equipped for ppo, ilql, etc
120123
model = self.get_arch(self.config)
121124
if self.config.model.model_arch_type == "seq2seq":
@@ -279,16 +282,30 @@ def add_eval_pipeline(self, eval_pipeline):
279282

280283
def evaluate(self): # noqa: C901
281284
"""Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided"""
282-
stats = {}
283-
table = []
285+
logger.info("Evaluating model")
284286

285287
# Do multiple evaluations over a single list in `gen_kwargs` if present
286288
if self.generate_sweep_kwarg is not None:
287289
gen_sweep_arg, gen_sweep_values = self.generate_sweep_kwarg
288290
else:
289291
gen_sweep_values = [None]
290292

291-
for gen_sweep_value in gen_sweep_values:
293+
desc = [
294+
f"generation sweep 0/{len(gen_sweep_values)}",
295+
f"eval batch 0/{len(self.eval_dataloader)}",
296+
]
297+
tbar = logging.tqdm(
298+
total=len(self.eval_dataloader) * len(gen_sweep_values),
299+
desc=f"[{' | '.join(desc)}]",
300+
disable=not self.accelerator.is_main_process,
301+
position=0,
302+
leave=True,
303+
)
304+
305+
stats = {}
306+
table = []
307+
308+
for i_sweep, gen_sweep_value in enumerate(gen_sweep_values):
292309
# A dedicated suffix for wandb logging
293310
if gen_sweep_value is not None:
294311
sweep_suffix = f"@{gen_sweep_arg}={gen_sweep_value}"
@@ -299,7 +316,7 @@ def evaluate(self): # noqa: C901
299316
all_prompts = []
300317
prompt_sizes = []
301318
generate_time = time()
302-
for prompts in self.eval_dataloader:
319+
for i_prompt, prompts in enumerate(self.eval_dataloader):
303320
if self.generate_sweep_kwarg:
304321
samples = self.generate_eval(**prompts, **{gen_sweep_arg: gen_sweep_value})
305322
else:
@@ -326,6 +343,14 @@ def evaluate(self): # noqa: C901
326343
torch.tensor(prompts.input_ids.shape[1], device=samples.device).repeat(len(prompts.input_ids))
327344
)
328345

346+
desc = [
347+
f"generation sweep {i_sweep + 1}/{len(gen_sweep_values)}",
348+
f"eval batch {i_prompt + 1}/{len(self.eval_dataloader)}",
349+
]
350+
tbar.set_description(f"[{' | '.join(desc)}]")
351+
tbar.update()
352+
tbar.close()
353+
329354
stats["time/generate"] = time() - generate_time
330355

331356
samples = self.accelerator.gather(torch.vstack(all_samples))
@@ -340,6 +365,7 @@ def evaluate(self): # noqa: C901
340365

341366
# in online setting, compute the reward for validation
342367
if self.reward_fn:
368+
logger.info("Computing rewards")
343369
rewards = torch.tensor(
344370
self.reward_fn(
345371
samples=str_samples,
@@ -357,6 +383,7 @@ def evaluate(self): # noqa: C901
357383

358384
# additionally log any other metrics
359385
if self.metric_fn:
386+
logger.info("Computing metrics")
360387
metric_time = time()
361388
metrics = self.metric_fn(
362389
samples=str_samples,
@@ -385,6 +412,7 @@ def evaluate(self): # noqa: C901
385412
table.append(list(zip(*columns_data)))
386413

387414
# Log and display evaluation metrics
415+
logger.info("Summarizing evaluation")
388416
if self.accelerator.is_main_process:
389417
rows = sum(list(map(list, zip(*table))), [])
390418

@@ -395,30 +423,30 @@ def evaluate(self): # noqa: C901
395423
table_title += f" {k}: {significant(x)}"
396424

397425
rich_table = Table(*columns, title=table_title, show_lines=True)
398-
399426
for ix in range(max(min(3, len(rows)), len(gen_sweep_values))):
400427
rich_table.add_row(*[str(significant(x)) for x in rows[ix]])
428+
Console().print(rich_table)
401429

402430
if not ray.is_initialized():
403431
if self.config.train.tracker == "wandb":
404432
import wandb
405433

406434
stats["samples"] = wandb.Table(columns, rows)
407435

408-
Console().print(rich_table)
409-
410436
self.nth_evaluation += 1
411437
return stats
412438

413439
def learn(self): # noqa: C901
414440
"""
415441
Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader`
416442
"""
443+
logger.info("Starting training")
444+
417445
self.generate_sweep_kwarg = None
418446
for k, v in self.config.method.gen_kwargs.items():
419447
if isinstance(v, list):
420448
if self.generate_sweep_kwarg is not None:
421-
print_rank_0("Only a single sweep is allowed, {k} is going to be set to {v[0]}")
449+
logger.info("Only a single sweep is allowed, {k} is going to be set to {v[0]}")
422450
self.generate_kwargs[k] = v[0]
423451
else:
424452
self.generate_sweep_kwarg = (k, v)
@@ -440,10 +468,12 @@ def learn(self): # noqa: C901
440468
results = self.evaluate()
441469
self.accelerator.log(results, step=self.iter_count)
442470

443-
tbar = tqdm(
471+
tbar = logging.tqdm(
444472
initial=self.iter_count,
445473
total=self.total_steps,
446474
disable=not self.accelerator.is_local_main_process,
475+
position=0,
476+
leave=True,
447477
)
448478

449479
best_reward = -float("inf")
@@ -491,7 +521,7 @@ def learn(self): # noqa: C901
491521
torch.distributed.all_reduce(do_save, torch.distributed.ReduceOp.MAX)
492522
if do_save:
493523
best_path = f"{self.config.train.checkpoint_dir}/best_checkpoint"
494-
print_rank_0(f"saving the best state so far into {best_path}")
524+
logger.info(f"Saving the best state so far into {best_path}")
495525
self.save(best_path)
496526

497527
# Report the metrics to Ray Tune.
@@ -505,8 +535,8 @@ def learn(self): # noqa: C901
505535
if not ray.is_initialized():
506536
self.accelerator.log(stats, step=self.iter_count)
507537

508-
desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss"))
509-
tbar.set_description(desc)
538+
desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss"))
539+
tbar.set_description(f"[{desc}]")
510540
tbar.update()
511541

512542
if self.iter_count >= self.total_steps:
@@ -516,6 +546,7 @@ def learn(self): # noqa: C901
516546
self.post_backward_callback()
517547

518548
self.post_epoch_callback()
549+
tbar.close()
519550

520551
@abstractmethod
521552
def get_arch(self, config: TRLConfig):

0 commit comments

Comments
 (0)