From 38367fa030cefd1126124ea695a39db3d12bf98e Mon Sep 17 00:00:00 2001 From: noli Date: Tue, 13 Aug 2024 09:40:57 +0000 Subject: [PATCH 1/7] adds placeholder op --- ark/api/executor.cpp | 75 +++++++++++++++++++------------- ark/include/ark/executor.hpp | 17 +++++--- ark/include/ark/model.hpp | 34 ++++++++++++++- ark/include/ark/tensor.hpp | 11 +++++ ark/model/model_op.cpp | 2 + ark/model_buffer_manager.hpp | 13 ++++-- ark/ops/ops_placeholder.cpp | 57 ++++++++++++++++++++++++ ark/ops/ops_placeholder.hpp | 23 ++++++++++ ark/ops/ops_placeholder_test.cpp | 51 ++++++++++++++++++++++ 9 files changed, 241 insertions(+), 42 deletions(-) create mode 100644 ark/ops/ops_placeholder.cpp create mode 100644 ark/ops/ops_placeholder.hpp create mode 100644 ark/ops/ops_placeholder_test.cpp diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 162aaa1f..7823c324 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -26,6 +26,7 @@ #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" #include "model_buffer_manager.hpp" +#include "unordered_map" #include "utils/utils_net.hpp" #if defined(ARK_CUDA) @@ -143,7 +144,10 @@ static size_t tensor_stride_bytes(const Json &tensor) { class Executor::Impl { public: - Impl() : plan_json_(), device_id_(-1) {}; + Impl() + : plan_json_(), + device_id_(-1), + buffer_manager_(ModelBufferManager::get_instance()) {}; ~Impl(); int device_id() const { return device_id_; } @@ -160,8 +164,10 @@ class Executor::Impl { void compile(const std::string &plan, int device_id, const std::string &name); - void launch(Stream stream, bool loop_mode); - void run(int iter); + void launch(Stream stream, bool loop_mode, + const std::unordered_map &external_tensors); + void run(int iter, + const std::unordered_map &external_tensors); void wait(int64_t max_spin_count); float stop(int64_t max_spin_count); void barrier(); @@ -203,6 +209,7 @@ class Executor::Impl { bool is_recording_ = false; float elapsed_msec_ = -1; + ModelBufferManager &buffer_manager_; std::vector external_buffers_; std::vector external_args_; std::map buffer_id_to_name_; @@ -408,45 +415,40 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { for (auto &kv : buffer_id_to_info) { auto &buf_info = kv.second; int r = buf_info->buffer->rank(); + const size_t buf_id = buf_info->buffer->id(); if (r != rank_ && r != -1) { // this is a remote buffer for (const auto &tag_info : buf_info->buffer->send_tags()) { remote_rank_to_send_tag_to_buffer_id[buf_info->buffer->rank()] - [tag_info.second] = - buf_info->buffer->id(); + [tag_info.second] = buf_id; } for (const auto &tag_info : buf_info->buffer->recv_tags()) { remote_rank_to_recv_tag_to_buffer_id[buf_info->buffer->rank()] - [tag_info.second] = - buf_info->buffer->id(); + [tag_info.second] = buf_id; } continue; } - if (buf_info->buffer->is_external()) { + if (buffer_manager_.is_external(buf_id)) { if (buf_info->buffer->device_id() != device_id_) { ERR(InvalidUsageError, "PyTorch tensor and model execution are on different GPUs"); } - external_buffers_.push_back(buf_info->buffer->external_data()); + external_buffers_.push_back(buffer_manager_.get_buffer(buf_id)); const auto [it, inserted] = buffer_id_to_name_.try_emplace( - buf_info->buffer->id(), - "extern_buf_" + std::to_string(buf_info->buffer->id())); + buf_id, "extern_buf_" + std::to_string(buf_id)); external_args_.push_back(it->second); continue; } // if we are adding a plan and come across a buffer from a previous // plan, we utilize the buffer offset from the previous plan - if (buffer_id_to_offset_.find(buf_info->buffer->id()) != - buffer_id_to_offset_.end()) { - external_buffers_.push_back( - buffer_id_to_addr_[buf_info->buffer->id()]); - const std::string name = - "extern_buf_" + std::to_string(buf_info->buffer->id()); + if (buffer_id_to_offset_.find(buf_id) != buffer_id_to_offset_.end()) { + external_buffers_.push_back(buffer_id_to_addr_[buf_id]); + const std::string name = "extern_buf_" + std::to_string(buf_id); external_args_.push_back(name); - buffer_id_to_name_[buf_info->buffer->id()] = name; + buffer_id_to_name_[buf_id] = name; continue; } else { - buffer_id_to_offset[buf_info->buffer->id()] = offset; + buffer_id_to_offset[buf_id] = offset; for (const auto &tag_info : buf_info->buffer->send_tags()) { remote_rank_to_send_tags_and_offsets[tag_info.first] .first.push_back(tag_info.second); @@ -536,8 +538,9 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 1); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 2); for (int i = 0; i < len; ++i) { - if (!buffer_id_to_info[send_tag_to_buffer_id[tags[i]]] - ->buffer->is_external()) { + const size_t buf_id = + buffer_id_to_info[send_tag_to_buffer_id[tags[i]]]->buffer->id(); + if (!buffer_manager_.is_external(buf_id)) { buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = offsets[i]; } @@ -556,8 +559,9 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 4); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 5); for (int i = 0; i < len; ++i) { - if (!buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]] - ->buffer->is_external()) { + const size_t buf_id = + buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]]->buffer->id(); + if (!buffer_manager_.is_external(buf_id)) { buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = offsets[i]; } @@ -703,7 +707,9 @@ void Executor::Impl::compile(const std::string &plan, int device_id, kernel_->compile(); } -void Executor::Impl::launch(Stream stream, bool loop_mode) { +void Executor::Impl::launch( + Stream stream, bool loop_mode, + const std::unordered_map &external_tensors) { if ((kernel_ == nullptr) || !kernel_->is_compiled()) { ERR(InvalidUsageError, "Need to compile first before launch."); } @@ -796,7 +802,8 @@ void Executor::Impl::launch(Stream stream, bool loop_mode) { is_launched_ = true; } -void Executor::Impl::run(int iter) { +void Executor::Impl::run( + int iter, const std::unordered_map &external_tensors) { if (iter <= 0) return; if (loop_mode_) { while (atomicLoadRelaxed(flag_->ref()) > 0) { @@ -888,7 +895,7 @@ void *Executor::Impl::tensor_address(const Tensor &tensor) const { void Executor::Impl::tensor_read(const Tensor &tensor, void *data, size_t bytes, Stream stream, bool is_d2d) const { GLOG(gpuSetDevice(device_id_)); - if (tensor.ref()->buffer()->is_external()) { + if (buffer_manager_.is_external(tensor.ref()->buffer()->id())) { ERR(InvalidUsageError, "Reading data from a tensor preallocated by PyTorch is not " "supported. Use PyTorch's native methods."); @@ -944,7 +951,7 @@ void Executor::Impl::tensor_write(const Tensor &tensor, const void *data, size_t bytes, Stream stream, bool is_d2d) const { GLOG(gpuSetDevice(device_id_)); - if (tensor.ref()->buffer()->is_external()) { + if (buffer_manager_.is_external(tensor.ref()->buffer()->id())) { ERR(InvalidUsageError, "Writing data to a tensor preallocated by PyTorch is not " "supported. Use PyTorch's native methods."); @@ -1019,11 +1026,16 @@ void Executor::compile(const std::string &plan, int device_id, impl_->compile(plan, device_id, name); } -void Executor::launch(Stream stream, bool loop_mode) { - impl_->launch(stream, loop_mode); +void Executor::launch( + Stream stream, bool loop_mode, + const std::unordered_map &external_tensors) { + impl_->launch(stream, loop_mode, external_tensors); } -void Executor::run(int iter) { impl_->run(iter); } +void Executor::run(int iter, + const std::unordered_map &external_tensors) { + impl_->run(iter, external_tensors); +} void Executor::wait(int64_t max_spin_count) { impl_->wait(max_spin_count); } @@ -1071,7 +1083,8 @@ DefaultExecutor::DefaultExecutor( } void DefaultExecutor::launch() { - Executor::launch(reinterpret_cast(impl_->stream_raw_), impl_->loop_mode_); + Executor::launch(reinterpret_cast(impl_->stream_raw_), + impl_->loop_mode_); } } // namespace ark diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 8e6577cd..8e5e5c85 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace ark { @@ -45,10 +46,13 @@ class Executor { const std::string &name = "executor"); /// Launch the executor. This must be called after `compile()`. - void launch(Stream stream = nullptr, bool loop_mode = true); + void launch( + Stream stream = nullptr, bool loop_mode = true, + const std::unordered_map &external_tensors = {}); /// Run the executor for `iter` iterations. - void run(int iter); + void run(int iter, + const std::unordered_map &external_tensors = {}); /// Wait for the previous run to finish. void wait(int64_t max_spin_count = -1); @@ -99,10 +103,11 @@ class Model; class DefaultExecutor : public Executor { public: - DefaultExecutor( - const Model &model, int device_id = -1, Stream stream = nullptr, - const std::vector &config_rules = {}, - const std::string &name = "DefaultExecutor", bool loop_mode = true); + DefaultExecutor(const Model &model, int device_id = -1, + Stream stream = nullptr, + const std::vector &config_rules = {}, + const std::string &name = "DefaultExecutor", + bool loop_mode = true); /// Launch the default executor. void launch(); diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index 3c4f22e2..08b8fe63 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -76,6 +76,39 @@ class Model : public ModelGraph { const Dims &padded_shape = {}, int rank = -1, const std::string &name = ""); + /// + /// Returns a tensor object associated with an external buffer. + /// + /// @param shape Shape of the tensor, where the data of interest is. + /// @param dtype Type of the tensor data. + /// @param strides Strides of each dimension of the tensor, which may be + /// different from the shape. @p strides can be considered as the actual + /// shape of the underlying data buffer. + /// @param offsets Offsets of the tensor. The data of interest starts at + /// @p offsets and ends at @p offsets + @p padded_shape. + /// @param padded_shape Padded shape of the tensor. Padding is used to + /// reserve extra space for the tensor when computation requires it. + /// Data on the padded region is allowed to be accessed by computation, + /// but it is not considered as the data of interest. The padded region is + /// initialized to zero only once when the Executor is launched. The padded + /// shape should be greater than or equal to the @p shape, and the + /// @p strides should be greater than or equal to the padded shape. If the + /// @p strides are not provided, they are set to the padded shape. If the + /// padded shape is not provided, it is set to the @p shape. + /// @param rank Rank of the tensor. -1 means the rank of this model. + /// @param name Name of the tensor. + /// @param external_data Pointer to an external data buffer. If provided, + /// this buffer is registered with the ModelBufferManager and associated + /// with the tensor. + /// @return Pointer to a tensor object that references the external buffer. + /// + /// + Tensor placeholder(const Dims &shape, const DataType &data_type, + const Dims &strides = {}, const Dims &offsets = {}, + const Dims &padded_shape = {}, int rank = -1, + const std::string &name = "", + void *external_data = nullptr); + Tensor refer(Tensor input, const Dims &shape = {}, const Dims &strides = {}, const Dims &offsets = {}, const Dims &padded_shape = {}, const std::string &name = ""); @@ -254,7 +287,6 @@ class Model : public ModelGraph { Tensor local_all_reduce(Tensor input, int gpu_id, int gpu_num, const std::string &name = ""); - }; } // namespace ark diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 5e463f99..816738c0 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -54,6 +54,8 @@ class Tensor { const DataType &data_type() const; Dims torch_strides() const; + + friend struct std::hash; }; const Tensor NullTensor; @@ -62,4 +64,13 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor); } // namespace ark +namespace std { +template <> +struct hash { + size_t operator()(const ark::Tensor &t) const { + return hash()(t.id()); + } +}; +} // namespace std + #endif // ARK_TENSOR_HPP diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index 5db8576e..8f222b75 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -16,6 +16,7 @@ #include "ops/ops_math.hpp" #include "ops/ops_matmul.hpp" #include "ops/ops_noop.hpp" +#include "ops/ops_placeholder.hpp" #include "ops/ops_reduce.hpp" #include "ops/ops_refer.hpp" #include "ops/ops_reshape.hpp" @@ -78,6 +79,7 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { MODEL_OP_TYPE_REGISTER(Sqrt); MODEL_OP_TYPE_REGISTER(Sub); MODEL_OP_TYPE_REGISTER(Tensor); + MODEL_OP_TYPE_REGISTER(Placeholder); MODEL_OP_TYPE_REGISTER(Transpose); MODEL_OP_TYPE_REGISTER(SendPacket); MODEL_OP_TYPE_REGISTER(RecvPacket); diff --git a/ark/model_buffer_manager.hpp b/ark/model_buffer_manager.hpp index 4baaec7f..3e82b05f 100644 --- a/ark/model_buffer_manager.hpp +++ b/ark/model_buffer_manager.hpp @@ -8,7 +8,8 @@ #include namespace ark { -// Manages externally allocated buffers not in the ARK memory space. +// Manages externally allocated buffers (buffers corresponding to Tensors that +// are the output of a `placeholder` operation) outside of ARK's memory space. class ModelBufferManager { public: static ModelBufferManager& get_instance() { @@ -16,11 +17,11 @@ class ModelBufferManager { return instance; } - void register_buffer(size_t id, void* data, size_t size) { + void register_buffer(const size_t id, void* const data, const size_t size) { buffers_[id] = std::make_tuple(data, size); } - void* get_buffer(size_t id) { + void* get_buffer(const size_t id) const { auto it = buffers_.find(id); if (it != buffers_.end()) { return std::get<0>(it->second); @@ -28,7 +29,7 @@ class ModelBufferManager { return nullptr; } - size_t get_buffer_size(size_t id) { + size_t get_buffer_size(const size_t id) const { auto it = buffers_.find(id); if (it != buffers_.end()) { return std::get<1>(it->second); @@ -36,6 +37,10 @@ class ModelBufferManager { return 0; } + bool is_external(const size_t id) const { + return buffers_.find(id) != buffers_.end(); + } + const std::unordered_map>& get_buffers() const { return buffers_; diff --git a/ark/ops/ops_placeholder.cpp b/ark/ops/ops_placeholder.cpp new file mode 100644 index 00000000..fbac7390 --- /dev/null +++ b/ark/ops/ops_placeholder.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ops_placeholder.hpp" + +#include "logging.hpp" +#include "model_buffer_manager.hpp" +#include "ops_common.hpp" + +namespace ark { + +ModelOpPlaceholder::ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, + ModelDataType data_type, + const Dims &strides, const Dims &offsets, + const Dims &padded_shape, + void *external_data) + : ModelOp("Placeholder", true) { + if (!buffer) { + buffer = std::make_shared(); + } + const std::vector &shape_vec = shape.vector(); + DataType dtype = ModelDataType(data_type); + + size_t external_data_size = + std::accumulate(shape_vec.begin(), shape_vec.end(), 1, + std::multiplies()) * + dtype.bytes(); + + ModelBufferManager::get_instance().register_buffer( + buffer->id(), external_data, external_data_size); + + ModelTensorRef tensor = std::make_shared( + data_type, buffer, shape, strides, offsets, padded_shape); + + result_tensors_.emplace_back(tensor); + + verify(); +} + +Tensor Model::placeholder(const Dims &shape, const DataType &data_type, + const Dims &strides, const Dims &offsets, + const Dims &padded_shape, int rank, + const std::string &name, void *external_data) { + if (rank != -1) { + if (rank == this->rank()) { + rank = -1; + } else if (rank < 0 || rank >= this->world_size()) { + ERR(ModelError, "Invalid rank %d", rank); + } + } + return impl_ + ->create_op( + name, std::make_shared(rank), shape, data_type.ref(), + strides, offsets, padded_shape, external_data) + ->result_tensors()[0]; +} +} // namespace ark \ No newline at end of file diff --git a/ark/ops/ops_placeholder.hpp b/ark/ops/ops_placeholder.hpp new file mode 100644 index 00000000..7fb53f98 --- /dev/null +++ b/ark/ops/ops_placeholder.hpp @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_OPS_PLACEHOLDER_HPP_ +#define ARK_OPS_PLACEHOLDER_HPP_ + +#include "ark/model.hpp" +#include "model/model_op.hpp" + +namespace ark { + +class ModelOpPlaceholder : public ModelOp { + public: + ModelOpPlaceholder() = default; + ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, + ModelDataType data_type, const Dims &strides, + const Dims &offsets, const Dims &padded_shape, + void *external_data = nullptr); +}; + +} // namespace ark + +#endif // ARK_OPS_PLACEHOLDER_HPP_ \ No newline at end of file diff --git a/ark/ops/ops_placeholder_test.cpp b/ark/ops/ops_placeholder_test.cpp new file mode 100644 index 00000000..37c04777 --- /dev/null +++ b/ark/ops/ops_placeholder_test.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include + +#include "ark/executor.hpp" +#include "gpu/gpu.hpp" +#include "logging.hpp" +#include "model/model_node.hpp" +#include "model/model_op.hpp" +#include "ops_test_common.hpp" + +ark::unittest::State test_ops_placeholder_value_contiguous() { + ark::Model model; + ark::Dims shape{10, 1}; + + // Allocate GPU memory for the external buffer + float *d_ext_buffer = nullptr; + ark::gpuMalloc(&d_ext_buffer, shape.nelems() * sizeof(float)); + + // Initialize GPU Memory + std::vector h_ext_buffer(shape.nelems()); + std::iota(h_ext_buffer.begin(), h_ext_buffer.end(), 1.0f); + ark::gpuMemcpy(d_ext_buffer, h_ext_buffer.data(), + shape.nelems() * sizeof(float), ark::gpuMemcpyHostToDevice); + + // Associate the initialzied device buffer with a tensor produced from a + // placeholder operation + auto tns = + model.placeholder(shape, ark::FP32, {}, {}, {}, -1, "", d_ext_buffer); + + // Copy tensor data from GPU to CPU + std::vector res(shape.nelems(), 0.0f); + ark::gpuMemcpy(res.data(), d_ext_buffer, shape.nelems() * sizeof(float), + ark::gpuMemcpyDeviceToHost); + + for (auto i = 0; i < shape.nelems(); ++i) { + UNITTEST_EQ(res[i], i + 1); + } + + cudaFree(d_ext_buffer); + + return ark::unittest::SUCCESS; +} + +int main() { + ark::init(); + UNITTEST(test_ops_placeholder_value_contiguous); + return ark::unittest::SUCCESS; +} \ No newline at end of file From 920807f2a22afc28cc80b3904cdfe343753a5cfa Mon Sep 17 00:00:00 2001 From: noli Date: Tue, 13 Aug 2024 10:28:46 +0000 Subject: [PATCH 2/7] fix test --- ark/ops/ops_placeholder_test.cpp | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/ark/ops/ops_placeholder_test.cpp b/ark/ops/ops_placeholder_test.cpp index 37c04777..59f5e2dc 100644 --- a/ark/ops/ops_placeholder_test.cpp +++ b/ark/ops/ops_placeholder_test.cpp @@ -1,13 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include -#include - #include "ark/executor.hpp" #include "gpu/gpu.hpp" #include "logging.hpp" -#include "model/model_node.hpp" #include "model/model_op.hpp" #include "ops_test_common.hpp" @@ -25,18 +21,25 @@ ark::unittest::State test_ops_placeholder_value_contiguous() { ark::gpuMemcpy(d_ext_buffer, h_ext_buffer.data(), shape.nelems() * sizeof(float), ark::gpuMemcpyHostToDevice); - // Associate the initialzied device buffer with a tensor produced from a + // Associate the initialized device buffer with a tensor produced from a // placeholder operation - auto tns = + ark::Tensor tns = model.placeholder(shape, ark::FP32, {}, {}, {}, -1, "", d_ext_buffer); + ark::Tensor res = model.add(tns, 1.0); + + ark::DefaultExecutor exe(model); + + exe.launch(); + exe.run(1); + exe.stop(); + // Copy tensor data from GPU to CPU - std::vector res(shape.nelems(), 0.0f); - ark::gpuMemcpy(res.data(), d_ext_buffer, shape.nelems() * sizeof(float), - ark::gpuMemcpyDeviceToHost); + std::vector h_res(shape.nelems(), 0.0f); + exe.tensor_read(res, h_res); for (auto i = 0; i < shape.nelems(); ++i) { - UNITTEST_EQ(res[i], i + 1); + UNITTEST_EQ(h_res[i], i + 2); } cudaFree(d_ext_buffer); From 2d51327052d54f553776a49f92da703e65efbee5 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 15 Aug 2024 09:09:27 +0000 Subject: [PATCH 3/7] minor --- ark/cpu_timer.cpp | 16 ---------------- ark/cpu_timer.h | 4 ---- 2 files changed, 20 deletions(-) diff --git a/ark/cpu_timer.cpp b/ark/cpu_timer.cpp index c740de5f..129ba7bd 100644 --- a/ark/cpu_timer.cpp +++ b/ark/cpu_timer.cpp @@ -16,20 +16,4 @@ double cpu_timer(void) { return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; } -// Sleep in second. -int cpu_timer_sleep(double sec) { - struct timespec tspec; - tspec.tv_sec = (time_t)sec; - tspec.tv_nsec = (long)((sec - tspec.tv_sec) * 1.0e9); - return nanosleep(&tspec, 0); -} - -// Sleep in nanosecond. -int cpu_ntimer_sleep(long nsec) { - struct timespec tspec; - tspec.tv_sec = 0; - tspec.tv_nsec = nsec; - return nanosleep(&tspec, 0); -} - } // namespace ark diff --git a/ark/cpu_timer.h b/ark/cpu_timer.h index 52bf63d9..eaac9406 100644 --- a/ark/cpu_timer.h +++ b/ark/cpu_timer.h @@ -8,10 +8,6 @@ namespace ark { // Measure current time in second. double cpu_timer(void); -// Sleep in second. -int cpu_timer_sleep(double sec); -// Sleep in nanosecond. -int cpu_ntimer_sleep(long nsec); } // namespace ark From 7d62f0f8241bfc654237c2d9a405e4218f3128ed Mon Sep 17 00:00:00 2001 From: Noli Gerawork Date: Thu, 15 Aug 2024 05:11:13 -0400 Subject: [PATCH 4/7] Add Placeholder Operator (#238) - Separates externally allocated buffers from `ModelBuffer` by having `ModelBufferManager` manage them instead. - Adds the `placeholder` operation. `placeholder` is a virtual operation that produces a `Tensor` with the added feature of providing a data pointer (which can be null to support delayed binding) to an external buffer. --- ark/api/executor.cpp | 75 +++++++++++++++++++------------- ark/include/ark/executor.hpp | 17 +++++--- ark/include/ark/model.hpp | 34 ++++++++++++++- ark/include/ark/tensor.hpp | 11 +++++ ark/model/model_op.cpp | 2 + ark/model_buffer_manager.hpp | 13 ++++-- ark/ops/ops_placeholder.cpp | 57 ++++++++++++++++++++++++ ark/ops/ops_placeholder.hpp | 23 ++++++++++ ark/ops/ops_placeholder_test.cpp | 54 +++++++++++++++++++++++ 9 files changed, 244 insertions(+), 42 deletions(-) create mode 100644 ark/ops/ops_placeholder.cpp create mode 100644 ark/ops/ops_placeholder.hpp create mode 100644 ark/ops/ops_placeholder_test.cpp diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 162aaa1f..7823c324 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -26,6 +26,7 @@ #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" #include "model_buffer_manager.hpp" +#include "unordered_map" #include "utils/utils_net.hpp" #if defined(ARK_CUDA) @@ -143,7 +144,10 @@ static size_t tensor_stride_bytes(const Json &tensor) { class Executor::Impl { public: - Impl() : plan_json_(), device_id_(-1) {}; + Impl() + : plan_json_(), + device_id_(-1), + buffer_manager_(ModelBufferManager::get_instance()) {}; ~Impl(); int device_id() const { return device_id_; } @@ -160,8 +164,10 @@ class Executor::Impl { void compile(const std::string &plan, int device_id, const std::string &name); - void launch(Stream stream, bool loop_mode); - void run(int iter); + void launch(Stream stream, bool loop_mode, + const std::unordered_map &external_tensors); + void run(int iter, + const std::unordered_map &external_tensors); void wait(int64_t max_spin_count); float stop(int64_t max_spin_count); void barrier(); @@ -203,6 +209,7 @@ class Executor::Impl { bool is_recording_ = false; float elapsed_msec_ = -1; + ModelBufferManager &buffer_manager_; std::vector external_buffers_; std::vector external_args_; std::map buffer_id_to_name_; @@ -408,45 +415,40 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { for (auto &kv : buffer_id_to_info) { auto &buf_info = kv.second; int r = buf_info->buffer->rank(); + const size_t buf_id = buf_info->buffer->id(); if (r != rank_ && r != -1) { // this is a remote buffer for (const auto &tag_info : buf_info->buffer->send_tags()) { remote_rank_to_send_tag_to_buffer_id[buf_info->buffer->rank()] - [tag_info.second] = - buf_info->buffer->id(); + [tag_info.second] = buf_id; } for (const auto &tag_info : buf_info->buffer->recv_tags()) { remote_rank_to_recv_tag_to_buffer_id[buf_info->buffer->rank()] - [tag_info.second] = - buf_info->buffer->id(); + [tag_info.second] = buf_id; } continue; } - if (buf_info->buffer->is_external()) { + if (buffer_manager_.is_external(buf_id)) { if (buf_info->buffer->device_id() != device_id_) { ERR(InvalidUsageError, "PyTorch tensor and model execution are on different GPUs"); } - external_buffers_.push_back(buf_info->buffer->external_data()); + external_buffers_.push_back(buffer_manager_.get_buffer(buf_id)); const auto [it, inserted] = buffer_id_to_name_.try_emplace( - buf_info->buffer->id(), - "extern_buf_" + std::to_string(buf_info->buffer->id())); + buf_id, "extern_buf_" + std::to_string(buf_id)); external_args_.push_back(it->second); continue; } // if we are adding a plan and come across a buffer from a previous // plan, we utilize the buffer offset from the previous plan - if (buffer_id_to_offset_.find(buf_info->buffer->id()) != - buffer_id_to_offset_.end()) { - external_buffers_.push_back( - buffer_id_to_addr_[buf_info->buffer->id()]); - const std::string name = - "extern_buf_" + std::to_string(buf_info->buffer->id()); + if (buffer_id_to_offset_.find(buf_id) != buffer_id_to_offset_.end()) { + external_buffers_.push_back(buffer_id_to_addr_[buf_id]); + const std::string name = "extern_buf_" + std::to_string(buf_id); external_args_.push_back(name); - buffer_id_to_name_[buf_info->buffer->id()] = name; + buffer_id_to_name_[buf_id] = name; continue; } else { - buffer_id_to_offset[buf_info->buffer->id()] = offset; + buffer_id_to_offset[buf_id] = offset; for (const auto &tag_info : buf_info->buffer->send_tags()) { remote_rank_to_send_tags_and_offsets[tag_info.first] .first.push_back(tag_info.second); @@ -536,8 +538,9 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 1); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 2); for (int i = 0; i < len; ++i) { - if (!buffer_id_to_info[send_tag_to_buffer_id[tags[i]]] - ->buffer->is_external()) { + const size_t buf_id = + buffer_id_to_info[send_tag_to_buffer_id[tags[i]]]->buffer->id(); + if (!buffer_manager_.is_external(buf_id)) { buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = offsets[i]; } @@ -556,8 +559,9 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 4); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 5); for (int i = 0; i < len; ++i) { - if (!buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]] - ->buffer->is_external()) { + const size_t buf_id = + buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]]->buffer->id(); + if (!buffer_manager_.is_external(buf_id)) { buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = offsets[i]; } @@ -703,7 +707,9 @@ void Executor::Impl::compile(const std::string &plan, int device_id, kernel_->compile(); } -void Executor::Impl::launch(Stream stream, bool loop_mode) { +void Executor::Impl::launch( + Stream stream, bool loop_mode, + const std::unordered_map &external_tensors) { if ((kernel_ == nullptr) || !kernel_->is_compiled()) { ERR(InvalidUsageError, "Need to compile first before launch."); } @@ -796,7 +802,8 @@ void Executor::Impl::launch(Stream stream, bool loop_mode) { is_launched_ = true; } -void Executor::Impl::run(int iter) { +void Executor::Impl::run( + int iter, const std::unordered_map &external_tensors) { if (iter <= 0) return; if (loop_mode_) { while (atomicLoadRelaxed(flag_->ref()) > 0) { @@ -888,7 +895,7 @@ void *Executor::Impl::tensor_address(const Tensor &tensor) const { void Executor::Impl::tensor_read(const Tensor &tensor, void *data, size_t bytes, Stream stream, bool is_d2d) const { GLOG(gpuSetDevice(device_id_)); - if (tensor.ref()->buffer()->is_external()) { + if (buffer_manager_.is_external(tensor.ref()->buffer()->id())) { ERR(InvalidUsageError, "Reading data from a tensor preallocated by PyTorch is not " "supported. Use PyTorch's native methods."); @@ -944,7 +951,7 @@ void Executor::Impl::tensor_write(const Tensor &tensor, const void *data, size_t bytes, Stream stream, bool is_d2d) const { GLOG(gpuSetDevice(device_id_)); - if (tensor.ref()->buffer()->is_external()) { + if (buffer_manager_.is_external(tensor.ref()->buffer()->id())) { ERR(InvalidUsageError, "Writing data to a tensor preallocated by PyTorch is not " "supported. Use PyTorch's native methods."); @@ -1019,11 +1026,16 @@ void Executor::compile(const std::string &plan, int device_id, impl_->compile(plan, device_id, name); } -void Executor::launch(Stream stream, bool loop_mode) { - impl_->launch(stream, loop_mode); +void Executor::launch( + Stream stream, bool loop_mode, + const std::unordered_map &external_tensors) { + impl_->launch(stream, loop_mode, external_tensors); } -void Executor::run(int iter) { impl_->run(iter); } +void Executor::run(int iter, + const std::unordered_map &external_tensors) { + impl_->run(iter, external_tensors); +} void Executor::wait(int64_t max_spin_count) { impl_->wait(max_spin_count); } @@ -1071,7 +1083,8 @@ DefaultExecutor::DefaultExecutor( } void DefaultExecutor::launch() { - Executor::launch(reinterpret_cast(impl_->stream_raw_), impl_->loop_mode_); + Executor::launch(reinterpret_cast(impl_->stream_raw_), + impl_->loop_mode_); } } // namespace ark diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 8e6577cd..8e5e5c85 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace ark { @@ -45,10 +46,13 @@ class Executor { const std::string &name = "executor"); /// Launch the executor. This must be called after `compile()`. - void launch(Stream stream = nullptr, bool loop_mode = true); + void launch( + Stream stream = nullptr, bool loop_mode = true, + const std::unordered_map &external_tensors = {}); /// Run the executor for `iter` iterations. - void run(int iter); + void run(int iter, + const std::unordered_map &external_tensors = {}); /// Wait for the previous run to finish. void wait(int64_t max_spin_count = -1); @@ -99,10 +103,11 @@ class Model; class DefaultExecutor : public Executor { public: - DefaultExecutor( - const Model &model, int device_id = -1, Stream stream = nullptr, - const std::vector &config_rules = {}, - const std::string &name = "DefaultExecutor", bool loop_mode = true); + DefaultExecutor(const Model &model, int device_id = -1, + Stream stream = nullptr, + const std::vector &config_rules = {}, + const std::string &name = "DefaultExecutor", + bool loop_mode = true); /// Launch the default executor. void launch(); diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index 3c4f22e2..08b8fe63 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -76,6 +76,39 @@ class Model : public ModelGraph { const Dims &padded_shape = {}, int rank = -1, const std::string &name = ""); + /// + /// Returns a tensor object associated with an external buffer. + /// + /// @param shape Shape of the tensor, where the data of interest is. + /// @param dtype Type of the tensor data. + /// @param strides Strides of each dimension of the tensor, which may be + /// different from the shape. @p strides can be considered as the actual + /// shape of the underlying data buffer. + /// @param offsets Offsets of the tensor. The data of interest starts at + /// @p offsets and ends at @p offsets + @p padded_shape. + /// @param padded_shape Padded shape of the tensor. Padding is used to + /// reserve extra space for the tensor when computation requires it. + /// Data on the padded region is allowed to be accessed by computation, + /// but it is not considered as the data of interest. The padded region is + /// initialized to zero only once when the Executor is launched. The padded + /// shape should be greater than or equal to the @p shape, and the + /// @p strides should be greater than or equal to the padded shape. If the + /// @p strides are not provided, they are set to the padded shape. If the + /// padded shape is not provided, it is set to the @p shape. + /// @param rank Rank of the tensor. -1 means the rank of this model. + /// @param name Name of the tensor. + /// @param external_data Pointer to an external data buffer. If provided, + /// this buffer is registered with the ModelBufferManager and associated + /// with the tensor. + /// @return Pointer to a tensor object that references the external buffer. + /// + /// + Tensor placeholder(const Dims &shape, const DataType &data_type, + const Dims &strides = {}, const Dims &offsets = {}, + const Dims &padded_shape = {}, int rank = -1, + const std::string &name = "", + void *external_data = nullptr); + Tensor refer(Tensor input, const Dims &shape = {}, const Dims &strides = {}, const Dims &offsets = {}, const Dims &padded_shape = {}, const std::string &name = ""); @@ -254,7 +287,6 @@ class Model : public ModelGraph { Tensor local_all_reduce(Tensor input, int gpu_id, int gpu_num, const std::string &name = ""); - }; } // namespace ark diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 5e463f99..816738c0 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -54,6 +54,8 @@ class Tensor { const DataType &data_type() const; Dims torch_strides() const; + + friend struct std::hash; }; const Tensor NullTensor; @@ -62,4 +64,13 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor); } // namespace ark +namespace std { +template <> +struct hash { + size_t operator()(const ark::Tensor &t) const { + return hash()(t.id()); + } +}; +} // namespace std + #endif // ARK_TENSOR_HPP diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index 5db8576e..8f222b75 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -16,6 +16,7 @@ #include "ops/ops_math.hpp" #include "ops/ops_matmul.hpp" #include "ops/ops_noop.hpp" +#include "ops/ops_placeholder.hpp" #include "ops/ops_reduce.hpp" #include "ops/ops_refer.hpp" #include "ops/ops_reshape.hpp" @@ -78,6 +79,7 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { MODEL_OP_TYPE_REGISTER(Sqrt); MODEL_OP_TYPE_REGISTER(Sub); MODEL_OP_TYPE_REGISTER(Tensor); + MODEL_OP_TYPE_REGISTER(Placeholder); MODEL_OP_TYPE_REGISTER(Transpose); MODEL_OP_TYPE_REGISTER(SendPacket); MODEL_OP_TYPE_REGISTER(RecvPacket); diff --git a/ark/model_buffer_manager.hpp b/ark/model_buffer_manager.hpp index 4baaec7f..3e82b05f 100644 --- a/ark/model_buffer_manager.hpp +++ b/ark/model_buffer_manager.hpp @@ -8,7 +8,8 @@ #include namespace ark { -// Manages externally allocated buffers not in the ARK memory space. +// Manages externally allocated buffers (buffers corresponding to Tensors that +// are the output of a `placeholder` operation) outside of ARK's memory space. class ModelBufferManager { public: static ModelBufferManager& get_instance() { @@ -16,11 +17,11 @@ class ModelBufferManager { return instance; } - void register_buffer(size_t id, void* data, size_t size) { + void register_buffer(const size_t id, void* const data, const size_t size) { buffers_[id] = std::make_tuple(data, size); } - void* get_buffer(size_t id) { + void* get_buffer(const size_t id) const { auto it = buffers_.find(id); if (it != buffers_.end()) { return std::get<0>(it->second); @@ -28,7 +29,7 @@ class ModelBufferManager { return nullptr; } - size_t get_buffer_size(size_t id) { + size_t get_buffer_size(const size_t id) const { auto it = buffers_.find(id); if (it != buffers_.end()) { return std::get<1>(it->second); @@ -36,6 +37,10 @@ class ModelBufferManager { return 0; } + bool is_external(const size_t id) const { + return buffers_.find(id) != buffers_.end(); + } + const std::unordered_map>& get_buffers() const { return buffers_; diff --git a/ark/ops/ops_placeholder.cpp b/ark/ops/ops_placeholder.cpp new file mode 100644 index 00000000..fbac7390 --- /dev/null +++ b/ark/ops/ops_placeholder.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ops_placeholder.hpp" + +#include "logging.hpp" +#include "model_buffer_manager.hpp" +#include "ops_common.hpp" + +namespace ark { + +ModelOpPlaceholder::ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, + ModelDataType data_type, + const Dims &strides, const Dims &offsets, + const Dims &padded_shape, + void *external_data) + : ModelOp("Placeholder", true) { + if (!buffer) { + buffer = std::make_shared(); + } + const std::vector &shape_vec = shape.vector(); + DataType dtype = ModelDataType(data_type); + + size_t external_data_size = + std::accumulate(shape_vec.begin(), shape_vec.end(), 1, + std::multiplies()) * + dtype.bytes(); + + ModelBufferManager::get_instance().register_buffer( + buffer->id(), external_data, external_data_size); + + ModelTensorRef tensor = std::make_shared( + data_type, buffer, shape, strides, offsets, padded_shape); + + result_tensors_.emplace_back(tensor); + + verify(); +} + +Tensor Model::placeholder(const Dims &shape, const DataType &data_type, + const Dims &strides, const Dims &offsets, + const Dims &padded_shape, int rank, + const std::string &name, void *external_data) { + if (rank != -1) { + if (rank == this->rank()) { + rank = -1; + } else if (rank < 0 || rank >= this->world_size()) { + ERR(ModelError, "Invalid rank %d", rank); + } + } + return impl_ + ->create_op( + name, std::make_shared(rank), shape, data_type.ref(), + strides, offsets, padded_shape, external_data) + ->result_tensors()[0]; +} +} // namespace ark \ No newline at end of file diff --git a/ark/ops/ops_placeholder.hpp b/ark/ops/ops_placeholder.hpp new file mode 100644 index 00000000..7fb53f98 --- /dev/null +++ b/ark/ops/ops_placeholder.hpp @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_OPS_PLACEHOLDER_HPP_ +#define ARK_OPS_PLACEHOLDER_HPP_ + +#include "ark/model.hpp" +#include "model/model_op.hpp" + +namespace ark { + +class ModelOpPlaceholder : public ModelOp { + public: + ModelOpPlaceholder() = default; + ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, + ModelDataType data_type, const Dims &strides, + const Dims &offsets, const Dims &padded_shape, + void *external_data = nullptr); +}; + +} // namespace ark + +#endif // ARK_OPS_PLACEHOLDER_HPP_ \ No newline at end of file diff --git a/ark/ops/ops_placeholder_test.cpp b/ark/ops/ops_placeholder_test.cpp new file mode 100644 index 00000000..59f5e2dc --- /dev/null +++ b/ark/ops/ops_placeholder_test.cpp @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/executor.hpp" +#include "gpu/gpu.hpp" +#include "logging.hpp" +#include "model/model_op.hpp" +#include "ops_test_common.hpp" + +ark::unittest::State test_ops_placeholder_value_contiguous() { + ark::Model model; + ark::Dims shape{10, 1}; + + // Allocate GPU memory for the external buffer + float *d_ext_buffer = nullptr; + ark::gpuMalloc(&d_ext_buffer, shape.nelems() * sizeof(float)); + + // Initialize GPU Memory + std::vector h_ext_buffer(shape.nelems()); + std::iota(h_ext_buffer.begin(), h_ext_buffer.end(), 1.0f); + ark::gpuMemcpy(d_ext_buffer, h_ext_buffer.data(), + shape.nelems() * sizeof(float), ark::gpuMemcpyHostToDevice); + + // Associate the initialized device buffer with a tensor produced from a + // placeholder operation + ark::Tensor tns = + model.placeholder(shape, ark::FP32, {}, {}, {}, -1, "", d_ext_buffer); + + ark::Tensor res = model.add(tns, 1.0); + + ark::DefaultExecutor exe(model); + + exe.launch(); + exe.run(1); + exe.stop(); + + // Copy tensor data from GPU to CPU + std::vector h_res(shape.nelems(), 0.0f); + exe.tensor_read(res, h_res); + + for (auto i = 0; i < shape.nelems(); ++i) { + UNITTEST_EQ(h_res[i], i + 2); + } + + cudaFree(d_ext_buffer); + + return ark::unittest::SUCCESS; +} + +int main() { + ark::init(); + UNITTEST(test_ops_placeholder_value_contiguous); + return ark::unittest::SUCCESS; +} \ No newline at end of file From 192a3d34c24065cfceb2b0fbdd36d57a30f68867 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 15 Aug 2024 09:47:13 +0000 Subject: [PATCH 5/7] minor updates --- ark/api/executor.cpp | 61 +++++++++++++++----------------- ark/include/ark/executor.hpp | 13 +++---- ark/include/ark/tensor.hpp | 2 +- ark/ops/ops_placeholder_test.cpp | 13 ++++--- python/executor_py.cpp | 19 +++++++--- 5 files changed, 60 insertions(+), 48 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 7823c324..47a7a751 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -144,10 +144,7 @@ static size_t tensor_stride_bytes(const Json &tensor) { class Executor::Impl { public: - Impl() - : plan_json_(), - device_id_(-1), - buffer_manager_(ModelBufferManager::get_instance()) {}; + Impl() : plan_json_(), device_id_(-1){}; ~Impl(); int device_id() const { return device_id_; } @@ -164,10 +161,12 @@ class Executor::Impl { void compile(const std::string &plan, int device_id, const std::string &name); - void launch(Stream stream, bool loop_mode, - const std::unordered_map &external_tensors); - void run(int iter, - const std::unordered_map &external_tensors); + void launch( + Stream stream, bool loop_mode, + const std::unordered_map &placeholder_data); + void run( + int iter, + const std::unordered_map &placeholder_data); void wait(int64_t max_spin_count); float stop(int64_t max_spin_count); void barrier(); @@ -209,7 +208,6 @@ class Executor::Impl { bool is_recording_ = false; float elapsed_msec_ = -1; - ModelBufferManager &buffer_manager_; std::vector external_buffers_; std::vector external_args_; std::map buffer_id_to_name_; @@ -410,6 +408,8 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { std::map> remote_rank_to_send_tag_to_buffer_id; std::map> remote_rank_to_recv_tag_to_buffer_id; + auto &buffer_manager = ModelBufferManager::get_instance(); + // TODO: improve memory planning size_t offset = 0; for (auto &kv : buffer_id_to_info) { @@ -428,12 +428,12 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { } continue; } - if (buffer_manager_.is_external(buf_id)) { + if (buffer_manager.is_external(buf_id)) { if (buf_info->buffer->device_id() != device_id_) { ERR(InvalidUsageError, "PyTorch tensor and model execution are on different GPUs"); } - external_buffers_.push_back(buffer_manager_.get_buffer(buf_id)); + external_buffers_.push_back(buffer_manager.get_buffer(buf_id)); const auto [it, inserted] = buffer_id_to_name_.try_emplace( buf_id, "extern_buf_" + std::to_string(buf_id)); external_args_.push_back(it->second); @@ -540,7 +540,7 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { for (int i = 0; i < len; ++i) { const size_t buf_id = buffer_id_to_info[send_tag_to_buffer_id[tags[i]]]->buffer->id(); - if (!buffer_manager_.is_external(buf_id)) { + if (!buffer_manager.is_external(buf_id)) { buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = offsets[i]; } @@ -561,7 +561,7 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { for (int i = 0; i < len; ++i) { const size_t buf_id = buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]]->buffer->id(); - if (!buffer_manager_.is_external(buf_id)) { + if (!buffer_manager.is_external(buf_id)) { buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = offsets[i]; } @@ -709,7 +709,7 @@ void Executor::Impl::compile(const std::string &plan, int device_id, void Executor::Impl::launch( Stream stream, bool loop_mode, - const std::unordered_map &external_tensors) { + const std::unordered_map &placeholder_data) { if ((kernel_ == nullptr) || !kernel_->is_compiled()) { ERR(InvalidUsageError, "Need to compile first before launch."); } @@ -803,7 +803,8 @@ void Executor::Impl::launch( } void Executor::Impl::run( - int iter, const std::unordered_map &external_tensors) { + int iter, + const std::unordered_map &placeholder_data) { if (iter <= 0) return; if (loop_mode_) { while (atomicLoadRelaxed(flag_->ref()) > 0) { @@ -883,6 +884,10 @@ void Executor::Impl::barrier() { void *Executor::Impl::tensor_address(const Tensor &tensor) const { size_t buffer_id = tensor.ref()->buffer()->id(); + auto &buffer_manager = ModelBufferManager::get_instance(); + if (buffer_manager.is_external(buffer_id)) { + return buffer_manager.get_buffer(buffer_id); + } if (buffer_id_to_addr_.find(buffer_id) == buffer_id_to_addr_.end()) { ERR(InvalidUsageError, "Tensor has an unknown buffer ID ", buffer_id, ". This is likely caused by accessing a tensor that is optimized " @@ -895,11 +900,6 @@ void *Executor::Impl::tensor_address(const Tensor &tensor) const { void Executor::Impl::tensor_read(const Tensor &tensor, void *data, size_t bytes, Stream stream, bool is_d2d) const { GLOG(gpuSetDevice(device_id_)); - if (buffer_manager_.is_external(tensor.ref()->buffer()->id())) { - ERR(InvalidUsageError, - "Reading data from a tensor preallocated by PyTorch is not " - "supported. Use PyTorch's native methods."); - } std::shared_ptr copy_stream; gpuStream copy_stream_raw; if (stream) { @@ -951,11 +951,6 @@ void Executor::Impl::tensor_write(const Tensor &tensor, const void *data, size_t bytes, Stream stream, bool is_d2d) const { GLOG(gpuSetDevice(device_id_)); - if (buffer_manager_.is_external(tensor.ref()->buffer()->id())) { - ERR(InvalidUsageError, - "Writing data to a tensor preallocated by PyTorch is not " - "supported. Use PyTorch's native methods."); - } std::shared_ptr copy_stream; gpuStream copy_stream_raw; if (stream) { @@ -1028,13 +1023,14 @@ void Executor::compile(const std::string &plan, int device_id, void Executor::launch( Stream stream, bool loop_mode, - const std::unordered_map &external_tensors) { - impl_->launch(stream, loop_mode, external_tensors); + const std::unordered_map &placeholder_data) { + impl_->launch(stream, loop_mode, placeholder_data); } -void Executor::run(int iter, - const std::unordered_map &external_tensors) { - impl_->run(iter, external_tensors); +void Executor::run( + int iter, + const std::unordered_map &placeholder_data) { + impl_->run(iter, placeholder_data); } void Executor::wait(int64_t max_spin_count) { impl_->wait(max_spin_count); } @@ -1082,9 +1078,10 @@ DefaultExecutor::DefaultExecutor( impl_->loop_mode_ = loop_mode; } -void DefaultExecutor::launch() { +void DefaultExecutor::launch( + const std::unordered_map &placeholder_data) { Executor::launch(reinterpret_cast(impl_->stream_raw_), - impl_->loop_mode_); + impl_->loop_mode_, placeholder_data); } } // namespace ark diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 8e5e5c85..e71e087d 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -46,13 +46,13 @@ class Executor { const std::string &name = "executor"); /// Launch the executor. This must be called after `compile()`. - void launch( - Stream stream = nullptr, bool loop_mode = true, - const std::unordered_map &external_tensors = {}); + void launch(Stream stream = nullptr, bool loop_mode = true, + const std::unordered_map + &placeholder_data = {}); /// Run the executor for `iter` iterations. - void run(int iter, - const std::unordered_map &external_tensors = {}); + void run(int iter, const std::unordered_map + &placeholder_data = {}); /// Wait for the previous run to finish. void wait(int64_t max_spin_count = -1); @@ -110,7 +110,8 @@ class DefaultExecutor : public Executor { bool loop_mode = true); /// Launch the default executor. - void launch(); + void launch(const std::unordered_map + &placeholder_data = {}); }; } // namespace ark diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 816738c0..72ff9ff5 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -66,7 +66,7 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor); namespace std { template <> -struct hash { +struct hash { size_t operator()(const ark::Tensor &t) const { return hash()(t.id()); } diff --git a/ark/ops/ops_placeholder_test.cpp b/ark/ops/ops_placeholder_test.cpp index 59f5e2dc..903d8759 100644 --- a/ark/ops/ops_placeholder_test.cpp +++ b/ark/ops/ops_placeholder_test.cpp @@ -13,13 +13,16 @@ ark::unittest::State test_ops_placeholder_value_contiguous() { // Allocate GPU memory for the external buffer float *d_ext_buffer = nullptr; - ark::gpuMalloc(&d_ext_buffer, shape.nelems() * sizeof(float)); + UNITTEST_EQ(ark::gpuMalloc(&d_ext_buffer, shape.nelems() * sizeof(float)), + ark::gpuSuccess); // Initialize GPU Memory std::vector h_ext_buffer(shape.nelems()); std::iota(h_ext_buffer.begin(), h_ext_buffer.end(), 1.0f); - ark::gpuMemcpy(d_ext_buffer, h_ext_buffer.data(), - shape.nelems() * sizeof(float), ark::gpuMemcpyHostToDevice); + UNITTEST_EQ(ark::gpuMemcpy(d_ext_buffer, h_ext_buffer.data(), + shape.nelems() * sizeof(float), + ark::gpuMemcpyHostToDevice), + ark::gpuSuccess); // Associate the initialized device buffer with a tensor produced from a // placeholder operation @@ -34,6 +37,8 @@ ark::unittest::State test_ops_placeholder_value_contiguous() { exe.run(1); exe.stop(); + UNITTEST_EQ(exe.tensor_address(tns), d_ext_buffer); + // Copy tensor data from GPU to CPU std::vector h_res(shape.nelems(), 0.0f); exe.tensor_read(res, h_res); @@ -42,7 +47,7 @@ ark::unittest::State test_ops_placeholder_value_contiguous() { UNITTEST_EQ(h_res[i], i + 2); } - cudaFree(d_ext_buffer); + UNITTEST_EQ(ark::gpuFree(d_ext_buffer), ark::gpuSuccess); return ark::unittest::SUCCESS; } diff --git a/python/executor_py.cpp b/python/executor_py.cpp index 5b4e7959..dd53af51 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -171,11 +171,20 @@ void register_executor(py::module &m) { .def("name", &ark::Executor::name) .def("compile", &ark::Executor::compile, py::arg("device_id"), py::arg("plan"), py::arg("name") = "executor") - .def("launch", [](ark::Executor *self, uintptr_t stream, bool loop_mode) { - self->launch(reinterpret_cast(stream), loop_mode); - }, - py::arg("stream") = 0, py::arg("loop_mode") = true) - .def("run", &ark::Executor::run, py::arg("iter")) + .def( + "launch", + [](ark::Executor *self, uintptr_t stream, bool loop_mode, + const std::unordered_map + &placeholder_data) { + self->launch(reinterpret_cast(stream), loop_mode, + placeholder_data); + }, + py::arg("stream") = 0, py::arg("loop_mode") = true, + py::arg("placeholder_data") = + std::unordered_map()) + .def("run", &ark::Executor::run, py::arg("iter"), + py::arg("placeholder_data") = + std::unordered_map()) .def("wait", &ark::Executor::wait, py::arg("max_spin_count") = -1) .def("stop", &ark::Executor::stop, py::arg("max_spin_count") = -1) .def("barrier", &ark::Executor::barrier) From 977ce9e848cc514eacc0b0314d096e3e8a8259e7 Mon Sep 17 00:00:00 2001 From: noli Date: Thu, 15 Aug 2024 11:04:45 +0000 Subject: [PATCH 6/7] wip adds python binding and delayed buffer binding --- ark/api/executor.cpp | 75 +++++++++++++------------ ark/api/tensor.cpp | 12 ---- ark/include/ark/executor.hpp | 10 ++-- ark/include/ark/tensor.hpp | 6 +- ark/model/model_buffer.cpp | 50 +---------------- ark/model/model_buffer.hpp | 14 ----- ark/model_buffer_manager.hpp | 31 +++++++++-- ark/ops/ops_placeholder_test.cpp | 4 +- python/ark/module.py | 12 ++-- python/ark/ops.py | 61 ++++++++++++++++++++- python/ark/runtime.py | 19 ++++++- python/ark/tensor.py | 52 +++++++----------- python/executor_py.cpp | 22 ++++++-- python/model_py.cpp | 79 +++++++++++++++++++++++++++ python/tensor_py.cpp | 71 ------------------------ python/unittest/test_conversion.py | 88 +++++++++++++++++++++++++----- 16 files changed, 351 insertions(+), 255 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 7823c324..d5be65a9 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -163,11 +163,10 @@ class Executor::Impl { const std::string &name() const { return name_; } void compile(const std::string &plan, int device_id, - const std::string &name); - void launch(Stream stream, bool loop_mode, - const std::unordered_map &external_tensors); - void run(int iter, - const std::unordered_map &external_tensors); + const std::string &name, + const std::unordered_map &external_tensors); + void launch(Stream stream, bool loop_mode); + void run(int iter); void wait(int64_t max_spin_count); float stop(int64_t max_spin_count); void barrier(); @@ -330,8 +329,8 @@ std::map Executor::Impl::init_buffer_addrs( if (!buffer_id_to_addr_.empty()) { buffer_id_to_addr = buffer_id_to_addr_; } - for (const auto &kv : buffer_id_to_offset) { - buffer_id_to_addr[kv.first] = buffer->ref(kv.second); + for (const auto &[id, offset] : buffer_id_to_offset) { + buffer_id_to_addr[id] = buffer->ref(offset); } return buffer_id_to_addr; } @@ -428,15 +427,13 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { } continue; } - if (buffer_manager_.is_external(buf_id)) { - if (buf_info->buffer->device_id() != device_id_) { - ERR(InvalidUsageError, - "PyTorch tensor and model execution are on different GPUs"); - } - external_buffers_.push_back(buffer_manager_.get_buffer(buf_id)); - const auto [it, inserted] = buffer_id_to_name_.try_emplace( - buf_id, "extern_buf_" + std::to_string(buf_id)); - external_args_.push_back(it->second); + if (buffer_manager_.is_external(buf_id) && + !buffer_manager_.is_staged(buf_id)) { + external_buffers_.push_back( + buffer_manager_.get_buffer_addr(buf_id)); + const std::string name = "extern_buf_" + std::to_string(buf_id); + external_args_.push_back(name); + buffer_id_to_name_[buf_id] = name; continue; } // if we are adding a plan and come across a buffer from a previous @@ -692,8 +689,9 @@ void Executor::Impl::init_channels(const std::set &remote_ranks) { } } -void Executor::Impl::compile(const std::string &plan, int device_id, - const std::string &name) { +void Executor::Impl::compile( + const std::string &plan, int device_id, const std::string &name, + const std::unordered_map &external_tensors) { if (is_launched_) { ERR(InvalidUsageError, "Need to stop before re-compiling."); return; @@ -704,12 +702,26 @@ void Executor::Impl::compile(const std::string &plan, int device_id, } catch (const ::nlohmann::json::parse_error &e) { ERR(InvalidUsageError, "Failed to parse the plan JSON: ", e.what()); } + for (auto &[tns, addr] : external_tensors) { + const size_t buf_id = tns.ref()->buffer()->id(); + if (buffer_manager_.is_staged(buf_id)) { + buffer_manager_.set_buffer_address(buf_id, addr); + external_buffers_.push_back(addr); + const std::string name = "extern_buf_" + std::to_string(buf_id); + external_args_.push_back(name); + buffer_id_to_name_[buf_id] = name; + } else { + ERR(InvalidUsageError, + "Cannot set the buffer address for tensor with buffer:", buf_id, + " the address is already bound. " + "Address setting is only allowed for delayed binding of " + "uninitialized buffers."); + } + } kernel_->compile(); } -void Executor::Impl::launch( - Stream stream, bool loop_mode, - const std::unordered_map &external_tensors) { +void Executor::Impl::launch(Stream stream, bool loop_mode) { if ((kernel_ == nullptr) || !kernel_->is_compiled()) { ERR(InvalidUsageError, "Need to compile first before launch."); } @@ -802,8 +814,7 @@ void Executor::Impl::launch( is_launched_ = true; } -void Executor::Impl::run( - int iter, const std::unordered_map &external_tensors) { +void Executor::Impl::run(int iter) { if (iter <= 0) return; if (loop_mode_) { while (atomicLoadRelaxed(flag_->ref()) > 0) { @@ -1021,22 +1032,18 @@ std::string Executor::plan() const { return impl_->plan(); } const std::string &Executor::name() const { return impl_->name(); } -void Executor::compile(const std::string &plan, int device_id, - const std::string &name) { - impl_->compile(plan, device_id, name); -} - -void Executor::launch( - Stream stream, bool loop_mode, +void Executor::compile( + const std::string &plan, int device_id, const std::string &name, const std::unordered_map &external_tensors) { - impl_->launch(stream, loop_mode, external_tensors); + impl_->compile(plan, device_id, name, external_tensors); } -void Executor::run(int iter, - const std::unordered_map &external_tensors) { - impl_->run(iter, external_tensors); +void Executor::launch(Stream stream, bool loop_mode) { + impl_->launch(stream, loop_mode); } +void Executor::run(int iter) { impl_->run(iter); } + void Executor::wait(int64_t max_spin_count) { impl_->wait(max_spin_count); } float Executor::stop(int64_t max_spin_count) { diff --git a/ark/api/tensor.cpp b/ark/api/tensor.cpp index 084ce638..fc44b4a5 100644 --- a/ark/api/tensor.cpp +++ b/ark/api/tensor.cpp @@ -9,18 +9,6 @@ namespace ark { -Tensor::Tensor(void* data_ptr, int32_t device_id, - const std::vector& shape, const DataType& dtype) { - size_t external_data_size = std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies()) * - dtype.bytes(); - auto buffer = - std::make_shared(data_ptr, external_data_size, device_id); - auto tensor = std::make_shared( - dtype.ref(), buffer, Dims(shape), Dims(shape), Dims(), Dims()); - ref_ = tensor; -} - size_t Tensor::id() const { if (ref_) { return ref_->id(); diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 8e5e5c85..6b4235ae 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -43,16 +43,14 @@ class Executor { /// Compile the model. This must be called before `launch()`. void compile(const std::string &plan, int device_id, - const std::string &name = "executor"); + const std::string &name = "executor", + const std::unordered_map &external_tensors = {}); /// Launch the executor. This must be called after `compile()`. - void launch( - Stream stream = nullptr, bool loop_mode = true, - const std::unordered_map &external_tensors = {}); + void launch(Stream stream = nullptr, bool loop_mode = true); /// Run the executor for `iter` iterations. - void run(int iter, - const std::unordered_map &external_tensors = {}); + void run(int iter); /// Wait for the previous run to finish. void wait(int64_t max_spin_count = -1); diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 816738c0..05dbb11f 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -31,8 +31,6 @@ class Tensor { Tensor(ModelTensorRef ref) : ref_(ref) {} Tensor(const Tensor &other) = default; Tensor &operator=(const Tensor &other) = default; - Tensor(void *data_ptr, int32_t device_id, const std::vector &shape, - const DataType &dtype); bool operator==(const Tensor &other) const { return ref_ == other.ref_; } bool operator!=(const Tensor &other) const { return ref_ != other.ref_; } @@ -67,8 +65,8 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor); namespace std { template <> struct hash { - size_t operator()(const ark::Tensor &t) const { - return hash()(t.id()); + size_t operator()(const ark::Tensor &t) const noexcept { + return t.id(); } }; } // namespace std diff --git a/ark/model/model_buffer.cpp b/ark/model/model_buffer.cpp index 5ce255ce..9f494b7a 100644 --- a/ark/model/model_buffer.cpp +++ b/ark/model/model_buffer.cpp @@ -24,23 +24,6 @@ ModelBuffer::ModelBuffer(size_t id, int rank, } } -ModelBuffer::ModelBuffer(void *data, size_t size, int32_t device_id) - : rank_(-1), - external_data_(data), - external_data_size_(size), - device_id_(device_id), - is_external_(true) { - id_ = curr_id++; -} - -ModelBuffer::ModelBuffer(size_t id, void *data, size_t size, int32_t device_id) - : id_(id), - rank_(-1), - external_data_(data), - external_data_size_(size), - device_id_(device_id), - is_external_(true) {} - void ModelBuffer::tag_send(int remote_rank, int tag) { send_tags_.insert(TagInfo{remote_rank, tag}); } @@ -63,14 +46,6 @@ Json ModelBuffer::serialize() const { } j["SendTags"] = send_tags; j["RecvTags"] = recv_tags; - j["IsExternal"] = is_external_; - if (is_external_) { - ModelBufferManager::get_instance().register_buffer(id_, external_data_, - external_data_size_); - j["ExternalDataSize"] = external_data_size_; - j["DeviceId"] = device_id_; - } - // external_data_ptr_ is not included in JSON return j; } @@ -82,30 +57,7 @@ std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { } else if (!serialized.contains("SendTags")) { ERR(ModelError, "ModelBuffer deserialization failed: missing SendTags"); } else if (!serialized.contains("RecvTags")) { - ERR(ModelError, - "ModelBuffer deserialization failed: missing RecvTags"); - } else if (!serialized.contains("IsExternal")) { - ERR(ModelError, - "ModelBuffer deserialization failed: missing IsExternal"); - } - if (serialized["IsExternal"]) { - if (!serialized.contains("ExternalDataSize")) { - ERR(ModelError, - "ModelBuffer deserialization failed: missing ExternalDataSize"); - } else if (!serialized.contains("DeviceId")) { - ERR(ModelError, - "ModelBuffer deserialization failed: missing DeviceId"); - } - void *data_ptr = - ModelBufferManager::get_instance().get_buffer(serialized["Id"]); - if (!data_ptr) { - ERR(ModelError, - "ModelBuffer deserialization failed: external buffer not found " - "in BufferManager"); - } - return std::make_shared(serialized["Id"], data_ptr, - serialized["ExternalDataSize"], - serialized["DeviceId"]); + ERR(ModelError, "ModelBuffer deserialization failed: missing RecvTags"); } return std::make_shared(serialized["Id"], serialized["Rank"], serialized["SendTags"], diff --git a/ark/model/model_buffer.hpp b/ark/model/model_buffer.hpp index e7f1045b..342b08bb 100644 --- a/ark/model/model_buffer.hpp +++ b/ark/model/model_buffer.hpp @@ -22,10 +22,6 @@ class ModelBuffer { ModelBuffer(size_t id, int rank, const std::vector &send_tags, const std::vector &recv_tags); - // externally managed buffer - ModelBuffer(void *data, size_t size, int32_t device_id); - ModelBuffer(size_t id, void *data, size_t size, int32_t device_id); - size_t id() const { return id_; } int rank() const { return rank_; } @@ -48,22 +44,12 @@ class ModelBuffer { static std::shared_ptr deserialize(const Json &serialized); - // external buffer management - size_t external_data_size() const { return external_data_size_; } - void *external_data() const { return external_data_; } - int32_t device_id() const { return device_id_; } - bool is_external() const { return is_external_; } - private: static size_t curr_id; size_t id_; int rank_; std::set send_tags_; std::set recv_tags_; - void *external_data_ = nullptr; - size_t external_data_size_ = 0; - int32_t device_id_; - bool is_external_ = false; }; } // namespace ark diff --git a/ark/model_buffer_manager.hpp b/ark/model_buffer_manager.hpp index 3e82b05f..ab8d8df9 100644 --- a/ark/model_buffer_manager.hpp +++ b/ark/model_buffer_manager.hpp @@ -7,6 +7,8 @@ #include #include +#include "logging.hpp" + namespace ark { // Manages externally allocated buffers (buffers corresponding to Tensors that // are the output of a `placeholder` operation) outside of ARK's memory space. @@ -17,19 +19,35 @@ class ModelBufferManager { return instance; } - void register_buffer(const size_t id, void* const data, const size_t size) { + void register_buffer(size_t id, void* const data, size_t size) { buffers_[id] = std::make_tuple(data, size); } - void* get_buffer(const size_t id) const { + void* get_buffer_addr(size_t id) const { auto it = buffers_.find(id); if (it != buffers_.end()) { return std::get<0>(it->second); } + ERR(InvalidUsageError, "Tensor with buffer ID: ", id, + " , is not registered in the ModelBufferManager. Be sure to " + "register the tensor as an external tensor first (pass the tensor " + "into a placeholder operation)."); return nullptr; } - size_t get_buffer_size(const size_t id) const { + void set_buffer_address(size_t id, void* const new_address) { + void* curr_addr = get_buffer_addr(id); + if (curr_addr != nullptr) { + ERR(InvalidUsageError, + "Cannot set the buffer address for tensor with buffer: ", id, + " , the address is already bound. " + "Address setting is only allowed for delayed binding of " + "uninitialized buffers."); + } + std::get<0>(buffers_[id]) = new_address; + } + + size_t get_buffer_size(size_t id) const { auto it = buffers_.find(id); if (it != buffers_.end()) { return std::get<1>(it->second); @@ -37,10 +55,15 @@ class ModelBufferManager { return 0; } - bool is_external(const size_t id) const { + bool is_external(size_t id) const { return buffers_.find(id) != buffers_.end(); } + bool is_staged(size_t id) const { + const void* curr_addr = get_buffer_addr(id); + return curr_addr == nullptr; + } + const std::unordered_map>& get_buffers() const { return buffers_; diff --git a/ark/ops/ops_placeholder_test.cpp b/ark/ops/ops_placeholder_test.cpp index 59f5e2dc..7610ee61 100644 --- a/ark/ops/ops_placeholder_test.cpp +++ b/ark/ops/ops_placeholder_test.cpp @@ -7,7 +7,7 @@ #include "model/model_op.hpp" #include "ops_test_common.hpp" -ark::unittest::State test_ops_placeholder_value_contiguous() { +ark::unittest::State test_ops_placeholder() { ark::Model model; ark::Dims shape{10, 1}; @@ -42,7 +42,7 @@ ark::unittest::State test_ops_placeholder_value_contiguous() { UNITTEST_EQ(h_res[i], i + 2); } - cudaFree(d_ext_buffer); + ark::gpuFree(d_ext_buffer); return ark::unittest::SUCCESS; } diff --git a/python/ark/module.py b/python/ark/module.py index 49d2ddf0..4809ea43 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -4,13 +4,13 @@ import logging import numpy as np from typing import Any, Dict, Union -from .tensor import Tensor, Parameter +from .tensor import Parameter from .runtime import Runtime -from .init import init from .model import Model try: import torch + from .ops import placeholder _no_torch = False except ImportError: @@ -43,7 +43,7 @@ def __setattr__(self, __name: str, __value: Any) -> None: elif isinstance(__value, Parameter): self.register_parameter(__name, __value) elif not _no_torch and isinstance(__value, torch.nn.Parameter): - __value = Parameter(__value) + __value = Parameter(placeholder(torch_tensor=__value), True) self.register_parameter(__name, __value) super().__setattr__(__name, __value) @@ -151,14 +151,14 @@ def forward(ctx, ark_module, *args, **kwargs): input_requires_grad = 0 for arg in args: if isinstance(arg, torch.Tensor): - input_args.append(Tensor.from_torch(arg)) + input_args.append(placeholder(torch_tensor=arg)) if arg.requires_grad: input_requires_grad += 1 else: input_args.append(arg) for k, v in kwargs.items(): if isinstance(v, torch.Tensor): - input_kwargs[k] = Tensor.from_torch(v) + input_kwargs[k] = placeholder(torch_tensor=v) if v.requires_grad: input_requires_grad += 1 else: @@ -180,7 +180,7 @@ def backward(ctx, *grad_outputs): PyTorch parameters. """ Model.reset() - ark_grad_outputs = [Tensor.from_torch(grad) for grad in grad_outputs] + ark_grad_outputs = [placeholder(torch_tensor=grad) for grad in grad_outputs] grads = ctx.ark_module.backward(*ark_grad_outputs) grad_inputs, grad_weights = ( grads[: ctx.num_inp_grad], diff --git a/python/ark/ops.py b/python/ark/ops.py index f8b75a70..be145eb1 100644 --- a/python/ark/ops.py +++ b/python/ark/ops.py @@ -1,12 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import List, Iterable, Union +from typing import List, Iterable, Union, Optional from .tensor import Dims, Tensor, Parameter, NullTensor from .data_type import DataType, fp32 from .model import Model +try: + import torch + + _no_torch = False +except ImportError: + from . import torch_mock as torch + + _no_torch = True + def _is_list_or_tuple(obj): return isinstance(obj, list) or isinstance(obj, tuple) @@ -48,6 +57,55 @@ def _tensor( ) +def placeholder( + shape: Optional[Iterable[int]] = None, + dtype: Optional[DataType] = None, + torch_tensor: Optional[torch.Tensor] = None, + strides: Iterable[int] = [], + offsets: Iterable[int] = [], + padded_shape: Iterable[int] = [], + rank: int = -1, + name: str = "", +) -> Tensor: + if torch_tensor is not None: + if any( + (arg is not None and arg != []) + for arg in [shape, dtype, strides, offsets, padded_shape] + ): + raise ValueError( + "shape, dtype, strides, offsets, and padded_shape should not " + "be provided as they are inferred from the torch tensor." + ) + dl_tensor = torch.utils.dlpack.to_dlpack(torch_tensor) + return Tensor(Model.get_model().placeholder( + external_tensor=dl_tensor, + rank=rank, + name=name, + )) + if not _is_list_or_tuple(shape): + raise ValueError("shape should be a list or tuple of integers") + if not _is_list_or_tuple(strides): + raise ValueError("strides should be a list or tuple of integers") + if not _is_list_or_tuple(offsets): + raise ValueError("offsets should be a list or tuple of integers") + if not _is_list_or_tuple(padded_shape): + raise ValueError("padded_shape should be a list or tuple of integers") + # only support tensors with up to 4 dimensions + if any(len(arg) > 4 for arg in (shape, strides, offsets, padded_shape)): + raise ValueError("Only support tensors with up to 4 dimensions") + print(shape) + return Tensor(Model.get_model().placeholder( + Dims(shape), + dtype.ctype(), + Dims(strides), + Dims(offsets), + Dims(padded_shape), + rank, + name, + None, + )) + + def add( input: Union[Tensor, float], other: Union[Tensor, float], @@ -630,6 +688,7 @@ def all_reduce( __all__ = [ "tensor", + "placeholder", "parameter", "reshape", "identity", diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 1490cdeb..2cbed861 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -6,6 +6,15 @@ from _ark_core import _Executor from .planner import Planner, Plan +from typing import Dict +try: + import torch + + _no_torch = False +except ImportError: + from . import torch_mock as torch + + _no_torch = True class _RuntimeState: @@ -73,6 +82,7 @@ def launch( device_id: int = 0, stream: int = 0, loop_mode: bool = True, + tensor_mappings: Dict = {} ): """ Create an executor and schedule the ARK model. The scheduler will generate @@ -87,6 +97,12 @@ def launch( if self.launched(): # Stop the current running model self.stop() + + for ark_tensor in tensor_mappings: + torch_tensor = tensor_mappings[ark_tensor] + if not isinstance(torch_tensor, torch.Tensor): + raise ValueError("Must bind PyTorch tensor") + tensor_mappings[ark_tensor] = torch_tensor.data_ptr() # Recompile if the previous launch was not compiled with the same info # or if this is the first launch @@ -94,8 +110,7 @@ def launch( plan_str != self.executor.plan() or device_id != self.executor.device_id() ): - self.executor.compile(plan_str, device_id) - + self.executor.compile(plan_str, device_id, tensor_mappings) self.executor.launch(stream, loop_mode) self.state = Runtime.State.LaunchedNotRunning diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 9211f7d9..348962c4 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -7,7 +7,6 @@ from _ark_core import _Dims, _Tensor, _NullTensor from .data_type import DataType from .runtime import Runtime -from .model import Model try: import torch @@ -45,6 +44,15 @@ def __init__( self._tensor = _tensor self.initializer: Initializer = initializer self.requires_grad = requires_grad + + def __hash__(self): + return self._tensor.id() + + def __eq__(self, other): + if not isinstance(other, Tensor): + return False + return self._tensor.id() == other._tensor.id() + def shape(self) -> List[int]: """ @@ -132,13 +140,6 @@ def to_dlpack(self): ) return rt.executor.tensor_to_dlpack(self._tensor) - @staticmethod - def from_dlpack(ext_tensor) -> "Tensor": - """ - Copies the tensor from a DLPack tensor to the device. - """ - return Tensor(_Tensor(ext_tensor)) - def to_torch(self) -> torch.Tensor: """ Returns a torch tensor that shares the same memory with the device tensor. @@ -151,22 +152,6 @@ def to_torch(self) -> torch.Tensor: torch_view.__ark_buffer__ = dl_capsule return torch_view - @staticmethod - def from_torch(tensor: torch.Tensor) -> "Tensor": - """ - Returns an ARK tensor that shares the same memory with the torch tensor. - """ - if _no_torch: - raise ImportError("torch is not available") - elif not tensor.is_contiguous(): - raise ValueError("Torch tensor must be contiguous.") - elif tensor.device.type == "cpu": - raise ValueError("Torch tensor must be on a device.") - ark_tensor = Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)) - # Share ownership of the memory with the torch tensor - ark_tensor.__torch_buffer__ = tensor - return ark_tensor - def copy( self, data: Union[np.ndarray, torch.Tensor], stream: int = 0 ) -> "Tensor": @@ -216,33 +201,36 @@ def initialize(self) -> "Tensor": return self -class Parameter(Tensor, torch.nn.Parameter): +class Parameter(Tensor): """ A tensor as a parameter. """ def __init__( self, - tensor: Union[_Tensor, "torch.nn.Parameter"], + tensor: _Tensor, + from_torch: bool, ): """ Initializes a new instance of the Parameter class. + Args: + _tensor (_ark_core._Tensor): The underlying _Tensor object. + from_torch: Indicates if the Parameter is tied to a torch.nn.Paramter """ - if not _no_torch and isinstance(tensor, torch.nn.Parameter): - ark_tensor = Tensor.from_torch(tensor) - core_tensor = ark_tensor._tensor + if not _no_torch and from_torch: + _tensor = tensor._tensor self.torch_param = tensor self.staged_tensor = None Tensor.__init__( self, - core_tensor, + _tensor, requires_grad=tensor.requires_grad, ) elif isinstance(tensor, _Tensor): - core_tensor = tensor + _tensor = tensor self.torch_param = None self.staged_tensor = None - Tensor.__init__(self, core_tensor, requires_grad=False) + Tensor.__init__(self, _tensor, requires_grad=False) else: raise TypeError( "tensor must be an ARK tensor or a torch.nn.Parameter" diff --git a/python/executor_py.cpp b/python/executor_py.cpp index 5b4e7959..08fc9488 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -8,6 +8,7 @@ #include #include +#include #include "gpu/gpu_memory.hpp" #include "logging.hpp" @@ -134,7 +135,8 @@ DLTensor SharedTensor::dl_tensor() const { } // namespace ark -static py::capsule tensor_to_dlpack(ark::Executor &self, const ark::Tensor &tensor) { +static py::capsule tensor_to_dlpack(ark::Executor &self, + const ark::Tensor &tensor) { auto shared_tensor = new ark::SharedTensor(self, tensor); DLManagedTensor *dl_managed_tensor = new DLManagedTensor(); dl_managed_tensor->dl_tensor = shared_tensor->dl_tensor(); @@ -146,8 +148,9 @@ static py::capsule tensor_to_dlpack(ark::Executor &self, const ark::Tensor &tens } }; const char *capsule_name = "dltensor"; - PyObject *dl_capsule = PyCapsule_New(static_cast(dl_managed_tensor), - capsule_name, [](PyObject *capsule) { + PyObject *dl_capsule = PyCapsule_New( + static_cast(dl_managed_tensor), capsule_name, + [](PyObject *capsule) { const char *name = PyCapsule_GetName(capsule); auto *dl_managed_tensor = static_cast( PyCapsule_GetPointer(capsule, name)); @@ -169,8 +172,17 @@ void register_executor(py::module &m) { }) .def("plan", &ark::Executor::plan) .def("name", &ark::Executor::name) - .def("compile", &ark::Executor::compile, py::arg("device_id"), - py::arg("plan"), py::arg("name") = "executor") + .def("compile", + [](ark::Executor *self, int device_id, std::string &plan, const std::string &name, + const std::unordered_map &external_tensors) { + std::unordered_map tensor_map; + for (const auto &[tensor, ptr] : external_tensors) { + tensor_map[tensor] = reinterpret_cast(ptr); + } + self->compile(plan, device_id, name, tensor_map); + }, + py::arg("device_id"), py::arg("plan"), py::arg("name") = "executor", + py::arg("external_tensors") = std::unordered_map()) .def("launch", [](ark::Executor *self, uintptr_t stream, bool loop_mode) { self->launch(reinterpret_cast(stream), loop_mode); }, diff --git a/python/model_py.cpp b/python/model_py.cpp index c224a3d5..d1150e48 100644 --- a/python/model_py.cpp +++ b/python/model_py.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include #include @@ -8,8 +9,65 @@ #include #include +#include "logging.hpp" + namespace py = pybind11; +struct DLTensorMetadata { + void *data_ptr; + int32_t device_id; + DLDeviceType device_type; + int32_t ndim; + DLDataType dtype; + std::vector shape; + std::vector strides; + uint64_t byte_offset; +}; + +static DLTensorMetadata extractDLTensorMetadata(DLManagedTensor *dl_tensor) { + DLTensorMetadata metadata; + metadata.data_ptr = dl_tensor->dl_tensor.data; + metadata.device_id = dl_tensor->dl_tensor.device.device_id; + metadata.device_type = dl_tensor->dl_tensor.device.device_type; + metadata.ndim = dl_tensor->dl_tensor.ndim; + metadata.dtype = dl_tensor->dl_tensor.dtype; + metadata.shape.assign( + dl_tensor->dl_tensor.shape, + dl_tensor->dl_tensor.shape + dl_tensor->dl_tensor.ndim); + if (dl_tensor->dl_tensor.strides != nullptr) { + metadata.strides.assign( + dl_tensor->dl_tensor.strides, + dl_tensor->dl_tensor.strides + dl_tensor->dl_tensor.ndim); + } + metadata.byte_offset = dl_tensor->dl_tensor.byte_offset; + return metadata; +} + +static ark::DataType from_dl_dtype(const DLDataType &dl_dtype) { + if (dl_dtype.lanes != 1) { + ERR(ark::UnsupportedError, "unsupported data type"); + } + ark::DataType ark_dtype; + if (dl_dtype.code == kDLFloat && dl_dtype.bits == 32) { + ark_dtype = ark::FP32; + } else if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) { + ark_dtype = ark::FP16; + } else if (dl_dtype.code == kDLBfloat && dl_dtype.bits == 16) { + ark_dtype = ark::BF16; + } else if (dl_dtype.code == kDLInt && dl_dtype.bits == 32) { + ark_dtype = ark::INT32; + } else if (dl_dtype.code == kDLUInt && dl_dtype.bits == 32) { + ark_dtype = ark::UINT32; + } else if (dl_dtype.code == kDLInt && dl_dtype.bits == 8) { + ark_dtype = ark::INT8; + } else if (dl_dtype.code == kDLUInt && dl_dtype.bits == 8) { + ark_dtype = ark::UINT8; + } else { + ERR(ark::UnsupportedError, "unsupported data type"); + } + return ark_dtype; +} + void register_model(py::module &m) { py::class_(m, "_Model") .def(py::init(), py::arg("rank"), py::arg("world_size")) @@ -112,6 +170,27 @@ void register_model(py::module &m) { py::arg("shape"), py::arg("data_type"), py::arg("strides"), py::arg("offsets"), py::arg("padded_shape"), py::arg("rank"), py::arg("name")) + .def("placeholder", + py::overload_cast(&ark::Model::placeholder), + py::arg("shape"), py::arg("data_type"), py::arg("strides"), + py::arg("offsets"), py::arg("padded_shape"), py::arg("rank"), + py::arg("name"), py::arg("external_data")) + .def( + "placeholder", + [](ark::Model &self, py::capsule input, int rank, + const std::string &name) { + DLManagedTensor *dl_tensor = + static_cast(input.get_pointer()); + DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); + ark::DataType ark_dtype = from_dl_dtype(metadata.dtype); + ark::Dims shape(metadata.shape); + return self.placeholder(shape, ark_dtype, {}, {}, {}, rank, + name, metadata.data_ptr); + }, + py::arg("external_tensor"), py::arg("rank"), py::arg("name")) .def("transpose", &ark::Model::transpose, py::arg("input"), py::arg("permutation"), py::arg("output"), py::arg("name")) .def("all_reduce", &ark::Model::all_reduce, py::arg("input"), diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp index 5abb35c6..5c28563d 100644 --- a/python/tensor_py.cpp +++ b/python/tensor_py.cpp @@ -1,87 +1,16 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include #include #include #include #include -#include "logging.hpp" - namespace py = pybind11; -struct DLTensorMetadata { - void* data_ptr; - int32_t device_id; - DLDeviceType device_type; - int32_t ndim; - DLDataType dtype; - std::vector shape; - std::vector strides; - uint64_t byte_offset; -}; - -static DLTensorMetadata extractDLTensorMetadata(DLManagedTensor* dl_tensor) { - DLTensorMetadata metadata; - metadata.data_ptr = dl_tensor->dl_tensor.data; - metadata.device_id = dl_tensor->dl_tensor.device.device_id; - metadata.device_type = dl_tensor->dl_tensor.device.device_type; - metadata.ndim = dl_tensor->dl_tensor.ndim; - metadata.dtype = dl_tensor->dl_tensor.dtype; - metadata.shape.assign( - dl_tensor->dl_tensor.shape, - dl_tensor->dl_tensor.shape + dl_tensor->dl_tensor.ndim); - if (dl_tensor->dl_tensor.strides != nullptr) { - metadata.strides.assign( - dl_tensor->dl_tensor.strides, - dl_tensor->dl_tensor.strides + dl_tensor->dl_tensor.ndim); - } - metadata.byte_offset = dl_tensor->dl_tensor.byte_offset; - return metadata; -} - -static ark::DataType from_dl_dtype(const DLDataType &dl_dtype) { - if (dl_dtype.lanes != 1) { - ERR(ark::UnsupportedError, "unsupported data type"); - } - ark::DataType ark_dtype; - if (dl_dtype.code == kDLFloat && dl_dtype.bits == 32) { - ark_dtype = ark::FP32; - } else if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) { - ark_dtype = ark::FP16; - } else if (dl_dtype.code == kDLBfloat && dl_dtype.bits == 16) { - ark_dtype = ark::BF16; - } else if (dl_dtype.code == kDLInt && dl_dtype.bits == 32) { - ark_dtype = ark::INT32; - } else if (dl_dtype.code == kDLUInt && dl_dtype.bits == 32) { - ark_dtype = ark::UINT32; - } else if (dl_dtype.code == kDLInt && dl_dtype.bits == 8) { - ark_dtype = ark::INT8; - } else if (dl_dtype.code == kDLUInt && dl_dtype.bits == 8) { - ark_dtype = ark::UINT8; - } else { - ERR(ark::UnsupportedError, "unsupported data type"); - } - return ark_dtype; -} - void register_tensor(py::module& m) { py::class_(m, "_Tensor") - .def(py::init([](py::capsule capsule) { - DLManagedTensor* dl_tensor = (DLManagedTensor*)capsule; - if (!dl_tensor) { - ERR(ark::InvalidUsageError, - "Capsule does not contain a DLManagedTensor"); - } - DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); - int32_t device_id = metadata.device_id; - void* data_ptr = metadata.data_ptr; - auto shape = metadata.shape; - - return ark::Tensor(data_ptr, device_id, shape, from_dl_dtype(metadata.dtype)); - })) .def("id", &ark::Tensor::id) .def("shape", &ark::Tensor::shape) .def("strides", &ark::Tensor::strides) diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py index 833b8866..83fb77b3 100644 --- a/python/unittest/test_conversion.py +++ b/python/unittest/test_conversion.py @@ -37,9 +37,9 @@ def test_values_fixed_dims(num_dims: int, size: int, dtype: ark.DataType): input_tensor.from_numpy(input_tensor_host) other_tensor.from_numpy(other_tensor_host) - input_view = input_tensor.get_torch_view() - other_view = other_tensor.get_torch_view() - output_view = output_tensor.get_torch_view() + input_view = input_tensor.to_torch() + other_view = other_tensor.to_torch() + output_view = output_tensor.to_torch() runtime.run() @@ -50,7 +50,7 @@ def test_values_fixed_dims(num_dims: int, size: int, dtype: ark.DataType): output_tensor_host = output_tensor.to_numpy() runtime.stop() - runtime.delete_all_runtimes() + runtime.reset() assert np.allclose(input_tensor_host, input_view_numpy) assert np.allclose(other_tensor_host, other_view_numpy) @@ -83,9 +83,9 @@ def test_ark_to_torch_aliasing(dtype: ark.DataType): input_tensor.from_numpy(input_tensor_host) other_tensor.from_numpy(other_tensor_host) - input_view = input_tensor.get_torch_view() - other_view = other_tensor.get_torch_view() - output_view = output_tensor.get_torch_view() + input_view = input_tensor.to_torch() + other_view = other_tensor.to_torch() + output_view = output_tensor.to_torch() # make changes to the views input_view[1, 1] = 20 other_view[0, 0] = 30 @@ -105,7 +105,7 @@ def test_ark_to_torch_aliasing(dtype: ark.DataType): runtime.stop() runtime.reset() - +pytest.mark.skip() def test_conversion_torch(): if _no_torch: pytest.skip("PyTorch not available") @@ -149,8 +149,8 @@ def test_bin_op(dtype, ark_op: ArkBinOp, torch_op: TorchBinOp, tensor_dims): input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") expected_output = torch_op(input_tensor, other_tensor).cpu().numpy() - input_ark_view = ark.Tensor.from_torch(input_tensor) - other_ark_view = ark.Tensor.from_torch(other_tensor) + input_ark_view = ark.placeholder(torch_tensor=input_tensor) + other_ark_view = ark.placeholder(torch_tensor=other_tensor) output = ark_op(input_ark_view, other_ark_view) runtime = ark.Runtime() runtime.launch() @@ -170,7 +170,7 @@ def test_unary_op(dtype, ark_op: ArkUnOp, torch_op: TorchUnOp, tensor_dims): ark.init() input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") expected_output = torch_op(input_tensor).cpu().numpy() - input_ark_view = ark.Tensor.from_torch(input_tensor) + input_ark_view = ark.placeholder(torch_tensor=input_tensor) output = ark_op(input_ark_view) runtime = ark.Runtime() runtime.launch() @@ -189,8 +189,8 @@ def test_torch_to_ark_aliasing(dtype, tensor_dims): input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") - input_ark_view = ark.Tensor.from_torch(input_tensor) - other_ark_view = ark.Tensor.from_torch(other_tensor) + input_ark_view = ark.placeholder(torch_tensor=input_tensor) + other_ark_view = ark.placeholder(torch_tensor=other_tensor) output = ark.add(input_ark_view, other_ark_view) # Perform in place operations @@ -205,3 +205,65 @@ def test_torch_to_ark_aliasing(dtype, tensor_dims): runtime.stop() runtime.reset() assert np.allclose(output_host, expected_output) + + +# Staged View Tests + + +@pytest.mark.parametrize( + "dtype, ark_op, torch_op, tensor_dims", + [(torch.float16, ark.add, torch.add, (2, 3))], +) +def test_bin_op_staged( + dtype, ark_op: ArkBinOp, torch_op: TorchBinOp, tensor_dims +): + ark.init() + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch_op(input_tensor, other_tensor).cpu().numpy() + input_ark_view = ark.placeholder( + shape=tensor_dims, dtype=ark.DataType.from_torch(dtype) + ) + other_ark_view = ark.placeholder( + shape=tensor_dims, dtype=ark.DataType.from_torch(dtype) + ) + output = ark_op(input_ark_view, other_ark_view) + runtime = ark.Runtime() + tensor_mapping = { + input_ark_view: input_tensor, + other_ark_view: other_tensor, + } + runtime.launch(tensor_mappings=tensor_mapping) + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) + +test_bin_op_staged(torch.float16, ark.add, torch.add, (2, 3)) + + +@pytest.mark.parametrize( + "dtype, ark_op, torch_op, tensor_dims", + [(torch.float16, ark.exp, torch.exp, (3, 3))], +) +def test_unary_op_staged( + dtype, ark_op: ArkUnOp, torch_op: TorchUnOp, tensor_dims +): + ark.init() + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch_op(input_tensor).cpu().numpy() + input_ark_view = ark.placeholder( + shape=tensor_dims, dtype=ark.DataType.from_torch(dtype) + ) + output = ark_op(input_ark_view) + runtime = ark.Runtime() + tensor_mapping = {input_ark_view: input_tensor} + runtime.launch() + runtime.run(tensor_mappings=tensor_mapping) + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) + +test_unary_op_staged(torch.float16, ark.exp, torch.exp, (3, 3)) From d0a18361ef0db3f7447d83d4110f690e163957ac Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 15 Aug 2024 11:46:21 +0000 Subject: [PATCH 7/7] rename & remove unneeded code & python interface --- ark/api/executor.cpp | 35 ++++++++-------- ark/api/tensor.cpp | 12 ------ ark/codegen.cpp | 2 +- ark/codegen.hpp | 1 - ark/external_buffer_registry.cpp | 32 +++++++++++++++ ark/external_buffer_registry.hpp | 31 ++++++++++++++ ark/gpu/gpu.hpp | 4 ++ ark/include/ark/model.hpp | 8 ++-- ark/include/ark/tensor.hpp | 4 -- ark/model/model_buffer.cpp | 62 +++++----------------------- ark/model/model_buffer.hpp | 22 +++------- ark/model_buffer_manager.hpp | 62 ---------------------------- ark/ops/ops_placeholder.cpp | 28 +++++-------- ark/ops/ops_placeholder.hpp | 2 +- ark/ops/ops_placeholder_test.cpp | 2 +- python/ark/ops.py | 62 +++++++++++----------------- python/ark/tensor.py | 70 ++++++++++++++++++++++++++++++-- python/model_py.cpp | 24 +++++++---- python/tensor_py.cpp | 25 ++++++------ 19 files changed, 237 insertions(+), 251 deletions(-) create mode 100644 ark/external_buffer_registry.cpp create mode 100644 ark/external_buffer_registry.hpp delete mode 100644 ark/model_buffer_manager.hpp diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 47a7a751..06f31e67 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -15,6 +15,7 @@ #include "ark/planner.hpp" #include "codegen.hpp" #include "env.h" +#include "external_buffer_registry.hpp" #include "file_io.h" #include "gpu/gpu.hpp" #include "gpu/gpu_event.hpp" @@ -25,8 +26,6 @@ #include "model/model_buffer.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" -#include "model_buffer_manager.hpp" -#include "unordered_map" #include "utils/utils_net.hpp" #if defined(ARK_CUDA) @@ -408,7 +407,7 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { std::map> remote_rank_to_send_tag_to_buffer_id; std::map> remote_rank_to_recv_tag_to_buffer_id; - auto &buffer_manager = ModelBufferManager::get_instance(); + auto &ext_buf_reg = ExternalBufferRegistry::get_instance(); // TODO: improve memory planning size_t offset = 0; @@ -428,12 +427,16 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { } continue; } - if (buffer_manager.is_external(buf_id)) { - if (buf_info->buffer->device_id() != device_id_) { + void *ext_data = ext_buf_reg.get(buf_id); + if (ext_data) { + gpuPointerAttributes attr; + GLOG(gpuPointerGetAttributes(&attr, ext_data)); + if (attr.device != device_id_) { ERR(InvalidUsageError, - "PyTorch tensor and model execution are on different GPUs"); + "External data provided is on a different GPU: ", + attr.device, " vs ", device_id_); } - external_buffers_.push_back(buffer_manager.get_buffer(buf_id)); + external_buffers_.push_back(ext_data); const auto [it, inserted] = buffer_id_to_name_.try_emplace( buf_id, "extern_buf_" + std::to_string(buf_id)); external_args_.push_back(it->second); @@ -540,7 +543,8 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { for (int i = 0; i < len; ++i) { const size_t buf_id = buffer_id_to_info[send_tag_to_buffer_id[tags[i]]]->buffer->id(); - if (!buffer_manager.is_external(buf_id)) { + void *buf_data = ext_buf_reg.get(buf_id); + if (buf_data == nullptr) { buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = offsets[i]; } @@ -561,7 +565,8 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { for (int i = 0; i < len; ++i) { const size_t buf_id = buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]]->buffer->id(); - if (!buffer_manager.is_external(buf_id)) { + void *buf_data = ext_buf_reg.get(buf_id); + if (buf_data == nullptr) { buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = offsets[i]; } @@ -884,9 +889,10 @@ void Executor::Impl::barrier() { void *Executor::Impl::tensor_address(const Tensor &tensor) const { size_t buffer_id = tensor.ref()->buffer()->id(); - auto &buffer_manager = ModelBufferManager::get_instance(); - if (buffer_manager.is_external(buffer_id)) { - return buffer_manager.get_buffer(buffer_id); + auto &ext_buf_reg = ExternalBufferRegistry::get_instance(); + void *ext_data = ext_buf_reg.get(buffer_id); + if (ext_data) { + return ext_data; } if (buffer_id_to_addr_.find(buffer_id) == buffer_id_to_addr_.end()) { ERR(InvalidUsageError, "Tensor has an unknown buffer ID ", buffer_id, @@ -1041,10 +1047,7 @@ float Executor::stop(int64_t max_spin_count) { void Executor::barrier() { impl_->barrier(); } -void Executor::destroy() { - ModelBufferManager::get_instance().clear_buffers(); - impl_.reset(nullptr); -} +void Executor::destroy() { impl_.reset(nullptr); } bool Executor::destroyed() const { return impl_.get() == nullptr; } diff --git a/ark/api/tensor.cpp b/ark/api/tensor.cpp index 084ce638..fc44b4a5 100644 --- a/ark/api/tensor.cpp +++ b/ark/api/tensor.cpp @@ -9,18 +9,6 @@ namespace ark { -Tensor::Tensor(void* data_ptr, int32_t device_id, - const std::vector& shape, const DataType& dtype) { - size_t external_data_size = std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies()) * - dtype.bytes(); - auto buffer = - std::make_shared(data_ptr, external_data_size, device_id); - auto tensor = std::make_shared( - dtype.ref(), buffer, Dims(shape), Dims(shape), Dims(), Dims()); - ref_ = tensor; -} - size_t Tensor::id() const { if (ref_) { return ref_->id(); diff --git a/ark/codegen.cpp b/ark/codegen.cpp index 2bd36d67..4a1c1ed8 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -7,13 +7,13 @@ #include "ark/data_type.hpp" #include "env.h" +#include "external_buffer_registry.hpp" #include "file_io.h" #include "logging.hpp" #include "model/model_buffer.hpp" #include "model/model_data_type.hpp" #include "model/model_op.hpp" #include "model/model_tensor.hpp" -#include "model_buffer_manager.hpp" #include "range.hpp" #include "utils/utils_math.hpp" diff --git a/ark/codegen.hpp b/ark/codegen.hpp index 8a4eed27..89d89080 100644 --- a/ark/codegen.hpp +++ b/ark/codegen.hpp @@ -9,7 +9,6 @@ #include #include "model/model_json.hpp" -#include "model_buffer_manager.hpp" namespace ark { diff --git a/ark/external_buffer_registry.cpp b/ark/external_buffer_registry.cpp new file mode 100644 index 00000000..450dd332 --- /dev/null +++ b/ark/external_buffer_registry.cpp @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "external_buffer_registry.hpp" + +#include "logging.hpp" + +namespace ark { + +ExternalBufferRegistry &ExternalBufferRegistry::get_instance() { + static ExternalBufferRegistry instance; + return instance; +} + +void ExternalBufferRegistry::set(const size_t id, void *data) { + if (data == nullptr) { + ERR(InternalError, "data is nullptr."); + } + buffers_[id] = data; +} + +void *ExternalBufferRegistry::get(const size_t id) const { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return it->second; + } + return nullptr; +} + +void ExternalBufferRegistry::clear() { buffers_.clear(); } + +} // namespace ark diff --git a/ark/external_buffer_registry.hpp b/ark/external_buffer_registry.hpp new file mode 100644 index 00000000..ab199baf --- /dev/null +++ b/ark/external_buffer_registry.hpp @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_EXTERNAL_BUFFER_REGISTRY_HPP_ +#define ARK_EXTERNAL_BUFFER_REGISTRY_HPP_ + +#include + +namespace ark { +// Manages externally allocated buffers (buffers corresponding to Tensors that +// are the output of a `placeholder` operation) outside of ARK's memory space. +class ExternalBufferRegistry { + public: + static ExternalBufferRegistry &get_instance(); + + void set(const size_t id, void *data); + + void *get(const size_t id) const; + + void clear(); + + private: + // Maps buffer IDs to pointers and sizes. + std::unordered_map buffers_; + ExternalBufferRegistry() {} + ExternalBufferRegistry(const ExternalBufferRegistry &) = delete; + ExternalBufferRegistry &operator=(const ExternalBufferRegistry &) = delete; +}; +} // namespace ark + +#endif // ARK_EXTERNAL_BUFFER_REGISTRY_HPP_ diff --git a/ark/gpu/gpu.hpp b/ark/gpu/gpu.hpp index 531d6c7e..8ff3b284 100644 --- a/ark/gpu/gpu.hpp +++ b/ark/gpu/gpu.hpp @@ -53,6 +53,8 @@ ARK_GPU_DEFINE_TYPE_ALIAS(gpuModule, CUmodule, hipModule_t); ARK_GPU_DEFINE_TYPE_ALIAS(gpuFunction, CUfunction, hipFunction_t); ARK_GPU_DEFINE_TYPE_ALIAS(gpuFunctionAttribute, CUfunction_attribute, hipFunction_attribute); +ARK_GPU_DEFINE_TYPE_ALIAS(gpuPointerAttributes, cudaPointerAttributes, + hipPointerAttributes); // runtime API ARK_GPU_DEFINE_CONSTANT_ALIAS(gpuSuccess, cudaSuccess, hipSuccess); @@ -126,6 +128,8 @@ ARK_GPU_DEFINE_CONSTANT_ALIAS(gpuPointerAttributeSyncMemops, ARK_GPU_DEFINE_FUNC_ALIAS(gpuGetErrorString, cudaGetErrorString, hipGetErrorString); ARK_GPU_DEFINE_FUNC_ALIAS(gpuGetLastError, cudaGetLastError, hipGetLastError); +ARK_GPU_DEFINE_FUNC_ALIAS(gpuPointerGetAttributes, cudaPointerGetAttributes, + hipPointerGetAttributes); ARK_GPU_DEFINE_FUNC_ALIAS(gpuDeviceGetAttribute, cudaDeviceGetAttribute, hipDeviceGetAttribute); ARK_GPU_DEFINE_FUNC_ALIAS(gpuDeviceSynchronize, cudaDeviceSynchronize, diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index 08b8fe63..e1b1f462 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -97,17 +97,15 @@ class Model : public ModelGraph { /// padded shape is not provided, it is set to the @p shape. /// @param rank Rank of the tensor. -1 means the rank of this model. /// @param name Name of the tensor. - /// @param external_data Pointer to an external data buffer. If provided, - /// this buffer is registered with the ModelBufferManager and associated + /// @param data Address of data to pass through placeholder. If provided, + /// this buffer is registered with the ExternalBufferRegistry and associated /// with the tensor. /// @return Pointer to a tensor object that references the external buffer. /// - /// Tensor placeholder(const Dims &shape, const DataType &data_type, const Dims &strides = {}, const Dims &offsets = {}, const Dims &padded_shape = {}, int rank = -1, - const std::string &name = "", - void *external_data = nullptr); + void *data = nullptr, const std::string &name = ""); Tensor refer(Tensor input, const Dims &shape = {}, const Dims &strides = {}, const Dims &offsets = {}, const Dims &padded_shape = {}, diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 72ff9ff5..8d658297 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -31,8 +31,6 @@ class Tensor { Tensor(ModelTensorRef ref) : ref_(ref) {} Tensor(const Tensor &other) = default; Tensor &operator=(const Tensor &other) = default; - Tensor(void *data_ptr, int32_t device_id, const std::vector &shape, - const DataType &dtype); bool operator==(const Tensor &other) const { return ref_ == other.ref_; } bool operator!=(const Tensor &other) const { return ref_ != other.ref_; } @@ -54,8 +52,6 @@ class Tensor { const DataType &data_type() const; Dims torch_strides() const; - - friend struct std::hash; }; const Tensor NullTensor; diff --git a/ark/model/model_buffer.cpp b/ark/model/model_buffer.cpp index 5ce255ce..5e240953 100644 --- a/ark/model/model_buffer.cpp +++ b/ark/model/model_buffer.cpp @@ -3,19 +3,22 @@ #include "model_buffer.hpp" +#include "external_buffer_registry.hpp" #include "logging.hpp" -#include "model_buffer_manager.hpp" namespace ark { size_t ModelBuffer::curr_id = 0; -ModelBuffer::ModelBuffer(int rank) : rank_(rank) { id_ = curr_id++; } +ModelBuffer::ModelBuffer(int rank, bool is_external) + : rank_(rank), is_external_(is_external) { + id_ = curr_id++; +} -ModelBuffer::ModelBuffer(size_t id, int rank, +ModelBuffer::ModelBuffer(size_t id, int rank, bool is_external, const std::vector &send_tags, const std::vector &recv_tags) - : id_(id), rank_(rank) { + : id_(id), rank_(rank), is_external_(is_external) { for (const auto &info : send_tags) { send_tags_.insert(info); } @@ -24,23 +27,6 @@ ModelBuffer::ModelBuffer(size_t id, int rank, } } -ModelBuffer::ModelBuffer(void *data, size_t size, int32_t device_id) - : rank_(-1), - external_data_(data), - external_data_size_(size), - device_id_(device_id), - is_external_(true) { - id_ = curr_id++; -} - -ModelBuffer::ModelBuffer(size_t id, void *data, size_t size, int32_t device_id) - : id_(id), - rank_(-1), - external_data_(data), - external_data_size_(size), - device_id_(device_id), - is_external_(true) {} - void ModelBuffer::tag_send(int remote_rank, int tag) { send_tags_.insert(TagInfo{remote_rank, tag}); } @@ -61,16 +47,9 @@ Json ModelBuffer::serialize() const { for (const auto &info : recv_tags_) { recv_tags.push_back({info.first, info.second}); } + j["IsExternal"] = is_external_; j["SendTags"] = send_tags; j["RecvTags"] = recv_tags; - j["IsExternal"] = is_external_; - if (is_external_) { - ModelBufferManager::get_instance().register_buffer(id_, external_data_, - external_data_size_); - j["ExternalDataSize"] = external_data_size_; - j["DeviceId"] = device_id_; - } - // external_data_ptr_ is not included in JSON return j; } @@ -88,28 +67,9 @@ std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { ERR(ModelError, "ModelBuffer deserialization failed: missing IsExternal"); } - if (serialized["IsExternal"]) { - if (!serialized.contains("ExternalDataSize")) { - ERR(ModelError, - "ModelBuffer deserialization failed: missing ExternalDataSize"); - } else if (!serialized.contains("DeviceId")) { - ERR(ModelError, - "ModelBuffer deserialization failed: missing DeviceId"); - } - void *data_ptr = - ModelBufferManager::get_instance().get_buffer(serialized["Id"]); - if (!data_ptr) { - ERR(ModelError, - "ModelBuffer deserialization failed: external buffer not found " - "in BufferManager"); - } - return std::make_shared(serialized["Id"], data_ptr, - serialized["ExternalDataSize"], - serialized["DeviceId"]); - } - return std::make_shared(serialized["Id"], serialized["Rank"], - serialized["SendTags"], - serialized["RecvTags"]); + return std::make_shared( + serialized["Id"], serialized["Rank"], serialized["IsExternal"], + serialized["SendTags"], serialized["RecvTags"]); } } // namespace ark diff --git a/ark/model/model_buffer.hpp b/ark/model/model_buffer.hpp index e7f1045b..8b66356b 100644 --- a/ark/model/model_buffer.hpp +++ b/ark/model/model_buffer.hpp @@ -17,19 +17,18 @@ class ModelBuffer { // (remote_rank, tag) using TagInfo = std::pair; - ModelBuffer(int rank = -1); + ModelBuffer(int rank = -1, bool is_external = false); - ModelBuffer(size_t id, int rank, const std::vector &send_tags, + ModelBuffer(size_t id, int rank, bool is_external, + const std::vector &send_tags, const std::vector &recv_tags); - // externally managed buffer - ModelBuffer(void *data, size_t size, int32_t device_id); - ModelBuffer(size_t id, void *data, size_t size, int32_t device_id); - size_t id() const { return id_; } int rank() const { return rank_; } + bool is_external() const { return is_external_; } + const std::set &send_tags() const { return send_tags_; } const std::set &recv_tags() const { return recv_tags_; } @@ -48,22 +47,13 @@ class ModelBuffer { static std::shared_ptr deserialize(const Json &serialized); - // external buffer management - size_t external_data_size() const { return external_data_size_; } - void *external_data() const { return external_data_; } - int32_t device_id() const { return device_id_; } - bool is_external() const { return is_external_; } - private: static size_t curr_id; size_t id_; int rank_; + bool is_external_; std::set send_tags_; std::set recv_tags_; - void *external_data_ = nullptr; - size_t external_data_size_ = 0; - int32_t device_id_; - bool is_external_ = false; }; } // namespace ark diff --git a/ark/model_buffer_manager.hpp b/ark/model_buffer_manager.hpp deleted file mode 100644 index 3e82b05f..00000000 --- a/ark/model_buffer_manager.hpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef ARK_MODEL_BUFFER_MANAGER_HPP_ -#define ARK_MODEL_BUFFER_MANAGER_HPP_ - -#include -#include - -namespace ark { -// Manages externally allocated buffers (buffers corresponding to Tensors that -// are the output of a `placeholder` operation) outside of ARK's memory space. -class ModelBufferManager { - public: - static ModelBufferManager& get_instance() { - static ModelBufferManager instance; - return instance; - } - - void register_buffer(const size_t id, void* const data, const size_t size) { - buffers_[id] = std::make_tuple(data, size); - } - - void* get_buffer(const size_t id) const { - auto it = buffers_.find(id); - if (it != buffers_.end()) { - return std::get<0>(it->second); - } - return nullptr; - } - - size_t get_buffer_size(const size_t id) const { - auto it = buffers_.find(id); - if (it != buffers_.end()) { - return std::get<1>(it->second); - } - return 0; - } - - bool is_external(const size_t id) const { - return buffers_.find(id) != buffers_.end(); - } - - const std::unordered_map>& get_buffers() - const { - return buffers_; - } - - void clear_buffers() { buffers_.clear(); } - - bool is_empty() const { return buffers_.empty(); } - - private: - // Maps buffer IDs to pointers and sizes. - std::unordered_map> buffers_; - ModelBufferManager() {} - ModelBufferManager(const ModelBufferManager&) = delete; - ModelBufferManager& operator=(const ModelBufferManager&) = delete; -}; -} // namespace ark - -#endif // ARK_MODEL_BUFFER_MANAGER_HPP_ diff --git a/ark/ops/ops_placeholder.cpp b/ark/ops/ops_placeholder.cpp index fbac7390..73c1c1b2 100644 --- a/ark/ops/ops_placeholder.cpp +++ b/ark/ops/ops_placeholder.cpp @@ -3,8 +3,8 @@ #include "ops_placeholder.hpp" +#include "external_buffer_registry.hpp" #include "logging.hpp" -#include "model_buffer_manager.hpp" #include "ops_common.hpp" namespace ark { @@ -12,22 +12,13 @@ namespace ark { ModelOpPlaceholder::ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, ModelDataType data_type, const Dims &strides, const Dims &offsets, - const Dims &padded_shape, - void *external_data) + const Dims &padded_shape, void *data) : ModelOp("Placeholder", true) { if (!buffer) { - buffer = std::make_shared(); + buffer = std::make_shared(-1, true); } - const std::vector &shape_vec = shape.vector(); - DataType dtype = ModelDataType(data_type); - size_t external_data_size = - std::accumulate(shape_vec.begin(), shape_vec.end(), 1, - std::multiplies()) * - dtype.bytes(); - - ModelBufferManager::get_instance().register_buffer( - buffer->id(), external_data, external_data_size); + ExternalBufferRegistry::get_instance().set(buffer->id(), data); ModelTensorRef tensor = std::make_shared( data_type, buffer, shape, strides, offsets, padded_shape); @@ -39,8 +30,8 @@ ModelOpPlaceholder::ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, Tensor Model::placeholder(const Dims &shape, const DataType &data_type, const Dims &strides, const Dims &offsets, - const Dims &padded_shape, int rank, - const std::string &name, void *external_data) { + const Dims &padded_shape, int rank, void *data, + const std::string &name) { if (rank != -1) { if (rank == this->rank()) { rank = -1; @@ -50,8 +41,9 @@ Tensor Model::placeholder(const Dims &shape, const DataType &data_type, } return impl_ ->create_op( - name, std::make_shared(rank), shape, data_type.ref(), - strides, offsets, padded_shape, external_data) + name, std::make_shared(rank, true), shape, + data_type.ref(), strides, offsets, padded_shape, data) ->result_tensors()[0]; } -} // namespace ark \ No newline at end of file + +} // namespace ark diff --git a/ark/ops/ops_placeholder.hpp b/ark/ops/ops_placeholder.hpp index 7fb53f98..91dd874a 100644 --- a/ark/ops/ops_placeholder.hpp +++ b/ark/ops/ops_placeholder.hpp @@ -15,7 +15,7 @@ class ModelOpPlaceholder : public ModelOp { ModelOpPlaceholder(ModelBufferRef buffer, const Dims &shape, ModelDataType data_type, const Dims &strides, const Dims &offsets, const Dims &padded_shape, - void *external_data = nullptr); + void *data = nullptr); }; } // namespace ark diff --git a/ark/ops/ops_placeholder_test.cpp b/ark/ops/ops_placeholder_test.cpp index 903d8759..22387232 100644 --- a/ark/ops/ops_placeholder_test.cpp +++ b/ark/ops/ops_placeholder_test.cpp @@ -27,7 +27,7 @@ ark::unittest::State test_ops_placeholder_value_contiguous() { // Associate the initialized device buffer with a tensor produced from a // placeholder operation ark::Tensor tns = - model.placeholder(shape, ark::FP32, {}, {}, {}, -1, "", d_ext_buffer); + model.placeholder(shape, ark::FP32, {}, {}, {}, -1, d_ext_buffer); ark::Tensor res = model.add(tns, 1.0); diff --git a/python/ark/ops.py b/python/ark/ops.py index f8b75a70..1e03cae9 100644 --- a/python/ark/ops.py +++ b/python/ark/ops.py @@ -3,7 +3,7 @@ from typing import List, Iterable, Union -from .tensor import Dims, Tensor, Parameter, NullTensor +from .tensor import Dims, Tensor, Parameter, NullTensor, _cpp_tensor from .data_type import DataType, fp32 from .model import Model @@ -12,42 +12,6 @@ def _is_list_or_tuple(obj): return isinstance(obj, list) or isinstance(obj, tuple) -def _tensor( - shape: Iterable[int], - dtype: DataType = fp32, - strides: Iterable[int] = [], - offsets: Iterable[int] = [], - padded_shape: Iterable[int] = [], - rank: int = -1, - name: str = "", -) -> Tensor: - if not _is_list_or_tuple(shape): - raise ValueError("shape should be a list or tuple of integers") - if not _is_list_or_tuple(strides): - raise ValueError("strides should be a list or tuple of integers") - if not _is_list_or_tuple(offsets): - raise ValueError("offsets should be a list or tuple of integers") - if not _is_list_or_tuple(padded_shape): - raise ValueError("padded_shape should be a list or tuple of integers") - # only support tensors with up to 4 dimensions - if ( - len(shape) > 4 - or len(strides) > 4 - or len(offsets) > 4 - or len(padded_shape) > 4 - ): - raise ValueError("Only support tensors with up to 4 dimensions") - return Model.get_model().tensor( - Dims(shape), - dtype.ctype(), - Dims(strides), - Dims(offsets), - Dims(padded_shape), - rank, - name, - ) - - def add( input: Union[Tensor, float], other: Union[Tensor, float], @@ -258,6 +222,24 @@ def noop(input: Tensor, name: str = "noop"): Model.get_model().noop(input._tensor, name) +def placeholder( + shape: Iterable[int], + dtype: DataType = fp32, + strides: Iterable[int] = [], + offsets: Iterable[int] = [], + padded_shape: Iterable[int] = [], + rank: int = -1, + data: int = 0, + name: str = "placeholder", +) -> Tensor: + """ """ + return Tensor( + _cpp_tensor( + shape, dtype, strides, offsets, padded_shape, rank, data, name + ) + ) + + def reduce_max( input: Tensor, axis: int, @@ -488,7 +470,9 @@ def tensor( tensor = ark.tensor([1, 2], dtype=ark.fp16) """ return Tensor( - _tensor(shape, dtype, strides, offsets, padded_shape, rank, name) + _cpp_tensor( + shape, dtype, strides, offsets, padded_shape, rank, None, name + ) ) @@ -554,7 +538,7 @@ def parameter( Construct a parameter with given shape and data type. """ return Parameter( - _tensor(shape, dtype, strides, offsets, padded_shape, name) + _cpp_tensor(shape, dtype, strides, offsets, padded_shape, None, name) ) diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 45a54d16..dec64682 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -2,10 +2,10 @@ # Licensed under the MIT license. import numpy as np -from typing import Callable, List, Union, Type +from typing import Callable, Iterable, List, Union, Type from ._ark_core import _Dims, _Tensor, _NullTensor -from .data_type import DataType +from .data_type import DataType, fp32 from .runtime import Runtime from .model import Model @@ -137,7 +137,8 @@ def from_dlpack(ext_tensor) -> "Tensor": """ Copies the tensor from a DLPack tensor to the device. """ - return Tensor(_Tensor(ext_tensor)) + # return Tensor(_Tensor(ext_tensor)) + raise NotImplementedError("from_dlpack is not implemented yet") def to_torch(self) -> torch.Tensor: """ @@ -162,7 +163,14 @@ def from_torch(tensor: torch.Tensor) -> "Tensor": raise ValueError("Torch tensor must be contiguous.") elif tensor.device.type == "cpu": raise ValueError("Torch tensor must be on a device.") - ark_tensor = Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)) + # TODO: support strides and offsets + ark_tensor = Tensor( + _cpp_tensor( + shape=list(tensor.shape), + dtype=DataType.from_torch(tensor.dtype), + data=tensor.data_ptr(), + ) + ) # Share ownership of the memory with the torch tensor ark_tensor.__torch_buffer__ = tensor return ark_tensor @@ -263,3 +271,57 @@ def update_gradient(self, ark_tensor: Tensor): if ark_tensor is None or not isinstance(ark_tensor, Tensor): raise ValueError("cannot use non-ARK tensor to update ARK gradient") self.staged_tensor = ark_tensor + + +def _is_list_or_tuple(obj): + return isinstance(obj, list) or isinstance(obj, tuple) + + +def _cpp_tensor( + shape: Iterable[int], + dtype: DataType = fp32, + strides: Iterable[int] = [], + offsets: Iterable[int] = [], + padded_shape: Iterable[int] = [], + rank: int = -1, + data: int = None, + name: str = "", +) -> Tensor: + if not _is_list_or_tuple(shape): + raise ValueError("shape should be a list or tuple of integers") + if not _is_list_or_tuple(strides): + raise ValueError("strides should be a list or tuple of integers") + if not _is_list_or_tuple(offsets): + raise ValueError("offsets should be a list or tuple of integers") + if not _is_list_or_tuple(padded_shape): + raise ValueError("padded_shape should be a list or tuple of integers") + # only support tensors with up to 4 dimensions + if ( + len(shape) > 4 + or len(strides) > 4 + or len(offsets) > 4 + or len(padded_shape) > 4 + ): + raise ValueError("Only support tensors with up to 4 dimensions") + if data is not None: + cpp_tensor = Model.get_model().placeholder( + Dims(shape), + dtype.ctype(), + Dims(strides), + Dims(offsets), + Dims(padded_shape), + rank, + data, + name, + ) + else: + cpp_tensor = Model.get_model().tensor( + Dims(shape), + dtype.ctype(), + Dims(strides), + Dims(offsets), + Dims(padded_shape), + rank, + name, + ) + return cpp_tensor diff --git a/python/model_py.cpp b/python/model_py.cpp index c224a3d5..76740ff1 100644 --- a/python/model_py.cpp +++ b/python/model_py.cpp @@ -71,6 +71,19 @@ void register_model(py::module &m) { py::arg("input"), py::arg("other"), py::arg("output"), py::arg("name")) .def("noop", &ark::Model::noop, py::arg("input"), py::arg("name")) + .def( + "placeholder", + [](ark::Model &model, const ark::Dims &shape, + const ark::DataType &data_type, const ark::Dims &strides, + const ark::Dims &offsets, const ark::Dims &padded_shape, + int rank, uintptr_t data, const std::string &name) { + return model.placeholder(shape, data_type, strides, offsets, + padded_shape, rank, + reinterpret_cast(data), name); + }, + py::arg("shape"), py::arg("data_type"), py::arg("strides"), + py::arg("offsets"), py::arg("padded_shape"), py::arg("rank"), + py::arg("data"), py::arg("name")) .def("reduce_max", &ark::Model::reduce_max, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), py::arg("name")) @@ -104,14 +117,9 @@ void register_model(py::module &m) { const std::string &>(&ark::Model::sub), py::arg("input"), py::arg("other"), py::arg("output"), py::arg("name")) - .def("tensor", - py::overload_cast( - &ark::Model::tensor), - py::arg("shape"), py::arg("data_type"), py::arg("strides"), - py::arg("offsets"), py::arg("padded_shape"), py::arg("rank"), - py::arg("name")) + .def("tensor", &ark::Model::tensor, py::arg("shape"), + py::arg("data_type"), py::arg("strides"), py::arg("offsets"), + py::arg("padded_shape"), py::arg("rank"), py::arg("name")) .def("transpose", &ark::Model::transpose, py::arg("input"), py::arg("permutation"), py::arg("output"), py::arg("name")) .def("all_reduce", &ark::Model::all_reduce, py::arg("input"), diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp index 5abb35c6..74ca7f1a 100644 --- a/python/tensor_py.cpp +++ b/python/tensor_py.cpp @@ -69,19 +69,20 @@ static ark::DataType from_dl_dtype(const DLDataType &dl_dtype) { void register_tensor(py::module& m) { py::class_(m, "_Tensor") - .def(py::init([](py::capsule capsule) { - DLManagedTensor* dl_tensor = (DLManagedTensor*)capsule; - if (!dl_tensor) { - ERR(ark::InvalidUsageError, - "Capsule does not contain a DLManagedTensor"); - } - DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); - int32_t device_id = metadata.device_id; - void* data_ptr = metadata.data_ptr; - auto shape = metadata.shape; + // .def(py::init([](py::capsule capsule) { + // DLManagedTensor* dl_tensor = (DLManagedTensor*)capsule; + // if (!dl_tensor) { + // ERR(ark::InvalidUsageError, + // "Capsule does not contain a DLManagedTensor"); + // } + // DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); + // int32_t device_id = metadata.device_id; + // void* data_ptr = metadata.data_ptr; + // auto shape = metadata.shape; - return ark::Tensor(data_ptr, device_id, shape, from_dl_dtype(metadata.dtype)); - })) + // return ark::Tensor(data_ptr, device_id, shape, + // from_dl_dtype(metadata.dtype)); + // })) .def("id", &ark::Tensor::id) .def("shape", &ark::Tensor::shape) .def("strides", &ark::Tensor::strides)