Skip to content

Commit 946f70b

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla:cpu] Add support for linking with external symbols to JitCompiler
PiperOrigin-RevId: 701074278
1 parent 697677f commit 946f70b

File tree

5 files changed

+181
-36
lines changed

5 files changed

+181
-36
lines changed

third_party/xla/xla/backends/cpu/codegen/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ xla_cc_test(
143143
"//xla:util",
144144
"//xla/tsl/lib/core:status_test_util",
145145
"@com_google_absl//absl/status",
146+
"@com_google_absl//absl/status:statusor",
147+
"@com_google_absl//absl/types:span",
146148
"@llvm-project//llvm:AsmParser",
147149
"@llvm-project//llvm:Core",
148150
"@llvm-project//llvm:OrcJIT",
151+
"@llvm-project//llvm:OrcShared",
149152
"@llvm-project//llvm:Support",
150153
"@llvm-project//llvm:Target",
151154
"@local_tsl//tsl/platform:env",

third_party/xla/xla/backends/cpu/codegen/jit_compiler.cc

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,19 @@ absl::StatusOr<JitCompiler> JitCompiler::Create(
127127
/*SSP=*/nullptr,
128128
std::make_unique<TaskDispatcher>(std::move(task_runner))));
129129

130+
execution_session->setErrorReporter([](llvm::Error err) {
131+
LOG(ERROR) << "LLVM compilation error: " << llvm::toString(std::move(err));
132+
});
133+
130134
// Create an instance of IrCompiler for lowering LLVM modules to machine code.
131135
auto ir_compiler = std::make_unique<IrCompiler>(
132136
target_machine_builder, std::move(options.ir_compiler_options),
133137
std::move(options.ir_compiler_hooks));
134138

135139
return JitCompiler(std::move(target_machine_builder),
136140
std::move(target_machine), std::move(execution_session),
137-
std::move(ir_compiler), options.num_dylibs);
141+
std::move(ir_compiler), options.num_dylibs,
142+
std::move(options.definition_generator));
138143
}
139144

140145
static std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer>
@@ -148,35 +153,39 @@ CreateObjectLinkingLayer(llvm::orc::ExecutionSession& execution_session) {
148153

149154
static std::unique_ptr<llvm::orc::IRCompileLayer> CreateCompileLayer(
150155
llvm::orc::ExecutionSession& execution_session,
151-
llvm::orc::RTDyldObjectLinkingLayer& object_linking_layer,
156+
llvm::orc::RTDyldObjectLinkingLayer& object_layer,
152157
std::unique_ptr<IrCompiler> ir_compiler) {
153158
return std::make_unique<llvm::orc::IRCompileLayer>(
154-
execution_session, object_linking_layer, std::move(ir_compiler));
159+
execution_session, object_layer, std::move(ir_compiler));
155160
}
156161

157162
JitCompiler::JitCompiler(
158163
IrCompiler::TargetMachineBuilder target_machine_builder,
159164
std::shared_ptr<llvm::TargetMachine> target_machine,
160165
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
161-
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs)
166+
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs,
167+
DefinitionGenerator definition_generator)
162168
: target_machine_builder_(std::move(target_machine_builder)),
163169
target_machine_(std::move(target_machine)),
164170
execution_session_(std::move(execution_session)),
165-
object_linking_layer_(CreateObjectLinkingLayer(*execution_session_)),
166-
compile_layer_(CreateCompileLayer(
167-
*execution_session_, *object_linking_layer_, std::move(ir_compiler))),
171+
object_layer_(CreateObjectLinkingLayer(*execution_session_)),
172+
compile_layer_(CreateCompileLayer(*execution_session_, *object_layer_,
173+
std::move(ir_compiler))),
168174
gdb_(llvm::JITEventListener::createGDBRegistrationListener()),
169175
perf_(llvm::JITEventListener::createPerfJITEventListener()) {
170176
// Create at least one dynamic library for the given jit compiler.
171177
dylibs_.resize(std::max<size_t>(1, num_dylibs));
172178
for (size_t i = 0; i < dylibs_.size(); ++i) {
173179
dylibs_[i] = &execution_session_->createBareJITDylib(
174180
absl::StrCat("<xla_jit_dylib_", i, ">"));
181+
if (definition_generator) {
182+
dylibs_[i]->addGenerator(definition_generator(target_machine_.get()));
183+
}
175184
}
176185

177186
// Register GDB and perf event listeners with the object linking layer.
178-
if (gdb_) object_linking_layer_->registerJITEventListener(*gdb_);
179-
if (perf_) object_linking_layer_->registerJITEventListener(*perf_);
187+
if (gdb_) object_layer_->registerJITEventListener(*gdb_);
188+
if (perf_) object_layer_->registerJITEventListener(*perf_);
180189
}
181190

182191
JitCompiler::~JitCompiler() {
@@ -210,6 +219,22 @@ absl::Status JitCompiler::AddModule(llvm::orc::ThreadSafeModule module,
210219
return absl::OkStatus();
211220
}
212221

222+
absl::Status JitCompiler::AddObjFile(
223+
std::unique_ptr<llvm::MemoryBuffer> obj_file, size_t dylib_index) {
224+
if (dylib_index >= dylibs_.size()) {
225+
return Internal("Invalid dylib index %d (num dylibs: %d))", dylib_index,
226+
dylibs_.size());
227+
}
228+
229+
llvm::orc::JITDylib* dylib = dylibs_[dylib_index];
230+
if (auto err = object_layer_->add(*dylib, std::move(obj_file))) {
231+
return Internal("Failed to add object file to dylib %d: %s", dylib_index,
232+
llvm::toString(std::move(err)));
233+
}
234+
235+
return absl::OkStatus();
236+
}
237+
213238
absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
214239
absl::Span<const Symbol> symbols) && {
215240
TraceMe trace([&] {
@@ -243,8 +268,7 @@ absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
243268
auto symbol_map = execution_session_->lookup(std::move(search_order),
244269
std::move(lookup_set));
245270
if (auto err = symbol_map.takeError()) {
246-
return Internal("Failed to lookup symbols: %s",
247-
llvm::toString(std::move(err)));
271+
return Internal("%s", llvm::toString(std::move(err)));
248272
}
249273

250274
// Resolve type-erased symbol pointers from the symbol map.
@@ -260,17 +284,14 @@ absl::StatusOr<std::unique_ptr<FunctionLibrary>> JitCompiler::Compile(
260284
}
261285

262286
return std::make_unique<CompiledFunctionLibrary>(
263-
std::move(execution_session_), std::move(resolved_map));
287+
std::move(execution_session_), std::move(object_layer_),
288+
std::move(resolved_map));
264289
}
265290

266291
JitCompiler::TaskDispatcher::TaskDispatcher(TaskRunner task_runner)
267292
: task_runner_(std::move(task_runner)) {}
268293

269-
JitCompiler::TaskDispatcher::~TaskDispatcher() {
270-
absl::MutexLock lock(&mu_);
271-
DCHECK(num_dispatched_tasks_ == 0)
272-
<< "TaskDispatcher is still dispatching tasks";
273-
}
294+
JitCompiler::TaskDispatcher::~TaskDispatcher() { shutdown(); }
274295

275296
void JitCompiler::TaskDispatcher::dispatch(
276297
std::unique_ptr<llvm::orc::Task> task) {
@@ -309,8 +330,10 @@ void JitCompiler::TaskDispatcher::shutdown() {
309330

310331
JitCompiler::CompiledFunctionLibrary::CompiledFunctionLibrary(
311332
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
333+
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer,
312334
absl::flat_hash_map<std::string, ResolvedSymbol> symbols_map)
313335
: execution_session_(std::move(execution_session)),
336+
object_layer_(std::move(object_layer)),
314337
symbols_map_(std::move(symbols_map)) {
315338
DCHECK(execution_session_) << "Execution session must not be null";
316339
}

third_party/xla/xla/backends/cpu/codegen/jit_compiler.h

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class JitCompiler {
6161
using Task = std::function<void()>; // NOLINT (must be copyable)
6262
using TaskRunner = absl::AnyInvocable<void(Task)>;
6363

64+
// A callback that returns a definition generator that will be added to all
65+
// dynamic libraries created by the jit compiler. Definition generator enables
66+
// linking host runtime symbols into the jit-compiled function library.
67+
using DefinitionGenerator =
68+
std::function<std::unique_ptr<llvm::orc::DefinitionGenerator>(
69+
llvm::TargetMachine*)>;
70+
6471
JitCompiler(JitCompiler&&) = default;
6572
JitCompiler& operator=(JitCompiler&&) = default;
6673

@@ -93,6 +100,10 @@ class JitCompiler {
93100
// multiple dynamic libraries we enable parallel compilation.
94101
size_t num_dylibs = 1;
95102

103+
// Optional definition generator to inject host runtime symbols into the
104+
// jit-compiled function library.
105+
DefinitionGenerator definition_generator;
106+
96107
// Maximum CPU instruction set for wich the compiler should generate code.
97108
// If instruction set is empty, compiler will generate code for all ISA
98109
// extensions detected on the current machine.
@@ -109,11 +120,19 @@ class JitCompiler {
109120
absl::Status AddModule(llvm::orc::ThreadSafeModule module,
110121
size_t dylib_index = 0);
111122

112-
// Compiles all added LLVM modules into the FunctionLibrary by resolving all
113-
// symbols in `symbols`. After this method returns, the FunctionLibrary will
114-
// contain compiled functions that can be invoked via function calls. Returned
115-
// FunctionLibrary track type ids of the resolved symbols, but the compiler
116-
// doesn't verify that LLVM IR function signature matches the type id.
123+
// Adds an object file to the dynamic library at `dylib_index`.
124+
absl::Status AddObjFile(std::unique_ptr<llvm::MemoryBuffer> obj_file,
125+
size_t dylib_index = 0);
126+
127+
// Compiles all added LLVM modules and object files into the FunctionLibrary
128+
// by resolving all symbols in `symbols`.
129+
//
130+
// After this method returns, the FunctionLibrary will contain compiled
131+
// functions that can be invoked via function calls. Returned FunctionLibrary
132+
// tracks type ids of the resolved symbols, but the compiler doesn't verify
133+
// that LLVM IR function signature matches the type id, and it's up to the
134+
// user to make sure that function types actually match, otherwise it will
135+
// lead to run-time crashes.
117136
//
118137
// TODO(ezhulenev): Add an option to pass symbol (function) types at compile
119138
// time together with names and type-check LLVM function signature against the
@@ -123,11 +142,14 @@ class JitCompiler {
123142
absl::StatusOr<std::unique_ptr<FunctionLibrary>> Compile(
124143
absl::Span<const Symbol> symbols) &&;
125144

145+
llvm::TargetMachine* target_machine() { return target_machine_.get(); }
146+
126147
private:
127148
JitCompiler(IrCompiler::TargetMachineBuilder target_machine_builder,
128149
std::shared_ptr<llvm::TargetMachine> target_machine,
129150
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
130-
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs);
151+
std::unique_ptr<IrCompiler> ir_compiler, size_t num_dylibs,
152+
DefinitionGenerator definition_generator);
131153

132154
// LLVM ORC task dispatcher that uses `TaskRunner` to run compilation tasks.
133155
class TaskDispatcher : public llvm::orc::TaskDispatcher {
@@ -156,6 +178,7 @@ class JitCompiler {
156178

157179
CompiledFunctionLibrary(
158180
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
181+
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer,
159182
absl::flat_hash_map<std::string, ResolvedSymbol> symbols_map);
160183

161184
~CompiledFunctionLibrary() final;
@@ -165,6 +188,7 @@ class JitCompiler {
165188

166189
private:
167190
std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
191+
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer_;
168192
absl::flat_hash_map<std::string, ResolvedSymbol> symbols_map_;
169193
};
170194

@@ -174,9 +198,10 @@ class JitCompiler {
174198
std::shared_ptr<llvm::TargetMachine> target_machine_;
175199

176200
std::unique_ptr<llvm::orc::ExecutionSession> execution_session_;
177-
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_linking_layer_;
201+
std::unique_ptr<llvm::orc::RTDyldObjectLinkingLayer> object_layer_;
178202
std::unique_ptr<llvm::orc::IRCompileLayer> compile_layer_;
179203

204+
// Non-owning pointers to dynamic libraries created for the execution session.
180205
std::vector<llvm::orc::JITDylib*> dylibs_;
181206

182207
// Non owning pointer to JIT event listeners for gdb and perf.

third_party/xla/xla/backends/cpu/codegen/jit_compiler_test.cc

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,21 @@ limitations under the License.
2525
#include <vector>
2626

2727
#include "absl/status/status.h"
28+
#include "absl/status/statusor.h"
29+
#include "absl/types/span.h"
2830
#include "llvm/AsmParser/Parser.h"
31+
#include "llvm/ExecutionEngine/JITSymbol.h"
32+
#include "llvm/ExecutionEngine/Orc/AbsoluteSymbols.h"
33+
#include "llvm/ExecutionEngine/Orc/Core.h"
34+
#include "llvm/ExecutionEngine/Orc/CoreContainers.h"
35+
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
36+
#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h"
2937
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
3038
#include "llvm/IR/LLVMContext.h"
3139
#include "llvm/Support/CodeGen.h"
40+
#include "llvm/Support/Error.h"
3241
#include "llvm/Support/SourceMgr.h"
42+
#include "llvm/Target/TargetMachine.h"
3343
#include "llvm/Target/TargetOptions.h"
3444
#include "xla/backends/cpu/codegen/function_library.h"
3545
#include "xla/tsl/lib/core/status_test_util.h"
@@ -42,6 +52,31 @@ limitations under the License.
4252

4353
namespace xla::cpu {
4454

55+
// We use static function to compile the function library, because we transfer
56+
// compiler object into the function and make sure that it gets destroyed before
57+
// returning the function library to the caller, as we test that we don't
58+
// accidentally reference freed objects owned by the compiler.
59+
static absl::StatusOr<std::unique_ptr<FunctionLibrary>> Compile(
60+
JitCompiler compiler, absl::Span<const FunctionLibrary::Symbol> symbols) {
61+
return std::move(compiler).Compile(symbols);
62+
};
63+
64+
// Parses the LLVM IR into a ThreadSafeModule.
65+
static absl::StatusOr<llvm::orc::ThreadSafeModule> ParseModule(
66+
llvm::orc::ThreadSafeContext& context, std::string_view ir,
67+
std::string_view name) {
68+
llvm::SMDiagnostic diagnostic;
69+
llvm::MemoryBufferRef ir_buffer(ir, name);
70+
71+
auto m = llvm::parseAssembly(ir_buffer, diagnostic, *context.getContext());
72+
if (m == nullptr) {
73+
return Internal("Failed to parse LLVM IR: %s",
74+
diagnostic.getMessage().str());
75+
}
76+
77+
return llvm::orc::ThreadSafeModule(std::move(m), context);
78+
}
79+
4580
TEST(JitCompilerTest, Compile) {
4681
auto context = std::make_unique<llvm::LLVMContext>();
4782
llvm::orc::ThreadSafeContext tsc(std::move(context));
@@ -80,18 +115,9 @@ TEST(JitCompilerTest, Compile) {
80115

81116
auto add_module = [&](std::string_view ir, std::string_view name,
82117
size_t dylib_index) -> absl::Status {
83-
llvm::SMDiagnostic diagnostic;
84-
llvm::MemoryBufferRef ir_buffer(ir, name);
85-
86-
auto m = llvm::parseAssembly(ir_buffer, diagnostic, *tsc.getContext());
87-
if (m == nullptr) {
88-
return Internal("Failed to parse LLVM IR: %s",
89-
diagnostic.getMessage().str());
90-
}
91-
92-
llvm::orc::ThreadSafeModule tsm(std::move(m), tsc);
118+
TF_ASSIGN_OR_RETURN(llvm::orc::ThreadSafeModule tsm,
119+
ParseModule(tsc, ir, name));
93120
TF_RETURN_IF_ERROR(compiler.AddModule(std::move(tsm), dylib_index));
94-
95121
return absl::OkStatus();
96122
};
97123

@@ -104,7 +130,7 @@ TEST(JitCompilerTest, Compile) {
104130
FunctionLibrary::Sym<ScalarFn>("MulInplace")};
105131

106132
TF_ASSERT_OK_AND_ASSIGN(auto function_library,
107-
std::move(compiler).Compile(symbols));
133+
Compile(std::move(compiler), symbols));
108134

109135
EXPECT_GE(num_tasks, 2);
110136

@@ -127,4 +153,70 @@ TEST(JitCompilerTest, Compile) {
127153
EXPECT_EQ(value, 4.0f);
128154
}
129155

156+
class ExternalDefinitionGenerator : public llvm::orc::DefinitionGenerator {
157+
public:
158+
static void AddInplace(float* value) { *value += *value; }
159+
160+
llvm::Error tryToGenerate(llvm::orc::LookupState&, llvm::orc::LookupKind,
161+
llvm::orc::JITDylib& jit_dylib,
162+
llvm::orc::JITDylibLookupFlags,
163+
const llvm::orc::SymbolLookupSet& names) final {
164+
llvm::orc::SymbolMap new_defs;
165+
for (auto& [name, flags] : names) {
166+
if (*name == "__external_fn") {
167+
new_defs[name] = llvm::orc::ExecutorSymbolDef{
168+
llvm::orc::ExecutorAddr(reinterpret_cast<uint64_t>(&AddInplace)),
169+
llvm::JITSymbolFlags::None};
170+
}
171+
}
172+
173+
cantFail(jit_dylib.define(llvm::orc::absoluteSymbols(std::move(new_defs))));
174+
return llvm::Error::success();
175+
}
176+
};
177+
178+
TEST(JitCompilerTest, ExternalDefinitionGenerator) {
179+
auto context = std::make_unique<llvm::LLVMContext>();
180+
llvm::orc::ThreadSafeContext tsc(std::move(context));
181+
182+
JitCompiler::Options options;
183+
options.definition_generator = [](llvm::TargetMachine*) {
184+
return std::make_unique<ExternalDefinitionGenerator>();
185+
};
186+
187+
TF_ASSERT_OK_AND_ASSIGN(
188+
auto compiler,
189+
JitCompiler::Create(llvm::TargetOptions(), llvm::CodeGenOptLevel::None,
190+
std::move(options), /*task_runner=*/nullptr));
191+
192+
constexpr std::string_view call_external_fn_ir = R"(
193+
declare void @__external_fn(ptr %arg)
194+
195+
define void @CallExternalFn(ptr %arg) {
196+
call void @__external_fn(ptr %arg)
197+
ret void
198+
})";
199+
200+
TF_ASSERT_OK_AND_ASSIGN(
201+
llvm::orc::ThreadSafeModule tsm,
202+
ParseModule(tsc, call_external_fn_ir, "CallExternalFn"));
203+
204+
TF_ASSERT_OK(compiler.AddModule(std::move(tsm)));
205+
206+
using ScalarFn = void(float*);
207+
std::vector<FunctionLibrary::Symbol> symbols = {
208+
FunctionLibrary::Sym<ScalarFn>("CallExternalFn")};
209+
210+
TF_ASSERT_OK_AND_ASSIGN(auto function_library,
211+
Compile(std::move(compiler), symbols));
212+
213+
TF_ASSERT_OK_AND_ASSIGN(
214+
ScalarFn * call_external_fn,
215+
function_library->ResolveFunction<ScalarFn>("CallExternalFn"));
216+
217+
float value = 1.0f;
218+
call_external_fn(&value);
219+
EXPECT_EQ(value, 2.0f);
220+
}
221+
130222
} // namespace xla::cpu

0 commit comments

Comments
 (0)