Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sycl/source/detail/device_kernel_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

DeviceKernelInfo::DeviceKernelInfo(const CompileTimeKernelInfoTy &Info)
: CompileTimeKernelInfoTy(Info) {}
DeviceKernelInfo::DeviceKernelInfo(const CompileTimeKernelInfoTy &Info,
std::optional<sycl::kernel_id> KernelID)
: CompileTimeKernelInfoTy{Info}, MKernelID{std::move(KernelID)} {}

template <typename OtherTy>
inline constexpr bool operator==(const CompileTimeKernelInfoTy &LHS,
Expand Down
16 changes: 12 additions & 4 deletions sycl/source/detail/device_kernel_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <sycl/detail/compile_time_kernel_info.hpp>
#include <sycl/detail/spinlock.hpp>
#include <sycl/detail/ur.hpp>
#include <sycl/kernel_bundle.hpp>

#include <mutex>
#include <optional>
Expand Down Expand Up @@ -84,12 +85,10 @@ struct FastKernelSubcacheT {
// information that is uniform between different submissions of the same
// kernel). Pointers to instances of this class are stored in header function
// templates as a static variable to avoid repeated runtime lookup overhead.
// TODO Currently this class duplicates information fetched from the program
// manager. Instead, we should merge all of this information
// into this structure and get rid of the other KernelName -> * maps.
class DeviceKernelInfo : public CompileTimeKernelInfoTy {
public:
DeviceKernelInfo(const CompileTimeKernelInfoTy &Info);
DeviceKernelInfo(const CompileTimeKernelInfoTy &Info,
std::optional<sycl::kernel_id> KernelID = std::nullopt);

void init(std::string_view KernelName);
void setCompileTimeInfoIfNeeded(const CompileTimeKernelInfoTy &Info);
Expand All @@ -100,6 +99,14 @@ class DeviceKernelInfo : public CompileTimeKernelInfoTy {
return MImplicitLocalArgPos;
}

const sycl::kernel_id &getKernelID() const {
// Expected to be called only for DeviceKernelInfo instances created by
// program manager (as opposed to allocated by sycl::kernel with
// origins other than SYCL offline compilation).
assert(MKernelID);
return *MKernelID;
}

// Implicit local argument position is used only for some backends, so this
// function allows setting it as more images are added.
void setImplicitLocalArgPos(int Pos);
Expand All @@ -109,6 +116,7 @@ class DeviceKernelInfo : public CompileTimeKernelInfoTy {

FastKernelSubcacheT MFastKernelSubcache;
std::optional<int> MImplicitLocalArgPos;
const std::optional<sycl::kernel_id> MKernelID;
};

} // namespace detail
Expand Down
90 changes: 45 additions & 45 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,8 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
"Cannot resolve external symbols, linking is unsupported "
"for the backend");

// Access to m_ExportedSymbolImages must be guarded by m_KernelIDsMutex.
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
// Access to m_ExportedSymbolImages must be guarded by m_ImgMapsMutex.
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);

while (!WorkList.empty()) {
std::string Symbol = WorkList.front();
Expand Down Expand Up @@ -770,8 +770,8 @@ ProgramManager::collectDependentDeviceImagesForVirtualFunctions(
if (!WorkList.empty()) {
// Guard read access to m_VFSet2BinImage:
// TODO: a better solution should be sought in the future, i.e. a different
// mutex than m_KernelIDsMutex, check lock check pattern, etc.
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
// mutex than m_ImgMapsMutex, check lock check pattern, etc.
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);

while (!WorkList.empty()) {
std::string SetName = WorkList.front();
Expand Down Expand Up @@ -1333,11 +1333,12 @@ ProgramManager::getDeviceImage(std::string_view KernelName,

const RTDeviceBinaryImage *Img = nullptr;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
if (auto KernelId = m_KernelName2KernelIDs.find(KernelName);
KernelId != m_KernelName2KernelIDs.end()) {
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage, KernelId->second,
ContextImpl, DeviceImpl);
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
if (auto It = m_DeviceKernelInfoMap.find(KernelName);
It != m_DeviceKernelInfoMap.end()) {
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage,
It->second.getKernelID(), ContextImpl,
DeviceImpl);
}
}

Expand Down Expand Up @@ -1369,7 +1370,7 @@ const RTDeviceBinaryImage &ProgramManager::getDeviceImage(
debugPrintBinaryImages();
}

std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
std::vector<sycl_device_binary> RawImgs(ImageSet.size());
auto ImageIterator = ImageSet.begin();
for (size_t i = 0; i < ImageSet.size(); i++, ImageIterator++)
Expand Down Expand Up @@ -1642,7 +1643,7 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
}

// Fill maps for kernel bundles
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);

