diff --git a/README.md b/README.md index b462662..1192fc1 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ python examples/data/shakespeare.py And finally, let's train a GPT: ```bash -python examples/train-GPT.py +python examples/train-gpt.py ``` This runs on CPU and should get train loss: 1.65 and test loss: 1.80 after 2000 iterations. diff --git a/examples/train-gpt.py b/examples/train-gpt.py index a05587e..ea327f9 100644 --- a/examples/train-gpt.py +++ b/examples/train-gpt.py @@ -1,3 +1,4 @@ +import time import torch import numpy as np @@ -11,14 +12,66 @@ d_value = 32 num_blocks = 4 +# Llama-7b-like values, excluding the vocabulary size. +vocab_size = 256 +context = 1024 +num_heads = 32 +d_embed = 4096 +d_query = 128 +d_value = 128 +num_blocks = 4 + +GPU_16BIT_FLOPS = { + "h100-sxm": 1.979e15 / 2, + "h100-pcie": 1.513e15 / 2, + "a100": 312e12, + "v100-sxm": 125e12, + "6000A": 364.25e12, + "4090": 165.2 * 10**12, + "3090": 71 * 10**12, + "t4": 65e12, +} +def xf_layer_fwd_flops(slen: int, bs: int=1, causal=True) -> int: + p_mlp = d_embed * 4 * d_embed * 2 + f_mlp = p_mlp * 2 * slen + + assert d_query == d_value, "Dq != Dv not implemented" + p_att = 4 * d_embed * d_embed + f_att = p_att * 2 * slen + f_sdpa = 4 * slen * slen * d_embed // (2 if causal else 1) # approximation + + return (f_mlp + f_att + f_sdpa) * bs + +def gpt_train_flops(slen: int, bs: int, causal=True) -> int: + # lmhead layer: + flops = 6 * slen * bs * d_embed * vocab_size + # assume no activation checkpointing + flops += num_blocks * xf_layer_fwd_flops(slen, bs, causal) * 3 + return flops + +class SpeedLogger: + def __init__(self, ideal_flops_per_sec: float): + self.tps = [] + self.mfu = [] + self.fps = ideal_flops_per_sec + + def add(self, slen: int, bs: int, duration: float) -> tuple[float,float]: + flops = gpt_train_flops(slen, bs) + self.tps.append(slen*bs / duration) + self.mfu.append(flops / duration / self.fps) + return self.tps[-1], self.mfu[-1] + + def ave(self): + return sum(self.tps) / len(self.tps), sum(self.mfu) / len(self.mfu) + # training hparams init_lr = 0.5 wd = 0.01 -batch_size = 12 +batch_size = 2 # 12 steps = 2001 eval_steps = 100 -log_interval = 200 +log_interval = 10 # 200 # let's start by defining our GPT architecture # (we could instead just import GPT from modula.compound) @@ -80,8 +133,9 @@ def __len__(self): # now let's start doing stuff -if __name__ == "__main__": +@torch.cuda.amp.autocast(dtype=torch.bfloat16) +def train(device, ideal_flops_per_sec): # load the data trainset = SimpleLLMDataset(np.memmap("examples/data/shakespeare/train.bin", dtype=np.uint16, mode='r'), context) @@ -96,12 +150,18 @@ def __len__(self): train_iterator = iter(train_loader) test_iterator = iter(test_loader) - getBatch = lambda train: next(train_iterator if train else test_iterator) + def getBatch(train: bool) -> list: + res = next(train_iterator if train else test_iterator) + return [t.to(device=device) for t in res] # load the model gpt = GPT(vocab_size, context, num_heads, d_embed, d_query, d_value, num_blocks) - weights = gpt.initialize(device="cpu") + weights = gpt.initialize(device=device) + gpt.forward = torch.compile(gpt.forward) + # gpt.normalize = torch.compile(gpt.normalize) + # gpt.regularize = torch.compile(gpt.regularize) + # init_lr_t = torch.tensor(init_lr, device=device) # initialize the Adam state @@ -114,6 +174,8 @@ def __len__(self): # train the model + speed_logger = SpeedLogger(ideal_flops_per_sec) + for step in range(steps): if step % log_interval == 0: @@ -131,6 +193,7 @@ def __len__(self): test_loss /= eval_steps test_acc /= eval_steps + t0 = time.time() data, target = getBatch(train = True) output = gpt.forward(data, weights) output = output.view(-1, output.size(-1)) @@ -159,7 +222,26 @@ def __len__(self): gpt.regularize(weights, strength = init_lr * schedule * wd) weights.zero_grad() - if step % log_interval == 0: - print( "step:", step, - "\t train loss:", "%.2f" % train_loss.item(), - "\t test loss:", "%.2f" % test_loss.item() ) + # avoid first compile && first recompile + if step > 1: + speed_logger.add(*data.shape, time.time() - t0) + + if step > 1 and step % log_interval == 0: + tps, mfu = speed_logger.ave() + print( + "step:", step, + "\t train loss:", "%.2f" % train_loss.item(), + "\t test loss:", "%.2f" % test_loss.item(), + f"\t tokens/gpu/sec: {tps:.2f}", + f"\t MFU: {mfu*100:.2f}%", + ) + + +if __name__ == "__main__": + import argparse + ap = argparse.ArgumentParser() + ap.add_argument('--cuda', action='store_true') + args = ap.parse_args() + + torch.set_float32_matmul_precision("medium") + train('cuda' if args.cuda else 'cpu', GPU_16BIT_FLOPS['3090'])