Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 62 additions & 15 deletions src/models/llama_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,60 @@ IRunState& LLamaModel::get_run_state() const {
return *RunState;
}

std::size_t LLamaModel::num_block_tensors() const {
return 7;
}

void LLamaModel::fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config,
ETensorDType matrix_dtype, ETensorDType other_dtype) const
{
long C = config.HiddenSize;
long H = config.IntermediateSize;
long HS = config.head_size();

auto create = [&](Tensor& tgt, long rows, long cols, ETensorDType dtype) {
tgt.Rank = cols != 0 ? 2 : 1;
tgt.DType = dtype;
tgt.Sizes[0] = rows;
tgt.Sizes[1] = cols;
};

long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * HS;
create(target.get_tensor(LLamaWeightID::QKV_W), attn_intermediate_size, C, matrix_dtype);
create(target.get_tensor(LLamaWeightID::ATTO_W), C, C, matrix_dtype);
create(target.get_tensor(LLamaWeightID::UP_W), 2 * H, C, matrix_dtype);
create(target.get_tensor(LLamaWeightID::DOWN_W), C, H, matrix_dtype);

create(target.get_tensor(LLamaWeightID::LN1_W), C, 0, other_dtype);
create(target.get_tensor(LLamaWeightID::LN2_W), C, 0, other_dtype);
create(target.get_tensor(LLamaWeightID::QKV_B), config.UseQKVBias ? attn_intermediate_size : 0, 0, other_dtype);
}

std::size_t LLamaModel::num_non_block_tensors() const {
return 3;
}

void LLamaModel::fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config,
ETensorDType matrix_dtype, ETensorDType other_dtype) const
{
long V = config.VocabSize;
long C = config.HiddenSize;

auto create = [&](Tensor& tgt, long rows, long cols, ETensorDType dtype) {
tgt.Rank = cols != 0 ? 2 : 1;
tgt.DType = dtype;
tgt.Sizes[0] = rows;
tgt.Sizes[1] = cols;
};

create(target.get_tensor(LLamaWeightID::EMBEDDING), V, C, matrix_dtype);
create(target.get_tensor(LLamaWeightID::LNF_W), C, 0, other_dtype);
if(!config.TiedWordEmbeddings) {
create(target.get_tensor(LLamaWeightID::LM_HEAD), V, C, matrix_dtype);
}
}


void LLamaModel::_calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip, cudaStream_t stream) {
auto& rs = RunState;

Expand Down Expand Up @@ -756,15 +810,15 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_
learning_rate, beta_1, beta_2, t, epsilon, wd, grad_scale, scales, val.abs_max(), rng(), main_stream);
};

auto& m_scales = OptimizerState->scales_m();
auto& nb_scales = OptimizerState->non_block_m_scales();

using namespace LLamaWeightID;
run_update(Parameters->get_master_embeddings(), Grads->get_embeddings_shard(main_stream),
OptimizerState->non_block_m().get_tensor(EMBEDDING), OptimizerState->non_block_v().get_tensor(EMBEDDING),
m_scales.NonBlocks.Embeddings, weight_decay);
nb_scales.get_tensor(EMBEDDING), weight_decay);
comm.reduce_max(Parameters->get_master_embeddings().abs_max());
run_update(Parameters->get_master_lnf_w(), Grads->get_lnf_w_shard(main_stream),
OptimizerState->non_block_m().get_tensor(LNF_W), OptimizerState->non_block_v().get_tensor(LNF_W), m_scales.NonBlocks.LNF_w, 0.f);
OptimizerState->non_block_m().get_tensor(LNF_W), OptimizerState->non_block_v().get_tensor(LNF_W), nb_scales.get_tensor(LNF_W), 0.f);
comm.reduce_max(Parameters->get_master_lnf_w().abs_max());
CUDA_CHECK(cudaEventRecord(rs->OptEmbeddingsDone, main_stream));

Expand Down Expand Up @@ -806,7 +860,7 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_

