Skip to content

Commit 3922ee8

Browse files
[SYCL] Move kernel id into device kernel info struct (#20928)
1 parent 36ce6ea commit 3922ee8

File tree

6 files changed

+85
-88
lines changed

6 files changed

+85
-88
lines changed

sycl/source/detail/device_kernel_info.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ namespace sycl {
1212
inline namespace _V1 {
1313
namespace detail {
1414

15-
DeviceKernelInfo::DeviceKernelInfo(const CompileTimeKernelInfoTy &Info)
16-
: CompileTimeKernelInfoTy(Info) {}
15+
DeviceKernelInfo::DeviceKernelInfo(const CompileTimeKernelInfoTy &Info,
16+
std::optional<sycl::kernel_id> KernelID)
17+
: CompileTimeKernelInfoTy{Info}, MKernelID{std::move(KernelID)} {}
1718

1819
template <typename OtherTy>
1920
inline constexpr bool operator==(const CompileTimeKernelInfoTy &LHS,

sycl/source/detail/device_kernel_info.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <sycl/detail/compile_time_kernel_info.hpp>
1414
#include <sycl/detail/spinlock.hpp>
1515
#include <sycl/detail/ur.hpp>
16+
#include <sycl/kernel_bundle.hpp>
1617

1718
#include <mutex>
1819
#include <optional>
@@ -84,12 +85,10 @@ struct FastKernelSubcacheT {
8485
// information that is uniform between different submissions of the same
8586
// kernel). Pointers to instances of this class are stored in header function
8687
// templates as a static variable to avoid repeated runtime lookup overhead.
87-
// TODO Currently this class duplicates information fetched from the program
88-
// manager. Instead, we should merge all of this information
89-
// into this structure and get rid of the other KernelName -> * maps.
9088
class DeviceKernelInfo : public CompileTimeKernelInfoTy {
9189
public:
92-
DeviceKernelInfo(const CompileTimeKernelInfoTy &Info);
90+
DeviceKernelInfo(const CompileTimeKernelInfoTy &Info,
91+
std::optional<sycl::kernel_id> KernelID = std::nullopt);
9392

9493
void init(std::string_view KernelName);
9594
void setCompileTimeInfoIfNeeded(const CompileTimeKernelInfoTy &Info);
@@ -100,6 +99,14 @@ class DeviceKernelInfo : public CompileTimeKernelInfoTy {
10099
return MImplicitLocalArgPos;
101100
}
102101

102+
const sycl::kernel_id &getKernelID() const {
103+
// Expected to be called only for DeviceKernelInfo instances created by
104+
// program manager (as opposed to allocated by sycl::kernel with
105+
// origins other than SYCL offline compilation).
106+
assert(MKernelID);
107+
return *MKernelID;
108+
}
109+
103110
// Implicit local argument position is used only for some backends, so this
104111
// function allows setting it as more images are added.
105112
void setImplicitLocalArgPos(int Pos);
@@ -109,6 +116,7 @@ class DeviceKernelInfo : public CompileTimeKernelInfoTy {
109116

110117
FastKernelSubcacheT MFastKernelSubcache;
111118
std::optional<int> MImplicitLocalArgPos;
119+
const std::optional<sycl::kernel_id> MKernelID;
112120
};
113121

114122
} // namespace detail

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,8 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
667667
"Cannot resolve external symbols, linking is unsupported "
668668
"for the backend");
669669

670-
// Access to m_ExportedSymbolImages must be guarded by m_KernelIDsMutex.
671-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
670+
// Access to m_ExportedSymbolImages must be guarded by m_ImgMapsMutex.
671+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
672672

673673
while (!WorkList.empty()) {
674674
std::string Symbol = WorkList.front();
@@ -748,8 +748,8 @@ ProgramManager::collectDependentDeviceImagesForVirtualFunctions(
748748
if (!WorkList.empty()) {
749749
// Guard read access to m_VFSet2BinImage:
750750
// TODO: a better solution should be sought in the future, i.e. a different
751-
// mutex than m_KernelIDsMutex, check lock check pattern, etc.
752-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
751+
// mutex than m_ImgMapsMutex, check lock check pattern, etc.
752+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
753753

754754
while (!WorkList.empty()) {
755755
std::string SetName = WorkList.front();
@@ -1311,11 +1311,12 @@ ProgramManager::getDeviceImage(std::string_view KernelName,
13111311

13121312
const RTDeviceBinaryImage *Img = nullptr;
13131313
{
1314-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1315-
if (auto KernelId = m_KernelName2KernelIDs.find(KernelName);
1316-
KernelId != m_KernelName2KernelIDs.end()) {
1317-
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage, KernelId->second,
1318-
ContextImpl, DeviceImpl);
1314+
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
1315+
if (auto It = m_DeviceKernelInfoMap.find(KernelName);
1316+
It != m_DeviceKernelInfoMap.end()) {
1317+
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage,
1318+
It->second.getKernelID(), ContextImpl,
1319+
DeviceImpl);
13191320
}
13201321
}
13211322

@@ -1347,7 +1348,7 @@ const RTDeviceBinaryImage &ProgramManager::getDeviceImage(
13471348
debugPrintBinaryImages();
13481349
}
13491350

1350-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1351+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
13511352
std::vector<sycl_device_binary> RawImgs(ImageSet.size());
13521353
auto ImageIterator = ImageSet.begin();
13531354
for (size_t i = 0; i < ImageSet.size(); i++, ImageIterator++)
@@ -1620,7 +1621,7 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
16201621
}
16211622

16221623
// Fill maps for kernel bundles
1623-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1624+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
16241625

16251626
// For bfloat16 device library image, it doesn't include any kernel, device
16261627
// global, virtual function, so just skip adding it to any related maps.
@@ -1694,31 +1695,31 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
16941695
m_BinImg2KernelIDs[Img.get()];
16951696
KernelIDs.reset(new std::vector<kernel_id>);
16961697

1698+
std::lock_guard<std::mutex> DKIGuard(m_DeviceKernelInfoMapMutex);
1699+
16971700
for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE;
16981701
EntriesIt = EntriesIt->Increment()) {
16991702

17001703
auto name = EntriesIt->GetName();
17011704

1702-
// Skip creating unique kernel ID if it is an exported device
1705+
// Skip creating device kernel information if it is an exported device
17031706
// function. Exported device functions appear in the offload entries
17041707
// among kernels, but are identifiable by being listed in properties.
17051708
if (m_ExportedSymbolImages.find(name) != m_ExportedSymbolImages.end())
17061709
continue;
17071710

1708-
// ... and create a unique kernel ID for the entry
1709-
auto It = m_KernelName2KernelIDs.find(name);
1710-
if (It == m_KernelName2KernelIDs.end()) {
1711+
auto It = m_DeviceKernelInfoMap.find(std::string_view(name));
1712+
if (It == m_DeviceKernelInfoMap.end()) {
17111713
sycl::kernel_id KernelID = detail::createSyclObjFromImpl<sycl::kernel_id>(
17121714
std::make_shared<detail::kernel_id_impl>(name));
1713-
1714-
It = m_KernelName2KernelIDs.emplace_hint(It, name, KernelID);
1715+
CompileTimeKernelInfoTy DefaultCompileTimeInfo{std::string_view(name)};
1716+
It = m_DeviceKernelInfoMap.emplace_hint(
1717+
It, std::piecewise_construct, std::forward_as_tuple(name),
1718+
std::forward_as_tuple(DefaultCompileTimeInfo, KernelID));
17151719
}
1716-
m_KernelIDs2BinImage.insert(std::make_pair(It->second, Img.get()));
1717-
KernelIDs->push_back(It->second);
1718-
1719-
CompileTimeKernelInfoTy DefaultCompileTimeInfo{std::string_view(name)};
1720-
m_DeviceKernelInfoMap.try_emplace(std::string_view(name),
1721-
DefaultCompileTimeInfo);
1720+
m_KernelIDs2BinImage.insert(
1721+
std::make_pair(It->second.getKernelID(), Img.get()));
1722+
KernelIDs->push_back(It->second.getKernelID());
17221723

17231724
// Keep track of image to kernel name reference count for cleanup.
17241725
m_KernelNameRefCount[name]++;
@@ -1777,7 +1778,7 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
17771778
if (DeviceBinary->NumDeviceBinaries == 0)
17781779
return;
17791780
// Acquire lock to read and modify maps for kernel bundles
1780-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1781+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
17811782

17821783
// Acquire lock to erase DeviceKernelInfoMap
17831784
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
@@ -1846,9 +1847,10 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
18461847
continue;
18471848
}
18481849

1849-
auto Name2IDIt = m_KernelName2KernelIDs.find(Name);
1850-
if (Name2IDIt != m_KernelName2KernelIDs.end())
1851-
removeFromMultimapByVal(m_KernelIDs2BinImage, Name2IDIt->second, Img);
1850+
auto DKIIt = m_DeviceKernelInfoMap.find(Name);
1851+
assert(DKIIt != m_DeviceKernelInfoMap.end());
1852+
removeFromMultimapByVal(m_KernelIDs2BinImage, DKIIt->second.getKernelID(),
1853+
Img);
18521854

18531855
auto RefCountIt = m_KernelNameRefCount.find(Name);
18541856
assert(RefCountIt != m_KernelNameRefCount.end());
@@ -1860,10 +1862,8 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
18601862
if (--RefCount == 0) {
18611863
// TODO aggregate all these maps into a single one since their entries
18621864
// share lifetime.
1863-
m_DeviceKernelInfoMap.erase(Name);
1865+
m_DeviceKernelInfoMap.erase(DKIIt);
18641866
m_KernelNameRefCount.erase(RefCountIt);
1865-
if (Name2IDIt != m_KernelName2KernelIDs.end())
1866-
m_KernelName2KernelIDs.erase(Name2IDIt);
18671867
}
18681868
}
18691869

@@ -1971,7 +1971,7 @@ ProgramManager::getBinImageState(const RTDeviceBinaryImage *BinImage) {
19711971
}
19721972

19731973
bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
1974-
std::lock_guard<std::mutex> Guard(m_KernelIDsMutex);
1974+
std::lock_guard<std::mutex> Guard(m_ImgMapsMutex);
19751975

19761976
return std::any_of(
19771977
m_BinImg2KernelIDs.cbegin(), m_BinImg2KernelIDs.cend(),
@@ -1981,19 +1981,19 @@ bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
19811981
}
19821982

19831983
std::vector<kernel_id> ProgramManager::getAllSYCLKernelIDs() {
1984-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1984+
std::lock_guard<std::mutex> DKIGuard(m_DeviceKernelInfoMapMutex);
19851985

19861986
std::vector<sycl::kernel_id> AllKernelIDs;
1987-
AllKernelIDs.reserve(m_KernelName2KernelIDs.size());
1988-
for (std::pair<std::string_view, kernel_id> KernelID :
1989-
m_KernelName2KernelIDs) {
1990-
AllKernelIDs.push_back(KernelID.second);
1987+
AllKernelIDs.reserve(m_DeviceKernelInfoMap.size());
1988+
for (const std::pair<const std::string_view, DeviceKernelInfo> &Pair :
1989+
m_DeviceKernelInfoMap) {
1990+
AllKernelIDs.push_back(Pair.second.getKernelID());
19911991
}
19921992
return AllKernelIDs;
19931993
}
19941994

19951995
kernel_id ProgramManager::getBuiltInKernelID(std::string_view KernelName) {
1996-
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);
1996+
std::lock_guard<std::mutex> BuiltInImgMapsGuard(m_BuiltInKernelIDsMutex);
19971997

19981998
auto KernelID = m_BuiltInKernelIDs.find(KernelName);
19991999
if (KernelID == m_BuiltInKernelIDs.end()) {
@@ -2044,7 +2044,7 @@ ProgramManager::getKernelGlobalInfoDesc(const char *UniqueId) {
20442044
std::set<const RTDeviceBinaryImage *>
20452045
ProgramManager::getRawDeviceImages(const std::vector<kernel_id> &KernelIDs) {
20462046
std::set<const RTDeviceBinaryImage *> BinImages;
2047-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2047+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
20482048
for (const kernel_id &KID : KernelIDs) {
20492049
auto Range = m_KernelIDs2BinImage.equal_range(KID);
20502050
for (auto It = Range.first, End = Range.second; It != End; ++It)
@@ -2099,7 +2099,7 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
20992099
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
21002100
// Collect kernel names for the image.
21012101
{
2102-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2102+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
21032103
KernelIDs = m_BinImg2KernelIDs[BinImage];
21042104
}
21052105

@@ -2129,7 +2129,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
21292129
}
21302130
BinImages = getRawDeviceImages(KernelIDs);
21312131
} else {
2132-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2132+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
21332133
for (auto &ImageUPtr : m_BinImg2KernelIDs) {
21342134
BinImages.insert(ImageUPtr.first);
21352135
}
@@ -2188,7 +2188,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
21882188
ImgInfo.State = getBinImageState(BinImage);
21892189
// Collect kernel names for the image
21902190
{
2191-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2191+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
21922192
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
21932193
}
21942194
ImgInfo.Deps = collectDeviceImageDeps(*BinImage, Dev);
@@ -2285,7 +2285,7 @@ ProgramManager::createDependencyImage(const context &Ctx, devices_range Devs,
22852285
bundle_state DepState) {
22862286
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
22872287
{
2288-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2288+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
22892289
// For device library images, they are not in m_BinImg2KernelIDs since
22902290
// no kernel is included.
22912291
auto DepIt = m_BinImg2KernelIDs.find(DepImage);
@@ -2408,7 +2408,7 @@ ProgramManager::getSYCLDeviceImages(const context &Ctx, devices_range Devs,
24082408
return {};
24092409

24102410
{
2411-
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);
2411+
std::lock_guard<std::mutex> BuiltInImgMapsGuard(m_BuiltInKernelIDsMutex);
24122412

24132413
for (auto &It : m_BuiltInKernelIDs) {
24142414
if (std::find(KernelIDs.begin(), KernelIDs.end(), It.second) !=
@@ -2838,7 +2838,7 @@ ur_kernel_handle_t ProgramManager::getCachedMaterializedKernel(
28382838
<< "KernelName: " << KernelName << "\n";
28392839

28402840
{
2841-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2841+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
28422842
if (auto KnownMaterializations = m_MaterializedKernels.find(KernelName);
28432843
KnownMaterializations != m_MaterializedKernels.end()) {
28442844
if constexpr (DbgProgMgr > 0)
@@ -2895,7 +2895,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
28952895
BuildProgram, KernelName.data(), &UrKernel);
28962896
ur_kernel_handle_t RawUrKernel = UrKernel;
28972897
{
2898-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2898+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
28992899
m_MaterializedKernels[KernelName][SpecializationConsts] =
29002900
std::move(UrKernel);
29012901
}

0 commit comments

Comments
 (0)