-
Notifications
You must be signed in to change notification settings - Fork 15.5k
NFC - llvm_ir2vec.cpp breakup to extract a reusable header for IR2VecTool, and MIR2VecTool classes #172304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…a common importable module
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-mlgo Author: Nishant Sachdeva (nishant-sachdeva) ChangesRefactor llvm-ir2vec: Extract reusable header for Python bindings Separated the IR2Vec/MIR2Vec tool implementation into a header file ( Changes
Motivation Testing Patch is 44.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172304.diff 2 Files Affected:
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index 7b8d3f093a3d1..8b52c385ff524 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -54,10 +54,12 @@
///
//===----------------------------------------------------------------------===//
+#include "llvm-ir2vec.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
@@ -89,8 +91,6 @@
namespace llvm {
-static const char *ToolName = "llvm-ir2vec";
-
// Common option category for options shared between IR2Vec and MIR2Vec
static cl::OptionCategory CommonCategory("Common Options",
"Options applicable to both IR2Vec "
@@ -134,12 +134,6 @@ static cl::opt<std::string>
cl::value_desc("name"), cl::Optional, cl::init(""),
cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
-enum EmbeddingLevel {
- InstructionLevel, // Generate instruction-level embeddings
- BasicBlockLevel, // Generate basic block-level embeddings
- FunctionLevel // Generate function-level embeddings
-};
-
static cl::opt<EmbeddingLevel>
Level("level", cl::desc("Embedding generation level:"),
cl::values(clEnumValN(InstructionLevel, "inst",
@@ -153,177 +147,7 @@ static cl::opt<EmbeddingLevel>
namespace ir2vec {
-/// Relation types for triplet generation
-enum RelationType {
- TypeRelation = 0, ///< Instruction to type relationship
- NextRelation = 1, ///< Sequential instruction relationship
- ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N)
-};
-
-/// Helper class for collecting IR triplets and generating embeddings
-class IR2VecTool {
-private:
- Module &M;
- ModuleAnalysisManager MAM;
- const Vocabulary *Vocab = nullptr;
-
-public:
- explicit IR2VecTool(Module &M) : M(M) {}
-
- /// Initialize the IR2Vec vocabulary analysis
- bool initializeVocabulary() {
- // Register and run the IR2Vec vocabulary analysis
- // The vocabulary file path is specified via --ir2vec-vocab-path global
- // option
- MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
- MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
- // This will throw an error if vocab is not found or invalid
- Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
- return Vocab->isValid();
- }
-
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(raw_ostream &OS) const {
- unsigned MaxRelation = NextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
-
- for (const Function &F : M) {
- unsigned FuncMaxRelation = generateTriplets(F, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
- }
-
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
- }
-
- /// Generate triplets for a single function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const Function &F, raw_ostream &OS) const {
- if (F.isDeclaration())
- return 0;
-
- unsigned MaxRelation = 1;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
-
- for (const BasicBlock &BB : F) {
- for (const auto &I : BB.instructionsWithoutDebug()) {
- unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
- unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
-
- // Add "Next" relationship with previous instruction
- if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << Opcode << '\t' << NextRelation << '\n';
- LLVM_DEBUG(dbgs()
- << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t'
- << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << "Next\n");
- }
-
- // Add "Type" relationship
- OS << Opcode << '\t' << TypeID << '\t' << TypeRelation << '\n';
- LLVM_DEBUG(
- dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
- << '\t' << "Type\n");
-
- // Add "Arg" relationships
- unsigned ArgIndex = 0;
- for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getIndex(*U.get());
- unsigned RelationID = ArgRelation + ArgIndex;
- OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
-
- LLVM_DEBUG({
- StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
- Vocabulary::getOperandKind(U.get()));
- dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
- << OperandStr << '\t' << "Arg" << ArgIndex << '\n';
- });
-
- ++ArgIndex;
- }
- // Only update MaxRelation if there were operands
- if (ArgIndex > 0) {
- MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
- }
- PrevOpcode = Opcode;
- HasPrevOpcode = true;
- }
- }
-
- return MaxRelation;
- }
-
- /// Dump entity ID to string mappings
- static void generateEntityMappings(raw_ostream &OS) {
- auto EntityLen = Vocabulary::getCanonicalSize();
- OS << EntityLen << "\n";
- for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
- OS << Vocabulary::getStringKey(EntityID) << '\t' << EntityID << '\n';
- }
-
- /// Generate embeddings for the entire module
- void generateEmbeddings(raw_ostream &OS) const {
- if (!Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
-
- for (const Function &F : M)
- generateEmbeddings(F, OS);
- }
-
- /// Generate embeddings for a single function
- void generateEmbeddings(const Function &F, raw_ostream &OS) const {
- if (F.isDeclaration()) {
- OS << "Function " << F.getName() << " is a declaration, skipping.\n";
- return;
- }
-
- // Create embedder for this function
- assert(Vocab->isValid() && "Vocabulary is not valid");
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function " << F.getName() << "\n";
- return;
- }
-
- OS << "Function: " << F.getName() << "\n";
-
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel: {
- Emb->getFunctionVector().print(OS);
- break;
- }
- case BasicBlockLevel: {
- for (const BasicBlock &BB : F) {
- OS << BB.getName() << ":";
- Emb->getBBVector(BB).print(OS);
- }
- break;
- }
- case InstructionLevel: {
- for (const BasicBlock &BB : F) {
- for (const Instruction &I : BB) {
- I.print(OS);
- Emb->getInstVector(I).print(OS);
- }
- }
- break;
- }
- }
- }
-};
-
+/// Process the module and generate output based on selected subcommand
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
@@ -337,18 +161,18 @@ Error processModule(Module &M, raw_ostream &OS) {
if (!FunctionName.empty()) {
// Process single function
if (const Function *F = M.getFunction(FunctionName))
- Tool.generateEmbeddings(*F, OS);
+ Tool.writeEmbeddingsToStream(*F, OS, Level);
else
return createStringError(errc::invalid_argument,
"Function '%s' not found",
FunctionName.c_str());
} else {
// Process all functions
- Tool.generateEmbeddings(OS);
+ Tool.writeEmbeddingsToStream(OS, Level);
}
} else {
// Both triplets and entities use triplet generation
- Tool.generateTriplets(OS);
+ Tool.writeTripletsToStream(OS);
}
return Error::success();
}
@@ -356,237 +180,151 @@ Error processModule(Module &M, raw_ostream &OS) {
namespace mir2vec {
-/// Relation types for MIR2Vec triplet generation
-enum MIRRelationType {
- MIRNextRelation = 0, ///< Sequential instruction relationship
- MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
-};
-
-/// Helper class for MIR2Vec embedding generation
-class MIR2VecTool {
-private:
- MachineModuleInfo &MMI;
- std::unique_ptr<MIRVocabulary> Vocab;
-
-public:
- explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
-
- /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
- bool initializeVocabulary(const Module &M) {
- MIR2VecVocabProvider Provider(MMI);
- auto VocabOrErr = Provider.getVocabulary(M);
- if (!VocabOrErr) {
- WithColor::error(errs(), ToolName)
- << "Failed to load MIR2Vec vocabulary - "
- << toString(VocabOrErr.takeError()) << "\n";
- return false;
- }
- Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
- return true;
- }
+/// Setup MIR context from input file
+Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
+ SMDiagnostic Err;
- /// Initialize vocabulary with layout information only.
- /// This creates a minimal vocabulary with correct layout but no actual
- /// embeddings. Sufficient for generating training data and entity mappings.
- ///
- /// Note: Requires target-specific information from the first machine function
- /// to determine the vocabulary layout (number of opcodes, register classes).
- ///
- /// FIXME: Use --target option to get target info directly, avoiding the need
- /// to parse machine functions for pre-training operations.
- bool initializeVocabularyForLayout(const Module &M) {
- for (const Function &F : M.getFunctionDefs()) {
-
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF)
- continue;
-
- const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
- const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
- const MachineRegisterInfo &MRI = MF->getRegInfo();
-
- auto VocabOrErr =
- MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
- if (!VocabOrErr) {
- WithColor::error(errs(), ToolName)
- << "Failed to create dummy vocabulary - "
- << toString(VocabOrErr.takeError()) << "\n";
- return false;
- }
- Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
- return true;
- }
-
- WithColor::error(errs(), ToolName)
- << "No machine functions found to initialize vocabulary\n";
- return false;
+ auto MIR = createMIRParserFromFile(InputFile, Err, Ctx.Context);
+ if (!MIR) {
+ Err.print(ToolName, errs());
+ return createStringError(errc::invalid_argument,
+ "Failed to parse MIR file");
}
- /// Generate triplets for the module
- /// Output format: MAX_RELATION=N header followed by relationships
- void generateTriplets(const Module &M, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
- std::string Relationships;
- raw_string_ostream RelOS(Relationships);
-
- for (const Function &F : M.getFunctionDefs()) {
+ auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
+ StringRef OldDLStr) -> std::optional<std::string> {
+ std::string IRTargetTriple = DataLayoutTargetTriple.str();
+ Triple TheTriple = Triple(IRTargetTriple);
+ if (TheTriple.getTriple().empty())
+ TheTriple.setTriple(sys::getDefaultTargetTriple());
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
-
- unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
- MaxRelation = std::max(MaxRelation, FuncMaxRelation);
+ auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
+ if (!TMOrErr) {
+ Err.print(ToolName, errs());
+ exit(1); // Match original behavior
}
+ Ctx.TM = std::move(*TMOrErr);
+ return Ctx.TM->createDataLayout().getStringRepresentation();
+ };
- RelOS.flush();
-
- // Write metadata header followed by relationships
- OS << "MAX_RELATION=" << MaxRelation << '\n';
- OS << Relationships;
+ Ctx.M = MIR->parseIRModule(SetDataLayout);
+ if (!Ctx.M) {
+ Err.print(ToolName, errs());
+ return createStringError(errc::invalid_argument,
+ "Failed to parse IR module");
}
- /// Generate triplets for a single machine function
- /// Returns the maximum relation ID used in this function
- unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
- unsigned MaxRelation = MIRNextRelation;
- unsigned PrevOpcode = 0;
- bool HasPrevOpcode = false;
-
- if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "MIR Vocabulary must be initialized for triplet generation.\n";
- return MaxRelation;
- }
-
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- // Skip debug instructions
- if (MI.isDebugInstr())
- continue;
-
- // Get opcode entity ID
- unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
-
- // Add "Next" relationship with previous instruction
- if (HasPrevOpcode) {
- OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
- << '\n';
- LLVM_DEBUG(dbgs()
- << Vocab->getStringKey(PrevOpcode) << '\t'
- << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
- }
-
- // Add "Arg" relationships for operands
- unsigned ArgIndex = 0;
- for (const MachineOperand &MO : MI.operands()) {
- auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
- unsigned RelationID = MIRArgRelation + ArgIndex;
- OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
- LLVM_DEBUG({
- std::string OperandStr = Vocab->getStringKey(OperandID);
- dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
- << '\t' << "Arg" << ArgIndex << '\n';
- });
-
- ++ArgIndex;
- }
-
- // Update MaxRelation if there were operands
- if (ArgIndex > 0)
- MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
-
- PrevOpcode = OpcodeID;
- HasPrevOpcode = true;
- }
- }
-
- return MaxRelation;
+ Ctx.MMI = std::make_unique<MachineModuleInfo>(Ctx.TM.get());
+ if (!Ctx.MMI || MIR->parseMachineFunctions(*Ctx.M, *Ctx.MMI)) {
+ Err.print(ToolName, errs());
+ return createStringError(errc::invalid_argument,
+ "Failed to parse machine functions");
}
- /// Generate entity mappings with vocabulary
- void generateEntityMappings(raw_ostream &OS) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary must be initialized for entity mappings.\n";
- return;
- }
+ return Error::success();
+}
- const unsigned EntityCount = Vocab->getCanonicalSize();
- OS << EntityCount << "\n";
- for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
- OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
- }
+/// Generic vocabulary initialization and processing
+template <typename ProcessFunc>
+Error processWithVocabulary(MIRContext &Ctx, raw_ostream &OS,
+ bool useLayoutVocab, ProcessFunc processFn) {
+ MIR2VecTool Tool(*Ctx.MMI);
- /// Generate embeddings for all machine functions in the module
- void generateEmbeddings(const Module &M, raw_ostream &OS) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
- }
+ // Initialize appropriate vocabulary type
+ bool success = useLayoutVocab ? Tool.initializeVocabularyForLayout(*Ctx.M)
+ : Tool.initializeVocabulary(*Ctx.M);
- for (const Function &F : M.getFunctionDefs()) {
+ if (!success) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to initialize MIR2Vec vocabulary"
+ << (useLayoutVocab ? " for layout" : "") << ".\n";
+ return createStringError(errc::invalid_argument,
+ "Vocabulary initialization failed");
+ }
- MachineFunction *MF = MMI.getMachineFunction(F);
- if (!MF) {
- WithColor::warning(errs(), ToolName)
- << "No MachineFunction for " << F.getName() << "\n";
- continue;
- }
+ assert(Tool.getVocabulary() &&
+ "MIR2Vec vocabulary should be initialized at this point");
- generateEmbeddings(*MF, OS);
- }
- }
+ LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
+ << "Vocabulary dimension: "
+ << Tool.getVocabulary()->getDimension() << "\n"
+ << "Vocabulary size: "
+ << Tool.getVocabulary()->getCanonicalSize() << "\n");
- /// Generate embeddings for a specific machine function
- void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const {
- if (!Vocab) {
- WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
- return;
- }
+ // Execute the specific processing logic
+ return processFn(Tool);
+}
- auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
- if (!Emb) {
- WithColor::error(errs(), ToolName)
- << "Failed to create embedder for " << MF.getName() << "\n";
- return;
- }
+/// Process module for triplet generation
+Error processModuleForTriplets(MIRContext &Ctx, raw_ostream &OS) {
+ return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
+ [&](MIR2VecTool &Tool) -> Error {
+ Tool.writeTripletsToStream(*Ctx.M, OS);
+ return Error::success();
+ });
+}
- OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+/// Process module for entity generation
+Error processModuleForEntities(MIRContext &Ctx, raw_ostream &OS) {
+ return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
+ [&](MIR2VecTool &Tool) -> Error {
+ Tool.writeEntitiesToStream(OS);
+ return Error::success();
+ });
+}
- // Generate embeddings based on the specified level
- switch (Level) {
- case FunctionLevel: {
- OS << "Function vector: ";
- Emb->getMFunctionVector().print(OS);
- break;
- }
- case BasicBlockLevel: {
- OS << "Basic block vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- OS << "MBB " << MBB.getName() << ": ";
- Emb->getMBBVector(MBB).print(OS);
- }
- break;
- }
- case InstructionLevel: {
- OS << "Instruction vectors:\n";
- for (const MachineBasicBlock &MBB : MF) {
- for (const MachineInstr &MI : MBB) {
- OS << MI << " -> ";
- Emb->getMInstVector(MI).print(OS);
+/// Process module for embedding generation
+Error processModuleForEmbeddings(MIRContext &Ctx, raw_ostream &OS) {
+ return processWithVocabulary(
+ Ctx, OS, /*useLayoutVocab=*/false, [&](MIR2VecTool &Tool) -> Error {
+ if (!FunctionName.empty()) {
+ // Process single function
+ Function *F = Ctx.M->getFunction(FunctionName);
+ if (!F) {
+ WithColor::error(errs(), ToolName)
+ << "Function '" << FunctionName << "' not found\n";
+ return createStringError(errc::invalid_argument,
+ "Function not found");
+ }
+
+ MachineFunction *MF = Ctx.MMI->getMachineFunction(*F);
+ if (!MF) {
+ WithColor::error...
[truncated]
|
mtrofin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually do [NFC] llvm_ir2vec.cpp breakup <for brief reason> - if you can update the title along that.
Refactor llvm-ir2vec: Extract reusable header for Python bindings
Separated the IR2Vec/MIR2Vec tool implementation into a header file (
llvm-ir2vec.h) and implementation file (llvm-ir2vec.cpp) to enable reuse in Python bindings and other projects.Changes
llvm-ir2vec.h: ContainsIR2VecToolandMIR2VecToolclass definitions with all implementations, making it a standalone header-only libraryllvm-ir2vec.cpp: Now contains only command-line interface code (options, main function, and helper functions)Motivation
The original monolithic
.cppfile made it impossible to use IR2Vec/MIR2Vec functionality in Python bindings without compiling the entire command-line tool. This refactoring enables clean separation between the library interface and the CLI tool. This will enable easier development for the upcoming python bindings HERETesting
All existing tests pass without modification. No functional changes.