From 7bee7d795716e8c18a19b1bd26f5989de0aab07c Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 9 Jan 2026 11:56:28 +0100 Subject: [PATCH 1/4] add a generic TensorContainer implementation --- src/utilities/tensor_container.h | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/utilities/tensor_container.h b/src/utilities/tensor_container.h index c29d02d..69c78d8 100644 --- a/src/utilities/tensor_container.h +++ b/src/utilities/tensor_container.h @@ -40,6 +40,23 @@ void visit(const std::function& func, SimpleTensorContainer& cont //! in both containers. void visit(const std::function& func, SimpleTensorContainer& a, SimpleTensorContainer& b); +//! \brief `SimpleTensorContainer` that stores all tensors in a vector +class GenericTensorContainer final : public SimpleTensorContainer { +public: + GenericTensorContainer(std::vector t) : mTensors( std::move(t) ) { }; + + //! Get the total number of tensors in this container. This count includes empty tensors. + std::size_t num_tensors() const noexcept { return mTensors.size(); }; + + //! Return a constant reference to the tensor at the given index. + const Tensor& get_tensor(std::size_t idx) const { return mTensors.at(idx); } + + using SimpleTensorContainer::get_tensor; +private: + std::vector mTensors; +}; + + class ITensorContainer { public: virtual void iterate_tensors(const std::function& callback) = 0; From ec2bcfac62dc391e0e96c3045fd4fc829c3b92b4 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 9 Jan 2026 14:15:19 +0100 Subject: [PATCH 2/4] move buffer allocation to generic optimizer --- src/models/llama_model.cpp | 37 ++++++++++++++++++++++++++++++- src/models/llama_model.h | 3 +++ src/models/llama_optimizer.cpp | 35 ++--------------------------- src/models/llama_optimizer.h | 8 +------ src/training/adamw_optimizer.cpp | 38 ++++++++++++++++++++++++++++---- src/training/adamw_optimizer.h | 9 ++++---- src/training/model.cpp | 8 +++++++ src/training/model.h | 8 +++++++ src/utilities/tensor_container.h | 1 + 9 files changed, 97 insertions(+), 50 deletions(-) diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index d57d350..bccbc64 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -697,6 +697,41 @@ IRunState& LLamaModel::get_run_state() const { return *RunState; } +std::size_t LLamaModel::num_block_tensors() const { + return 7; +} + +void LLamaModel::fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, + ETensorDType matrix_dtype, ETensorDType other_dtype) const { + long C = config.HiddenSize; + long H = config.IntermediateSize; + long HS = config.head_size(); + + auto create_matrix_shard = [&](Tensor& tgt, long rows, long cols) { + tgt.Rank = 2; + tgt.DType = matrix_dtype; + tgt.Sizes[0] = rows; + tgt.Sizes[1] = cols; + }; + + auto create_vector_shard = [&](Tensor& tgt, long elems) { + tgt.Rank = 1; + tgt.DType = other_dtype; + tgt.Sizes[0] = elems; + }; + + long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * HS; + create_matrix_shard(target.get_tensor(LLamaWeightID::QKV_W), attn_intermediate_size, C); + create_matrix_shard(target.get_tensor(LLamaWeightID::ATTO_W), C, C); + create_matrix_shard(target.get_tensor(LLamaWeightID::UP_W), 2 * H, C); + create_matrix_shard(target.get_tensor(LLamaWeightID::DOWN_W), C, H); + + create_vector_shard(target.get_tensor(LLamaWeightID::LN1_W), C); + create_vector_shard(target.get_tensor(LLamaWeightID::LN2_W), C); + create_vector_shard(target.get_tensor(LLamaWeightID::QKV_B), config.UseQKVBias ? attn_intermediate_size : 0); +} + + void LLamaModel::_calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip, cudaStream_t stream) { auto& rs = RunState; @@ -834,7 +869,7 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato acts = ::allocate_run_state(Config, options, B, T, stack, Allocator); } - OptimizerState = std::make_unique(Config, options, acts.MainStream, comm, *Allocator); + OptimizerState = std::make_unique(Config, *this, options, acts.MainStream, comm, *Allocator); Parameters->begin_optimizer(stack, comm.stream()); OptimizerState->begin_optimizer(stack, comm.stream()); diff --git a/src/models/llama_model.h b/src/models/llama_model.h index f1c8275..73c9b32 100644 --- a/src/models/llama_model.h +++ b/src/models/llama_model.h @@ -119,6 +119,9 @@ class LLamaModel : public IModel { IRunState& get_run_state() const override; + std::size_t num_block_tensors() const override; + void fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const override; + protected: void _calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip, cudaStream_t stream); void _reduce_loss(LLamaRunState& acts, NCCLCommunicator& comm, int B, int T); diff --git a/src/models/llama_optimizer.cpp b/src/models/llama_optimizer.cpp index df7b352..6aa16c6 100644 --- a/src/models/llama_optimizer.cpp +++ b/src/models/llama_optimizer.cpp @@ -93,12 +93,9 @@ std::vector allocate_weights_opt(sLLamaWeights& weights, const Transform } -LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc): - AdamWStateManager(cfg, options.OffloadOptM, options.OffloadOptV, options.UseZeroCopy, comm.rank(), comm.world_size()) +LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc): + AdamWStateManager(cfg, model, options.OffloadOptM, options.OffloadOptV, options.OptMomentumType, options.OptVarianceType, options.UseZeroCopy, comm.rank(), comm.world_size()) { - mMType = options.OptMomentumType; - mVType = options.OptVarianceType; - { auto ctx = alloc.with_context("Adam M"); EAllocationType alloc_type = options.OffloadOptM ? options.offload_alloc() : EAllocationType::ON_DEVICE; @@ -124,36 +121,8 @@ LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, LL } zero_opt_non_block(mOptV, stream); } - - if((options.OffloadOptM || options.OffloadOptV) && !mUseZeroCopy) { - mStatus[0] = sBufferStatus{-1, create_named_event("opt_fetch_0"), false, true}; - mStatus[1] = sBufferStatus{-1, create_named_event("opt_fetch_1"), false, true}; - } - - if(mOffloadM && !mUseZeroCopy) { - fill_matrix_shapes(mOptMBuffer[0], mConfig, mMType, mRank, mWorld); - fill_non_matrix_shapes(mOptMBuffer[0], mConfig, mMType, mRank, mWorld); - fill_matrix_shapes(mOptMBuffer[1], mConfig, mMType, mRank, mWorld); - fill_non_matrix_shapes(mOptMBuffer[1], mConfig, mMType, mRank, mWorld); - } - - if(mOffloadV && !mUseZeroCopy) { - fill_matrix_shapes(mOptVBuffer[0], mConfig, mVType, mRank, mWorld); - fill_non_matrix_shapes(mOptVBuffer[0], mConfig, mVType, mRank, mWorld); - fill_matrix_shapes(mOptVBuffer[1], mConfig, mVType, mRank, mWorld); - fill_non_matrix_shapes(mOptVBuffer[1], mConfig, mVType, mRank, mWorld); - - } } SimpleTensorContainer& LLamaOptimizerStateManager::get_block_scales_m(int layer_idx) { return mOptMScales.Blocks.at(layer_idx); } - -SimpleTensorContainer& LLamaOptimizerStateManager::get_m_buffer(int idx) { - return mOptMBuffer.at(idx); -} - -SimpleTensorContainer& LLamaOptimizerStateManager::get_v_buffer(int idx) { - return mOptVBuffer.at(idx); -} diff --git a/src/models/llama_optimizer.h b/src/models/llama_optimizer.h index b83045f..24ab065 100644 --- a/src/models/llama_optimizer.h +++ b/src/models/llama_optimizer.h @@ -12,7 +12,7 @@ class LLamaOptimizerStateManager : public AdamWStateManager { public: - LLamaOptimizerStateManager(TransformerConfig cfg, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc); + LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc); SimpleTensorContainer& non_block_m() override; SimpleTensorContainer& non_block_v() override; @@ -30,12 +30,6 @@ class LLamaOptimizerStateManager : public AdamWStateManager { sLLamaWeights mOptM; sLLamaWeights mOptV; sLLamaWeights mOptMScales; - - std::array, 2> mOptMBuffer; - std::array, 2> mOptVBuffer; - - SimpleTensorContainer& get_m_buffer(int idx) override; - SimpleTensorContainer& get_v_buffer(int idx) override; }; #endif //LLMQ_SRC_MODELS_LLAMA_OPTIMIZER_H diff --git a/src/training/adamw_optimizer.cpp b/src/training/adamw_optimizer.cpp index 0c1ac8d..0bf7855 100644 --- a/src/training/adamw_optimizer.cpp +++ b/src/training/adamw_optimizer.cpp @@ -3,11 +3,41 @@ // #include "adamw_optimizer.h" + +#include "model.h" #include "utilities/utils.h" #include "utilities/tensor.h" #include "utilities/stack.h" #include "utilities/lazy_allocator.h" +static GenericTensorContainer& shard_container(GenericTensorContainer&& c, int world) { + visit([world](Tensor& t) { + if (!t.empty()) { throw std::logic_error("shard_container called with non-empty tensor"); } + t.Sizes[0] = div_exact(t.Sizes[0], static_cast(world)); + }, c); + return c; +} + + +AdamWStateManager::AdamWStateManager(TransformerConfig cfg, IModel& model, bool offload_m, bool offload_v, + ETensorDType type_m, ETensorDType type_v, bool zero_copy, int rank, int world): + mConfig(cfg), mOffloadM(offload_m), mOffloadV(offload_v), mUseZeroCopy(zero_copy), mRank(rank), mWorld(world), mMType(type_m), mVType(type_v) { + + if(mOffloadM && !mUseZeroCopy) { + mOptMBuffer[0] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); + mOptMBuffer[1] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); + } + + if(mOffloadV && !mUseZeroCopy) { + mOptVBuffer[0] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); + mOptVBuffer[1] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); + } + + if((mOffloadM || mOffloadV) && !mUseZeroCopy) { + mStatus[0] = sBufferStatus{-1, create_named_event("opt_fetch_0"), false, true}; + mStatus[1] = sBufferStatus{-1, create_named_event("opt_fetch_1"), false, true}; + } +} void AdamWStateManager::begin_optimizer(DeviceMemoryStack& memory, cudaStream_t main_stream) { LazyAllocator alloc; @@ -18,16 +48,16 @@ void AdamWStateManager::begin_optimizer(DeviceMemoryStack& memory, cudaStream_t } if(mOffloadM && !mUseZeroCopy) { - alloc.allocate(get_m_buffer(0)); + alloc.allocate(mOptMBuffer.at(0)); mMBufferStorage[0] = alloc.commit(memory, "opt_m_a"); - alloc.allocate(get_m_buffer(1)); + alloc.allocate(mOptMBuffer.at(1)); mMBufferStorage[1] = alloc.commit(memory, "opt_m_b"); } if(mOffloadV && !mUseZeroCopy) { - alloc.allocate(get_v_buffer(0)); + alloc.allocate(mOptVBuffer.at(0)); mVBufferStorage[0] = alloc.commit(memory, "opt_v_a"); - alloc.allocate(get_v_buffer(1)); + alloc.allocate(mOptVBuffer.at(1)); mVBufferStorage[1] = alloc.commit(memory, "opt_v_b"); } } diff --git a/src/training/adamw_optimizer.h b/src/training/adamw_optimizer.h index c936fdb..3fa7247 100644 --- a/src/training/adamw_optimizer.h +++ b/src/training/adamw_optimizer.h @@ -10,14 +10,14 @@ #include "utilities/tensor_container.h" #include "utilities/tensor.h" +class IModel; typedef struct CUstream_st *cudaStream_t; class DeviceMemoryStack; class AdamWStateManager { public: - AdamWStateManager(TransformerConfig cfg, bool offload_m, bool offload_v, bool zero_copy, int rank, int world) : - mConfig(cfg), mOffloadM(offload_m), mOffloadV(offload_v), mUseZeroCopy(zero_copy), mRank(rank), mWorld(world) {} + AdamWStateManager(TransformerConfig cfg, IModel& model, bool offload_m, bool offload_v, ETensorDType type_m, ETensorDType type_v, bool zero_copy, int rank, int world); virtual ~AdamWStateManager() = default; virtual void begin_optimizer(DeviceMemoryStack& memory, cudaStream_t main_stream); virtual void end_optimizer(DeviceMemoryStack& memory); @@ -34,9 +34,6 @@ class AdamWStateManager { protected: SimpleTensorContainer& get_block_from(int layer_idx, cudaStream_t stream, SimpleTensorContainer& buf); - virtual SimpleTensorContainer& get_m_buffer(int idx) = 0; - virtual SimpleTensorContainer& get_v_buffer(int idx) = 0; - TransformerConfig mConfig; bool mOffloadM; @@ -62,6 +59,8 @@ class AdamWStateManager { std::array mMBufferStorage; std::array mVBufferStorage; std::array mStatus; + std::array mOptMBuffer; + std::array mOptVBuffer; }; #endif //LLMQ_ADAMW_OPTIMIZER_H diff --git a/src/training/model.cpp b/src/training/model.cpp index 795e1e3..974f09d 100644 --- a/src/training/model.cpp +++ b/src/training/model.cpp @@ -8,6 +8,7 @@ #include "transformer_config.h" #include "utilities/allocator.h" +#include "utilities/tensor_container.h" cudnnHandle_t create_cudnn_handle(); cublasLtHandle_t create_cublaslt_handle(); @@ -27,6 +28,13 @@ Tensor& IModel::get_target_buffer() { return get_run_state().Targets_CPU; } +GenericTensorContainer IModel::create_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, + ETensorDType other_dtype) const { + std::vector tensors(num_block_tensors()); + GenericTensorContainer container(std::move(tensors)); + fill_block_shapes(container, config, matrix_dtype, other_dtype); + return container; +} IRunState::IRunState(TransformerConfig config, long batch_size, long seq_len, std::shared_ptr alloc) : Config(config), B(batch_size), T(seq_len), Allocator(std::move(alloc)) { int did; diff --git a/src/training/model.h b/src/training/model.h index 6548fb9..2f83ed5 100644 --- a/src/training/model.h +++ b/src/training/model.h @@ -17,6 +17,7 @@ class ITensorContainer; class NCCLCommunicator; class TensorAllocator; +class GenericTensorContainer; class DataLoader; typedef struct cudnnContext* cudnnHandle_t; @@ -102,6 +103,13 @@ class IModel { //! Get a const reference to the model's RunState. virtual IRunState& get_run_state() const = 0; + + // generic model param utilities + virtual std::size_t num_block_tensors() const = 0; + virtual void fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const = 0; + + GenericTensorContainer create_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const; + protected: ~IModel() = default; }; diff --git a/src/utilities/tensor_container.h b/src/utilities/tensor_container.h index 69c78d8..a650ab5 100644 --- a/src/utilities/tensor_container.h +++ b/src/utilities/tensor_container.h @@ -43,6 +43,7 @@ void visit(const std::function& func, SimpleTensorContai //! \brief `SimpleTensorContainer` that stores all tensors in a vector class GenericTensorContainer final : public SimpleTensorContainer { public: + GenericTensorContainer() = default; GenericTensorContainer(std::vector t) : mTensors( std::move(t) ) { }; //! Get the total number of tensors in this container. This count includes empty tensors. From 36bc12fcb6d02975f66db3849d469a7c46b0d4a3 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 9 Jan 2026 15:09:17 +0100 Subject: [PATCH 3/4] also handle non-block weights --- src/models/llama_model.cpp | 49 ++++++++++++++++++++++++++++---------- src/models/llama_model.h | 3 ++- src/training/model.cpp | 12 +++++++++- src/training/model.h | 4 ++++ 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index bccbc64..f20f2d0 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -702,14 +702,15 @@ std::size_t LLamaModel::num_block_tensors() const { } void LLamaModel::fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, - ETensorDType matrix_dtype, ETensorDType other_dtype) const { + ETensorDType matrix_dtype, ETensorDType other_dtype) const +{ long C = config.HiddenSize; long H = config.IntermediateSize; long HS = config.head_size(); - auto create_matrix_shard = [&](Tensor& tgt, long rows, long cols) { - tgt.Rank = 2; - tgt.DType = matrix_dtype; + auto create = [&](Tensor& tgt, long rows, long cols, ETensorDType dtype) { + tgt.Rank = cols != 0 ? 2 : 1; + tgt.DType = dtype; tgt.Sizes[0] = rows; tgt.Sizes[1] = cols; }; @@ -721,14 +722,38 @@ void LLamaModel::fill_block_shapes(GenericTensorContainer& target, const Transfo }; long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * HS; - create_matrix_shard(target.get_tensor(LLamaWeightID::QKV_W), attn_intermediate_size, C); - create_matrix_shard(target.get_tensor(LLamaWeightID::ATTO_W), C, C); - create_matrix_shard(target.get_tensor(LLamaWeightID::UP_W), 2 * H, C); - create_matrix_shard(target.get_tensor(LLamaWeightID::DOWN_W), C, H); - - create_vector_shard(target.get_tensor(LLamaWeightID::LN1_W), C); - create_vector_shard(target.get_tensor(LLamaWeightID::LN2_W), C); - create_vector_shard(target.get_tensor(LLamaWeightID::QKV_B), config.UseQKVBias ? attn_intermediate_size : 0); + create(target.get_tensor(LLamaWeightID::QKV_W), attn_intermediate_size, C, matrix_dtype); + create(target.get_tensor(LLamaWeightID::ATTO_W), C, C, matrix_dtype); + create(target.get_tensor(LLamaWeightID::UP_W), 2 * H, C, matrix_dtype); + create(target.get_tensor(LLamaWeightID::DOWN_W), C, H, matrix_dtype); + + create(target.get_tensor(LLamaWeightID::LN1_W), C, 0, other_dtype); + create(target.get_tensor(LLamaWeightID::LN2_W), C, 0, other_dtype); + create(target.get_tensor(LLamaWeightID::QKV_B), config.UseQKVBias ? attn_intermediate_size : 0, 0, other_dtype); +} + +std::size_t LLamaModel::num_non_block_tensors() const { + return 3; +} + +void LLamaModel::fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, + ETensorDType matrix_dtype, ETensorDType other_dtype) const +{ + long V = config.VocabSize; + long C = config.HiddenSize; + + auto create = [&](Tensor& tgt, long rows, long cols, ETensorDType dtype) { + tgt.Rank = cols != 0 ? 2 : 1; + tgt.DType = dtype; + tgt.Sizes[0] = rows; + tgt.Sizes[1] = cols; + }; + + create(target.get_tensor(LLamaWeightID::EMBEDDING), V, C, matrix_dtype); + create(target.get_tensor(LLamaWeightID::LNF_W), C, 0, other_dtype); + if(!config.TiedWordEmbeddings) { + create(target.get_tensor(LLamaWeightID::LM_HEAD), V, C, matrix_dtype); + } } diff --git a/src/models/llama_model.h b/src/models/llama_model.h index 73c9b32..06205de 100644 --- a/src/models/llama_model.h +++ b/src/models/llama_model.h @@ -121,7 +121,8 @@ class LLamaModel : public IModel { std::size_t num_block_tensors() const override; void fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const override; - + std::size_t num_non_block_tensors() const override; + void fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const override; protected: void _calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip, cudaStream_t stream); void _reduce_loss(LLamaRunState& acts, NCCLCommunicator& comm, int B, int T); diff --git a/src/training/model.cpp b/src/training/model.cpp index 974f09d..e00b1d4 100644 --- a/src/training/model.cpp +++ b/src/training/model.cpp @@ -29,13 +29,23 @@ Tensor& IModel::get_target_buffer() { } GenericTensorContainer IModel::create_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, - ETensorDType other_dtype) const { + ETensorDType other_dtype) const +{ std::vector tensors(num_block_tensors()); GenericTensorContainer container(std::move(tensors)); fill_block_shapes(container, config, matrix_dtype, other_dtype); return container; } +GenericTensorContainer IModel::create_non_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, + ETensorDType other_dtype) const +{ + std::vector tensors(num_non_block_tensors()); + GenericTensorContainer container(std::move(tensors)); + fill_non_block_shapes(container, config, matrix_dtype, other_dtype); + return container; +} + IRunState::IRunState(TransformerConfig config, long batch_size, long seq_len, std::shared_ptr alloc) : Config(config), B(batch_size), T(seq_len), Allocator(std::move(alloc)) { int did; CUDA_CHECK(cudaGetDevice(&did)); diff --git a/src/training/model.h b/src/training/model.h index 2f83ed5..ed390ef 100644 --- a/src/training/model.h +++ b/src/training/model.h @@ -108,7 +108,11 @@ class IModel { virtual std::size_t num_block_tensors() const = 0; virtual void fill_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const = 0; + virtual std::size_t num_non_block_tensors() const = 0; + virtual void fill_non_block_shapes(GenericTensorContainer& target, const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const = 0; + GenericTensorContainer create_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const; + GenericTensorContainer create_non_block_container(const TransformerConfig& config, ETensorDType matrix_dtype, ETensorDType other_dtype) const; protected: ~IModel() = default; From 99ed1e16567272b51c06bd26c7f4735c70b359f5 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 9 Jan 2026 17:31:35 +0100 Subject: [PATCH 4/4] more optimizer generalization --- src/models/llama_model.cpp | 29 ++---- src/models/llama_model.h | 4 +- src/models/llama_optimizer.cpp | 157 +++++++++++-------------------- src/models/llama_optimizer.h | 21 +---- src/training/adamw_optimizer.cpp | 118 ++++++++++++++++++++--- src/training/adamw_optimizer.h | 37 +++++--- src/training/checkpoint.cpp | 32 +------ src/training/model.h | 9 +- src/utilities/tensor.cpp | 11 +++ src/utilities/tensor_container.h | 6 +- 10 files changed, 214 insertions(+), 210 deletions(-) diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index f20f2d0..9f92b67 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -715,12 +715,6 @@ void LLamaModel::fill_block_shapes(GenericTensorContainer& target, const Transfo tgt.Sizes[1] = cols; }; - auto create_vector_shard = [&](Tensor& tgt, long elems) { - tgt.Rank = 1; - tgt.DType = other_dtype; - tgt.Sizes[0] = elems; - }; - long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * HS; create(target.get_tensor(LLamaWeightID::QKV_W), attn_intermediate_size, C, matrix_dtype); create(target.get_tensor(LLamaWeightID::ATTO_W), C, C, matrix_dtype); @@ -816,15 +810,15 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_ learning_rate, beta_1, beta_2, t, epsilon, wd, grad_scale, scales, val.abs_max(), rng(), main_stream); }; - auto& m_scales = OptimizerState->scales_m(); + auto& nb_scales = OptimizerState->non_block_m_scales(); using namespace LLamaWeightID; run_update(Parameters->get_master_embeddings(), Grads->get_embeddings_shard(main_stream), OptimizerState->non_block_m().get_tensor(EMBEDDING), OptimizerState->non_block_v().get_tensor(EMBEDDING), - m_scales.NonBlocks.Embeddings, weight_decay); + nb_scales.get_tensor(EMBEDDING), weight_decay); comm.reduce_max(Parameters->get_master_embeddings().abs_max()); run_update(Parameters->get_master_lnf_w(), Grads->get_lnf_w_shard(main_stream), - OptimizerState->non_block_m().get_tensor(LNF_W), OptimizerState->non_block_v().get_tensor(LNF_W), m_scales.NonBlocks.LNF_w, 0.f); + OptimizerState->non_block_m().get_tensor(LNF_W), OptimizerState->non_block_v().get_tensor(LNF_W), nb_scales.get_tensor(LNF_W), 0.f); comm.reduce_max(Parameters->get_master_lnf_w().abs_max()); CUDA_CHECK(cudaEventRecord(rs->OptEmbeddingsDone, main_stream)); @@ -866,7 +860,7 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_ if(!Config.TiedWordEmbeddings) { run_update(Parameters->get_master_lmhead(), Grads->get_lmhead_shard(main_stream), - OptimizerState->non_block_m().get_tensor(LM_HEAD), OptimizerState->non_block_v().get_tensor(LM_HEAD), m_scales.NonBlocks.LMHead, weight_decay); + OptimizerState->non_block_m().get_tensor(LM_HEAD), OptimizerState->non_block_v().get_tensor(LM_HEAD), nb_scales.get_tensor(LM_HEAD), weight_decay); comm.reduce_max(Parameters->get_master_lmhead().abs_max()); } comm.wait_on_comms(main_stream); @@ -894,7 +888,8 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato acts = ::allocate_run_state(Config, options, B, T, stack, Allocator); } - OptimizerState = std::make_unique(Config, *this, options, acts.MainStream, comm, *Allocator); + OptimizerState = std::make_unique(Config, *this, options, comm); + OptimizerState->allocate_state(*this, acts.MainStream, options.offload_alloc(), *Allocator); Parameters->begin_optimizer(stack, comm.stream()); OptimizerState->begin_optimizer(stack, comm.stream()); @@ -922,16 +917,8 @@ ITensorContainer& LLamaModel::weights() { return *Parameters; } -ITensorContainer& LLamaModel::opt_momentum() { - return OptimizerState->full_m(); -} - -ITensorContainer& LLamaModel::opt_momentum_scales() { - return OptimizerState->scales_m(); -} - -ITensorContainer& LLamaModel::opt_variance() { - return OptimizerState->full_v(); +AdamWStateManager& LLamaModel::optimizer() { + return *OptimizerState; } std::vector LLamaModel::rng_state() const { diff --git a/src/models/llama_model.h b/src/models/llama_model.h index 06205de..44293b8 100644 --- a/src/models/llama_model.h +++ b/src/models/llama_model.h @@ -104,9 +104,7 @@ class LLamaModel : public IModel { void calculate_gradient_norm(NCCLCommunicator& comm, float grad_clip); ITensorContainer& weights() override; - ITensorContainer& opt_momentum() override; - ITensorContainer& opt_momentum_scales() override; - ITensorContainer& opt_variance() override; + AdamWStateManager& optimizer() override; std::vector rng_state() const override; void set_rng_state(const std::vector& state) override; std::string_view model_type() const override; diff --git a/src/models/llama_optimizer.cpp b/src/models/llama_optimizer.cpp index 6aa16c6..bcad5e8 100644 --- a/src/models/llama_optimizer.cpp +++ b/src/models/llama_optimizer.cpp @@ -4,125 +4,76 @@ #include "llama_optimizer.h" + +#include + #include "training/transformer_config.h" #include "llama_model.h" #include "utilities/comm.h" #include "kernels/kernels.h" -#include "utilities/stack.h" #include "utilities/lazy_allocator.h" - -SimpleTensorContainer& LLamaOptimizerStateManager::get_block_m(int layer_idx, cudaStream_t stream) { - if(!mOffloadM || mUseZeroCopy) return mOptM.Blocks[layer_idx]; - return get_block_from(layer_idx, stream, mOptMBuffer.at(layer_idx % 2)); -} - -SimpleTensorContainer& LLamaOptimizerStateManager::get_block_v(int layer_idx, cudaStream_t stream) { - if(!mOffloadV || mUseZeroCopy) return mOptV.Blocks[layer_idx]; - return get_block_from(layer_idx, stream, mOptVBuffer.at(layer_idx % 2)); -} - -SimpleTensorContainer& LLamaOptimizerStateManager::non_block_m() { - return mOptM.NonBlocks; -} - -SimpleTensorContainer& LLamaOptimizerStateManager::non_block_v() { - return mOptV.NonBlocks; -} - -void zero_opt_non_block(sLLamaWeights& weights, cudaStream_t stream) { - // here's the first disadvantage of having individual buffers: We need to make a ton of memset calls - fill_zero(weights.NonBlocks.Embeddings, stream); - fill_zero(weights.NonBlocks.LNF_w, stream); - if(weights.NonBlocks.LMHead.Data != weights.NonBlocks.Embeddings.Data) { - fill_zero(weights.NonBlocks.LMHead, stream); +#include "utilities/safetensors.h" + +struct OptStateWrapper : ITensorContainer { + void iterate_tensors(const std::function& callback) override; + std::vector* Blocks; + GenericTensorContainer* NonBlock; + OptStateWrapper() = default; + OptStateWrapper(std::vector* b, GenericTensorContainer* nb) : Blocks(b), NonBlock(nb) {}; +}; + +void OptStateWrapper::iterate_tensors(const std::function& callback) { + callback("model.embed_tokens.weight", NonBlock->get_tensor(LLamaWeightID::EMBEDDING)); + if(NonBlock->get_tensor(LLamaWeightID::LM_HEAD)) { + callback("lm_head.weight", NonBlock->get_tensor(LLamaWeightID::LM_HEAD)); } -} - -sLLamaWeights allocate_scales(TransformerConfig config, int shard_idx, int num_shards, TensorAllocator& alloc) { - long C = config.HiddenSize; - long V = config.VocabSize; - long H = config.IntermediateSize; - long head_size = C / config.NumQueryHeads; - long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * head_size; - - sLLamaWeights result; - result.Blocks.resize(config.NumLayers); - for(auto& block : result.Blocks) { - block.Attn_QKV_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "att_qkv_w", {div_exact(attn_intermediate_size * C, 128l)}, EAllocationType::ON_DEVICE); - block.Attn_Out_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "attproj_w", {div_exact(C * C, 128l)}, EAllocationType::ON_DEVICE); - block.MLP_Up_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "mlp_up_w", {div_exact(2 * H * C, 128l)}, EAllocationType::ON_DEVICE); - block.MLP_Down_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "mlp_down_w", {div_exact(C * H, 128l)}, EAllocationType::ON_DEVICE); - - block.LN1_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "ln1_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE); - block.LN2_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "ln2_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE); - if(config.UseQKVBias) { - block.Attn_QKV_b = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "att_qkv_b", {div_exact(attn_intermediate_size, 128l)}, EAllocationType::ON_DEVICE); - } else { - block.Attn_QKV_b = Tensor{}; + callback("model.norm.weight", NonBlock->get_tensor(LLamaWeightID::LNF_W)); + + for(int i = 0; i < Blocks->size(); i++) { + auto& layer = Blocks->at(i); + const Tensor& qkv_w = layer.get_tensor(LLamaWeightID::QKV_W); + const Tensor& up_proj = layer.get_tensor(LLamaWeightID::UP_W); + std::string prefix = "model.layers." + std::to_string(i); + callback(prefix + ".self_attn.qkv.weight", qkv_w); + if (layer.get_tensor(LLamaWeightID::QKV_B)) { + callback(prefix + ".self_attn.qkv.bias", layer.get_tensor(LLamaWeightID::QKV_B)); } - visit([](Tensor& t){ - fill_constant(t, 1.f, t.nelem(), nullptr); - }, block); - } - result.NonBlocks.Embeddings = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "embeddings", {div_exact(V * C, 128l)}, EAllocationType::ON_DEVICE); - result.NonBlocks.LNF_w = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards,"lnf_w", {div_exact(C, 128l)}, EAllocationType::ON_DEVICE); - fill_constant(result.NonBlocks.Embeddings, 1.f, result.NonBlocks.Embeddings.nelem(), nullptr); - fill_constant(result.NonBlocks.LNF_w, 1.f, result.NonBlocks.LNF_w.nelem(), nullptr); - if(config.TiedWordEmbeddings) { - result.NonBlocks.LMHead = result.NonBlocks.Embeddings; - } else { - result.NonBlocks.LMHead = alloc.allocate_shard(ETensorDType::FP32, shard_idx, num_shards, "lmhead", {div_exact(V * C, 128l)}, EAllocationType::ON_DEVICE); - fill_constant(result.NonBlocks.LMHead, 1.f, result.NonBlocks.LMHead.nelem(), nullptr); + callback(prefix + ".self_attn.o_proj.weight", layer.get_tensor(LLamaWeightID::ATTO_W)); + callback(prefix + ".mlp.up.weight", up_proj); + callback(prefix + ".mlp.down_proj.weight", layer.get_tensor(LLamaWeightID::DOWN_W)); + callback(prefix + ".input_layernorm.weight", layer.get_tensor(LLamaWeightID::LN1_W)); + callback(prefix + ".post_attention_layernorm.weight", layer.get_tensor(LLamaWeightID::LN2_W)); } - return result; } -std::vector allocate_weights_opt(sLLamaWeights& weights, const TransformerConfig& config, ETensorDType dtype, EAllocationType kind, int shard_idx, int num_shards, TensorAllocator& alloc) { - std::vector result; - weights.Blocks.resize(config.NumLayers); - LazyAllocator alloc_lazy; - for(auto& block : weights.Blocks) { - fill_matrix_shapes(block, config, dtype, shard_idx, num_shards); - fill_non_matrix_shapes(block, config, dtype, shard_idx, num_shards); - alloc_lazy.allocate(block); - result.push_back(alloc_lazy.commit(alloc, kind, "block_shard")); - } - weights.NonBlocks = allocate_non_block_shard(config, dtype, kind, shard_idx, num_shards, alloc); - return result; -} - - -LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc): +LLamaOptimizerStateManager::LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, NCCLCommunicator& comm): AdamWStateManager(cfg, model, options.OffloadOptM, options.OffloadOptV, options.OptMomentumType, options.OptVarianceType, options.UseZeroCopy, comm.rank(), comm.world_size()) { - { - auto ctx = alloc.with_context("Adam M"); - EAllocationType alloc_type = options.OffloadOptM ? options.offload_alloc() : EAllocationType::ON_DEVICE; - mMBlockStorage = allocate_weights_opt(mOptM, cfg, mMType, alloc_type, comm.rank(), comm.world_size(), alloc); - for(auto& block : mMBlockStorage) { - fill_zero(block, stream); - } - zero_opt_non_block(mOptM, stream); +} - if(mMType == ETensorDType::FP8_E4M3) { - mOptMScales = allocate_scales(cfg, comm.rank(), comm.world_size(), alloc); - } else { - mOptMScales.Blocks.resize(cfg.NumLayers); - } - } +void LLamaOptimizerStateManager::safe_to_checkpoint(const std::string& checkpoint_dir) { + OptStateWrapper m_state{&mBlocksM, &mNonBlockM}; + OptStateWrapper v_state{&mBlocksV, &mNonBlockV}; - { - auto ctx = alloc.with_context("Adam V"); - EAllocationType alloc_type = options.OffloadOptV ? options.offload_alloc() : EAllocationType::ON_DEVICE; - mVBlockStorage = allocate_weights_opt(mOptV, cfg, mVType, alloc_type, comm.rank(), comm.world_size(), alloc); - for(auto& block : mVBlockStorage) { - fill_zero(block, stream); - } - zero_opt_non_block(mOptV, stream); + write_safetensors(checkpoint_dir + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_state); + write_safetensors(checkpoint_dir + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), v_state); + if (mMType == ETensorDType::FP8_E4M3) { + OptStateWrapper m_scales{&mBlocksMScales, &mNonBlockMScales}; + write_safetensors(checkpoint_dir + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_scales); } } -SimpleTensorContainer& LLamaOptimizerStateManager::get_block_scales_m(int layer_idx) { - return mOptMScales.Blocks.at(layer_idx); +void LLamaOptimizerStateManager::load_from_checkpoint(const std::string& checkpoint_dir) { + OptStateWrapper m_state{&mBlocksM, &mNonBlockM}; + OptStateWrapper v_state{&mBlocksV, &mNonBlockV}; + + // load optimizer shards + load_safetensors(checkpoint_dir + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_state, false); + load_safetensors(checkpoint_dir + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), v_state, false); + + if (mMType == ETensorDType::FP8_E4M3) { + OptStateWrapper m_scales{&mBlocksMScales, &mNonBlockMScales}; + load_safetensors(checkpoint_dir + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", mRank, mWorld), m_scales, false); + } } diff --git a/src/models/llama_optimizer.h b/src/models/llama_optimizer.h index 24ab065..d618626 100644 --- a/src/models/llama_optimizer.h +++ b/src/models/llama_optimizer.h @@ -8,28 +8,13 @@ #include "llama_weights.h" #include "training/adamw_optimizer.h" -#include class LLamaOptimizerStateManager : public AdamWStateManager { public: - LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, cudaStream_t stream, NCCLCommunicator& comm, TensorAllocator& alloc); - SimpleTensorContainer& non_block_m() override; - SimpleTensorContainer& non_block_v() override; + LLamaOptimizerStateManager(TransformerConfig cfg, IModel& model, LLamaOptions options, NCCLCommunicator& comm); - ITensorContainer& full_m() { return mOptM; } - ITensorContainer& full_v() { return mOptV; } - sLLamaWeights& scales_m() { return mOptMScales; } - - SimpleTensorContainer& get_block_m(int layer_idx, cudaStream_t stream) override; - SimpleTensorContainer& get_block_v(int layer_idx, cudaStream_t stream) override; - SimpleTensorContainer& get_block_scales_m(int layer_idx) override; -private: - // mOptM.Blocks[i] and mMBlockStorage[i] alias the same memory. - // mOptM provides convenient access to the individual tensors of a block, whereas - // mMBlockStorage has just one large, byte-typed buffer for bulk transfers. - sLLamaWeights mOptM; - sLLamaWeights mOptV; - sLLamaWeights mOptMScales; + void safe_to_checkpoint(const std::string& checkpoint_dir) override; + void load_from_checkpoint(const std::string& checkpoint_dir) override; }; #endif //LLMQ_SRC_MODELS_LLAMA_OPTIMIZER_H diff --git a/src/training/adamw_optimizer.cpp b/src/training/adamw_optimizer.cpp index 0bf7855..d391dbe 100644 --- a/src/training/adamw_optimizer.cpp +++ b/src/training/adamw_optimizer.cpp @@ -5,12 +5,14 @@ #include "adamw_optimizer.h" #include "model.h" +#include "kernels/kernels.h" +#include "utilities/allocator.h" #include "utilities/utils.h" #include "utilities/tensor.h" #include "utilities/stack.h" #include "utilities/lazy_allocator.h" -static GenericTensorContainer& shard_container(GenericTensorContainer&& c, int world) { +static GenericTensorContainer shard_container(GenericTensorContainer&& c, int world) { visit([world](Tensor& t) { if (!t.empty()) { throw std::logic_error("shard_container called with non-empty tensor"); } t.Sizes[0] = div_exact(t.Sizes[0], static_cast(world)); @@ -18,19 +20,18 @@ static GenericTensorContainer& shard_container(GenericTensorContainer&& c, int w return c; } - AdamWStateManager::AdamWStateManager(TransformerConfig cfg, IModel& model, bool offload_m, bool offload_v, ETensorDType type_m, ETensorDType type_v, bool zero_copy, int rank, int world): mConfig(cfg), mOffloadM(offload_m), mOffloadV(offload_v), mUseZeroCopy(zero_copy), mRank(rank), mWorld(world), mMType(type_m), mVType(type_v) { if(mOffloadM && !mUseZeroCopy) { - mOptMBuffer[0] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); - mOptMBuffer[1] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); + mMDeviceBuffer[0] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); + mMDeviceBuffer[1] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); } if(mOffloadV && !mUseZeroCopy) { - mOptVBuffer[0] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); - mOptVBuffer[1] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); + mVDeviceBuffer[0] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); + mVDeviceBuffer[1] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); } if((mOffloadM || mOffloadV) && !mUseZeroCopy) { @@ -48,16 +49,16 @@ void AdamWStateManager::begin_optimizer(DeviceMemoryStack& memory, cudaStream_t } if(mOffloadM && !mUseZeroCopy) { - alloc.allocate(mOptMBuffer.at(0)); + alloc.allocate(mMDeviceBuffer.at(0)); mMBufferStorage[0] = alloc.commit(memory, "opt_m_a"); - alloc.allocate(mOptMBuffer.at(1)); + alloc.allocate(mMDeviceBuffer.at(1)); mMBufferStorage[1] = alloc.commit(memory, "opt_m_b"); } if(mOffloadV && !mUseZeroCopy) { - alloc.allocate(mOptVBuffer.at(0)); + alloc.allocate(mVDeviceBuffer.at(0)); mVBufferStorage[0] = alloc.commit(memory, "opt_v_a"); - alloc.allocate(mOptVBuffer.at(1)); + alloc.allocate(mVDeviceBuffer.at(1)); mVBufferStorage[1] = alloc.commit(memory, "opt_v_b"); } } @@ -94,14 +95,14 @@ void AdamWStateManager::fetch_block(int layer_idx, cudaStream_t fetch_stream) { if(mOffloadM) { auto& buf = mMBufferStorage.at(buffer); - auto& ref = mMBlockStorage.at(layer_idx); + auto& ref = mStorageM.at(layer_idx); fetch(buf, ref); } if(mOffloadV) { auto& buf = mVBufferStorage.at(buffer); - auto& ref = mVBlockStorage.at(layer_idx); + auto& ref = mStorageV.at(layer_idx); fetch(buf, ref); } @@ -109,6 +110,33 @@ void AdamWStateManager::fetch_block(int layer_idx, cudaStream_t fetch_stream) { CUDA_CHECK(cudaEventRecord(stat.DoneEvent, fetch_stream)); } +SimpleTensorContainer& AdamWStateManager::get_block_m(int layer_idx, cudaStream_t stream) { + if(!mOffloadM || mUseZeroCopy) return mBlocksM.at(layer_idx); + return get_block_from(layer_idx, stream, mMDeviceBuffer.at(layer_idx % 2)); +} + +SimpleTensorContainer& AdamWStateManager::get_block_v(int layer_idx, cudaStream_t stream) { + if(!mOffloadV || mUseZeroCopy) return mBlocksV.at(layer_idx); + return get_block_from(layer_idx, stream, mVDeviceBuffer.at(layer_idx % 2)); +} + +SimpleTensorContainer& AdamWStateManager::get_block_scales_m(int layer_idx) { + return mBlocksMScales.at(layer_idx); +} + +SimpleTensorContainer& AdamWStateManager::non_block_m() { + return mNonBlockM; +} + +SimpleTensorContainer& AdamWStateManager::non_block_m_scales() { + return mNonBlockMScales; +} + +SimpleTensorContainer& AdamWStateManager::non_block_v() { + return mNonBlockV; +} + + void AdamWStateManager::store_block(int layer_idx, cudaStream_t stream, cudaStream_t put_stream) { if (mUseZeroCopy) return; @@ -121,11 +149,11 @@ void AdamWStateManager::store_block(int layer_idx, cudaStream_t stream, cudaStre } if(mOffloadM) { - CUDA_CHECK(cudaMemcpyAsync(mMBlockStorage.at(layer_idx).Data, mMBufferStorage.at(buffer).Data, mMBlockStorage.at(layer_idx).bytes(), cudaMemcpyDeviceToHost, put_stream)); + CUDA_CHECK(cudaMemcpyAsync(mStorageM.at(layer_idx).Data, mMBufferStorage.at(buffer).Data, mStorageM.at(layer_idx).bytes(), cudaMemcpyDeviceToHost, put_stream)); } if(mOffloadV) { - CUDA_CHECK(cudaMemcpyAsync(mVBlockStorage.at(layer_idx).Data, mVBufferStorage.at(buffer).Data, mVBlockStorage.at(layer_idx).bytes(), cudaMemcpyDeviceToHost, put_stream)); + CUDA_CHECK(cudaMemcpyAsync(mStorageV.at(layer_idx).Data, mVBufferStorage.at(buffer).Data, mStorageV.at(layer_idx).bytes(), cudaMemcpyDeviceToHost, put_stream)); } if(mOffloadM || mOffloadV) { @@ -137,6 +165,68 @@ void AdamWStateManager::store_block(int layer_idx, cudaStream_t stream, cudaStre } } +void AdamWStateManager::allocate_state(IModel& model, cudaStream_t stream, EAllocationType kind, TensorAllocator& alloc) { + { + auto ctx = alloc.with_context("Adam M"); + LazyAllocator alloc_lazy; + mBlocksM.resize(mConfig.NumLayers); + for (int i = 0; i < mConfig.NumLayers; ++i) { + mBlocksM[i] = shard_container(model.create_block_container(mConfig, mMType, mMType), mWorld); + alloc_lazy.allocate(mBlocksM[i]); + mStorageM.push_back(alloc_lazy.commit(alloc, mOffloadM ? kind : EAllocationType::ON_DEVICE, "m_block_shard")); + } + mNonBlockM = shard_container(model.create_non_block_container(mConfig, mMType, mMType), mWorld); + alloc_lazy.allocate(mNonBlockM); + mStorageM.push_back(alloc_lazy.commit(alloc, mOffloadM ? kind : EAllocationType::ON_DEVICE, "m_nonblock_shard")); + + for (auto& t : mStorageM) { + fill_zero(t, stream); + } + + mBlocksMScales.resize(mConfig.NumLayers); + if(mMType == ETensorDType::FP8_E4M3) { + // we "shard" for 128 as many GPUs, so that we get 1 scale per 128 weights. + for (int i = 0; i < mConfig.NumLayers; ++i) { + mBlocksMScales[i] = shard_container(model.create_block_container(mConfig, ETensorDType::FP32, ETensorDType::FP32), 128 * mWorld); + alloc_lazy.allocate(mBlocksMScales[i]); + alloc_lazy.commit(alloc, EAllocationType::ON_DEVICE, "m_block_scales"); + visit([stream](Tensor& t){ + fill_constant(t, 1.f, t.nelem(), stream); + }, mBlocksMScales[i]); + } + mNonBlockMScales = shard_container(model.create_non_block_container(mConfig, ETensorDType::FP32, ETensorDType::FP32), 128 * mWorld); + alloc_lazy.allocate(mNonBlockMScales); + alloc_lazy.commit(alloc, EAllocationType::ON_DEVICE, "m_nonblock_scales"); + visit([stream](Tensor& t){ + fill_constant(t, 1.f, t.nelem(), stream); + }, mNonBlockMScales); + } else { + for (int i = 0; i < mConfig.NumLayers; ++i) { + mBlocksMScales[i] = GenericTensorContainer(std::vector(model.num_block_tensors())); + } + mNonBlockMScales = GenericTensorContainer(std::vector(model.num_non_block_tensors())); + } + } + + { + auto ctx = alloc.with_context("Adam V"); + LazyAllocator alloc_lazy; + mBlocksV.resize(mConfig.NumLayers); + for (int i = 0; i < mConfig.NumLayers; ++i) { + mBlocksV[i] = shard_container(model.create_block_container(mConfig, mVType, mVType), mWorld); + alloc_lazy.allocate(mBlocksV[i]); + mStorageV.push_back(alloc_lazy.commit(alloc, mOffloadV ? kind : EAllocationType::ON_DEVICE, "v_block_shard")); + } + mNonBlockV = shard_container(model.create_non_block_container(mConfig, mVType, mVType), mWorld); + alloc_lazy.allocate(mNonBlockV); + mStorageV.push_back(alloc_lazy.commit(alloc, mOffloadV ? kind : EAllocationType::ON_DEVICE, "v_nonblock_shard")); + + for (auto& t : mStorageV) { + fill_zero(t, stream); + } + } +} + SimpleTensorContainer& AdamWStateManager::get_block_from(int layer_idx, cudaStream_t stream, SimpleTensorContainer &buf) { int buffer = layer_idx % 2; auto& stat = mStatus.at(buffer); diff --git a/src/training/adamw_optimizer.h b/src/training/adamw_optimizer.h index 3fa7247..9535c87 100644 --- a/src/training/adamw_optimizer.h +++ b/src/training/adamw_optimizer.h @@ -10,7 +10,9 @@ #include "utilities/tensor_container.h" #include "utilities/tensor.h" +enum class EAllocationType : int; class IModel; +class TensorAllocator; typedef struct CUstream_st *cudaStream_t; class DeviceMemoryStack; @@ -19,21 +21,26 @@ class AdamWStateManager { public: AdamWStateManager(TransformerConfig cfg, IModel& model, bool offload_m, bool offload_v, ETensorDType type_m, ETensorDType type_v, bool zero_copy, int rank, int world); virtual ~AdamWStateManager() = default; - virtual void begin_optimizer(DeviceMemoryStack& memory, cudaStream_t main_stream); - virtual void end_optimizer(DeviceMemoryStack& memory); + void begin_optimizer(DeviceMemoryStack& memory, cudaStream_t main_stream); + void end_optimizer(DeviceMemoryStack& memory); void fetch_block(int layer_idx, cudaStream_t fetch_stream); - virtual SimpleTensorContainer& get_block_m(int layer_idx, cudaStream_t stream) = 0; - virtual SimpleTensorContainer& get_block_v(int layer_idx, cudaStream_t stream) = 0; - virtual SimpleTensorContainer& get_block_scales_m(int layer_idx) = 0; + SimpleTensorContainer& get_block_m(int layer_idx, cudaStream_t stream); + SimpleTensorContainer& get_block_v(int layer_idx, cudaStream_t stream); + SimpleTensorContainer& get_block_scales_m(int layer_idx); void store_block(int layer_idx, cudaStream_t stream, cudaStream_t put_stream); - virtual SimpleTensorContainer& non_block_m() = 0; - virtual SimpleTensorContainer& non_block_v() = 0; + SimpleTensorContainer& non_block_m(); + SimpleTensorContainer& non_block_m_scales(); + SimpleTensorContainer& non_block_v(); + + void allocate_state(IModel& model, cudaStream_t stream, EAllocationType kind, TensorAllocator& alloc); + + virtual void safe_to_checkpoint(const std::string& checkpoint_dir) = 0; + virtual void load_from_checkpoint(const std::string& checkpoint_dir) = 0; protected: SimpleTensorContainer& get_block_from(int layer_idx, cudaStream_t stream, SimpleTensorContainer& buf); - TransformerConfig mConfig; bool mOffloadM; @@ -53,14 +60,20 @@ class AdamWStateManager { bool Done = true; }; - std::vector mMBlockStorage; - std::vector mVBlockStorage; + std::vector mStorageM; + std::vector mStorageV; + std::vector mBlocksM; + std::vector mBlocksV; + std::vector mBlocksMScales; + GenericTensorContainer mNonBlockM; + GenericTensorContainer mNonBlockMScales; + GenericTensorContainer mNonBlockV; std::array mMBufferStorage; std::array mVBufferStorage; std::array mStatus; - std::array mOptMBuffer; - std::array mOptVBuffer; + std::array mMDeviceBuffer; + std::array mVDeviceBuffer; }; #endif //LLMQ_ADAMW_OPTIMIZER_H diff --git a/src/training/checkpoint.cpp b/src/training/checkpoint.cpp index 2d740a5..ec37e7f 100644 --- a/src/training/checkpoint.cpp +++ b/src/training/checkpoint.cpp @@ -9,6 +9,7 @@ #include #include +#include "adamw_optimizer.h" #include "dataloader.h" #include "model.h" #include "utilities/comm.h" @@ -31,20 +32,7 @@ std::string save_checkpoint(std::string target, int step, IModel& model, const D // weights // TODO don't duplicate weights if they are unsharded write_safetensors(target + fmt::format("/weights.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.weights()); - - // sharded optimizer state - write_safetensors(target + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_momentum()); - write_safetensors(target + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_variance()); - - bool has_scales = false; - model.opt_momentum_scales().iterate_tensors([&has_scales](const std::string& name, const TensorShard& tensor){ - if(tensor.Data != nullptr) { - has_scales = true; - } - }); - if(has_scales) { - write_safetensors(target + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_momentum_scales()); - } + model.optimizer().safe_to_checkpoint(target); comm.barrier(); // only write checkpoint.json once we know all the shard files are saved @@ -126,21 +114,7 @@ void load_checkpoint(std::string source, int step, IModel& model, DataLoader* lo // weights load_safetensors(source + fmt::format("/weights.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.weights(), false); - - // load optimizer shards - load_safetensors(source + fmt::format("/adam.m.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_momentum(), false); - load_safetensors(source + fmt::format("/adam.v.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_variance(), false); - - bool has_scales = false; - model.opt_momentum_scales().iterate_tensors([&has_scales](const std::string& name, const TensorShard& tensor){ - if(tensor.Data != nullptr) { - has_scales = true; - } - }); - if(has_scales) { - load_safetensors(source + fmt::format("/adam.m.scales.shard_{:03}_of_{:03}.safetensors", comm.rank(), comm.world_size()), model.opt_momentum_scales(), false); - } - + model.optimizer().load_from_checkpoint(source); model.on_restore_checkpoint(comm); } diff --git a/src/training/model.h b/src/training/model.h index ed390ef..2362495 100644 --- a/src/training/model.h +++ b/src/training/model.h @@ -14,6 +14,7 @@ #include "utilities/tensor.h" #include "training/transformer_config.h" +class AdamWStateManager; class ITensorContainer; class NCCLCommunicator; class TensorAllocator; @@ -69,13 +70,7 @@ class IModel { virtual ITensorContainer& weights() = 0; //! (First order) momentum. Sharded. - virtual ITensorContainer& opt_momentum() = 0; - - //! (First order) momentum. Sharded. - virtual ITensorContainer& opt_momentum_scales() = 0; - - //! Second order moments. Sharded. - virtual ITensorContainer& opt_variance() = 0; + virtual AdamWStateManager& optimizer() = 0; //! Get the current RNG state virtual std::vector rng_state() const = 0; diff --git a/src/utilities/tensor.cpp b/src/utilities/tensor.cpp index 409b0a2..b4cf80c 100644 --- a/src/utilities/tensor.cpp +++ b/src/utilities/tensor.cpp @@ -142,3 +142,14 @@ void visit(const std::function& func, SimpleTensorContai } } } + +GenericTensorContainer::GenericTensorContainer(std::vector t): mTensors( std::move(t) ) { +} + +std::size_t GenericTensorContainer::num_tensors() const noexcept { + return mTensors.size(); +} + +const Tensor& GenericTensorContainer::get_tensor(std::size_t idx) const { + return mTensors.at(idx); +} diff --git a/src/utilities/tensor_container.h b/src/utilities/tensor_container.h index a650ab5..7d5924e 100644 --- a/src/utilities/tensor_container.h +++ b/src/utilities/tensor_container.h @@ -44,13 +44,13 @@ void visit(const std::function& func, SimpleTensorContai class GenericTensorContainer final : public SimpleTensorContainer { public: GenericTensorContainer() = default; - GenericTensorContainer(std::vector t) : mTensors( std::move(t) ) { }; + explicit GenericTensorContainer(std::vector t); //! Get the total number of tensors in this container. This count includes empty tensors. - std::size_t num_tensors() const noexcept { return mTensors.size(); }; + std::size_t num_tensors() const noexcept override; //! Return a constant reference to the tensor at the given index. - const Tensor& get_tensor(std::size_t idx) const { return mTensors.at(idx); } + const Tensor& get_tensor(std::size_t idx) const override; using SimpleTensorContainer::get_tensor; private: