Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ObjectiveC = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9"
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Expand Down
19 changes: 15 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using Pkg
Pkg.add(url="https://github.com/christiangnrd/ParallelTestRunner.jl", rev="total_time")

using Metal
using ParallelTestRunner

Expand Down Expand Up @@ -96,16 +99,16 @@ if filter_tests!(testsuite, args)
end

# workers to run tests on
function test_worker(name)
function test_worker(name, init_worker_code)
if name == "capturing"
return addworker(env=["METAL_CAPTURE_ENABLED"=>"1"])
return addworker(; env=["METAL_CAPTURE_ENABLED"=>"1"], init_worker_code)
end

return nothing
end

# code to run in each test's sandbox module before running the test
init_code = quote
init_worker_code = quote
using Metal, Adapt, ObjectiveC, ObjectiveC.Foundation, BFloat16s

# XXX: expose this as --validate
Expand Down Expand Up @@ -158,4 +161,12 @@ init_code = quote
end
end

runtests(Metal, args; testsuite, init_code, test_worker)
init_code = quote
using Metal, Adapt, ObjectiveC, ObjectiveC.Foundation, BFloat16s

# bring used symbols into the temporary module
import ..TestSuite, ..testf
import ..runtime_validation, ..shader_validation, ..capturing, ..@grab_output, ..@on_device
end

runtests(Metal, args; testsuite, init_code, init_worker_code, test_worker)