From 18ac21dedff9d60bf4299bb72410f0b960ad6f09 Mon Sep 17 00:00:00 2001 From: Brandon Ban Date: Sun, 14 Dec 2025 15:26:28 +0800 Subject: [PATCH 1/2] Add gradient accumulation support and update README with usage instructions --- F2LLM/arguments.py | 2 + F2LLM/run.py | 9 +- F2LLM/smoke_test_accumulation.py | 161 ++++++++++++++++++++++++++++ F2LLM/test_gradient_accumulation.py | 31 ++++++ F2LLM/utils.py | 33 +++--- README.md | 23 ++++ 6 files changed, 245 insertions(+), 14 deletions(-) create mode 100644 F2LLM/smoke_test_accumulation.py create mode 100644 F2LLM/test_gradient_accumulation.py diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py index b967c8f..44ba330 100644 --- a/F2LLM/arguments.py +++ b/F2LLM/arguments.py @@ -21,6 +21,8 @@ class Args: warmup_steps: int = 100 # embedding-related settings num_hard_neg: int = 7 + # gradient accumulation to simulate larger effective batch size + gradient_accumulation_steps: int = 1 # train steps take precedence over epochs, set to -1 to disable train_steps: int = -1 train_epochs: int = 5 diff --git a/F2LLM/run.py b/F2LLM/run.py index e40b707..1ab254d 100644 --- a/F2LLM/run.py +++ b/F2LLM/run.py @@ -115,7 +115,10 @@ def __iter__(self): # determine training steps override_train_step = False if args.train_steps < 0: - args.train_steps = sum(len(v) for v in train_loaders.values()) * args.train_epochs + # interpret train_steps as optimization steps (after accumulation) + total_micro_batches = sum(len(v) for v in train_loaders.values()) * args.train_epochs + accum = max(1, getattr(args, 'gradient_accumulation_steps', 1)) + args.train_steps = total_micro_batches // accum override_train_step = True accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************") @@ -145,7 +148,9 @@ def __iter__(self): # if training on multiple GPUs, length of dataloader would have changed if override_train_step: - args.train_steps = len(train_dataloader) * args.train_epochs + total_micro_batches = len(train_dataloader) * args.train_epochs + accum = max(1, getattr(args, 'gradient_accumulation_steps', 1)) + args.train_steps = total_micro_batches // accum accelerator.print(f"******************************** Training step after prepare: {args.train_steps} ********************************") diff --git a/F2LLM/smoke_test_accumulation.py b/F2LLM/smoke_test_accumulation.py new file mode 100644 index 0000000..a4da5ed --- /dev/null +++ b/F2LLM/smoke_test_accumulation.py @@ -0,0 +1,161 @@ +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader +from accelerate import Accelerator +from tqdm import tqdm + +# Minimal tokenizer-like object +class DummyTokenizer: + def __init__(self, pad_token_id=0): + self.pad_token_id = pad_token_id + +# Dummy model implementing required interface +class DummyModel: + def __init__(self, hidden_size=32, tokenizer=None, device="cpu"): + self.tokenizer = tokenizer or DummyTokenizer() + self.lm = nn.Sequential( + nn.Embedding(30522, hidden_size), + nn.Linear(hidden_size, hidden_size) + ) + self._device = torch.device(device) + self.lm.to(self._device) + + def set_device(self): + self._device = next(self.lm.parameters()).device + + @property + def device(self): + return self._device + + def forward(self, batch): + input_ids = batch['input_ids'].to(self.device) # [bs_total, seq] + attention_mask = batch['attention_mask'].to(self.device) + bs = batch['bs'] + # Compute simple pooled features + emb = self.lm[0](input_ids) # [bs_total, seq, h] + pooled = (emb * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp_min(1.0) + # split back into query/passages/negatives + num_hard = 1 # keep simple + q = pooled[:bs] + p = pooled[bs:2*bs] + negs = pooled[2*bs:2*bs+bs*num_hard].view(bs, num_hard, -1) + return { + 'query_passage_features': q.unsqueeze(1), # [bs,1,h] + 'passage_passage_features': p.unsqueeze(1), # [bs,1,h] + 'negative_passage_features': negs # [bs,num_hard,h] + } + +class SyntheticDataset(Dataset): + def __init__(self, length=64, seq_len=16, vocab=100): + self.length = length + self.seq_len = seq_len + self.vocab = vocab + + def __len__(self): + return self.length + + def __getitem__(self, idx): + def rand_ids(): + return [torch.randint(1, self.vocab, ()).item() for _ in range(self.seq_len)] + return { + 'query_input_ids': rand_ids(), + 'passage_input_ids': rand_ids(), + 'negative_1_input_ids': rand_ids(), + 'dataset_name': 'msmarco' + } + +def _stack(input_ids, max_len, pad_id): + data = [ids[:max_len] for ids in input_ids] + lens = [len(x) for x in data] + tensor = torch.tensor(sum(data, [])) + chunks = tensor.split(lens) + return chunks + +def collate_fn(batch_raw, max_seq_length=32, tokenizer=None): + tokenizer = tokenizer or DummyTokenizer() + num_hard_neg = 1 + input_ids = _stack( + [s['query_input_ids'] for s in batch_raw]+ + [s['passage_input_ids'] for s in batch_raw]+ + [s[f'negative_1_input_ids'] for s in batch_raw], + max_seq_length, + tokenizer.pad_token_id + ) + seqlens = torch.tensor([ids.size(0) for ids in input_ids]) + # pad to batch + input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + attention_masks = input_ids.ne(tokenizer.pad_token_id).long() + return { + 'input_ids': input_ids, + 'seq_lens': seqlens, + 'attention_mask': attention_masks, + 'bs': len(batch_raw), + 'dataset_name': batch_raw[0]['dataset_name'] + } + +# Minimal loss helpers adapted from utils.py +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +def inbatch_loss(q, c, criterion, accelerator, temperature=0.05): + bs = q.size(0) + a_norm = F.normalize(q, p=2, dim=-1) + b_cross = accelerator.gather(c) + b_norm = F.normalize(b_cross, p=2, dim=-1) + logits = torch.matmul(a_norm, b_norm.t()) / temperature + labels = torch.arange(bs, device=logits.device) + bs * accelerator.process_index + loss_bs = criterion(logits, labels) + return loss_bs.mean() + +def hard_loss(q, c, negs, criterion, accelerator, temperature=0.05): + if negs is None: + return torch.tensor(0.0, device=q.device) + bs = q.size(0) + a = F.normalize(q, p=2, dim=-1) + hard = torch.concat([c.unsqueeze(1), negs], dim=1) + hard = F.normalize(hard, p=2, dim=-1) + logits = (a.unsqueeze(1) * hard).sum(-1) / temperature + return criterion(logits, torch.zeros((bs), dtype=torch.long, device=logits.device)).mean() + + +def main(): + accelerator = Accelerator() + tokenizer = DummyTokenizer() + model = DummyModel(tokenizer=tokenizer, device="cpu") + model.set_device() + + ds = SyntheticDataset(length=32, seq_len=8, vocab=100) + loader = DataLoader(ds, batch_size=4, shuffle=True, collate_fn=lambda b: collate_fn(b, max_seq_length=16, tokenizer=tokenizer)) + loader = accelerator.prepare(loader) + + optimizer = torch.optim.SGD(model.lm.parameters(), lr=0.01) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0) + criterion = CrossEntropyLoss(reduction='none') + + accumulation_steps = 4 + total_micro = len(loader) + expected_opt_steps = total_micro // accumulation_steps + completed = 0 + local_accum = 0 + + for batch in tqdm(loader, disable=not accelerator.is_local_main_process): + out = model.forward(batch) + loss_h = hard_loss(out['query_passage_features'].squeeze(1), out['passage_passage_features'].squeeze(1), out['negative_passage_features'], criterion, accelerator) + loss_ib = inbatch_loss(out['query_passage_features'].squeeze(1), out['passage_passage_features'].squeeze(1), criterion, accelerator) + loss = (loss_h + loss_ib) / accumulation_steps + accelerator.backward(loss) + local_accum += 1 + if local_accum % accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + completed += 1 + if completed >= expected_opt_steps: + break + + print(f"Optimization steps: {completed} (expected {expected_opt_steps})") + assert completed == expected_opt_steps, "Accumulation did not match expected steps" + print("Smoke test passed.") + +if __name__ == "__main__": + main() diff --git a/F2LLM/test_gradient_accumulation.py b/F2LLM/test_gradient_accumulation.py new file mode 100644 index 0000000..8e646f0 --- /dev/null +++ b/F2LLM/test_gradient_accumulation.py @@ -0,0 +1,31 @@ +import torch +from torch import nn + + +def run_accumulation_test(accumulation_steps=4, micro_batches=12): + torch.manual_seed(0) + model = nn.Linear(10, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0) + + steps = 0 + optimizer.zero_grad() + for i in range(micro_batches): + x = torch.randn(8, 10) + y = torch.randn(8, 1) + out = model(x) + loss = nn.functional.mse_loss(out, y) + (loss / accumulation_steps).backward() + if (i + 1) % accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + steps += 1 + return steps + + +if __name__ == "__main__": + s = run_accumulation_test(accumulation_steps=4, micro_batches=12) + print(f"Optimization steps: {s} (expected 3)") + assert s == 3, f"Expected 3 optimization steps, got {s}" + print("Gradient accumulation test passed.") diff --git a/F2LLM/utils.py b/F2LLM/utils.py index b167d3c..68cd30b 100644 --- a/F2LLM/utils.py +++ b/F2LLM/utils.py @@ -137,6 +137,8 @@ def accelerate_train(args, criterion = CrossEntropyLoss(reduction='none') pbar = tqdm(range(args.train_steps), disable=not accelerator.is_local_main_process) completed_steps = 0 + accumulation_steps = max(1, getattr(args, 'gradient_accumulation_steps', 1)) + local_accum_counter = 0 loss_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} loss_hard_dict = {ds_name: torch.tensor(0.0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} count_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in RETRIEVAL_DATASETS} @@ -164,19 +166,26 @@ def accelerate_train(args, loss = 0.0 loss_total = loss + loss_hard + # scale loss for gradient accumulation + loss_total = loss_total / accumulation_steps # backward, optimizer, scheduler accelerator.backward(loss_total) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - if optimizer.param_groups[0]['lr'] < args.min_lr: - for i in range(len(optimizer.param_groups)): - optimizer.param_groups[i]['lr'] = args.min_lr - - # log - completed_steps += 1 - if completed_steps % args.log_interval == 0: + local_accum_counter += 1 + stepped = False + if local_accum_counter % accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + stepped = True + if optimizer.param_groups[0]['lr'] < args.min_lr: + for i in range(len(optimizer.param_groups)): + optimizer.param_groups[i]['lr'] = args.min_lr + + # log only on optimization steps + if stepped: + completed_steps += 1 + if completed_steps % args.log_interval == 0 and completed_steps > 0: pbar.update(args.log_interval) train_log_dict = {"lr": optimizer.param_groups[0]['lr']} @@ -202,13 +211,13 @@ def accelerate_train(args, count_hard_dict = {ds_name: torch.tensor(0, device=model.lm.device) for ds_name in train_dataloader.loader_dict.keys()} # validation - if completed_steps % args.validation_steps == 0: + if completed_steps % args.validation_steps == 0 and completed_steps > 0: model.lm.eval() validate(args, accelerator, model, valid_loader_dict, criterion, completed_steps, summary_writer) model.lm.train() # step checkpoint - if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: + if args.checkpointing_steps and completed_steps > 0 and completed_steps % args.checkpointing_steps == 0: output_dir = os.path.join(args.output_dir, f"step_{completed_steps}") save_checkpoint(args, accelerator, model, output_dir, lr_scheduler) diff --git a/README.md b/README.md index 0a2dea2..015aa16 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,29 @@

+### Gradient Accumulation + +To train with larger effective batch sizes on limited GPU memory, we added gradient accumulation. + +- New config key: `gradient_accumulation_steps` (default: 1) +- Effective global batch size: `train_batch_size * gradient_accumulation_steps * num_processes` +- `train_steps` represent optimization steps (after accumulation). When not set, they are computed as `total_micro_batches * train_epochs // gradient_accumulation_steps`. + +Usage: + +1. Set in your config JSON: + - `"gradient_accumulation_steps": 8` +2. Run training as usual with `F2LLM/run.py`. + +Quick Tests (no real data required): + +```bash +python F2LLM/test_gradient_accumulation.py +python F2LLM/smoke_test_accumulation.py +``` + +The first verifies optimizer step counts; the second runs a small synthetic pipeline on CPU with accumulation. +

Embedding-related repos from CodeFuse, including: From 8e6b078d9a789f24f16a9d4a1cbc1d6d65180ed3 Mon Sep 17 00:00:00 2001 From: Brandon Ban Date: Sun, 14 Dec 2025 15:28:04 +0800 Subject: [PATCH 2/2] Add demo configuration for gradient accumulation --- F2LLM/configs/demo_accumulation.json | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 F2LLM/configs/demo_accumulation.json diff --git a/F2LLM/configs/demo_accumulation.json b/F2LLM/configs/demo_accumulation.json new file mode 100644 index 0000000..acdd51f --- /dev/null +++ b/F2LLM/configs/demo_accumulation.json @@ -0,0 +1,21 @@ +{ + "model_path": "bert-base-uncased", + "experiment_id": "demo-ga", + "output_dir": "output", + "tb_dir": "output/tb", + "cache_dir": "cache", + "train_data_path": "training_data/data_tokenized", + "train_batch_size": 2, + "max_seq_length": 128, + "learning_rate": 1e-4, + "min_lr": 1e-6, + "weight_decay": 1e-2, + "warmup_steps": 10, + "num_hard_neg": 1, + "train_steps": -1, + "train_epochs": 1, + "log_interval": 2, + "checkpointing_steps": 0, + "validation_steps": 1000000, + "gradient_accumulation_steps": 8 +}