Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions torch_npu/csrc/core/npu/NPUStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ static std::once_flag device_priority_flags[C10_COMPILE_TIME_MAX_NPUS][kMaxStrea
// SyncLaunch streams pool init flags
static std::once_flag device_sync_launch_flags[C10_COMPILE_TIME_MAX_NPUS];
static std::array<
std::array<std::atomic<uint32_t>, C10_COMPILE_TIME_MAX_NPUS>,
kMaxStreamPriorities>
std::array<std::atomic<uint32_t>, kMaxStreamPriorities>,
C10_COMPILE_TIME_MAX_NPUS>
npu_counters;
static std::atomic<uint32_t> sync_stream_counters[C10_COMPILE_TIME_MAX_NPUS];
// npu_streams is a stream pool, each device has a stream pool,
// and 8 streams are created in each pool.
static std::array<
std::array<
std::array<LeakyStreamInternals, kStreamsPerPool>,
C10_COMPILE_TIME_MAX_NPUS>,
kMaxStreamPriorities>
kMaxStreamPriorities>,
C10_COMPILE_TIME_MAX_NPUS>
npu_streams;
static thread_local std::unique_ptr<LeakyStreamInternals* []> current_streams = nullptr;

Expand Down Expand Up @@ -177,9 +177,9 @@ static c10::StreamId NPUStream_getStreamId(const LeakyStreamInternals* ptr)
return makeStreamId(StreamIdType::DEFAULT, 0);
}
for (const auto p : c10::irange(kMaxStreamPriorities)) {
if (pointer_within<LeakyStreamInternals>(ptr, npu_streams[p][device_index])) {
if (pointer_within<LeakyStreamInternals>(ptr, npu_streams[device_index][p])) {
return makeStreamId(StreamIdType(static_cast<uint8_t>(StreamIdType::NORMAL) + p),
ptr - npu_streams[p][device_index].data());
ptr - npu_streams[device_index][p].data());
}
}
if (pointer_within<LeakyStreamInternals>(ptr, sync_launch_streams[device_index])) {
Expand Down Expand Up @@ -218,7 +218,7 @@ static void initGlobalStreamState()
// Initializes default streams
default_streams[device_id].device_index = device_id;
for (const auto p : c10::irange(kMaxStreamPriorities)) {
npu_counters[p][device_id] = 0;
npu_counters[device_id][p] = 0;
}
auto& default_streamsi = default_streams[device_id];
NPU_CHECK_ERROR(
Expand All @@ -240,7 +240,7 @@ static void initDeviceStreamState(c10::DeviceIndex device_index, int p)
NPUGuard device_guard{device_index};
static int StreamsPerPool = GetStreamsPerPool();
for (auto i = decltype(StreamsPerPool){0}; i < StreamsPerPool; ++i) {
auto& npu_streami = npu_streams[p][device_index][i];
auto& npu_streami = npu_streams[device_index][p][i];

npu_streami.device_index = device_index;

Expand Down Expand Up @@ -315,7 +315,7 @@ LeakyStreamInternals* NPUStream_internals(NPUStream s)
return &default_streams[device_index];
case StreamIdType::NORMAL:
case StreamIdType::HIGH:
return &npu_streams[static_cast<uint8_t>(st) - static_cast<uint8_t>(StreamIdType::NORMAL)][device_index][si];
return &npu_streams[device_index][static_cast<uint8_t>(st) - static_cast<uint8_t>(StreamIdType::NORMAL)][si];
case StreamIdType::SECONDARY:
return &secondary_streams[device_index];
case StreamIdType::SYNCLAUNCH:
Expand Down Expand Up @@ -387,8 +387,8 @@ NPUStream getStreamFromPool(const int priority, c10::DeviceIndex device_index)
// Initializes the stream pools (once)
std::call_once(
device_priority_flags[device_index][pri_idx], initDeviceStreamState, device_index, pri_idx);
const auto idx = get_idx(npu_counters[pri_idx][device_index]);
return NPUStream_fromInternals(&npu_streams[pri_idx][device_index][idx]);
const auto idx = get_idx(npu_counters[device_index][pri_idx]);
return NPUStream_fromInternals(&npu_streams[device_index][pri_idx][idx]);
}

NPUStream getNPUStreamFromPool(c10::DeviceIndex device_index)
Expand Down Expand Up @@ -641,9 +641,9 @@ void recovery_all_npu_streams(c10::DeviceIndex device_index)
NPU_CHECK_ERROR(
acl::AclrtCreateStreamWithConfig(&secondary_streamsi.stream, 0, (ACL_STREAM_FAST_LAUNCH | ACL_STREAM_FAST_SYNC)));
static int StreamsPerPool = GetStreamsPerPool();
for (auto i = decltype(StreamsPerPool){0}; i < StreamsPerPool; ++i) {
for (const auto p : c10::irange(kMaxStreamPriorities)) {
auto& npu_streami = npu_streams[p][device_index][i];
for (const auto p : c10::irange(kMaxStreamPriorities)) {
for (auto i = decltype(StreamsPerPool){0}; i < StreamsPerPool; ++i) {
auto& npu_streami = npu_streams[device_index][p][i];
if (npu_streami.stream == nullptr) {
continue;
}
Expand Down