diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index d57d350..9f92b67 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -697,6 +697,60 @@ IRunState& LLamaModel::get_run_state() const { return *RunState; } +std::size_t LLamaModel::num_block_tensors() const { + return 7; +} + +void LLamaModel::fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, + ETensorDType matrix_dtype, ETensorDType other_dtype) const +{ + long C = config.HiddenSize; + long H = config.IntermediateSize; + long HS = config.head_size(); + + auto create = [&](Tensor& tgt, long rows, long cols, ETensorDType dtype) { + tgt.Rank = cols != 0 ? 2 : 1; + tgt.DType = dtype; + tgt.Sizes[0] = rows; + tgt.Sizes[1] = cols; + }; + + long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * HS; + create(target.get_tensor(LLamaWeightID::QKV_W), attn_intermediate_size, C, matrix_dtype); + create(target.get_tensor(LLamaWeightID::ATTO_W), C, C, matrix_dtype); + create(target.get_tensor(LLamaWeightID::UP_W), 2 * H, C, matrix_dtype); + create(target.get_tensor(LLamaWeightID::DOWN_W), C, H, matrix_dtype); + + create(target.get_tensor(LLamaWeightID::LN1_W), C, 0, other_dtype); + create(target.get_tensor(LLamaWeightID::LN2_W), C, 0, other_dtype); + create(target.get_tensor(LLamaWeightID::QKV_B), config.UseQKVBias ? attn_intermediate_size : 0, 0, other_dtype); +} + +std::size_t LLamaModel::num_non_block_tensors() const { + return 3; +} + +void LLamaModel::fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, + ETensorDType matrix_dtype, ETensorDType other_dtype) const +{ + long V = config.VocabSize; + long C = config.HiddenSize; + + auto create = [&](Tensor& tgt, long rows, long cols, ETensorDType dtype) { + tgt.Rank = cols != 0 ? 2 : 1; + tgt.DType = dtype; + tgt.Sizes[0] = rows; + tgt.Sizes[1] = cols; + }; + + create(target.get_tensor(LLamaWeightID::EMBEDDING), V, C, matrix_dtype); + create(target.get_tensor(LLamaWeightID::LNF_W), C, 0, other_dtype); + if(!config.TiedWordEmbeddings) { + create(target.get_tensor(LLamaWeightID::LM_HEAD), V, C, matrix_dtype); + } +} + + void LLamaModel::_calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip, cudaStream_t stream) { auto& rs = RunState; @@ -756,15 +810,15 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_ learning_rate, beta_1, beta_2, t, epsilon, wd, grad_scale, scales, val.abs_max(), rng(), main_stream); }; - auto& m_scales = OptimizerState->scales_m(); + auto& nb_scales = OptimizerState->non_block_m_scales(); using namespace LLamaWeightID; run_update(Parameters->get_master_embeddings(), Grads->get_embeddings_shard(main_stream), OptimizerState->non_block_m().get_tensor(EMBEDDING), OptimizerState->non_block_v().get_tensor(EMBEDDING), - m_scales.NonBlocks.Embeddings, weight_decay); + nb_scales.get_tensor(EMBEDDING), weight_decay); comm.reduce_max(Parameters->get_master_embeddings().abs_max()); run_update(Parameters->get_master_lnf_w(), Grads->get_lnf_w_shard(main_stream), - OptimizerState->non_block_m().get_tensor(LNF_W), OptimizerState->non_block_v().get_tensor(LNF_W), m_scales.NonBlocks.LNF_w, 0.f); + OptimizerState->non_block_m().get_tensor(LNF_W), OptimizerState->non_block_v().get_tensor(LNF_W), nb_scales.get_tensor(LNF_W), 0.f); comm.reduce_max(Parameters->get_master_lnf_w().abs_max()); CUDA_CHECK(cudaEventRecord(rs->OptEmbeddingsDone, main_stream)); @@ -806,7 +860,7 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_ if(!Config.TiedWordEmbeddings) { run_update(Parameters->get_master_lmhead(), Grads->get_lmhead_shard(main_stream), - OptimizerState->non_block_m().get_tensor(LM_HEAD), OptimizerState->non_block_v().get_tensor(LM_HEAD), m_scales.NonBlocks.LMHead, weight_decay); + OptimizerState->non_block_m().get_tensor(LM_HEAD), OptimizerState->non_block_v().get_tensor(LM_HEAD), nb_scales.get_tensor(LM_HEAD), weight_decay); comm.reduce_max(Parameters->get_master_lmhead().abs_max()); } comm.wait_on_comms(main_stream); @@ -834,7 +888,8 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato acts = ::allocate_run_state(Config, options, B, T, stack, Allocator); } - OptimizerState = std::make_unique(Config, options, acts.MainStream, comm, *Allocator); + OptimizerState = std::make_unique(Config, *this, options, comm); + OptimizerState->allocate_state(*this, acts.MainStream, options.offload_alloc(), *Allocator); Parameters->begin_optimizer(stack, comm.stream()); OptimizerState->begin_optimizer(stack, comm.stream()); @@ -862,16 +917,8 @@ ITensorContainer& LLamaModel::weights() { return *Parameters; } -ITensorContainer& LLamaModel::opt_momentum() { - return OptimizerState->full_m(); -} - -ITensorContainer& LLamaModel::opt_momentum_scales() { - return OptimizerState->scales_m(); -} - -ITensorContainer& LLamaModel::opt_variance() { - return OptimizerState->full_v(); +AdamWStateManager& LLamaModel::optimizer() { + return *OptimizerState; } std::vector LLamaModel::rng_state() const { diff --git a/src/models/llama_model.h b/src/models/llama_model.h index f1c8275..44293b8 100644 --- a/src/models/llama_model.h +++ b/src/models/llama_model.h @@ -104,9 +104,7 @@ class LLamaModel : public IModel { void calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip); ITensorContainer& weights() override; - ITensorContainer& opt_momentum() override; - ITensorContainer& opt_momentum_scales() override; - ITensorContainer& opt_variance() override; + AdamWStateManager& optimizer() override; std::vector rng_state() const override; void set_rng_state(const std::vector& state) override; std::string_view model_type() const override; @@ -119,6 +117,10 @@ class LLamaModel : public IModel { IRunState& get_run_state() const override; + std::size_t num_block_tensors() const override; + void fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const override; + std::size_t num_non_block_tensors() const override; + void fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const override; protected: void _calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip, cudaStream_t stream); void _reduce_loss(LLamaRunState& acts, NCCLCommunicator& comm, int B, int T); diff --git a/src/models/llama_optimizer.cpp b/src/models/llama_optimizer.cpp index df7b352..bcad5e8 100644 --- a/src/models/llama_optimizer.cpp +++ b/src/models/llama_optimizer.cpp @@ -4,156 +4,76 @@ #include "llama_optimizer.h" + +#include + #include "training/transformer_config.h" #include "llama_model.h" #include "utilities/comm.h" #include "kernels/kernels.h" -#include "utilities/stack.h" #include "utilities/lazy_allocator.h" - -SimpleTensorContainer& LLamaOptimizerStateManager::get_block_m(int layer_idx, cudaStream_t stream) { - if(!mOffloadM || mUseZeroCopy) return mOptM.Blocks[layer_idx]; - return get_block_from(layer_idx, stream, mOptMBuffer.at(layer_idx % 2)); -} - -SimpleTensorContainer& LLamaOptimizerStateManager::get_block_v(int layer_idx, cudaStream_t stream) { - if(!mOffloadV || mUseZeroCopy) return mOptV.Blocks[layer_idx]; - return get_block_from(layer_idx, stream, mOptVBuffer.at(layer_idx % 2)); -} - -SimpleTensorContainer& LLamaOptimizerStateManager::non_block_m() { - return mOptM.NonBlocks; -} - -SimpleTensorContainer& LLamaOptimizerStateManager::non_block_v() { - return mOptV.NonBlocks; -} - -void zero_opt_non_block(sLLamaWeights& weights, cudaStream_t stream) { - // here's the first disadvantage of having individual buffers: We need to make a ton of memset calls - fill_zero(weights.NonBlocks.Embeddings, stream); - fill_zero(weights.NonBlocks.LNF_w, stream); - if(weights.NonBlocks.LMHead.Data != weights.NonBlocks.Embeddings.Data) { - fill_zero(weights.NonBlocks.LMHead, stream); +#include "utilities/safetensors.h" + +struct OptStateWrapper : ITensorContainer { + void iterate_tensors(const std::function& callback) override; + std::vector* Blocks; + GenericTensorContainer* NonBlock; + OptStateWrapper() = default; + OptStateWrapper(std::vector* b, GenericTensorContainer* nb) : Blocks(b), NonBlock(nb) {}; +}; + +void OptStateWrapper::iterate_tensors(const std::function& callback) { + callback("model.embed_tokens.weight", NonBlock->get_tensor(LLamaWeightID::EMBEDDING)); + if(NonBlock->get_tensor(LLamaWeightID::LM_HEAD)) { + callback("lm_head.weight", NonBlock->get_tensor(LLamaWeightID::LM_HEAD)); } -} - -sLLamaWeights allocate_scales(TransformerConfig config, int shard_idx, int num_shards, TensorAllocator& alloc) { - long C = config.HiddenSize; - long V = config.VocabSize; - long H = config.IntermediateSize; - long head_size = C / config.NumQueryHeads; - long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * head_size; - - sLLamaWeights result; - result.Blocks.resize(config.NumLayers); - for(auto& block : result.Blocks) { - block.Attn_QKV_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "att_qkv_w", {div_exact(attn_intermediate_size * C, 128l)}, EAllocationType::ON_DEVICE); - block.Attn_Out_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "attproj_w", {div_exact(C * C, 128l)}, EAllocationType::ON_DEVICE); - block.MLP_Up_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "mlp_up_w", {div_exact(2 * H * C, 128l)}, EAllocationType::ON_DEVICE); - block.MLP_Down_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "mlp_down_w", {div_exact(C * H, 128l)}, EAllocationType::ON_DEVICE); - - block.LN1_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "ln1_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE); - block.LN2_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "ln2_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE); - if(config.UseQKVBias) { - block.Attn_QKV_b = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "att_qkv_b", {div_exact(attn_intermediate_size, 128l)}, EAllocationType::ON_DEVICE); - } else { - block.Attn_QKV_b = Tensor{}; + callback("model.norm.weight", NonBlock->get_tensor(LLamaWeightID::LNF_W)); + + for(int i = 0; i < Blocks->size(); i++) { + auto& layer = Blocks->at(i); + const Tensor& qkv_w = layer.get_tensor(LLamaWeightID::QKV_W); + const Tensor& up_proj = layer.get_tensor(LLamaWeightID::UP_W); + std::string prefix = "model.layers." + std::to_string(i); + callback(prefix + ".self_attn.qkv.weight", qkv_w); + if (layer.get_tensor(LLamaWeightID::QKV_B)) { + callback(prefix + ".self_attn.qkv.bias", layer.get_tensor(LLamaWeightID::QKV_B)); } - visit([](Tensor& t){ - fill_constant(t, 1.f, t.nelem(), nullptr); - }, block); - } - result.NonBlocks.Embeddings = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "embeddings", {div_exact(V * C, 128l)}, EAllocationType::ON_DEVICE); - result.NonBlocks.LNF_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards,"lnf_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE); - fill_constant(result.NonBlocks.Embeddings, 1.f, result.NonBlocks.Embeddings.nelem(), nullptr); - fill_constant(result.NonBlocks.LNF_w, 1.f, result.NonBlocks.LNF_w.nelem(), nullptr); - if(config.TiedWordEmbeddings) { - result.NonBlocks.LMHead = result.NonBlocks.Embeddings; - } else { - result.NonBlocks.LMHead = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "lmhead", {div_exact(V * C, 128l)}, EAllocationType::ON_DEVICE); - fill_constant(result.NonBlocks.LMHead, 1.f, result.NonBlocks.LMHead.nelem(), nullptr); - } - return result; -} - -std::vector allocate_weights_opt(sLLamaWeights& weights, const TransformerConfig& config, ETensorDType dtype, EAllocationType kind, int shard_idx, int num_shards, TensorAllocator& alloc) { - std::vector result; - weights.Blocks.resize(config.NumLayers); - LazyAllocator alloc_lazy; - for(auto& block : weights.Blocks) { - fill_matrix_shapes(block, config, dtype, shard_idx, num_shards); - fill_non_matrix_shapes(block, config, dtype, shard_idx, num_shards); - alloc_lazy.allocate(block); - result.push_back(alloc_lazy.commit(alloc, kind, "block_shard")); + callback(prefix + ".self_attn.o_proj.weight", layer.get_tensor(LLamaWeightID::ATTO_W)); + callback(prefix + ".mlp.up.weight", up_proj); + callback(prefix + ".mlp.down_proj.weight", layer.get_tensor(LLamaWeightID::DOWN_W)); + callback(prefix + ".input_layernorm.weight", layer.get_tensor(LLamaWeightID::LN1_W)); + callback(prefix + ".post_attention_layernorm.weight", layer.get_tensor(LLamaWeightID::LN2_W)); } - weights.NonBlocks = allocate_non_block_shard(config, dtype, kind, shard_idx, num_shards, alloc); - return result; } - -LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc): - AdamWStateManager(cfg, options.OffloadOptM, options.OffloadOptV, options.UseZeroCopy, comm.rank(), comm.world_size()) +LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, NCCLCommunicator& comm): + AdamWStateManager(cfg, model, options.OffloadOptM, options.OffloadOptV, options.OptMomentumType, options.OptVarianceType, options.UseZeroCopy, comm.rank(), comm.world_size()) { - mMType = options.OptMomentumType; - mVType = options.OptVarianceType; - - { - auto ctx = alloc.with_context("Adam M"); - EAllocationType alloc_type = options.OffloadOptM ? options.offload_alloc() : EAllocationType::ON_DEVICE; - mMBlockStorage = allocate_weights_opt(mOptM, cfg, mMType, alloc_type, comm.rank(), comm.world_size(), alloc); - for(auto& block : mMBlockStorage) { - fill_zero(block, stream); - } - zero_opt_non_block(mOptM, stream); - - if(mMType == ETensorDType::FP8_E4M3) { - mOptMScales = allocate_scales(cfg, comm.rank(), comm.world_size(), alloc); - } else { - mOptMScales.Blocks.resize(cfg.NumLayers); - } - } - - { - auto ctx = alloc.with_context("Adam V"); - EAllocationType alloc_type = options.OffloadOptV ? options.offload_alloc() : EAllocationType::ON_DEVICE; - mVBlockStorage = allocate_weights_opt(mOptV, cfg, mVType, alloc_type, comm.rank(), comm.world_size(), alloc); - for(auto& block : mVBlockStorage) { - fill_zero(block, stream); - } - zero_opt_non_block(mOptV, stream); - } - - if((options.OffloadOptM || options.OffloadOptV) && !mUseZeroCopy) { - mStatus[0] = sBufferStatus{-1, create_named_event("opt_fetch_0"), false, true}; - mStatus[1] = sBufferStatus{-1, create_named_event("opt_fetch_1"), false, true}; - } - - if(mOffloadM && !mUseZeroCopy) { - fill_matrix_shapes(mOptMBuffer[0], mConfig, mMType, mRank, mWorld); - fill_non_matrix_shapes(mOptMBuffer[0], mConfig, mMType, mRank, mWorld); - fill_matrix_shapes(mOptMBuffer[1], mConfig, mMType, mRank, mWorld); - fill_non_matrix_shapes(mOptMBuffer[1], mConfig, mMType, mRank, mWorld); - } +} - if(mOffloadV && !mUseZeroCopy) { - fill_matrix_shapes(mOptVBuffer[0], mConfig, mVType, mRank, mWorld); - fill_non_matrix_shapes(mOptVBuffer[0], mConfig, mVType, mRank, mWorld); - fill_matrix_shapes(mOptVBuffer[1], mConfig, mVType, mRank, mWorld); - fill_non_matrix_shapes(mOptVBuffer[1], mConfig, mVType, mRank, mWorld); +void LLamaOptimizerStateManager::safe_to_checkpoint(const std::string& checkpoint_dir) { + OptStateWrapper m_state{&mBlocksM, &mNonBlockM}; + OptStateWrapper v_state{&mBlocksV, &mNonBlockV}; + write_safetensors(checkpoint_dir + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_state); + write_safetensors(checkpoint_dir + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), v_state); + if (mMType == ETensorDType::FP8_E4M3) { + OptStateWrapper m_scales{&mBlocksMScales, &mNonBlockMScales}; + write_safetensors(checkpoint_dir + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_scales); } } -SimpleTensorContainer& LLamaOptimizerStateManager::get_block_scales_m(int layer_idx) { - return mOptMScales.Blocks.at(layer_idx); -} +void LLamaOptimizerStateManager::load_from_checkpoint(const std::string& checkpoint_dir) { + OptStateWrapper m_state{&mBlocksM, &mNonBlockM}; + OptStateWrapper v_state{&mBlocksV, &mNonBlockV}; -SimpleTensorContainer& LLamaOptimizerStateManager::get_m_buffer(int idx) { - return mOptMBuffer.at(idx); -} + // load optimizer shards + load_safetensors(checkpoint_dir + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_state, false); + load_safetensors(checkpoint_dir + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), v_state, false); -SimpleTensorContainer& LLamaOptimizerStateManager::get_v_buffer(int idx) { - return mOptVBuffer.at(idx); + if (mMType == ETensorDType::FP8_E4M3) { + OptStateWrapper m_scales{&mBlocksMScales, &mNonBlockMScales}; + load_safetensors(checkpoint_dir + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_scales, false); + } } diff --git a/src/models/llama_optimizer.h b/src/models/llama_optimizer.h index b83045f..d618626 100644 --- a/src/models/llama_optimizer.h +++ b/src/models/llama_optimizer.h @@ -8,34 +8,13 @@ #include "llama_weights.h" #include "training/adamw_optimizer.h" -#include class LLamaOptimizerStateManager : public AdamWStateManager { public: - LLamaOptimizerStateManager(TransformerConfig cfg, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc); - SimpleTensorContainer& non_block_m() override; - SimpleTensorContainer& non_block_v() override; + LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, NCCLCommunicator& comm); - ITensorContainer& full_m() { return mOptM; } - ITensorContainer& full_v() { return mOptV; } - sLLamaWeights& scales_m() { return mOptMScales; } - - SimpleTensorContainer& get_block_m(int layer_idx, cudaStream_t stream) override; - SimpleTensorContainer& get_block_v(int layer_idx, cudaStream_t stream) override; - SimpleTensorContainer& get_block_scales_m(int layer_idx) override; -private: - // mOptM.Blocks[i] and mMBlockStorage[i] alias the same memory. - // mOptM provides convenient access to the individual tensors of a block, whereas - // mMBlockStorage has just one large, byte-typed buffer for bulk transfers. - sLLamaWeights mOptM; - sLLamaWeights mOptV; - sLLamaWeights mOptMScales; - - std::array, 2> mOptMBuffer; - std::array, 2> mOptVBuffer; - - SimpleTensorContainer& get_m_buffer(int idx) override; - SimpleTensorContainer& get_v_buffer(int idx) override; + void safe_to_checkpoint(const std::string& checkpoint_dir) override; + void load_from_checkpoint(const std::string& checkpoint_dir) override; }; #endif //LLMQ_SRC_MODELS_LLAMA_OPTIMIZER_H diff --git a/src/training/adamw_optimizer.cpp b/src/training/adamw_optimizer.cpp index 0c1ac8d..d391dbe 100644 --- a/src/training/adamw_optimizer.cpp +++ b/src/training/adamw_optimizer.cpp @@ -3,11 +3,42 @@ // #include "adamw_optimizer.h" + +#include "model.h" +#include "kernels/kernels.h" +#include "utilities/allocator.h" #include "utilities/utils.h" #include "utilities/tensor.h" #include "utilities/stack.h" #include "utilities/lazy_allocator.h" +static GenericTensorContainer shard_container(GenericTensorContainer&& c, int world) { + visit([world](Tensor& t) { + if (!t.empty()) { throw std::logic_error("shard_container called with non-empty tensor"); } + t.Sizes[0] = div_exact(t.Sizes[0], static_cast(world)); + }, c); + return c; +} + +AdamWStateManager::AdamWStateManager(TransformerConfig cfg, IModel& model, bool offload_m, bool offload_v, + ETensorDType type_m, ETensorDType type_v, bool zero_copy, int rank, int world): + mConfig(cfg), mOffloadM(offload_m), mOffloadV(offload_v), mUseZeroCopy(zero_copy), mRank(rank), mWorld(world), mMType(type_m), mVType(type_v) { + + if(mOffloadM && !mUseZeroCopy) { + mMDeviceBuffer[0] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); + mMDeviceBuffer[1] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); + } + + if(mOffloadV && !mUseZeroCopy) { + mVDeviceBuffer[0] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); + mVDeviceBuffer[1] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); + } + + if((mOffloadM || mOffloadV) && !mUseZeroCopy) { + mStatus[0] = sBufferStatus{-1, create_named_event("opt_fetch_0"), false, true}; + mStatus[1] = sBufferStatus{-1, create_named_event("opt_fetch_1"), false, true}; + } +} void AdamWStateManager::begin_optimizer(DeviceMemoryStack& memory, cudaStream_t main_stream) { LazyAllocator alloc; @@ -18,16 +49,16 @@ void AdamWStateManager::begin_optimizer(DeviceMemoryStack& memory, cudaStream_t } if(mOffloadM && !mUseZeroCopy) { - alloc.allocate(get_m_buffer(0)); + alloc.allocate(mMDeviceBuffer.at(0)); mMBufferStorage[0] = alloc.commit(memory, "opt_m_a"); - alloc.allocate(get_m_buffer(1)); + alloc.allocate(mMDeviceBuffer.at(1)); mMBufferStorage[1] = alloc.commit(memory, "opt_m_b"); } if(mOffloadV && !mUseZeroCopy) { - alloc.allocate(get_v_buffer(0)); + alloc.allocate(mVDeviceBuffer.at(0)); mVBufferStorage[0] = alloc.commit(memory, "opt_v_a"); - alloc.allocate(get_v_buffer(1)); + alloc.allocate(mVDeviceBuffer.at(1)); mVBufferStorage[1] = alloc.commit(memory, "opt_v_b"); } } @@ -64,14 +95,14 @@ void AdamWStateManager::fetch_block(int layer_idx, cudaStream_t fetch_stream) { if(mOffloadM) { auto& buf = mMBufferStorage.at(buffer); - auto& ref = mMBlockStorage.at(layer_idx); + auto& ref = mStorageM.at(layer_idx); fetch(buf, ref); } if(mOffloadV) { auto& buf = mVBufferStorage.at(buffer); - auto& ref = mVBlockStorage.at(layer_idx); + auto& ref = mStorageV.at(layer_idx); fetch(buf, ref); } @@ -79,6 +110,33 @@ void AdamWStateManager::fetch_block(int layer_idx, cudaStream_t fetch_stream) { CUDA_CHECK(cudaEventRecord(stat.DoneEvent, fetch_stream)); } +SimpleTensorContainer& AdamWStateManager::get_block_m(int layer_idx, cudaStream_t stream) { + if(!mOffloadM || mUseZeroCopy) return mBlocksM.at(layer_idx); + return get_block_from(layer_idx, stream, mMDeviceBuffer.at(layer_idx % 2)); +} + +SimpleTensorContainer& AdamWStateManager::get_block_v(int layer_idx, cudaStream_t stream) { + if(!mOffloadV || mUseZeroCopy) return mBlocksV.at(layer_idx); + return get_block_from(layer_idx, stream, mVDeviceBuffer.at(layer_idx % 2)); +} + +SimpleTensorContainer& AdamWStateManager::get_block_scales_m(int layer_idx) { + return mBlocksMScales.at(layer_idx); +} + +SimpleTensorContainer& AdamWStateManager::non_block_m() { + return mNonBlockM; +} + +SimpleTensorContainer& AdamWStateManager::non_block_m_scales() { + return mNonBlockMScales; +} + +SimpleTensorContainer& AdamWStateManager::non_block_v() { + return mNonBlockV; +} + + void AdamWStateManager::store_block(int layer_idx, cudaStream_t stream, cudaStream_t put_stream) { if (mUseZeroCopy) return; @@ -91,11 +149,11 @@ void AdamWStateManager::store_block(int layer_idx, cudaStream_t stream, cudaStre } if(mOffloadM) { - CUDA_CHECK(cudaMemcpyAsync(mMBlockStorage.at(layer_idx).Data, mMBufferStorage.at(buffer).Data, mMBlockStorage.at(layer_idx).bytes(), cudaMemcpyDeviceToHost, put_stream)); + CUDA_CHECK(cudaMemcpyAsync(mStorageM.at(layer_idx).Data, mMBufferStorage.at(buffer).Data, mStorageM.at(layer_idx).bytes(), cudaMemcpyDeviceToHost, put_stream)); } if(mOffloadV) { - CUDA_CHECK(cudaMemcpyAsync(mVBlockStorage.at(layer_idx).Data, mVBufferStorage.at(buffer).Data, mVBlockStorage.at(layer_idx).bytes(), cudaMemcpyDeviceToHost, put_stream)); + CUDA_CHECK(cudaMemcpyAsync(mStorageV.at(layer_idx).Data, mVBufferStorage.at(buffer).Data, mStorageV.at(layer_idx).bytes(), cudaMemcpyDeviceToHost, put_stream)); } if(mOffloadM || mOffloadV) { @@ -107,6 +165,68 @@ void AdamWStateManager::store_block(int layer_idx, cudaStream_t stream, cudaStre } } +void AdamWStateManager::allocate_state(IModel& model, cudaStream_t stream, EAllocationType kind, TensorAllocator& alloc) { + { + auto ctx = alloc.with_context("Adam M"); + LazyAllocator alloc_lazy; + mBlocksM.resize(mConfig.NumLayers); + for (int i = 0; i < mConfig.NumLayers; ++i) { + mBlocksM[i] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); + alloc_lazy.allocate(mBlocksM[i]); + mStorageM.push_back(alloc_lazy.commit(alloc, mOffloadM ? kind : EAllocationType::ON_DEVICE, "m_block_shard")); + } + mNonBlockM = shard_container(model.create_non_block_container(mConfig, mMType, mMType), mWorld); + alloc_lazy.allocate(mNonBlockM); + mStorageM.push_back(alloc_lazy.commit(alloc, mOffloadM ? kind : EAllocationType::ON_DEVICE, "m_nonblock_shard")); + + for (auto& t : mStorageM) { + fill_zero(t, stream); + } + + mBlocksMScales.resize(mConfig.NumLayers); + if(mMType == ETensorDType::FP8_E4M3) { + // we "shard" for 128 as many GPUs, so that we get 1 scale per 128 weights. + for (int i = 0; i < mConfig.NumLayers; ++i) { + mBlocksMScales[i] = shard_container(model.create_block_container(mConfig, ETensorDType::FP32, ETensorDType::FP32), 128 * mWorld); + alloc_lazy.allocate(mBlocksMScales[i]); + alloc_lazy.commit(alloc, EAllocationType::ON_DEVICE, "m_block_scales"); + visit([stream](Tensor& t){ + fill_constant(t, 1.f, t.nelem(), stream); + }, mBlocksMScales[i]); + } + mNonBlockMScales = shard_container(model.create_non_block_container(mConfig, ETensorDType::FP32, ETensorDType::FP32), 128 * mWorld); + alloc_lazy.allocate(mNonBlockMScales); + alloc_lazy.commit(alloc, EAllocationType::ON_DEVICE, "m_nonblock_scales"); + visit([stream](Tensor& t){ + fill_constant(t, 1.f, t.nelem(), stream); + }, mNonBlockMScales); + } else { + for (int i = 0; i < mConfig.NumLayers; ++i) { + mBlocksMScales[i] = GenericTensorContainer(std::vector(model.num_block_tensors())); + } + mNonBlockMScales = GenericTensorContainer(std::vector(model.num_non_block_tensors())); + } + } + + { + auto ctx = alloc.with_context("Adam V"); + LazyAllocator alloc_lazy; + mBlocksV.resize(mConfig.NumLayers); + for (int i = 0; i < mConfig.NumLayers; ++i) { + mBlocksV[i] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); + alloc_lazy.allocate(mBlocksV[i]); + mStorageV.push_back(alloc_lazy.commit(alloc, mOffloadV ? kind : EAllocationType::ON_DEVICE, "v_block_shard")); + } + mNonBlockV = shard_container(model.create_non_block_container(mConfig, mVType, mVType), mWorld); + alloc_lazy.allocate(mNonBlockV); + mStorageV.push_back(alloc_lazy.commit(alloc, mOffloadV ? kind : EAllocationType::ON_DEVICE, "v_nonblock_shard")); + + for (auto& t : mStorageV) { + fill_zero(t, stream); + } + } +} + SimpleTensorContainer& AdamWStateManager::get_block_from(int layer_idx, cudaStream_t stream, SimpleTensorContainer &buf) { int buffer = layer_idx % 2; auto& stat = mStatus.at(buffer); diff --git a/src/training/adamw_optimizer.h b/src/training/adamw_optimizer.h index c936fdb..9535c87 100644 --- a/src/training/adamw_optimizer.h +++ b/src/training/adamw_optimizer.h @@ -10,33 +10,37 @@ #include "utilities/tensor_container.h" #include "utilities/tensor.h" +enum class EAllocationType : int; +class IModel; +class TensorAllocator; typedef struct CUstream_st *cudaStream_t; class DeviceMemoryStack; class AdamWStateManager { public: - AdamWStateManager(TransformerConfig cfg, bool offload_m, bool offload_v, bool zero_copy, int rank, int world) : - mConfig(cfg), mOffloadM(offload_m), mOffloadV(offload_v), mUseZeroCopy(zero_copy), mRank(rank), mWorld(world) {} + AdamWStateManager(TransformerConfig cfg, IModel& model, bool offload_m, bool offload_v, ETensorDType type_m, ETensorDType type_v, bool zero_copy, int rank, int world); virtual ~AdamWStateManager() = default; - virtual void begin_optimizer(DeviceMemoryStack& memory, cudaStream_t main_stream); - virtual void end_optimizer(DeviceMemoryStack& memory); + void begin_optimizer(DeviceMemoryStack& memory, cudaStream_t main_stream); + void end_optimizer(DeviceMemoryStack& memory); void fetch_block(int layer_idx, cudaStream_t fetch_stream); - virtual SimpleTensorContainer& get_block_m(int layer_idx, cudaStream_t stream) = 0; - virtual SimpleTensorContainer& get_block_v(int layer_idx, cudaStream_t stream) = 0; - virtual SimpleTensorContainer& get_block_scales_m(int layer_idx) = 0; + SimpleTensorContainer& get_block_m(int layer_idx, cudaStream_t stream); + SimpleTensorContainer& get_block_v(int layer_idx, cudaStream_t stream); + SimpleTensorContainer& get_block_scales_m(int layer_idx); void store_block(int layer_idx, cudaStream_t stream, cudaStream_t put_stream); - virtual SimpleTensorContainer& non_block_m() = 0; - virtual SimpleTensorContainer& non_block_v() = 0; + SimpleTensorContainer& non_block_m(); + SimpleTensorContainer& non_block_m_scales(); + SimpleTensorContainer& non_block_v(); -protected: - SimpleTensorContainer& get_block_from(int layer_idx, cudaStream_t stream, SimpleTensorContainer& buf); + void allocate_state(IModel& model, cudaStream_t stream, EAllocationType kind, TensorAllocator& alloc); - virtual SimpleTensorContainer& get_m_buffer(int idx) = 0; - virtual SimpleTensorContainer& get_v_buffer(int idx) = 0; + virtual void safe_to_checkpoint(const std::string& checkpoint_dir) = 0; + virtual void load_from_checkpoint(const std::string& checkpoint_dir) = 0; +protected: + SimpleTensorContainer& get_block_from(int layer_idx, cudaStream_t stream, SimpleTensorContainer& buf); TransformerConfig mConfig; bool mOffloadM; @@ -56,12 +60,20 @@ class AdamWStateManager { bool Done = true; }; - std::vector mMBlockStorage; - std::vector mVBlockStorage; + std::vector mStorageM; + std::vector mStorageV; + std::vector mBlocksM; + std::vector mBlocksV; + std::vector mBlocksMScales; + GenericTensorContainer mNonBlockM; + GenericTensorContainer mNonBlockMScales; + GenericTensorContainer mNonBlockV; std::array mMBufferStorage; std::array mVBufferStorage; std::array mStatus; + std::array mMDeviceBuffer; + std::array mVDeviceBuffer; }; #endif //LLMQ_ADAMW_OPTIMIZER_H diff --git a/src/training/checkpoint.cpp b/src/training/checkpoint.cpp index 2d740a5..ec37e7f 100644 --- a/src/training/checkpoint.cpp +++ b/src/training/checkpoint.cpp @@ -9,6 +9,7 @@ #include #include +#include "adamw_optimizer.h" #include "dataloader.h" #include "model.h" #include "utilities/comm.h" @@ -31,20 +32,7 @@ std::string save_checkpoint(std::string target, int step, IModel& model, const D // weights // TODO don't duplicate weights if they are unsharded write_safetensors(target + fmt::format("/weights.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.weights()); - - // sharded optimizer state - write_safetensors(target + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_momentum()); - write_safetensors(target + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_variance()); - - bool has_scales = false; - model.opt_momentum_scales().iterate_tensors([&has_scales](const std::string& name, const TensorShard& tensor){ - if(tensor.Data != nullptr) { - has_scales = true; - } - }); - if(has_scales) { - write_safetensors(target + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_momentum_scales()); - } + model.optimizer().safe_to_checkpoint(target); comm.barrier(); // only write checkpoint.json once we know all the shard files are saved @@ -126,21 +114,7 @@ void load_checkpoint(std::string source, int step, IModel& model, DataLoader* lo // weights load_safetensors(source + fmt::format("/weights.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.weights(), false); - - // load optimizer shards - load_safetensors(source + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_momentum(), false); - load_safetensors(source + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_variance(), false); - - bool has_scales = false; - model.opt_momentum_scales().iterate_tensors([&has_scales](const std::string& name, const TensorShard& tensor){ - if(tensor.Data != nullptr) { - has_scales = true; - } - }); - if(has_scales) { - load_safetensors(source + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_momentum_scales(), false); - } - + model.optimizer().load_from_checkpoint(source); model.on_restore_checkpoint(comm); } diff --git a/src/training/model.cpp b/src/training/model.cpp index 795e1e3..e00b1d4 100644 --- a/src/training/model.cpp +++ b/src/training/model.cpp @@ -8,6 +8,7 @@ #include "transformer_config.h" #include "utilities/allocator.h" +#include "utilities/tensor_container.h" cudnnHandle_t create_cudnn_handle(); cublasLtHandle_t create_cublaslt_handle(); @@ -27,6 +28,23 @@ Tensor& IModel::get_target_buffer() { return get_run_state().Targets_CPU; } +GenericTensorContainer IModel::create_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, + ETensorDType other_dtype) const +{ + std::vector tensors(num_block_tensors()); + GenericTensorContainer container(std::move(tensors)); + fill_block_shapes(container, config, matrix_dtype, other_dtype); + return container; +} + +GenericTensorContainer IModel::create_non_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, + ETensorDType other_dtype) const +{ + std::vector tensors(num_non_block_tensors()); + GenericTensorContainer container(std::move(tensors)); + fill_non_block_shapes(container, config, matrix_dtype, other_dtype); + return container; +} IRunState::IRunState(TransformerConfig config, long batch_size, long seq_len, std::shared_ptr alloc) : Config(config), B(batch_size), T(seq_len), Allocator(std::move(alloc)) { int did; diff --git a/src/training/model.h b/src/training/model.h index 6548fb9..2362495 100644 --- a/src/training/model.h +++ b/src/training/model.h @@ -14,9 +14,11 @@ #include "utilities/tensor.h" #include "training/transformer_config.h" +class AdamWStateManager; class ITensorContainer; class NCCLCommunicator; class TensorAllocator; +class GenericTensorContainer; class DataLoader; typedef struct cudnnContext* cudnnHandle_t; @@ -68,13 +70,7 @@ class IModel { virtual ITensorContainer& weights() = 0; //! (First order) momentum. Sharded. - virtual ITensorContainer& opt_momentum() = 0; - - //! (First order) momentum. Sharded. - virtual ITensorContainer& opt_momentum_scales() = 0; - - //! Second order moments. Sharded. - virtual ITensorContainer& opt_variance() = 0; + virtual AdamWStateManager& optimizer() = 0; //! Get the current RNG state virtual std::vector rng_state() const = 0; @@ -102,6 +98,17 @@ class IModel { //! Get a const reference to the model's RunState. virtual IRunState& get_run_state() const = 0; + + // generic model param utilities + virtual std::size_t num_block_tensors() const = 0; + virtual void fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const = 0; + + virtual std::size_t num_non_block_tensors() const = 0; + virtual void fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const = 0; + + GenericTensorContainer create_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const; + GenericTensorContainer create_non_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const; + protected: ~IModel() = default; }; diff --git a/src/utilities/tensor.cpp b/src/utilities/tensor.cpp index 409b0a2..b4cf80c 100644 --- a/src/utilities/tensor.cpp +++ b/src/utilities/tensor.cpp @@ -142,3 +142,14 @@ void visit(const std::function& func, SimpleTensorContai } } } + +GenericTensorContainer::GenericTensorContainer(std::vector t): mTensors( std::move(t) ) { +} + +std::size_t GenericTensorContainer::num_tensors() const noexcept { + return mTensors.size(); +} + +const Tensor& GenericTensorContainer::get_tensor(std::size_t idx) const { + return mTensors.at(idx); +} diff --git a/src/utilities/tensor_container.h b/src/utilities/tensor_container.h index c29d02d..7d5924e 100644 --- a/src/utilities/tensor_container.h +++ b/src/utilities/tensor_container.h @@ -40,6 +40,24 @@ void visit(const std::function& func, SimpleTensorContainer& cont //! in both containers. void visit(const std::function& func, SimpleTensorContainer& a, SimpleTensorContainer& b); +//! \brief `SimpleTensorContainer` that stores all tensors in a vector +class GenericTensorContainer final : public SimpleTensorContainer { +public: + GenericTensorContainer() = default; + explicit GenericTensorContainer(std::vector t); + + //! Get the total number of tensors in this container. This count includes empty tensors. + std::size_t num_tensors() const noexcept override; + + //! Return a constant reference to the tensor at the given index. + const Tensor& get_tensor(std::size_t idx) const override; + + using SimpleTensorContainer::get_tensor; +private: + std::vector mTensors; +}; + + class ITensorContainer { public: virtual void iterate_tensors(const std::function& callback) = 0;