// For bfloat16 device library image, it doesn't include any kernel, device
// global, virtual function, so just skip adding it to any related maps.
Expand Down Expand Up @@ -1716,31 +1717,31 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
m_BinImg2KernelIDs[Img.get()];
KernelIDs.reset(new std::vector<kernel_id>);

std::lock_guard<std::mutex> DKIGuard(m_DeviceKernelInfoMapMutex);

for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE;
EntriesIt = EntriesIt->Increment()) {

auto name = EntriesIt->GetName();

// Skip creating unique kernel ID if it is an exported device
// Skip creating device kernel information if it is an exported device
// function. Exported device functions appear in the offload entries
// among kernels, but are identifiable by being listed in properties.
if (m_ExportedSymbolImages.find(name) != m_ExportedSymbolImages.end())
continue;

// ... and create a unique kernel ID for the entry
auto It = m_KernelName2KernelIDs.find(name);
if (It == m_KernelName2KernelIDs.end()) {
auto It = m_DeviceKernelInfoMap.find(std::string_view(name));
if (It == m_DeviceKernelInfoMap.end()) {
sycl::kernel_id KernelID = detail::createSyclObjFromImpl<sycl::kernel_id>(
std::make_shared<detail::kernel_id_impl>(name));

It = m_KernelName2KernelIDs.emplace_hint(It, name, KernelID);
CompileTimeKernelInfoTy DefaultCompileTimeInfo{std::string_view(name)};
It = m_DeviceKernelInfoMap.emplace_hint(
It, std::piecewise_construct, std::forward_as_tuple(name),
std::forward_as_tuple(DefaultCompileTimeInfo, KernelID));
}
m_KernelIDs2BinImage.insert(std::make_pair(It->second, Img.get()));
KernelIDs->push_back(It->second);

CompileTimeKernelInfoTy DefaultCompileTimeInfo{std::string_view(name)};
m_DeviceKernelInfoMap.try_emplace(std::string_view(name),
DefaultCompileTimeInfo);
m_KernelIDs2BinImage.insert(
std::make_pair(It->second.getKernelID(), Img.get()));
KernelIDs->push_back(It->second.getKernelID());

// Keep track of image to kernel name reference count for cleanup.
m_KernelNameRefCount[name]++;
Expand Down Expand Up @@ -1831,7 +1832,7 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
if (DeviceBinary->NumDeviceBinaries == 0)
return;
// Acquire lock to read and modify maps for kernel bundles
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);

// Acquire lock to erase DeviceKernelInfoMap
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
Expand Down Expand Up @@ -1919,9 +1920,10 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
continue;
}

auto Name2IDIt = m_KernelName2KernelIDs.find(Name);
if (Name2IDIt != m_KernelName2KernelIDs.end())
removeFromMultimapByVal(m_KernelIDs2BinImage, Name2IDIt->second, Img);
auto DKIIt = m_DeviceKernelInfoMap.find(Name);
assert(DKIIt != m_DeviceKernelInfoMap.end());
removeFromMultimapByVal(m_KernelIDs2BinImage, DKIIt->second.getKernelID(),
Img);

auto RefCountIt = m_KernelNameRefCount.find(Name);
assert(RefCountIt != m_KernelNameRefCount.end());
Expand All @@ -1933,10 +1935,8 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
if (--RefCount == 0) {
// TODO aggregate all these maps into a single one since their entries
// share lifetime.
m_DeviceKernelInfoMap.erase(Name);
m_DeviceKernelInfoMap.erase(DKIIt);
m_KernelNameRefCount.erase(RefCountIt);
if (Name2IDIt != m_KernelName2KernelIDs.end())
m_KernelName2KernelIDs.erase(Name2IDIt);
}
}

Expand Down Expand Up @@ -2045,7 +2045,7 @@ ProgramManager::getBinImageState(const RTDeviceBinaryImage *BinImage) {
}

bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
std::lock_guard<std::mutex> Guard(m_KernelIDsMutex);
std::lock_guard<std::mutex> Guard(m_ImgMapsMutex);

return std::any_of(
m_BinImg2KernelIDs.cbegin(), m_BinImg2KernelIDs.cend(),
Expand All @@ -2055,19 +2055,19 @@ bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
}

std::vector<kernel_id> ProgramManager::getAllSYCLKernelIDs() {
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> DKIGuard(m_DeviceKernelInfoMapMutex);

std::vector<sycl::kernel_id> AllKernelIDs;
AllKernelIDs.reserve(m_KernelName2KernelIDs.size());
for (std::pair<std::string_view, kernel_id> KernelID :
m_KernelName2KernelIDs) {
AllKernelIDs.push_back(KernelID.second);
AllKernelIDs.reserve(m_DeviceKernelInfoMap.size());
for (const std::pair<const std::string_view, DeviceKernelInfo> &Pair :
m_DeviceKernelInfoMap) {
AllKernelIDs.push_back(Pair.second.getKernelID());
}
return AllKernelIDs;
}

kernel_id ProgramManager::getBuiltInKernelID(std::string_view KernelName) {
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);
std::lock_guard<std::mutex> BuiltInImgMapsGuard(m_BuiltInKernelIDsMutex);

auto KernelID = m_BuiltInKernelIDs.find(KernelName);
if (KernelID == m_BuiltInKernelIDs.end()) {
Expand Down Expand Up @@ -2118,7 +2118,7 @@ ProgramManager::getKernelGlobalInfoDesc(const char *UniqueId) {
std::set<const RTDeviceBinaryImage *>
ProgramManager::getRawDeviceImages(const std::vector<kernel_id> &KernelIDs) {
std::set<const RTDeviceBinaryImage *> BinImages;
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
for (const kernel_id &KID : KernelIDs) {
auto Range = m_KernelIDs2BinImage.equal_range(KID);
for (auto It = Range.first, End = Range.second; It != End; ++It)
Expand Down Expand Up @@ -2204,7 +2204,7 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
// Collect kernel names for the image.
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
KernelIDs = m_BinImg2KernelIDs[BinImage];
}

Expand Down Expand Up @@ -2234,7 +2234,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
}
BinImages = getRawDeviceImages(KernelIDs);
} else {
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
for (auto &ImageUPtr : m_BinImg2KernelIDs) {
BinImages.insert(ImageUPtr.first);
}
Expand Down Expand Up @@ -2293,7 +2293,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
ImgInfo.State = getBinImageState(BinImage);
// Collect kernel names for the image
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
}
ImgInfo.Deps = collectDeviceImageDeps(*BinImage, Dev);
Expand Down Expand Up @@ -2390,7 +2390,7 @@ ProgramManager::createDependencyImage(const context &Ctx, devices_range Devs,
bundle_state DepState) {
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
// For device library images, they are not in m_BinImg2KernelIDs since
// no kernel is included.
auto DepIt = m_BinImg2KernelIDs.find(DepImage);
Expand Down Expand Up @@ -2513,7 +2513,7 @@ ProgramManager::getSYCLDeviceImages(const context &Ctx, devices_range Devs,
return {};

{
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);
std::lock_guard<std::mutex> BuiltInImgMapsGuard(m_BuiltInKernelIDsMutex);

for (auto &It : m_BuiltInKernelIDs) {
if (std::find(KernelIDs.begin(), KernelIDs.end(), It.second) !=
Expand Down Expand Up @@ -2943,7 +2943,7 @@ ur_kernel_handle_t ProgramManager::getCachedMaterializedKernel(
<< "KernelName: " << KernelName << "\n";

{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
if (auto KnownMaterializations = m_MaterializedKernels.find(KernelName);
KnownMaterializations != m_MaterializedKernels.end()) {
if constexpr (DbgProgMgr > 0)
Expand Down Expand Up @@ -3000,7 +3000,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
BuildProgram, KernelName.data(), &UrKernel);
ur_kernel_handle_t RawUrKernel = UrKernel;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
m_MaterializedKernels[KernelName][SpecializationConsts] =
std::move(UrKernel);
}
Expand Down
Loading
Loading