diff --git a/plugins/tenstorrent/tt_command.cpp b/plugins/tenstorrent/tt_command.cpp index 6ded6a4..07c2767 100644 --- a/plugins/tenstorrent/tt_command.cpp +++ b/plugins/tenstorrent/tt_command.cpp @@ -4,6 +4,8 @@ #include +#include + /************************************************************************ * @def _cpu_barrier * @brief Barrier for CPU fibers @@ -57,21 +59,28 @@ nxs_status TTCommand::runCommand(nxs_int stream, ttmd::MeshWorkload &workload, // collect uniform args TTLibrary::RunTimeArgs rt_args; size_t numArgs = getArgsCount(); + assert(numArgs <= NXS_KERNEL_MAX_ARGS - 5); for (size_t i = 0; i < numArgs; i++) { uint32_t arg_val = *static_cast(args[i].value); NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Runtime arg: ", i, "=", arg_val); rt_args[i] = arg_val; } + // compute persistent grid size + int total_grid_size = grid_size.x * grid_size.y * grid_size.z; + int persistent_grid_stride = std::max(1, total_grid_size / (int)cores.size()); + NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Total grid size: ", total_grid_size, ", cores: ", cores.size(), ", persistent grid stride: ", persistent_grid_stride); + // set params - int grid_idx = 0; + int persistent_grid_idx = 0; for (const auto& core : cores) { - rt_args[numArgs] = grid_idx % grid_size.x; - rt_args[numArgs + 1] = (grid_idx % (grid_size.x * grid_size.y)) / grid_size.x; - rt_args[numArgs + 2] = grid_idx / (grid_size.x * grid_size.y); - NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Launch params: grid_idx=", grid_idx, ", x=", rt_args[numArgs], ", y=", rt_args[numArgs+1]); + rt_args[numArgs] = persistent_grid_idx * persistent_grid_stride; + rt_args[numArgs+1] = persistent_grid_idx * persistent_grid_stride + persistent_grid_stride; + if (rt_args[numArgs+1] > total_grid_size) + rt_args[numArgs+1] = total_grid_size; + NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Launch params: grid_idx=", persistent_grid_idx, ", start=", rt_args[numArgs], ", end=", rt_args[numArgs+1]); library->setupCoreRuntime(program, core, rt_args); - grid_idx++; + persistent_grid_idx++; } // local or passed in? diff --git a/plugins/tenstorrent/tt_device.h b/plugins/tenstorrent/tt_device.h index cb8c9c6..9ec5ce3 100644 --- a/plugins/tenstorrent/tt_device.h +++ b/plugins/tenstorrent/tt_device.h @@ -8,9 +8,12 @@ class TTDevice { std::shared_ptr device; public: TTDevice(int device_id = 0) : device_id(device_id) {} - virtual ~TTDevice() = default; + virtual ~TTDevice() { release(); } - nxs_status release() { device = nullptr; return NXS_Success; } + nxs_status release() { + device = nullptr; + return NXS_Success; + } std::shared_ptr get() { initDevice(); return device; } diff --git a/plugins/tenstorrent/tt_runtime.h b/plugins/tenstorrent/tt_runtime.h index 5085ba2..24b88a1 100644 --- a/plugins/tenstorrent/tt_runtime.h +++ b/plugins/tenstorrent/tt_runtime.h @@ -21,7 +21,8 @@ class TTRuntime : public rt::Runtime { public: TTRuntime() : rt::Runtime() { - TT_NOBJ_CHECK(numDevs, ttm::GetNumAvailableDevices); + //TT_NOBJ_CHECK(numDevs, ttm::GetNumAvailableDevices); + int numDevs = 1; // TODO: remove this once tt-metal supports multiple devices NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Create TTDevice count: ", numDevs); for (size_t i = 0; i < numDevs; ++i) { devices.emplace_back(i); diff --git a/plugins/tenstorrent/tt_schedule.cpp b/plugins/tenstorrent/tt_schedule.cpp index b092d1d..63662a5 100644 --- a/plugins/tenstorrent/tt_schedule.cpp +++ b/plugins/tenstorrent/tt_schedule.cpp @@ -13,13 +13,19 @@ bool placeCommand(nxs_uint cmdSize, ttm::CoreRange &cmdRange, ttm::CoreRange &de auto numRows = (cmdSize / rowLen) + !!(cmdSize % rowLen); auto tail = cmdSize % rowLen; + if (devRange.end_coord.y <= devRange.start_coord.y) { + // No rows available + return false; + } + // TODO: use this instead if (numRows == 1) { // find gap and return + } else if (numRows > devRange.end_coord.y - devRange.start_coord.y + 1) { + numRows = devRange.end_coord.y - devRange.start_coord.y + 1; } - if (numRows > devRange.end_coord.y - devRange.start_coord.y + 1) { - return false; - } + + // Compute range and make persistent if necessary cmdRange.start_coord.x = devRange.start_coord.x; cmdRange.start_coord.y = devRange.start_coord.y; cmdRange.end_coord.x = numRows > 1 ? devRange.end_coord.x : devRange.start_coord.x + tail - 1; @@ -55,8 +61,7 @@ nxs_status TTSchedule::run(nxs_int stream, nxs_uint run_settings) { TT_NOBJ_CHECK(devGrid, device->get()->compute_with_storage_grid_size); NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Device grid: ", devGrid.x, ",", devGrid.y); -// ttm::CoreRangeSet coreRangeSet( -// ttm::num_cores_to_corerangeset(devGrid.x * devGrid.y, devGrid, true)); +// TODO: use split_work_to_cores utility function to distribute commands across cores // NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Core range set: ", coreRangeSet.bounding_box().start_coord.x, ",", coreRangeSet.bounding_box().start_coord.y, " - ", coreRangeSet.bounding_box().end_coord.x, ",", coreRangeSet.bounding_box().end_coord.y); ttm::CoreRange devRange = {{0,0}, {devGrid.x - 1, devGrid.y - 1}};