if(!Config.TiedWordEmbeddings) {
run_update(Parameters->get_master_lmhead(), Grads->get_lmhead_shard(main_stream),
OptimizerState->non_block_m().get_tensor(LM_HEAD), OptimizerState->non_block_v().get_tensor(LM_HEAD), m_scales.NonBlocks.LMHead, weight_decay);
OptimizerState->non_block_m().get_tensor(LM_HEAD), OptimizerState->non_block_v().get_tensor(LM_HEAD), nb_scales.get_tensor(LM_HEAD), weight_decay);
comm.reduce_max(Parameters->get_master_lmhead().abs_max());
}
comm.wait_on_comms(main_stream);
Expand Down Expand Up @@ -834,7 +888,8 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato
acts = ::allocate_run_state(Config, options, B, T, stack, Allocator);
}

OptimizerState = std::make_unique<LLamaOptimizerStateManager>(Config, options, acts.MainStream, comm, *Allocator);
OptimizerState = std::make_unique<LLamaOptimizerStateManager>(Config, *this, options, comm);
OptimizerState->allocate_state(*this, acts.MainStream, options.offload_alloc(), *Allocator);

Parameters->begin_optimizer(stack, comm.stream());
OptimizerState->begin_optimizer(stack, comm.stream());
Expand Down Expand Up @@ -862,16 +917,8 @@ ITensorContainer& LLamaModel::weights() {
return *Parameters;
}

ITensorContainer& LLamaModel::opt_momentum() {
return OptimizerState->full_m();
}

ITensorContainer& LLamaModel::opt_momentum_scales() {
return OptimizerState->scales_m();
}

