Skip to content

Resuming takes too much time with load_path activated #246

@dtamayo-nlp

Description

@dtamayo-nlp

Hi,

Thank you for your code and your contribution, it is certainly useful!

Following the pre-training instructions mentioned in various issues, I successfully managed to pre-train from a pre-tokenized mixture in a multi-gpu (or multi-node) setup. However, when attempting to resume training from a specific checkpoint, I noticed that the startup time before training actually begins increases significantly.

After reviewing the codebase, I found that composer iterates over all samples until it reaches the sample corresponding to the last iteration (source). While this might not be an issue for smaller datasets, it becomes quite slow in my case due to the given sequence packing implementation and the large token count involved. I also came across a related issue in the composer github, but I wasn’t able to find whether a fix or workaround was proposed.

Is there a recommended approach to speed up the process?

The specific yaml that I used was:

data_local: custom_dataset
data_remote: # If blank, files must be present in data_local

max_seq_len: 512
tokenizer_name: custom_tokenizer
mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance
count_padding_tokens: false

# Run Name
run_name: modernbert-base-pretrain

# Model
model:
  name: flex_bert
  pretrained_model_name: bert-base-uncased # has to be set to bert-base-uncased legacy reasons
  tokenizer_name: ${tokenizer_name}
  disable_train_metrics: true # save some time by not computing metrics on the training set
  model_config:
    vocab_size: 256000
    init_method: full_megatron
    num_hidden_layers: 22
    hidden_size: 768
    intermediate_size: 1152
    num_attention_heads: 12 # to have head size of 64
    attention_layer: rope
    attention_probs_dropout_prob: 0.0
    attn_out_bias: false
    attn_out_dropout_prob: 0.1
    attn_qkv_bias: false
    bert_layer: prenorm
    embed_dropout_prob: 0.0
    embed_norm: true
    final_norm: true
    skip_first_prenorm: true
    embedding_layer: sans_pos
    loss_function: fa_cross_entropy
    loss_kwargs:
      reduction: mean
    mlp_dropout_prob: 0.0
    mlp_in_bias: false
    mlp_layer: glu
    mlp_out_bias: false
    normalization: layernorm
    norm_kwargs:
      eps: 1e-5
      bias: false
    hidden_act: gelu
    head_pred_act: gelu
    activation_function: gelu # better safe than sorry
    padding: unpadded
    rotary_emb_dim: null
    rotary_emb_base: 10000.0
    rotary_emb_scale_base: null
    rotary_emb_interleaved: false
    allow_embedding_resizing: true
    sliding_window: 128
    global_attn_every_n_layers: 3
    unpad_embeddings: true
    compile_model: true
    masked_prediction: true

# Dataloaders
train_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: train
    tokenizer_name: ${tokenizer_name}
    max_seq_len: ${max_seq_len}
    shuffle: true
    mlm_probability: ${mlm_probability}
    streaming: false
  drop_last: true
  num_workers: 6
  sequence_packing: true
  batch_size_warmup_min_size: ${device_train_microbatch_size}
  batch_size_warmup_tokens: 50_000_000_000tok

eval_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: val
    tokenizer_name: ${tokenizer_name}
    max_seq_len: ${max_seq_len}
    shuffle: false
    mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison
    streaming: false
  drop_last: false
  num_workers: 3
  sequence_packing: false

# Optimization
scheduler:
  name: warmup_stable_decay
  t_warmup: 3_000_000_000tok
  alpha_f: 0.00 # Linearly decay to 0.02x the full LR by the end of the training duration
  t_decay: 0tok

optimizer:
  name: decoupled_stableadamw
  lr: 8e-4 # Peak learning rate
  betas:
  - 0.9
  - 0.98
  eps: 1.0e-06
  weight_decay: 1.0e-5 # Amount of weight decay regularization
  filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases
  log_grad_norm: true

max_duration: 1_719_000_000_000tok
eval_interval: 100ba
global_train_batch_size: 8192
device_train_microbatch_size: 64
global_eval_batch_size: 8192
device_eval_batch_size: 64

# System
seed: 17
precision: amp_bf16

# Logging
progress_bar: true
log_to_console: true
console_log_interval: 10ba

callbacks:
  speed_monitor:
    window_size: 20
  lr_monitor: {}
  scheduled_gc: {}
  log_grad_norm:
    batch_log_interval: 100
  packing_efficiency:
    log_interval: 5

# W&B logging
# loggers:
#   wandb:
#     project: modernbert
#     entity: ???

# Checkpoint to local filesystem or remote object store
save_interval: 100ba
save_num_checkpoints_to_keep: 10  # Important, this cleans up checkpoints saved to DISK
save_folder: checkpoints/{run_name}

# Load from local filesystem or remote object store to
load_path: /pathto/ep0-ba4400-rank0.pt

With the following packages:

Package                     Version        Build
--------------------------- -------------- -----
aiohappyeyeballs            2.6.1
aiohttp                     3.10.5
aiosignal                   1.3.1
antlr4-python3-runtime      4.9.3
anyio                       4.4.0
appdirs                     1.4.4
argcomplete                 3.5.0
arrow                       1.3.0
async-timeout               4.0.3
attrs                       24.2.0
azure-core                  1.30.2
azure-identity              1.17.1
azure-storage-blob          12.22.0
azure-storage-file-datalake 12.16.0
backoff                     2.2.1
bcrypt                      4.2.0
binaryornot                 0.4.4
boto3                       1.35.2
botocore                    1.35.2
Brotli                      1.1.0
cachetools                  5.5.0
catalogue                   2.0.10
certifi                     2024.7.4
cffi                        1.17.0
chardet                     5.2.0
charset-normalizer          3.3.2
circuitbreaker              2.1.3
click                       8.1.7
colorama                    0.4.6
contourpy                   1.3.0
cookiecutter                2.6.0
coolname                    2.2.0
cramjam                     2.11.0
cryptography                41.0.7
cycler                      0.12.1
datasets                    2.21.0
dill                        0.3.8
docker-pycreds              0.4.0
einops                      0.8.0
evaluate                    0.4.1
exceptiongroup              1.3.0
execnet                     2.1.1
filelock                    3.15.4
flash-attn                  2.7.3
fonttools                   4.53.1
frozenlist                  1.7.0
fsspec                      2024.6.1
gitdb                       4.0.11
GitPython                   3.1.43
google-api-core             2.19.1
google-auth                 2.33.0
google-cloud-core           2.4.1
google-cloud-storage        2.10.0
google-crc32c               1.7.1
google-resumable-media      2.7.2
googleapis-common-protos    1.63.2
gql                         3.5.0
graphql-core                3.2.3
grpcio                      1.62.2
h2                          4.1.0
hpack                       4.0.0
huggingface-hub             0.24.6
hyperframe                  6.0.1
idna                        3.7
importlib_metadata          8.4.0
importlib_resources         6.5.2
iniconfig                   2.0.0
isodate                     0.6.1
Jinja2                      3.1.4
jmespath                    1.0.1
kiwisolver                  1.4.5
lightning-utilities         0.11.6
llvmlite                    0.43.0
markdown-it-py              3.0.0
MarkupSafe                  2.1.5
matplotlib                  3.9.2
mdurl                       0.1.2
mosaicml                    0.24.1
mosaicml-cli                0.6.41
mosaicml-streaming          0.11.0
mpmath                      1.3.0
msal                        1.30.0
msal-extensions             1.1.0
multidict                   6.0.5
multiprocess                0.70.16
networkx                    3.2.1
ninja                       1.11.1.1
numba                       0.60.0
numpy                       1.26.4
nvidia-cublas-cu12          12.1.3.1
nvidia-cuda-cupti-cu12      12.1.105
nvidia-cuda-nvrtc-cu12      12.1.105
nvidia-cuda-runtime-cu12    12.1.105
nvidia-cudnn-cu12           9.1.0.70
nvidia-cufft-cu12           11.0.2.54
nvidia-curand-cu12          10.3.2.106
nvidia-cusolver-cu12        11.4.5.107
nvidia-cusparse-cu12        12.1.0.106
nvidia-nccl-cu12            2.20.5
nvidia-nvjitlink-cu12       12.9.86
nvidia-nvtx-cu12            12.1.105
oci                         2.132.0
omegaconf                   2.3.0
packaging                   24.1
pandas                      2.2.2
paramiko                    3.4.1
pathtools                   0.1.2
pillow                      10.3.0
pip                         25.2
platformdirs                4.2.2
pluggy                      1.5.0
portalocker                 2.10.1
prompt-toolkit              3.0.36
proto-plus                  1.26.1
protobuf                    4.25.3
psutil                      6.0.0
py-cpuinfo                  9.0.0
pyarrow                     17.0.0
pyasn1                      0.6.0
pyasn1_modules              0.4.0
pycairo                     1.26.1
pycparser                   2.22
Pygments                    2.18.0
PyJWT                       2.9.0
PyNaCl                      1.5.0
pyOpenSSL                   23.3.0
pyparsing                   3.1.2
PySide6                     6.7.2
PySide6_Addons              6.7.2
PySide6_Essentials          6.7.2
PySocks                     1.7.1
pytest                      8.3.2
pytest-xdist                3.6.1
python-dateutil             2.9.0
python-slugify              8.0.4
python-snappy               0.7.2
pytorch-ranger              0.1.1
pytz                        2024.1
pyu2f                       0.1.5
PyYAML                      6.0.2
questionary                 2.0.1
regex                       2024.7.24
requests                    2.32.3
responses                   0.18.0
rich                        13.7.1
rsa                         4.9
ruamel.yaml                 0.18.6
ruamel.yaml.clib            0.2.8
s3transfer                  0.10.2
safetensors                 0.4.4
sentencepiece               0.2.1
sentry-sdk                  2.13.0
setproctitle                1.3.3
setuptools                  72.2.0
shellingham                 1.5.4
shiboken6                   6.7.2
six                         1.16.0
smmap                       5.0.0
sniffio                     1.3.1
sympy                       1.13.2
tabulate                    0.9.0
termcolor                   2.4.0
text-unidecode              1.3
tokenizers                  0.21.4
tomli                       2.0.1
torch                       2.4.0
torch-optimi                0.2.1
torch-optimizer             0.3.0
torchaudio                  2.4.0
torchmetrics                1.4.0.post0
torchvision                 0.19.0
tornado                     6.4.1
tqdm                        4.66.5
transformers                4.48.0
triton                      3.0.0          1
typer                       0.12.4
types-python-dateutil       2.9.0.20250708
typing_extensions           4.12.2
tzdata                      2025.2
urllib3                     1.26.20
validators                  0.33.0
wandb                       0.16.6
wcwidth                     0.2.13
websockets                  11.0.3
wheel                       0.44.0
xxhash                      3.5.0
yarl                        1.9.4
zipp                        3.20.0
zstandard                   0.23.0
zstd                        1.5.5.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions