-
Notifications
You must be signed in to change notification settings - Fork 141
Description
Hi,
Thank you for your code and your contribution, it is certainly useful!
Following the pre-training instructions mentioned in various issues, I successfully managed to pre-train from a pre-tokenized mixture in a multi-gpu (or multi-node) setup. However, when attempting to resume training from a specific checkpoint, I noticed that the startup time before training actually begins increases significantly.
After reviewing the codebase, I found that composer iterates over all samples until it reaches the sample corresponding to the last iteration (source). While this might not be an issue for smaller datasets, it becomes quite slow in my case due to the given sequence packing implementation and the large token count involved. I also came across a related issue in the composer github, but I wasn’t able to find whether a fix or workaround was proposed.
Is there a recommended approach to speed up the process?
The specific yaml that I used was:
data_local: custom_dataset
data_remote: # If blank, files must be present in data_local
max_seq_len: 512
tokenizer_name: custom_tokenizer
mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance
count_padding_tokens: false
# Run Name
run_name: modernbert-base-pretrain
# Model
model:
name: flex_bert
pretrained_model_name: bert-base-uncased # has to be set to bert-base-uncased legacy reasons
tokenizer_name: ${tokenizer_name}
disable_train_metrics: true # save some time by not computing metrics on the training set
model_config:
vocab_size: 256000
init_method: full_megatron
num_hidden_layers: 22
hidden_size: 768
intermediate_size: 1152
num_attention_heads: 12 # to have head size of 64
attention_layer: rope
attention_probs_dropout_prob: 0.0
attn_out_bias: false
attn_out_dropout_prob: 0.1
attn_qkv_bias: false
bert_layer: prenorm
embed_dropout_prob: 0.0
embed_norm: true
final_norm: true
skip_first_prenorm: true
embedding_layer: sans_pos
loss_function: fa_cross_entropy
loss_kwargs:
reduction: mean
mlp_dropout_prob: 0.0
mlp_in_bias: false
mlp_layer: glu
mlp_out_bias: false
normalization: layernorm
norm_kwargs:
eps: 1e-5
bias: false
hidden_act: gelu
head_pred_act: gelu
activation_function: gelu # better safe than sorry
padding: unpadded
rotary_emb_dim: null
rotary_emb_base: 10000.0
rotary_emb_scale_base: null
rotary_emb_interleaved: false
allow_embedding_resizing: true
sliding_window: 128
global_attn_every_n_layers: 3
unpad_embeddings: true
compile_model: true
masked_prediction: true
# Dataloaders
train_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: train
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
shuffle: true
mlm_probability: ${mlm_probability}
streaming: false
drop_last: true
num_workers: 6
sequence_packing: true
batch_size_warmup_min_size: ${device_train_microbatch_size}
batch_size_warmup_tokens: 50_000_000_000tok
eval_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: val
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
shuffle: false
mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison
streaming: false
drop_last: false
num_workers: 3
sequence_packing: false
# Optimization
scheduler:
name: warmup_stable_decay
t_warmup: 3_000_000_000tok
alpha_f: 0.00 # Linearly decay to 0.02x the full LR by the end of the training duration
t_decay: 0tok
optimizer:
name: decoupled_stableadamw
lr: 8e-4 # Peak learning rate
betas:
- 0.9
- 0.98
eps: 1.0e-06
weight_decay: 1.0e-5 # Amount of weight decay regularization
filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases
log_grad_norm: true
max_duration: 1_719_000_000_000tok
eval_interval: 100ba
global_train_batch_size: 8192
device_train_microbatch_size: 64
global_eval_batch_size: 8192
device_eval_batch_size: 64
# System
seed: 17
precision: amp_bf16
# Logging
progress_bar: true
log_to_console: true
console_log_interval: 10ba
callbacks:
speed_monitor:
window_size: 20
lr_monitor: {}
scheduled_gc: {}
log_grad_norm:
batch_log_interval: 100
packing_efficiency:
log_interval: 5
# W&B logging
# loggers:
# wandb:
# project: modernbert
# entity: ???
# Checkpoint to local filesystem or remote object store
save_interval: 100ba
save_num_checkpoints_to_keep: 10 # Important, this cleans up checkpoints saved to DISK
save_folder: checkpoints/{run_name}
# Load from local filesystem or remote object store to
load_path: /pathto/ep0-ba4400-rank0.pt
With the following packages:
Package Version Build
--------------------------- -------------- -----
aiohappyeyeballs 2.6.1
aiohttp 3.10.5
aiosignal 1.3.1
antlr4-python3-runtime 4.9.3
anyio 4.4.0
appdirs 1.4.4
argcomplete 3.5.0
arrow 1.3.0
async-timeout 4.0.3
attrs 24.2.0
azure-core 1.30.2
azure-identity 1.17.1
azure-storage-blob 12.22.0
azure-storage-file-datalake 12.16.0
backoff 2.2.1
bcrypt 4.2.0
binaryornot 0.4.4
boto3 1.35.2
botocore 1.35.2
Brotli 1.1.0
cachetools 5.5.0
catalogue 2.0.10
certifi 2024.7.4
cffi 1.17.0
chardet 5.2.0
charset-normalizer 3.3.2
circuitbreaker 2.1.3
click 8.1.7
colorama 0.4.6
contourpy 1.3.0
cookiecutter 2.6.0
coolname 2.2.0
cramjam 2.11.0
cryptography 41.0.7
cycler 0.12.1
datasets 2.21.0
dill 0.3.8
docker-pycreds 0.4.0
einops 0.8.0
evaluate 0.4.1
exceptiongroup 1.3.0
execnet 2.1.1
filelock 3.15.4
flash-attn 2.7.3
fonttools 4.53.1
frozenlist 1.7.0
fsspec 2024.6.1
gitdb 4.0.11
GitPython 3.1.43
google-api-core 2.19.1
google-auth 2.33.0
google-cloud-core 2.4.1
google-cloud-storage 2.10.0
google-crc32c 1.7.1
google-resumable-media 2.7.2
googleapis-common-protos 1.63.2
gql 3.5.0
graphql-core 3.2.3
grpcio 1.62.2
h2 4.1.0
hpack 4.0.0
huggingface-hub 0.24.6
hyperframe 6.0.1
idna 3.7
importlib_metadata 8.4.0
importlib_resources 6.5.2
iniconfig 2.0.0
isodate 0.6.1
Jinja2 3.1.4
jmespath 1.0.1
kiwisolver 1.4.5
lightning-utilities 0.11.6
llvmlite 0.43.0
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.2
mdurl 0.1.2
mosaicml 0.24.1
mosaicml-cli 0.6.41
mosaicml-streaming 0.11.0
mpmath 1.3.0
msal 1.30.0
msal-extensions 1.1.0
multidict 6.0.5
multiprocess 0.70.16
networkx 3.2.1
ninja 1.11.1.1
numba 0.60.0
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.9.86
nvidia-nvtx-cu12 12.1.105
oci 2.132.0
omegaconf 2.3.0
packaging 24.1
pandas 2.2.2
paramiko 3.4.1
pathtools 0.1.2
pillow 10.3.0
pip 25.2
platformdirs 4.2.2
pluggy 1.5.0
portalocker 2.10.1
prompt-toolkit 3.0.36
proto-plus 1.26.1
protobuf 4.25.3
psutil 6.0.0
py-cpuinfo 9.0.0
pyarrow 17.0.0
pyasn1 0.6.0
pyasn1_modules 0.4.0
pycairo 1.26.1
pycparser 2.22
Pygments 2.18.0
PyJWT 2.9.0
PyNaCl 1.5.0
pyOpenSSL 23.3.0
pyparsing 3.1.2
PySide6 6.7.2
PySide6_Addons 6.7.2
PySide6_Essentials 6.7.2
PySocks 1.7.1
pytest 8.3.2
pytest-xdist 3.6.1
python-dateutil 2.9.0
python-slugify 8.0.4
python-snappy 0.7.2
pytorch-ranger 0.1.1
pytz 2024.1
pyu2f 0.1.5
PyYAML 6.0.2
questionary 2.0.1
regex 2024.7.24
requests 2.32.3
responses 0.18.0
rich 13.7.1
rsa 4.9
ruamel.yaml 0.18.6
ruamel.yaml.clib 0.2.8
s3transfer 0.10.2
safetensors 0.4.4
sentencepiece 0.2.1
sentry-sdk 2.13.0
setproctitle 1.3.3
setuptools 72.2.0
shellingham 1.5.4
shiboken6 6.7.2
six 1.16.0
smmap 5.0.0
sniffio 1.3.1
sympy 1.13.2
tabulate 0.9.0
termcolor 2.4.0
text-unidecode 1.3
tokenizers 0.21.4
tomli 2.0.1
torch 2.4.0
torch-optimi 0.2.1
torch-optimizer 0.3.0
torchaudio 2.4.0
torchmetrics 1.4.0.post0
torchvision 0.19.0
tornado 6.4.1
tqdm 4.66.5
transformers 4.48.0
triton 3.0.0 1
typer 0.12.4
types-python-dateutil 2.9.0.20250708
typing_extensions 4.12.2
tzdata 2025.2
urllib3 1.26.20
validators 0.33.0
wandb 0.16.6
wcwidth 0.2.13
websockets 11.0.3
wheel 0.44.0
xxhash 3.5.0
yarl 1.9.4
zipp 3.20.0
zstandard 0.23.0
zstd 1.5.5.1