ITensorContainer& LLamaModel::opt_variance() {
return OptimizerState->full_v();
AdamWStateManager& LLamaModel::optimizer() {
return *OptimizerState;
}

std::vector<std::byte> LLamaModel::rng_state() const {
Expand Down
8 changes: 5 additions & 3 deletions src/models/llama_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ class LLamaModel : public IModel {
void calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip);

ITensorContainer& weights() override;
ITensorContainer& opt_momentum() override;
ITensorContainer& opt_momentum_scales() override;
ITensorContainer& opt_variance() override;
AdamWStateManager& optimizer() override;
std::vector<std::byte> rng_state() const override;
void set_rng_state(const std::vector<std::byte>& state) override;
std::string_view model_type() const override;
Expand All @@ -119,6 +117,10 @@ class LLamaModel : public IModel {

IRunState& get_run_state() const override;

std::size_t num_block_tensors() const override;
void fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const override;
std::size_t num_non_block_tensors() const override;
void fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const override;
protected:
void _calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip, cudaStream_t stream);
void _reduce_loss(LLamaRunState& acts, NCCLCommunicator& comm, int B, int T);
Expand Down
186 changes: 53 additions & 133 deletions src/models/llama_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,156 +4,76 @@


#include "llama_optimizer.h"

#include <fmt/format.h>

#include "training/transformer_config.h"
#include "llama_model.h"
#include "utilities/comm.h"
#include "kernels/kernels.h"
#include "utilities/stack.h"
#include "utilities/lazy_allocator.h"

SimpleTensorContainer& LLamaOptimizerStateManager::get_block_m(int layer_idx, cudaStream_t stream) {
if(!mOffloadM || mUseZeroCopy) return mOptM.Blocks[layer_idx];
return get_block_from(layer_idx, stream, mOptMBuffer.at(layer_idx % 2));
}

SimpleTensorContainer& LLamaOptimizerStateManager::get_block_v(int layer_idx, cudaStream_t stream) {
if(!mOffloadV || mUseZeroCopy) return mOptV.Blocks[layer_idx];
return get_block_from(layer_idx, stream, mOptVBuffer.at(layer_idx % 2));
}

SimpleTensorContainer& LLamaOptimizerStateManager::non_block_m() {
return mOptM.NonBlocks;
}

SimpleTensorContainer& LLamaOptimizerStateManager::non_block_v() {
return mOptV.NonBlocks;
}

void zero_opt_non_block(sLLamaWeights& weights, cudaStream_t stream) {
// here's the first disadvantage of having individual buffers: We need to make a ton of memset calls
fill_zero(weights.NonBlocks.Embeddings, stream);
fill_zero(weights.NonBlocks.LNF_w, stream);
if(weights.NonBlocks.LMHead.Data != weights.NonBlocks.Embeddings.Data) {
fill_zero(weights.NonBlocks.LMHead, stream);
#include "utilities/safetensors.h"

struct OptStateWrapper : ITensorContainer {
void iterate_tensors(const std::function<void(std::string, const TensorShard&)>& callback) override;
std::vector<GenericTensorContainer>* Blocks;
GenericTensorContainer* NonBlock;
OptStateWrapper() = default;
OptStateWrapper(std::vector<GenericTensorContainer>* b, GenericTensorContainer* nb) : Blocks(b), NonBlock(nb) {};
};

void OptStateWrapper::iterate_tensors(const std::function<void(std::string, const TensorShard&)>& callback) {
callback("model.embed_tokens.weight", NonBlock->get_tensor(LLamaWeightID::EMBEDDING));
if(NonBlock->get_tensor(LLamaWeightID::LM_HEAD)) {
callback("lm_head.weight", NonBlock->get_tensor(LLamaWeightID::LM_HEAD));
}
}

sLLamaWeights allocate_scales(TransformerConfig config, int shard_idx, int num_shards, TensorAllocator& alloc) {
long C = config.HiddenSize;
long V = config.VocabSize;
long H = config.IntermediateSize;
long head_size = C / config.NumQueryHeads;
long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * head_size;

sLLamaWeights result;
result.Blocks.resize(config.NumLayers);
for(auto& block : result.Blocks) {
block.Attn_QKV_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "att_qkv_w", {div_exact(attn_intermediate_size * C, 128l)}, EAllocationType::ON_DEVICE);
block.Attn_Out_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "attproj_w", {div_exact(C * C, 128l)}, EAllocationType::ON_DEVICE);
block.MLP_Up_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "mlp_up_w", {div_exact(2 * H * C, 128l)}, EAllocationType::ON_DEVICE);
block.MLP_Down_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "mlp_down_w", {div_exact(C * H, 128l)}, EAllocationType::ON_DEVICE);

block.LN1_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "ln1_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE);
block.LN2_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "ln2_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE);
if(config.UseQKVBias) {
block.Attn_QKV_b = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "att_qkv_b", {div_exact(attn_intermediate_size, 128l)}, EAllocationType::ON_DEVICE);
} else {
block.Attn_QKV_b = Tensor{};
callback("model.norm.weight", NonBlock->get_tensor(LLamaWeightID::LNF_W));

for(int i = 0; i < Blocks->size(); i++) {
auto& layer = Blocks->at(i);
const Tensor& qkv_w = layer.get_tensor(LLamaWeightID::QKV_W);
const Tensor& up_proj = layer.get_tensor(LLamaWeightID::UP_W);
std::string prefix = "model.layers." + std::to_string(i);
callback(prefix + ".self_attn.qkv.weight", qkv_w);
if (layer.get_tensor(LLamaWeightID::QKV_B)) {
callback(prefix + ".self_attn.qkv.bias", layer.get_tensor(LLamaWeightID::QKV_B));
}

visit([](Tensor& t){
fill_constant(t, 1.f, t.nelem(), nullptr);
}, block);
}
result.NonBlocks.Embeddings = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "embeddings", {div_exact(V * C, 128l)}, EAllocationType::ON_DEVICE);
result.NonBlocks.LNF_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards,"lnf_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE);
fill_constant(result.NonBlocks.Embeddings, 1.f, result.NonBlocks.Embeddings.nelem(), nullptr);
fill_constant(result.NonBlocks.LNF_w, 1.f, result.NonBlocks.LNF_w.nelem(), nullptr);
if(config.TiedWordEmbeddings) {
result.NonBlocks.LMHead = result.NonBlocks.Embeddings;
} else {
result.NonBlocks.LMHead = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "lmhead", {div_exact(V * C, 128l)}, EAllocationType::ON_DEVICE);
fill_constant(result.NonBlocks.LMHead, 1.f, result.NonBlocks.LMHead.nelem(), nullptr);
}
return result;
}

std::vector<Tensor> allocate_weights_opt(sLLamaWeights& weights, const TransformerConfig& config, ETensorDType dtype, EAllocationType kind, int shard_idx, int num_shards, TensorAllocator& alloc) {
std::vector<Tensor> result;
weights.Blocks.resize(config.NumLayers);
LazyAllocator alloc_lazy;
for(auto& block : weights.Blocks) {
fill_matrix_shapes(block, config, dtype, shard_idx, num_shards);
fill_non_matrix_shapes(block, config, dtype, shard_idx, num_shards);
alloc_lazy.allocate(block);
result.push_back(alloc_lazy.commit(alloc, kind, "block_shard"));
callback(prefix + ".self_attn.o_proj.weight", layer.get_tensor(LLamaWeightID::ATTO_W));
callback(prefix + ".mlp.up.weight", up_proj);
callback(prefix + ".mlp.down_proj.weight", layer.get_tensor(LLamaWeightID::DOWN_W));
callback(prefix + ".input_layernorm.weight", layer.get_tensor(LLamaWeightID::LN1_W));
callback(prefix + ".post_attention_layernorm.weight", layer.get_tensor(LLamaWeightID::LN2_W));
}
weights.NonBlocks = allocate_non_block_shard(config, dtype, kind, shard_idx, num_shards, alloc);
return result;
}


LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc):
AdamWStateManager(cfg, options.OffloadOptM, options.OffloadOptV, options.UseZeroCopy, comm.rank(), comm.world_size())
LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, NCCLCommunicator& comm):
AdamWStateManager(cfg, model, options.OffloadOptM, options.OffloadOptV, options.OptMomentumType, options.OptVarianceType, options.UseZeroCopy, comm.rank(), comm.world_size())
{
mMType = options.OptMomentumType;
mVType = options.OptVarianceType;

{
auto ctx = alloc.with_context("Adam M");
EAllocationType alloc_type = options.OffloadOptM ? options.offload_alloc() : EAllocationType::ON_DEVICE;
mMBlockStorage = allocate_weights_opt(mOptM, cfg, mMType, alloc_type, comm.rank(), comm.world_size(), alloc);
for(auto& block : mMBlockStorage) {
fill_zero(block, stream);
}
zero_opt_non_block(mOptM, stream);

if(mMType == ETensorDType::FP8_E4M3) {
mOptMScales = allocate_scales(cfg, comm.rank(), comm.world_size(), alloc);
} else {
mOptMScales.Blocks.resize(cfg.NumLayers);
}
}

{
auto ctx = alloc.with_context("Adam V");
EAllocationType alloc_type = options.OffloadOptV ? options.offload_alloc() : EAllocationType::ON_DEVICE;
mVBlockStorage = allocate_weights_opt(mOptV, cfg, mVType, alloc_type, comm.rank(), comm.world_size(), alloc);
for(auto& block : mVBlockStorage) {
fill_zero(block, stream);
}
zero_opt_non_block(mOptV, stream);
}

if((options.OffloadOptM || options.OffloadOptV) && !mUseZeroCopy) {
mStatus[0] = sBufferStatus{-1, create_named_event("opt_fetch_0"), false, true};
mStatus[1] = sBufferStatus{-1, create_named_event("opt_fetch_1"), false, true};
}

if(mOffloadM && !mUseZeroCopy) {
fill_matrix_shapes(mOptMBuffer[0], mConfig, mMType, mRank, mWorld);
fill_non_matrix_shapes(mOptMBuffer[0], mConfig, mMType, mRank, mWorld);
fill_matrix_shapes(mOptMBuffer[1], mConfig, mMType, mRank, mWorld);
fill_non_matrix_shapes(mOptMBuffer[1], mConfig, mMType, mRank, mWorld);
}
}

if(mOffloadV && !mUseZeroCopy) {
fill_matrix_shapes(mOptVBuffer[0], mConfig, mVType, mRank, mWorld);
fill_non_matrix_shapes(mOptVBuffer[0], mConfig, mVType, mRank, mWorld);
fill_matrix_shapes(mOptVBuffer[1], mConfig, mVType, mRank, mWorld);
fill_non_matrix_shapes(mOptVBuffer[1], mConfig, mVType, mRank, mWorld);
void LLamaOptimizerStateManager::safe_to_checkpoint(const std::string& checkpoint_dir) {
OptStateWrapper m_state{&mBlocksM, &mNonBlockM};
OptStateWrapper v_state{&mBlocksV, &mNonBlockV};

write_safetensors(checkpoint_dir + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_state);
write_safetensors(checkpoint_dir + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), v_state);
if (mMType == ETensorDType::FP8_E4M3) {
OptStateWrapper m_scales{&mBlocksMScales, &mNonBlockMScales};
write_safetensors(checkpoint_dir + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_scales);
}
}

SimpleTensorContainer& LLamaOptimizerStateManager::get_block_scales_m(int layer_idx) {
return mOptMScales.Blocks.at(layer_idx);
}
void LLamaOptimizerStateManager::load_from_checkpoint(const std::string& checkpoint_dir) {
OptStateWrapper m_state{&mBlocksM, &mNonBlockM};
OptStateWrapper v_state{&mBlocksV, &mNonBlockV};

SimpleTensorContainer& LLamaOptimizerStateManager::get_m_buffer(int idx) {
return mOptMBuffer.at(idx);
}
// load optimizer shards
load_safetensors(checkpoint_dir + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_state, false);
load_safetensors(checkpoint_dir + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), v_state, false);

SimpleTensorContainer& LLamaOptimizerStateManager::get_v_buffer(int idx) {
return mOptVBuffer.at(idx);
if (mMType == ETensorDType::FP8_E4M3) {
OptStateWrapper m_scales{&mBlocksMScales, &mNonBlockMScales};
load_safetensors(checkpoint_dir + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_scales, false);
}
}
27 changes: 3 additions & 24 deletions src/models/llama_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,13 @@

#include "llama_weights.h"
#include "training/adamw_optimizer.h"
#include <array>

class LLamaOptimizerStateManager : public AdamWStateManager {
public:
LLamaOptimizerStateManager(TransformerConfig cfg, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc);
SimpleTensorContainer& non_block_m() override;
SimpleTensorContainer& non_block_v() override;
LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, NCCLCommunicator& comm);

ITensorContainer& full_m() { return mOptM; }
ITensorContainer& full_v() { return mOptV; }
sLLamaWeights& scales_m() { return mOptMScales; }

SimpleTensorContainer& get_block_m(int layer_idx, cudaStream_t stream) override;
SimpleTensorContainer& get_block_v(int layer_idx, cudaStream_t stream) override;
SimpleTensorContainer& get_block_scales_m(int layer_idx) override;
private:
// mOptM.Blocks[i] and mMBlockStorage[i] alias the same memory.
// mOptM provides convenient access to the individual tensors of a block, whereas
// mMBlockStorage has just one large, byte-typed buffer for bulk transfers.
sLLamaWeights mOptM;
sLLamaWeights mOptV;
sLLamaWeights mOptMScales;

std::array<sLLamaBlockWeights<TensorShard>, 2> mOptMBuffer;
std::array<sLLamaBlockWeights<TensorShard>, 2> mOptVBuffer;

SimpleTensorContainer& get_m_buffer(int idx) override;
SimpleTensorContainer& get_v_buffer(int idx) override;
void safe_to_checkpoint(const std::string& checkpoint_dir) override;
void load_from_checkpoint(const std::string& checkpoint_dir) override;
};

#endif //LLMQ_SRC_MODELS_LLAMA_OPTIMIZER_H
Loading