diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index 9ab12e36718cd..8b52c385ff524 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -54,6 +54,7 @@ /// //===----------------------------------------------------------------------===// +#include "llvm-ir2vec.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Analysis/IR2Vec.h" #include "llvm/IR/BasicBlock.h" @@ -90,8 +91,6 @@ namespace llvm { -static const char *ToolName = "llvm-ir2vec"; - // Common option category for options shared between IR2Vec and MIR2Vec static cl::OptionCategory CommonCategory("Common Options", "Options applicable to both IR2Vec " @@ -135,12 +134,6 @@ static cl::opt cl::value_desc("name"), cl::Optional, cl::init(""), cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); -enum EmbeddingLevel { - InstructionLevel, // Generate instruction-level embeddings - BasicBlockLevel, // Generate basic block-level embeddings - FunctionLevel // Generate function-level embeddings -}; - static cl::opt Level("level", cl::desc("Embedding generation level:"), cl::values(clEnumValN(InstructionLevel, "inst", @@ -152,219 +145,8 @@ static cl::opt cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); -/// Represents a single knowledge graph triplet (Head, Relation, Tail) -/// where indices reference entities in an EntityList -struct Triplet { - unsigned Head = 0; ///< Index of the head entity in the entity list - unsigned Tail = 0; ///< Index of the tail entity in the entity list - unsigned Relation = 0; ///< Relation type (see RelationType enum) -}; - -/// Result structure containing all generated triplets and metadata -struct TripletResult { - unsigned MaxRelation = - 0; ///< Highest relation index used (for ArgRelation + N) - std::vector Triplets; ///< Collection of all generated triplets -}; - -/// Entity mappings: [entity_name] -using EntityList = std::vector; - namespace ir2vec { -/// Relation types for triplet generation -enum RelationType { - TypeRelation = 0, ///< Instruction to type relationship - NextRelation = 1, ///< Sequential instruction relationship - ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N) -}; - -/// Helper class for collecting IR triplets and generating embeddings -class IR2VecTool { -private: - Module &M; - ModuleAnalysisManager MAM; - const Vocabulary *Vocab = nullptr; - -public: - explicit IR2VecTool(Module &M) : M(M) {} - - /// Initialize the IR2Vec vocabulary analysis - bool initializeVocabulary() { - // Register and run the IR2Vec vocabulary analysis - // The vocabulary file path is specified via --ir2vec-vocab-path global - // option - MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); - MAM.registerPass([&] { return IR2VecVocabAnalysis(); }); - // This will throw an error if vocab is not found or invalid - Vocab = &MAM.getResult(M); - return Vocab->isValid(); - } - - /// Generate triplets for a single function - /// Returns a TripletResult with: - /// - Triplets: vector of all (subject, object, relation) tuples - /// - MaxRelation: highest Arg relation ID used, or NextRelation if none - TripletResult generateTriplets(const Function &F) const { - if (F.isDeclaration()) - return {}; - - TripletResult Result; - Result.MaxRelation = 0; - - unsigned MaxRelation = NextRelation; - unsigned PrevOpcode = 0; - bool HasPrevOpcode = false; - - for (const BasicBlock &BB : F) { - for (const auto &I : BB.instructionsWithoutDebug()) { - unsigned Opcode = Vocabulary::getIndex(I.getOpcode()); - unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID()); - - // Add "Next" relationship with previous instruction - if (HasPrevOpcode) { - Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation}); - LLVM_DEBUG(dbgs() - << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t' - << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t' - << "Next\n"); - } - - // Add "Type" relationship - Result.Triplets.push_back({Opcode, TypeID, TypeRelation}); - LLVM_DEBUG( - dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t' - << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID()) - << '\t' << "Type\n"); - - // Add "Arg" relationships - unsigned ArgIndex = 0; - for (const Use &U : I.operands()) { - unsigned OperandID = Vocabulary::getIndex(*U.get()); - unsigned RelationID = ArgRelation + ArgIndex; - Result.Triplets.push_back({Opcode, OperandID, RelationID}); - - LLVM_DEBUG({ - StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind( - Vocabulary::getOperandKind(U.get())); - dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t' - << OperandStr << '\t' << "Arg" << ArgIndex << '\n'; - }); - - ++ArgIndex; - } - // Only update MaxRelation if there were operands - if (ArgIndex > 0) - MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1); - PrevOpcode = Opcode; - HasPrevOpcode = true; - } - } - - Result.MaxRelation = MaxRelation; - return Result; - } - - /// Get triplets for the entire module - TripletResult generateTriplets() const { - TripletResult Result; - Result.MaxRelation = NextRelation; - - for (const Function &F : M.getFunctionDefs()) { - TripletResult FuncResult = generateTriplets(F); - Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation); - Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(), - FuncResult.Triplets.end()); - } - - return Result; - } - - /// Collect triplets for the module and dump output to stream - /// Output format: MAX_RELATION=N header followed by relationships - void writeTripletsToStream(raw_ostream &OS) const { - auto Result = generateTriplets(); - OS << "MAX_RELATION=" << Result.MaxRelation << '\n'; - for (const auto &T : Result.Triplets) - OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n'; - } - - /// Generate entity mappings for the entire vocabulary - /// Returns EntityList containing all entity strings - static EntityList collectEntityMappings() { - auto EntityLen = Vocabulary::getCanonicalSize(); - EntityList Result; - for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID) - Result.push_back(Vocabulary::getStringKey(EntityID).str()); - return Result; - } - - /// Dump entity ID to string mappings - static void writeEntitiesToStream(raw_ostream &OS) { - auto Entities = collectEntityMappings(); - OS << Entities.size() << "\n"; - for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID) - OS << Entities[EntityID] << '\t' << EntityID << '\n'; - } - - /// Generate embeddings for the entire module - void writeEmbeddingsToStream(raw_ostream &OS) const { - if (!Vocab->isValid()) { - WithColor::error(errs(), ToolName) - << "Vocabulary is not valid. IR2VecTool not initialized.\n"; - return; - } - - for (const Function &F : M.getFunctionDefs()) - writeEmbeddingsToStream(F, OS); - } - - /// Generate embeddings for a single function - void writeEmbeddingsToStream(const Function &F, raw_ostream &OS) const { - if (!Vocab || !Vocab->isValid()) { - WithColor::error(errs(), ToolName) - << "Vocabulary is not valid. IR2VecTool not initialized.\n"; - return; - } - if (F.isDeclaration()) { - OS << "Function " << F.getName() << " is a declaration, skipping.\n"; - return; - } - - // Create embedder for this function - auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab); - if (!Emb) { - WithColor::error(errs(), ToolName) - << "Failed to create embedder for function " << F.getName() << "\n"; - return; - } - - OS << "Function: " << F.getName() << "\n"; - - // Generate embeddings based on the specified level - switch (Level) { - case FunctionLevel: { - Emb->getFunctionVector().print(OS); - break; - } - case BasicBlockLevel: { - for (const BasicBlock &BB : F) { - OS << BB.getName() << ":"; - Emb->getBBVector(BB).print(OS); - } - break; - } - case InstructionLevel: { - for (const Instruction &I : instructions(F)) { - OS << I; - Emb->getInstVector(I).print(OS); - } - break; - } - } - } -}; - /// Process the module and generate output based on selected subcommand Error processModule(Module &M, raw_ostream &OS) { IR2VecTool Tool(M); @@ -379,14 +161,14 @@ Error processModule(Module &M, raw_ostream &OS) { if (!FunctionName.empty()) { // Process single function if (const Function *F = M.getFunction(FunctionName)) - Tool.writeEmbeddingsToStream(*F, OS); + Tool.writeEmbeddingsToStream(*F, OS, Level); else return createStringError(errc::invalid_argument, "Function '%s' not found", FunctionName.c_str()); } else { // Process all functions - Tool.writeEmbeddingsToStream(OS); + Tool.writeEmbeddingsToStream(OS, Level); } } else { // Both triplets and entities use triplet generation @@ -398,257 +180,151 @@ Error processModule(Module &M, raw_ostream &OS) { namespace mir2vec { -/// Relation types for MIR2Vec triplet generation -enum MIRRelationType { - MIRNextRelation = 0, ///< Sequential instruction relationship - MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N) -}; +/// Setup MIR context from input file +Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) { + SMDiagnostic Err; -/// Helper class for MIR2Vec embedding generation -class MIR2VecTool { -private: - MachineModuleInfo &MMI; - std::unique_ptr Vocab; - -public: - explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {} - - /// Initialize MIR2Vec vocabulary from file (for embeddings generation) - bool initializeVocabulary(const Module &M) { - MIR2VecVocabProvider Provider(MMI); - auto VocabOrErr = Provider.getVocabulary(M); - if (!VocabOrErr) { - WithColor::error(errs(), ToolName) - << "Failed to load MIR2Vec vocabulary - " - << toString(VocabOrErr.takeError()) << "\n"; - return false; - } - Vocab = std::make_unique(std::move(*VocabOrErr)); - return true; + auto MIR = createMIRParserFromFile(InputFile, Err, Ctx.Context); + if (!MIR) { + Err.print(ToolName, errs()); + return createStringError(errc::invalid_argument, + "Failed to parse MIR file"); } - /// Initialize vocabulary with layout information only. - /// This creates a minimal vocabulary with correct layout but no actual - /// embeddings. Sufficient for generating training data and entity mappings. - /// - /// Note: Requires target-specific information from the first machine function - /// to determine the vocabulary layout (number of opcodes, register classes). - /// - /// FIXME: Use --target option to get target info directly, avoiding the need - /// to parse machine functions for pre-training operations. - bool initializeVocabularyForLayout(const Module &M) { - for (const Function &F : M.getFunctionDefs()) { - - MachineFunction *MF = MMI.getMachineFunction(F); - if (!MF) - continue; - - const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo(); - const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo(); - const MachineRegisterInfo &MRI = MF->getRegInfo(); - - auto VocabOrErr = - MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1); - if (!VocabOrErr) { - WithColor::error(errs(), ToolName) - << "Failed to create dummy vocabulary - " - << toString(VocabOrErr.takeError()) << "\n"; - return false; - } - Vocab = std::make_unique(std::move(*VocabOrErr)); - return true; - } - - WithColor::error(errs(), ToolName) - << "No machine functions found to initialize vocabulary\n"; - return false; - } - - /// Get triplets for a single machine function - /// Returns TripletResult containing MaxRelation and vector of Triplets - TripletResult generateTriplets(const MachineFunction &MF) const { - TripletResult Result; - Result.MaxRelation = MIRNextRelation; - - if (!Vocab) { - WithColor::error(errs(), ToolName) - << "MIR Vocabulary must be initialized for triplet generation.\n"; - return Result; - } - - unsigned PrevOpcode = 0; - bool HasPrevOpcode = false; - for (const MachineBasicBlock &MBB : MF) { - for (const MachineInstr &MI : MBB) { - // Skip debug instructions - if (MI.isDebugInstr()) - continue; - - // Get opcode entity ID - unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode()); - - // Add "Next" relationship with previous instruction - if (HasPrevOpcode) { - Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation}); - LLVM_DEBUG(dbgs() - << Vocab->getStringKey(PrevOpcode) << '\t' - << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n"); - } - - // Add "Arg" relationships for operands - unsigned ArgIndex = 0; - for (const MachineOperand &MO : MI.operands()) { - auto OperandID = Vocab->getEntityIDForMachineOperand(MO); - unsigned RelationID = MIRArgRelation + ArgIndex; - Result.Triplets.push_back({OpcodeID, OperandID, RelationID}); - LLVM_DEBUG({ - std::string OperandStr = Vocab->getStringKey(OperandID); - dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr - << '\t' << "Arg" << ArgIndex << '\n'; - }); - - ++ArgIndex; - } + auto SetDataLayout = [&](StringRef DataLayoutTargetTriple, + StringRef OldDLStr) -> std::optional { + std::string IRTargetTriple = DataLayoutTargetTriple.str(); + Triple TheTriple = Triple(IRTargetTriple); + if (TheTriple.getTriple().empty()) + TheTriple.setTriple(sys::getDefaultTargetTriple()); - // Update MaxRelation if there were operands - if (ArgIndex > 0) - Result.MaxRelation = - std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1); - - PrevOpcode = OpcodeID; - HasPrevOpcode = true; - } - } - - return Result; - } - - /// Get triplets for the entire module - /// Returns TripletResult containing aggregated MaxRelation and all Triplets - TripletResult generateTriplets(const Module &M) const { - TripletResult Result; - Result.MaxRelation = MIRNextRelation; - - for (const Function &F : M.getFunctionDefs()) { - MachineFunction *MF = MMI.getMachineFunction(F); - if (!MF) { - WithColor::warning(errs(), ToolName) - << "No MachineFunction for " << F.getName() << "\n"; - continue; - } - - TripletResult FuncResult = generateTriplets(*MF); - Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation); - Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(), - FuncResult.Triplets.end()); + auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str()); + if (!TMOrErr) { + Err.print(ToolName, errs()); + exit(1); // Match original behavior } + Ctx.TM = std::move(*TMOrErr); + return Ctx.TM->createDataLayout().getStringRepresentation(); + }; - return Result; + Ctx.M = MIR->parseIRModule(SetDataLayout); + if (!Ctx.M) { + Err.print(ToolName, errs()); + return createStringError(errc::invalid_argument, + "Failed to parse IR module"); } - /// Collect triplets for the module and write to output stream - /// Output format: MAX_RELATION=N header followed by relationships - void writeTripletsToStream(const Module &M, raw_ostream &OS) const { - auto Result = generateTriplets(M); - OS << "MAX_RELATION=" << Result.MaxRelation << '\n'; - for (const auto &T : Result.Triplets) - OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n'; + Ctx.MMI = std::make_unique(Ctx.TM.get()); + if (!Ctx.MMI || MIR->parseMachineFunctions(*Ctx.M, *Ctx.MMI)) { + Err.print(ToolName, errs()); + return createStringError(errc::invalid_argument, + "Failed to parse machine functions"); } - /// Generate entity mappings for the entire vocabulary - EntityList collectEntityMappings() const { - if (!Vocab) { - WithColor::error(errs(), ToolName) - << "Vocabulary must be initialized for entity mappings.\n"; - return {}; - } - - const unsigned EntityCount = Vocab->getCanonicalSize(); - EntityList Result; - for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID) - Result.push_back(Vocab->getStringKey(EntityID)); + return Error::success(); +} - return Result; - } +/// Generic vocabulary initialization and processing +template +Error processWithVocabulary(MIRContext &Ctx, raw_ostream &OS, + bool useLayoutVocab, ProcessFunc processFn) { + MIR2VecTool Tool(*Ctx.MMI); - /// Generate entity mappings and write to output stream - void writeEntitiesToStream(raw_ostream &OS) const { - auto Entities = collectEntityMappings(); - if (Entities.empty()) - return; + // Initialize appropriate vocabulary type + bool success = useLayoutVocab ? Tool.initializeVocabularyForLayout(*Ctx.M) + : Tool.initializeVocabulary(*Ctx.M); - OS << Entities.size() << "\n"; - for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID) - OS << Entities[EntityID] << '\t' << EntityID << '\n'; + if (!success) { + WithColor::error(errs(), ToolName) + << "Failed to initialize MIR2Vec vocabulary" + << (useLayoutVocab ? " for layout" : "") << ".\n"; + return createStringError(errc::invalid_argument, + "Vocabulary initialization failed"); } - /// Generate embeddings for all machine functions in the module - void writeEmbeddingsToStream(const Module &M, raw_ostream &OS) const { - if (!Vocab) { - WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n"; - return; - } - - for (const Function &F : M.getFunctionDefs()) { - - MachineFunction *MF = MMI.getMachineFunction(F); - if (!MF) { - WithColor::warning(errs(), ToolName) - << "No MachineFunction for " << F.getName() << "\n"; - continue; - } + assert(Tool.getVocabulary() && + "MIR2Vec vocabulary should be initialized at this point"); - writeEmbeddingsToStream(*MF, OS); - } - } + LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n" + << "Vocabulary dimension: " + << Tool.getVocabulary()->getDimension() << "\n" + << "Vocabulary size: " + << Tool.getVocabulary()->getCanonicalSize() << "\n"); - /// Generate embeddings for a specific machine function - void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS) const { - if (!Vocab) { - WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n"; - return; - } + // Execute the specific processing logic + return processFn(Tool); +} - auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab); - if (!Emb) { - WithColor::error(errs(), ToolName) - << "Failed to create embedder for " << MF.getName() << "\n"; - return; - } +/// Process module for triplet generation +Error processModuleForTriplets(MIRContext &Ctx, raw_ostream &OS) { + return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true, + [&](MIR2VecTool &Tool) -> Error { + Tool.writeTripletsToStream(*Ctx.M, OS); + return Error::success(); + }); +} - OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n"; +/// Process module for entity generation +Error processModuleForEntities(MIRContext &Ctx, raw_ostream &OS) { + return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true, + [&](MIR2VecTool &Tool) -> Error { + Tool.writeEntitiesToStream(OS); + return Error::success(); + }); +} - // Generate embeddings based on the specified level - switch (Level) { - case FunctionLevel: { - OS << "Function vector: "; - Emb->getMFunctionVector().print(OS); - break; - } - case BasicBlockLevel: { - OS << "Basic block vectors:\n"; - for (const MachineBasicBlock &MBB : MF) { - OS << "MBB " << MBB.getName() << ": "; - Emb->getMBBVector(MBB).print(OS); - } - break; - } - case InstructionLevel: { - OS << "Instruction vectors:\n"; - for (const MachineBasicBlock &MBB : MF) { - for (const MachineInstr &MI : MBB) { - OS << MI << " -> "; - Emb->getMInstVector(MI).print(OS); +/// Process module for embedding generation +Error processModuleForEmbeddings(MIRContext &Ctx, raw_ostream &OS) { + return processWithVocabulary( + Ctx, OS, /*useLayoutVocab=*/false, [&](MIR2VecTool &Tool) -> Error { + if (!FunctionName.empty()) { + // Process single function + Function *F = Ctx.M->getFunction(FunctionName); + if (!F) { + WithColor::error(errs(), ToolName) + << "Function '" << FunctionName << "' not found\n"; + return createStringError(errc::invalid_argument, + "Function not found"); + } + + MachineFunction *MF = Ctx.MMI->getMachineFunction(*F); + if (!MF) { + WithColor::error(errs(), ToolName) + << "No MachineFunction for " << FunctionName << "\n"; + return createStringError(errc::invalid_argument, + "No MachineFunction"); + } + + Tool.writeEmbeddingsToStream(*MF, OS, Level); + } else { + // Process all functions + Tool.writeEmbeddingsToStream(*Ctx.M, OS, Level); } - } - break; - } - } - } + return Error::success(); + }); +} - /// Get the MIR vocabulary instance - const MIRVocabulary *getVocabulary() const { return Vocab.get(); } -}; +/// Main entry point for MIR processing +Error processModule(const std::string &InputFile, raw_ostream &OS) { + MIRContext Ctx; + + // Setup MIR context (parse file, setup target machine, etc.) + if (auto Err = setupMIRContext(InputFile, Ctx)) + return Err; + + // Process based on subcommand + if (TripletsSubCmd) + return processModuleForTriplets(Ctx, OS); + else if (EntitiesSubCmd) + return processModuleForEntities(Ctx, OS); + else if (EmbeddingsSubCmd) + return processModuleForEmbeddings(Ctx, OS); + else { + WithColor::error(errs(), ToolName) + << "Please specify a subcommand: triplets, entities, or embeddings\n"; + return createStringError(errc::invalid_argument, "No subcommand specified"); + } +} } // namespace mir2vec @@ -712,105 +388,10 @@ int main(int argc, char **argv) { InitializeAllAsmPrinters(); static codegen::RegisterCodeGenFlags CGF; - // Parse MIR input file - SMDiagnostic Err; - LLVMContext Context; - std::unique_ptr TM; - - auto MIR = createMIRParserFromFile(InputFilename, Err, Context); - if (!MIR) { - Err.print(ToolName, errs()); - return 1; - } - - auto SetDataLayout = [&](StringRef DataLayoutTargetTriple, - StringRef OldDLStr) -> std::optional { - std::string IRTargetTriple = DataLayoutTargetTriple.str(); - Triple TheTriple = Triple(IRTargetTriple); - if (TheTriple.getTriple().empty()) - TheTriple.setTriple(sys::getDefaultTargetTriple()); - auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str()); - if (!TMOrErr) { - Err.print(ToolName, errs()); - exit(1); - } - TM = std::move(*TMOrErr); - return TM->createDataLayout().getStringRepresentation(); - }; - - std::unique_ptr M = MIR->parseIRModule(SetDataLayout); - if (!M) { - Err.print(ToolName, errs()); - return 1; - } - - // Parse machine functions - auto MMI = std::make_unique(TM.get()); - if (!MMI || MIR->parseMachineFunctions(*M, *MMI)) { - Err.print(ToolName, errs()); - return 1; - } - - // Create MIR2Vec tool - MIR2VecTool Tool(*MMI); - - // Initialize vocabulary. For triplet/entity generation, only layout is - // needed For embedding generation, the full vocabulary is needed. - // - // Note: Unlike IR2Vec, MIR2Vec vocabulary initialization requires - // target-specific information for generating the vocabulary layout. So, we - // always initialize the vocabulary in this case. - if (TripletsSubCmd || EntitiesSubCmd) { - if (!Tool.initializeVocabularyForLayout(*M)) { - WithColor::error(errs(), ToolName) - << "Failed to initialize MIR2Vec vocabulary for layout.\n"; - return 1; - } - } else { - if (!Tool.initializeVocabulary(*M)) { - WithColor::error(errs(), ToolName) - << "Failed to initialize MIR2Vec vocabulary.\n"; - return 1; - } - } - assert(Tool.getVocabulary() && - "MIR2Vec vocabulary should be initialized at this point"); - LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n" - << "Vocabulary dimension: " - << Tool.getVocabulary()->getDimension() << "\n" - << "Vocabulary size: " - << Tool.getVocabulary()->getCanonicalSize() << "\n"); - - // Handle subcommands - if (TripletsSubCmd) { - Tool.writeTripletsToStream(*M, OS); - } else if (EntitiesSubCmd) { - Tool.writeEntitiesToStream(OS); - } else if (EmbeddingsSubCmd) { - if (!FunctionName.empty()) { - // Process single function - Function *F = M->getFunction(FunctionName); - if (!F) { - WithColor::error(errs(), ToolName) - << "Function '" << FunctionName << "' not found\n"; - return 1; - } - - MachineFunction *MF = MMI->getMachineFunction(*F); - if (!MF) { - WithColor::error(errs(), ToolName) - << "No MachineFunction for " << FunctionName << "\n"; - return 1; - } - - Tool.writeEmbeddingsToStream(*MF, OS); - } else { - // Process all functions - Tool.writeEmbeddingsToStream(*M, OS); - } - } else { - WithColor::error(errs(), ToolName) - << "Please specify a subcommand: triplets, entities, or embeddings\n"; + if (Error Err = mir2vec::processModule(InputFilename, OS)) { + handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) { + WithColor::error(errs(), ToolName) << EIB.message() << "\n"; + }); return 1; } diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.h b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h new file mode 100644 index 0000000000000..56fd834d380a8 --- /dev/null +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.h @@ -0,0 +1,534 @@ +//===- llvm-ir2vec.h - IR2Vec/MIR2Vec Tool Classes ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the IR2VecTool and MIR2VecTool class definitions and +/// implementations for the llvm-ir2vec embedding generation tool. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H +#define LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Analysis/IR2Vec.h" +#include "llvm/CodeGen/MIR2Vec.h" +#include "llvm/CodeGen/MIRParser/MIRParser.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/WithColor.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include +#include +#include + +#define DEBUG_TYPE "ir2vec" + +namespace llvm { + +/// Tool name for error reporting +static const char *ToolName = "llvm-ir2vec"; + +/// Specifies the granularity at which embeddings are generated. +enum EmbeddingLevel { + InstructionLevel, // Generate instruction-level embeddings + BasicBlockLevel, // Generate basic block-level embeddings + FunctionLevel // Generate function-level embeddings +}; + +/// Represents a single knowledge graph triplet (Head, Relation, Tail) +/// where indices reference entities in an EntityList +struct Triplet { + unsigned Head = 0; ///< Index of the head entity in the entity list + unsigned Tail = 0; ///< Index of the tail entity in the entity list + unsigned Relation = 0; ///< Relation type (see RelationType enum) +}; + +/// Result structure containing all generated triplets and metadata +struct TripletResult { + unsigned MaxRelation = + 0; ///< Highest relation index used (for ArgRelation + N) + std::vector Triplets; ///< Collection of all generated triplets +}; + +/// Entity mappings: [entity_name] +using EntityList = std::vector; + +namespace ir2vec { + +/// Relation types for triplet generation +enum RelationType { + TypeRelation = 0, ///< Instruction to type relationship + NextRelation = 1, ///< Sequential instruction relationship + ArgRelation = 2 ///< Instruction to operand relationship (ArgRelation + N) +}; + +/// Helper class for collecting IR triplets and generating embeddings +class IR2VecTool { +private: + Module &M; + ModuleAnalysisManager MAM; + const Vocabulary *Vocab = nullptr; + +public: + explicit IR2VecTool(Module &M) : M(M) {} + + /// Initialize the IR2Vec vocabulary analysis + bool initializeVocabulary() { + // Register and run the IR2Vec vocabulary analysis + // The vocabulary file path is specified via --ir2vec-vocab-path global + // option + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + MAM.registerPass([&] { return IR2VecVocabAnalysis(); }); + // This will throw an error if vocab is not found or invalid + Vocab = &MAM.getResult(M); + return Vocab->isValid(); + } + + /// Generate triplets for a single function + /// Returns a TripletResult with: + /// - Triplets: vector of all (subject, object, relation) tuples + /// - MaxRelation: highest Arg relation ID used, or NextRelation if none + TripletResult generateTriplets(const Function &F) const { + if (F.isDeclaration()) + return {}; + + TripletResult Result; + Result.MaxRelation = 0; + + unsigned MaxRelation = NextRelation; + unsigned PrevOpcode = 0; + bool HasPrevOpcode = false; + + for (const BasicBlock &BB : F) { + for (const auto &I : BB.instructionsWithoutDebug()) { + unsigned Opcode = Vocabulary::getIndex(I.getOpcode()); + unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID()); + + // Add "Next" relationship with previous instruction + if (HasPrevOpcode) { + Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation}); + LLVM_DEBUG(dbgs() + << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1) << '\t' + << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t' + << "Next\n"); + } + + // Add "Type" relationship + Result.Triplets.push_back({Opcode, TypeID, TypeRelation}); + LLVM_DEBUG( + dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t' + << Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID()) + << '\t' << "Type\n"); + + // Add "Arg" relationships + unsigned ArgIndex = 0; + for (const Use &U : I.operands()) { + unsigned OperandID = Vocabulary::getIndex(*U.get()); + unsigned RelationID = ArgRelation + ArgIndex; + Result.Triplets.push_back({Opcode, OperandID, RelationID}); + + LLVM_DEBUG({ + StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind( + Vocabulary::getOperandKind(U.get())); + dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t' + << OperandStr << '\t' << "Arg" << ArgIndex << '\n'; + }); + + ++ArgIndex; + } + // Only update MaxRelation if there were operands + if (ArgIndex > 0) + MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1); + PrevOpcode = Opcode; + HasPrevOpcode = true; + } + } + + Result.MaxRelation = MaxRelation; + return Result; + } + + /// Get triplets for the entire module + TripletResult generateTriplets() const { + TripletResult Result; + Result.MaxRelation = NextRelation; + + for (const Function &F : M.getFunctionDefs()) { + TripletResult FuncResult = generateTriplets(F); + Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation); + Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(), + FuncResult.Triplets.end()); + } + + return Result; + } + + /// Collect triplets for the module and dump output to stream + /// Output format: MAX_RELATION=N header followed by relationships + void writeTripletsToStream(raw_ostream &OS) const { + auto Result = generateTriplets(); + OS << "MAX_RELATION=" << Result.MaxRelation << '\n'; + for (const auto &T : Result.Triplets) + OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n'; + } + + /// Generate entity mappings for the entire vocabulary + /// Returns EntityList containing all entity strings + static EntityList collectEntityMappings() { + auto EntityLen = Vocabulary::getCanonicalSize(); + EntityList Result; + for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID) + Result.push_back(Vocabulary::getStringKey(EntityID).str()); + return Result; + } + + /// Dump entity ID to string mappings + static void writeEntitiesToStream(raw_ostream &OS) { + auto Entities = collectEntityMappings(); + OS << Entities.size() << "\n"; + for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID) + OS << Entities[EntityID] << '\t' << EntityID << '\n'; + } + + /// Generate embeddings for the entire module + void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const { + if (!Vocab->isValid()) { + WithColor::error(errs(), ToolName) + << "Vocabulary is not valid. IR2VecTool not initialized.\n"; + return; + } + + for (const Function &F : M.getFunctionDefs()) + writeEmbeddingsToStream(F, OS, Level); + } + + /// Generate embeddings for a single function + void writeEmbeddingsToStream(const Function &F, raw_ostream &OS, + EmbeddingLevel Level) const { + if (!Vocab || !Vocab->isValid()) { + WithColor::error(errs(), ToolName) + << "Vocabulary is not valid. IR2VecTool not initialized.\n"; + return; + } + if (F.isDeclaration()) { + OS << "Function " << F.getName() << " is a declaration, skipping.\n"; + return; + } + + // Create embedder for this function + auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab); + if (!Emb) { + WithColor::error(errs(), ToolName) + << "Failed to create embedder for function " << F.getName() << "\n"; + return; + } + + OS << "Function: " << F.getName() << "\n"; + + // Generate embeddings based on the specified level + switch (Level) { + case FunctionLevel: + Emb->getFunctionVector().print(OS); + break; + case BasicBlockLevel: + for (const BasicBlock &BB : F) { + OS << BB.getName() << ":"; + Emb->getBBVector(BB).print(OS); + } + break; + case InstructionLevel: + for (const Instruction &I : instructions(F)) { + OS << I; + Emb->getInstVector(I).print(OS); + } + break; + } + } +}; + +} // namespace ir2vec + +namespace mir2vec { + +/// Relation types for MIR2Vec triplet generation +enum MIRRelationType { + MIRNextRelation = 0, ///< Sequential instruction relationship + MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N) +}; + +/// Helper class for MIR2Vec embedding generation +class MIR2VecTool { +private: + MachineModuleInfo &MMI; + std::unique_ptr Vocab; + +public: + explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {} + + /// Initialize MIR2Vec vocabulary from file (for embeddings generation) + bool initializeVocabulary(const Module &M) { + MIR2VecVocabProvider Provider(MMI); + auto VocabOrErr = Provider.getVocabulary(M); + if (!VocabOrErr) { + WithColor::error(errs(), ToolName) + << "Failed to load MIR2Vec vocabulary - " + << toString(VocabOrErr.takeError()) << "\n"; + return false; + } + Vocab = std::make_unique(std::move(*VocabOrErr)); + return true; + } + + /// Initialize vocabulary with layout information only. + /// This creates a minimal vocabulary with correct layout but no actual + /// embeddings. Sufficient for generating training data and entity mappings. + /// + /// Note: Requires target-specific information from the first machine function + /// to determine the vocabulary layout (number of opcodes, register classes). + /// + /// FIXME: Use --target option to get target info directly, avoiding the need + /// to parse machine functions for pre-training operations. + bool initializeVocabularyForLayout(const Module &M) { + for (const Function &F : M.getFunctionDefs()) { + MachineFunction *MF = MMI.getMachineFunction(F); + if (!MF) + continue; + + const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo(); + const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo(); + const MachineRegisterInfo &MRI = MF->getRegInfo(); + + auto VocabOrErr = + MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1); + if (!VocabOrErr) { + WithColor::error(errs(), ToolName) + << "Failed to create dummy vocabulary - " + << toString(VocabOrErr.takeError()) << "\n"; + return false; + } + Vocab = std::make_unique(std::move(*VocabOrErr)); + return true; + } + + WithColor::error(errs(), ToolName) + << "No machine functions found to initialize vocabulary\n"; + return false; + } + + /// Get triplets for a single machine function + /// Returns TripletResult containing MaxRelation and vector of Triplets + TripletResult generateTriplets(const MachineFunction &MF) const { + TripletResult Result; + Result.MaxRelation = MIRNextRelation; + + if (!Vocab) { + WithColor::error(errs(), ToolName) + << "MIR Vocabulary must be initialized for triplet generation.\n"; + return Result; + } + + unsigned PrevOpcode = 0; + bool HasPrevOpcode = false; + for (const MachineBasicBlock &MBB : MF) { + for (const MachineInstr &MI : MBB) { + // Skip debug instructions + if (MI.isDebugInstr()) + continue; + + // Get opcode entity ID + unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode()); + + // Add "Next" relationship with previous instruction + if (HasPrevOpcode) { + Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation}); + LLVM_DEBUG(dbgs() + << Vocab->getStringKey(PrevOpcode) << '\t' + << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n"); + } + + // Add "Arg" relationships for operands + unsigned ArgIndex = 0; + for (const MachineOperand &MO : MI.operands()) { + auto OperandID = Vocab->getEntityIDForMachineOperand(MO); + unsigned RelationID = MIRArgRelation + ArgIndex; + Result.Triplets.push_back({OpcodeID, OperandID, RelationID}); + LLVM_DEBUG({ + std::string OperandStr = Vocab->getStringKey(OperandID); + dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr + << '\t' << "Arg" << ArgIndex << '\n'; + }); + + ++ArgIndex; + } + + // Update MaxRelation if there were operands + if (ArgIndex > 0) + Result.MaxRelation = + std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1); + + PrevOpcode = OpcodeID; + HasPrevOpcode = true; + } + } + + return Result; + } + + /// Get triplets for the entire module + /// Returns TripletResult containing aggregated MaxRelation and all Triplets + TripletResult generateTriplets(const Module &M) const { + TripletResult Result; + Result.MaxRelation = MIRNextRelation; + + for (const Function &F : M.getFunctionDefs()) { + MachineFunction *MF = MMI.getMachineFunction(F); + if (!MF) { + WithColor::warning(errs(), ToolName) + << "No MachineFunction for " << F.getName() << "\n"; + continue; + } + + TripletResult FuncResult = generateTriplets(*MF); + Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation); + Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(), + FuncResult.Triplets.end()); + } + + return Result; + } + + /// Collect triplets for the module and write to output stream + /// Output format: MAX_RELATION=N header followed by relationships + void writeTripletsToStream(const Module &M, raw_ostream &OS) const { + auto Result = generateTriplets(M); + OS << "MAX_RELATION=" << Result.MaxRelation << '\n'; + for (const auto &T : Result.Triplets) + OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n'; + } + + /// Generate entity mappings for the entire vocabulary + EntityList collectEntityMappings() const { + if (!Vocab) { + WithColor::error(errs(), ToolName) + << "Vocabulary must be initialized for entity mappings.\n"; + return {}; + } + + const unsigned EntityCount = Vocab->getCanonicalSize(); + EntityList Result; + for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID) + Result.push_back(Vocab->getStringKey(EntityID)); + + return Result; + } + + /// Generate entity mappings and write to output stream + void writeEntitiesToStream(raw_ostream &OS) const { + auto Entities = collectEntityMappings(); + if (Entities.empty()) + return; + + OS << Entities.size() << "\n"; + for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID) + OS << Entities[EntityID] << '\t' << EntityID << '\n'; + } + + /// Generate embeddings for all machine functions in the module + void writeEmbeddingsToStream(const Module &M, raw_ostream &OS, + EmbeddingLevel Level) const { + if (!Vocab) { + WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n"; + return; + } + + for (const Function &F : M.getFunctionDefs()) { + MachineFunction *MF = MMI.getMachineFunction(F); + if (!MF) { + WithColor::warning(errs(), ToolName) + << "No MachineFunction for " << F.getName() << "\n"; + continue; + } + + writeEmbeddingsToStream(*MF, OS, Level); + } + } + + /// Generate embeddings for a specific machine function + void writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS, + EmbeddingLevel Level) const { + if (!Vocab) { + WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n"; + return; + } + + auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab); + if (!Emb) { + WithColor::error(errs(), ToolName) + << "Failed to create embedder for " << MF.getName() << "\n"; + return; + } + + OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n"; + + // Generate embeddings based on the specified level + switch (Level) { + case FunctionLevel: + OS << "Function vector: "; + Emb->getMFunctionVector().print(OS); + break; + case BasicBlockLevel: + OS << "Basic block vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + OS << "MBB " << MBB.getName() << ": "; + Emb->getMBBVector(MBB).print(OS); + } + break; + case InstructionLevel: + OS << "Instruction vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + for (const MachineInstr &MI : MBB) { + OS << MI << " -> "; + Emb->getMInstVector(MI).print(OS); + } + } + break; + } + } + + /// Get the MIR vocabulary instance + const MIRVocabulary *getVocabulary() const { return Vocab.get(); } +}; + +/// Helper structure to hold MIR context +struct MIRContext { + LLVMContext Context; // CRITICAL: Must be first for proper destruction order + std::unique_ptr M; + std::unique_ptr MMI; + std::unique_ptr TM; +}; + +} // namespace mir2vec + +} // namespace llvm + +#endif // LLVM_TOOLS_LLVM_IR2VEC_LLVM_MIR2VEC_H \ No newline at end of file